# **Labs 1 and 2 PySpark:**

In these labs we will be using the "[[NeurIPS 2020] Data Science for COVID-19 (DS4C)](https://www.kaggle.com/datasets/kimjihoo/coronavirusdataset?select=PatientInfo.csv)" dataset, retrieved from [Kaggle](https://www.kaggle.com/) on 1/6/2022, for educational non commercial purpose, License
[CC BY-NC-SA 4.0
](https://creativecommons.org/licenses/by-nc-sa/4.0/)


The csv file that we will be using in this lab is **PatientInfo**.

## PatientInfo.csv

**patient_id**
the ID of the patient

**sex**
the sex of the patient

**age**
the age of the patient

**country**
the country of the patient

**province**
the province of the patient

**city**
the city of the patient

**infection_case**
the case of infection

**infected_by**
the ID of who infected the patient


**contact_number**
the number of contacts with people

**symptom_onset_date**
the date of symptom onset

**confirmed_date**
the date of being confirmed

**released_date**
the date of being released

**deceased_date**
the date of being deceased

**state**
isolated / released / deceased

### Import the pyspark and check it's version

In [48]:
import findspark
findspark.init()
import pyspark
print(pyspark.__version__)

3.4.0


### Import and create SparkSession

In [49]:
from pyspark.sql import SparkSession
# Create SparkSession
spark = SparkSession.builder.getOrCreate()

### Load the PatientInfo.csv file and show the first 5 rows

In [50]:
from IPython.display import display, HTML
display(HTML("<style>pre { white-space: pre !important; }</style>"))

In [51]:
sc = spark.sparkContext
df = spark.read.csv("PatientInfo.csv"
                    ,inferSchema=True
                    , header=True)

### Display the schema of the dataset

In [52]:
df.printSchema()

root
 |-- patient_id: long (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: string (nullable = true)
 |-- country: string (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- infected_by: string (nullable = true)
 |-- contact_number: string (nullable = true)
 |-- symptom_onset_date: string (nullable = true)
 |-- confirmed_date: date (nullable = true)
 |-- released_date: date (nullable = true)
 |-- deceased_date: date (nullable = true)
 |-- state: string (nullable = true)



### Display the statistical summary

In [53]:
df.summary().show()

+-------+--------------------+------+----+----------+--------+--------------+--------------------+--------------------+--------------------+------------------+--------+
|summary|          patient_id|   sex| age|   country|province|          city|      infection_case|         infected_by|      contact_number|symptom_onset_date|   state|
+-------+--------------------+------+----+----------+--------+--------------+--------------------+--------------------+--------------------+------------------+--------+
|  count|                5165|  4043|3785|      5165|    5165|          5071|                4246|                1346|                 791|               690|    5165|
|   mean|2.8636345618679576E9|  null|null|      null|    null|          null|                null|2.2845944015643125E9|1.6772572523506988E7|              null|    null|
| stddev| 2.074210725277473E9|  null|null|      null|    null|          null|                null|1.5265072953383324E9| 3.093097580985502E8|              n

### Using the state column.
### How many people survived (released), and how many didn't survive (isolated/deceased)?

In [54]:
df.groupBy('state').count().show()

+--------+-----+
|   state|count|
+--------+-----+
|isolated| 2158|
|released| 2929|
|deceased|   78|
+--------+-----+



### Display the number of null values in each column

In [55]:
from pyspark.sql.functions import col

null_counts = [df.where(col(c).isNull()).count() for c in df.columns]

for col_name, null_count in zip(df.columns, null_counts):
    print(f"{col_name}: {null_count}")

patient_id: 0
sex: 1122
age: 1380
country: 0
province: 0
city: 94
infection_case: 919
infected_by: 3819
contact_number: 4374
symptom_onset_date: 4475
confirmed_date: 3
released_date: 3578
deceased_date: 5099
state: 0


## Data preprocessing

### Fill the nulls in the deceased_date with the released_date. 
- You can use <b>coalesce</b> function

In [56]:
from pyspark.sql.functions import coalesce
    
df = df.withColumn("deceased_date",coalesce(df['deceased_date'],df["released_date"]))
df.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-gu|contact with patient| 2002000001|            17|              null|    2020-01-30|

### Add a column named no_days which is difference between the deceased_date and the confirmed_date then show the top 5 rows. Print the schema.
- <b> Hint: You need to typecast these columns as date first <b>

In [57]:
from pyspark.sql.functions import datediff
df = df.withColumn("deceased_date" , df["deceased_date"].cast("date"))
df = df.withColumn("confirmed_date" , df["confirmed_date"].cast("date"))
df = df.withColumn("no_days", datediff("deceased_date", "confirmed_date"))
df.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-gu|contact with patient| 2002000001|         

### Add a is_male column if male then it should yield true, else then False

In [58]:
from pyspark.sql.functions import when
df = df.withColumn("is_male" , when(col("sex") == "male",True ).otherwise(False) )
df.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|is_male|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|   true|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|   true|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-gu|co

### Add a is_dead column if patient state is not released then it should yield true, else then False

- Use <b>UDF</b> to perform this task. 
- However, UDF is not recommended there is no built in function can do the required operation.
- UDF is slower than built in functions.

In [59]:
def my_udf(inp):
    if inp !='released' : 
        return True
    else:
        return False

In [60]:
from pyspark.sql.functions import udf
from pyspark.sql.types import BooleanType
my_udf = udf(my_udf, BooleanType())
df = df.withColumn("is_dead" , my_udf("state") )
df.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|is_male|is_dead|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|   true|  false|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|   true|  false|
|1000000003|  m

### Change the ages to bins from 10s, 0s, 10s, 20s,.etc to 0,10, 20

In [61]:
from pyspark.sql.functions import substring

df = df.withColumn("age", substring(col("age"),0,2))
df.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|is_male|is_dead|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|1000000001|  male| 50|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|   true|  false|
|1000000002|  male| 30|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|   true|  false|
|1000000003|  m

### Change age, and no_days  to be typecasted as Double

In [62]:
from pyspark.sql.functions import col, unix_timestamp

df = df.withColumn("age", when(col("age").isNull(),None).otherwise(col("age").cast("double")))
df = df.withColumn("no_days", when(col("no_days").isNull(),None).otherwise(col("no_days").cast("double")))
df.show()

+----------+------+----+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|patient_id|   sex| age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|is_male|is_dead|
+----------+------+----+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|1000000001|  male|50.0|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|   13.0|   true|  false|
|1000000002|  male|30.0|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|   32.0|   true|  false|
|100000000

### Drop the columns
["patient_id","sex","infected_by","contact_number","released_date","state",
"symptom_onset_date","confirmed_date","deceased_date","country","no_days",
"city","infection_case"]

In [63]:
df = df.drop("patient_id","sex","infected_by","contact_number","released_date","state",
"symptom_onset_date","confirmed_date","deceased_date","country","no_days",
"city","infection_case")

In [64]:
df.show()

+----+--------+-------+-------+
| age|province|is_male|is_dead|
+----+--------+-------+-------+
|50.0|   Seoul|   true|  false|
|30.0|   Seoul|   true|  false|
|50.0|   Seoul|   true|  false|
|20.0|   Seoul|   true|  false|
|20.0|   Seoul|  false|  false|
|50.0|   Seoul|  false|  false|
|20.0|   Seoul|   true|  false|
|20.0|   Seoul|   true|  false|
|30.0|   Seoul|   true|  false|
|60.0|   Seoul|  false|  false|
|50.0|   Seoul|  false|  false|
|20.0|   Seoul|   true|  false|
|80.0|   Seoul|   true|   true|
|60.0|   Seoul|  false|  false|
|70.0|   Seoul|   true|  false|
|70.0|   Seoul|   true|  false|
|70.0|   Seoul|   true|  false|
|20.0|   Seoul|   true|  false|
|70.0|   Seoul|  false|  false|
|70.0|   Seoul|  false|  false|
+----+--------+-------+-------+
only showing top 20 rows



### Recount the number of nulls now

In [65]:
from pyspark.sql.functions import col

null_counts = [df.where(col(c).isNull()).count() for c in df.columns]

for col_name, null_count in zip(df.columns, null_counts):
    print(f"{col_name}: {null_count}")

age: 1446
province: 0
is_male: 0
is_dead: 0


## Now do the same but using SQL select statement

### From the original Patient DataFrame, Create a temporary view (table).

In [66]:
sql_df = spark.read.csv("PatientInfo.csv"
                    ,inferSchema=True
                    , header=True)
sql_df.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|         null|released|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|         null|released|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-gu|contact with patient| 2002000001|            17|              null|    2020-01-30|

In [67]:
sql_df.printSchema()

root
 |-- patient_id: long (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: string (nullable = true)
 |-- country: string (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- infected_by: string (nullable = true)
 |-- contact_number: string (nullable = true)
 |-- symptom_onset_date: string (nullable = true)
 |-- confirmed_date: date (nullable = true)
 |-- released_date: date (nullable = true)
 |-- deceased_date: date (nullable = true)
 |-- state: string (nullable = true)



In [68]:
sql_df.createOrReplaceTempView("table")

### Use SELECT statement to select all columns from the dataframe and show the output.

In [69]:
spark.sql("""SELECT *
            FROM table
          """).show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|         null|released|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|         null|released|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-gu|contact with patient| 2002000001|            17|              null|    2020-01-30|

### *Using SQL commands*, limit the output to only 5 rows 

In [70]:
spark.sql("""SELECT *
            FROM table
            LIMIT 5
          """).show()

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|         null|released|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|         null|released|
|1000000003|  male|50s|  Korea|   Seoul|  Jongno-gu|contact with patient| 2002000001|            17|              null|    2020-01-30|   202

### Select the count of males and females in the dataset

In [71]:
spark.sql("""SELECT sex, COUNT(*) as count
            FROM table
            GROUP BY sex;
          """).show()

+------+-----+
|   sex|count|
+------+-----+
|  null| 1122|
|female| 2218|
|  male| 1825|
+------+-----+



### How many people did survive, and how many didn't?

In [72]:
spark.sql("""SELECT state, COUNT(*) as count
            FROM table
            GROUP BY state;
          """).show()

+--------+-----+
|   state|count|
+--------+-----+
|isolated| 2158|
|released| 2929|
|deceased|   78|
+--------+-----+



### Now, let's perform some preprocessing using SQL:
1. Convert *age* column to double after removing the 's' at the end -- *hint: check SUBSTRING method*
2. Select only the following columns: `['sex', 'age', 'province', 'state']`
3. Store the result of the query in a new dataframe

In [73]:
from pyspark.sql.functions import substring , length
sql_df = spark.sql("""SELECT sex,CAST(SUBSTRING(age, 1, length(age) - 1) AS double) AS age ,province,state
                    FROM table;
                """)
sql_df.show()

+------+----+--------+--------+
|   sex| age|province|   state|
+------+----+--------+--------+
|  male|50.0|   Seoul|released|
|  male|30.0|   Seoul|released|
|  male|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|female|20.0|   Seoul|released|
|female|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|30.0|   Seoul|released|
|female|60.0|   Seoul|released|
|female|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|80.0|   Seoul|deceased|
|female|60.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|female|70.0|   Seoul|released|
|female|70.0|   Seoul|released|
+------+----+--------+--------+
only showing top 20 rows



## Machine Learning 
### Create a pipeline model to predict is_dead and evaluate the performance.
- Use <b>StringIndexer</b> to transform <b>string</b> data type to indices.
- Use <b>OneHotEncoder</b> to deal with categorical values.
- Use <b>Imputer</b> to fill missing data with mean.

In [74]:
df.show(5)

+----+--------+-------+-------+
| age|province|is_male|is_dead|
+----+--------+-------+-------+
|50.0|   Seoul|   true|  false|
|30.0|   Seoul|   true|  false|
|50.0|   Seoul|   true|  false|
|20.0|   Seoul|   true|  false|
|20.0|   Seoul|  false|  false|
+----+--------+-------+-------+
only showing top 5 rows



In [75]:
from pyspark.ml.feature import StringIndexer

indexer = StringIndexer(inputCol="province", outputCol="province_index")
indexed = indexer.fit(df).transform(df)

In [76]:
indexed.show(5)

+----+--------+-------+-------+--------------+
| age|province|is_male|is_dead|province_index|
+----+--------+-------+-------+--------------+
|50.0|   Seoul|   true|  false|           0.0|
|30.0|   Seoul|   true|  false|           0.0|
|50.0|   Seoul|   true|  false|           0.0|
|20.0|   Seoul|   true|  false|           0.0|
|20.0|   Seoul|  false|  false|           0.0|
+----+--------+-------+-------+--------------+
only showing top 5 rows



In [77]:
from pyspark.ml.feature import OneHotEncoder
encoder = OneHotEncoder(inputCol="province_index", outputCol="province_final")
final = encoder.fit(indexed).transform(indexed)

In [78]:
final.show(5)

+----+--------+-------+-------+--------------+--------------+
| age|province|is_male|is_dead|province_index|province_final|
+----+--------+-------+-------+--------------+--------------+
|50.0|   Seoul|   true|  false|           0.0|(16,[0],[1.0])|
|30.0|   Seoul|   true|  false|           0.0|(16,[0],[1.0])|
|50.0|   Seoul|   true|  false|           0.0|(16,[0],[1.0])|
|20.0|   Seoul|   true|  false|           0.0|(16,[0],[1.0])|
|20.0|   Seoul|  false|  false|           0.0|(16,[0],[1.0])|
+----+--------+-------+-------+--------------+--------------+
only showing top 5 rows



In [79]:
from pyspark.ml.feature import Imputer
from pyspark.sql.functions import mean
import numpy as np
imputer = Imputer(
    inputCols=["age"],
    outputCols=["age_imputed"],
    strategy='mean'
)
final_df = imputer.fit(final).transform(final)

In [80]:
from pyspark.sql.functions import col

null_counts = [final_df.where(col(c).isNull()).count() for c in final_df.columns]

for col_name, null_count in zip(final_df.columns, null_counts):
    print(f"{col_name}: {null_count}")

age: 1446
province: 0
is_male: 0
is_dead: 0
province_index: 0
province_final: 0
age_imputed: 0


In [81]:
final_df = final_df.drop("age","province","province_index")

In [82]:
final_df.show(5)

+-------+-------+--------------+-----------+
|is_male|is_dead|province_final|age_imputed|
+-------+-------+--------------+-----------+
|   true|  false|(16,[0],[1.0])|       50.0|
|   true|  false|(16,[0],[1.0])|       30.0|
|   true|  false|(16,[0],[1.0])|       50.0|
|   true|  false|(16,[0],[1.0])|       20.0|
|  false|  false|(16,[0],[1.0])|       20.0|
+-------+-------+--------------+-----------+
only showing top 5 rows



In [83]:
final_df = final_df.withColumn("is_dead", when(df["is_dead"] == True, 1).otherwise(0))
final_df = final_df.withColumn("is_male", when(df["is_male"] == True, 1).otherwise(0))

In [84]:
final_df.show(5)

+-------+-------+--------------+-----------+
|is_male|is_dead|province_final|age_imputed|
+-------+-------+--------------+-----------+
|      1|      0|(16,[0],[1.0])|       50.0|
|      1|      0|(16,[0],[1.0])|       30.0|
|      1|      0|(16,[0],[1.0])|       50.0|
|      1|      0|(16,[0],[1.0])|       20.0|
|      0|      0|(16,[0],[1.0])|       20.0|
+-------+-------+--------------+-----------+
only showing top 5 rows



In [85]:
from pyspark.ml.feature import VectorAssembler
vecAssembler = VectorAssembler(inputCols=["age_imputed","is_male","province_final"],outputCol="features")

In [86]:
final_df = vecAssembler.transform(final_df)

In [87]:
final_df.show(5)

+-------+-------+--------------+-----------+--------------------+
|is_male|is_dead|province_final|age_imputed|            features|
+-------+-------+--------------+-----------+--------------------+
|      1|      0|(16,[0],[1.0])|       50.0|(18,[0,1,2],[50.0...|
|      1|      0|(16,[0],[1.0])|       30.0|(18,[0,1,2],[30.0...|
|      1|      0|(16,[0],[1.0])|       50.0|(18,[0,1,2],[50.0...|
|      1|      0|(16,[0],[1.0])|       20.0|(18,[0,1,2],[20.0...|
|      0|      0|(16,[0],[1.0])|       20.0|(18,[0,2],[20.0,1...|
+-------+-------+--------------+-----------+--------------------+
only showing top 5 rows



In [88]:
trainDF, testDF = final_df.randomSplit([.8,.2],seed=42)
print(f"There are {trainDF.count()} rows in the training set, and {testDF.count()} in the test set")

There are 4166 rows in the training set, and 999 in the test set


In [89]:
from pyspark.ml.classification import LogisticRegression
lr = LogisticRegression(featuresCol='features',labelCol='is_dead')

In [90]:
lrModel = lr.fit(trainDF)

In [91]:
test_predict = lrModel.transform(testDF)

In [92]:
test_predict.show(5)

+-------+-------+--------------+-----------+--------------------+--------------------+--------------------+----------+
|is_male|is_dead|province_final|age_imputed|            features|       rawPrediction|         probability|prediction|
+-------+-------+--------------+-----------+--------------------+--------------------+--------------------+----------+
|      0|      0|    (16,[],[])|       20.0|     (18,[0],[20.0])|[2.66967340404418...|[0.93521324594268...|       0.0|
|      0|      0|    (16,[],[])|       20.0|     (18,[0],[20.0])|[2.66967340404418...|[0.93521324594268...|       0.0|
|      0|      0|    (16,[],[])|       30.0|     (18,[0],[30.0])|[2.38492024998639...|[0.91567014561979...|       0.0|
|      0|      0|(16,[0],[1.0])|       10.0|(18,[0,2],[10.0,1...|[1.90909713111872...|[0.87091768125460...|       0.0|
|      0|      0|(16,[0],[1.0])|       10.0|(18,[0,2],[10.0,1...|[1.90909713111872...|[0.87091768125460...|       0.0|
+-------+-------+--------------+-----------+----

In [93]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator
evaluator = BinaryClassificationEvaluator(labelCol="is_dead",rawPredictionCol="prediction")
evaluations = evaluator.evaluate(test_predict)

In [94]:
evaluations

0.8206341120248405

In [100]:
pipeline_df = df.withColumn("is_dead", when(df["is_dead"] == True, 1).otherwise(0))
pipeline_df = pipeline_df.withColumn("is_male", when(df["is_male"] == True, 1).otherwise(0))

In [96]:
from pyspark.ml import Pipeline
pipeline = Pipeline(stages= [indexer] + [encoder] + [imputer] + [vecAssembler] + [lr])

In [102]:
model = pipeline.fit(pipeline_df)