<a href="https://colab.research.google.com/github/AyaElsawyElghaysh/ML_Spark_Tasks/blob/main/practical_spark_session_ITI.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [71]:
!pip install pyspark py4j

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


# **Labs 1 and 2 PySpark:**

In these labs we will be using the "[[NeurIPS 2020] Data Science for COVID-19 (DS4C)](https://www.kaggle.com/datasets/kimjihoo/coronavirusdataset?select=PatientInfo.csv)" dataset, retrieved from [Kaggle](https://www.kaggle.com/) on 1/6/2022, for educational non commercial purpose, License
[CC BY-NC-SA 4.0
](https://creativecommons.org/licenses/by-nc-sa/4.0/)


The csv file that we will be using in this lab is **PatientInfo**.

## PatientInfo.csv

**patient_id**
the ID of the patient

**sex**
the sex of the patient

**age**
the age of the patient

**country**
the country of the patient

**province**
the province of the patient

**city**
the city of the patient

**infection_case**
the case of infection

**infected_by**
the ID of who infected the patient


**contact_number**
the number of contacts with people

**symptom_onset_date**
the date of symptom onset

**confirmed_date**
the date of being confirmed

**released_date**
the date of being released

**deceased_date**
the date of being deceased

**state**
isolated / released / deceased

### Import the pyspark and check it's version

In [72]:
import  pyspark.sql.functions  as fn
from pyspark.sql.functions import *
from pyspark.sql import  SparkSession
from pyspark.sql.types import DoubleType

### Import and create SparkSession

In [73]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [74]:
spark=SparkSession.builder.getOrCreate()

### Load the PatientInfo.csv file and show the first 5 rows

In [75]:
from IPython.display import display, HTML
display(HTML("<style>pre { white-space: pre !important; }</style>"))

In [76]:
df=spark.read.csv("/content/drive/MyDrive/Spark_data/data/PatientInfo.csv",header=True,inferSchema=True)

### Display the schema of the dataset

In [77]:
df.printSchema()

root
 |-- patient_id: long (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: string (nullable = true)
 |-- country: string (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- infected_by: string (nullable = true)
 |-- contact_number: string (nullable = true)
 |-- symptom_onset_date: string (nullable = true)
 |-- confirmed_date: timestamp (nullable = true)
 |-- released_date: timestamp (nullable = true)
 |-- deceased_date: timestamp (nullable = true)
 |-- state: string (nullable = true)



### Display the statistical summary

In [78]:
df.summary().show()

+-------+--------------------+------+----+----------+--------+--------------+--------------------+--------------------+--------------------+------------------+--------+
|summary|          patient_id|   sex| age|   country|province|          city|      infection_case|         infected_by|      contact_number|symptom_onset_date|   state|
+-------+--------------------+------+----+----------+--------+--------------+--------------------+--------------------+--------------------+------------------+--------+
|  count|                5165|  4043|3785|      5165|    5165|          5071|                4246|                1346|                 791|               690|    5165|
|   mean|2.8636345618679576E9|  null|null|      null|    null|          null|                null|2.2845944015643125E9|1.6772572523506988E7|              null|    null|
| stddev| 2.074210725277473E9|  null|null|      null|    null|          null|                null|1.5265072953383324E9| 3.093097580985502E8|              n

### Using the state column.
### How many people survived (released), and how many didn't survive (isolated/deceased)?

In [79]:
df.select("state").show()

+--------+
|   state|
+--------+
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|deceased|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
+--------+
only showing top 20 rows



In [80]:
df.groupby(col("state")).count().show()

+--------+-----+
|   state|count|
+--------+-----+
|isolated| 2158|
|released| 2929|
|deceased|   78|
+--------+-----+



In [81]:
survived_released=df.select(col("state")).\
          where(col("state")=="released").count()
survived_released

2929

In [82]:
survived_isolated=df.select(col("state")).\
          where(col("state")=="isolated").count()
survived_isolated  

2158

In [83]:
survived_deceased=df.select(col("state")).\
          where(col("state")=="deceased").count()
survived_deceased

78

### Display the number of null values in each column

In [84]:
df.select([count(when(isnull(c), c)).alias(c) for c in df.columns]).show()

+----------+----+----+-------+--------+----+--------------+-----------+--------------+------------------+--------------+-------------+-------------+-----+
|patient_id| sex| age|country|province|city|infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|state|
+----------+----+----+-------+--------+----+--------------+-----------+--------------+------------------+--------------+-------------+-------------+-----+
|         0|1122|1380|      0|       0|  94|           919|       3819|          4374|              4475|             3|         3578|         5099|    0|
+----------+----+----+-------+--------+----+--------------+-----------+--------------+------------------+--------------+-------------+-------------+-----+



## Data preprocessing

### Fill the nulls in the deceased_date with the released_date. 
- You can use <b>coalesce</b> function

In [85]:
df.select("deceased_date").show() 

+-------------+
|deceased_date|
+-------------+
|         null|
|         null|
|         null|
|         null|
|         null|
|         null|
|         null|
|         null|
|         null|
|         null|
|         null|
|         null|
|         null|
|         null|
|         null|
|         null|
|         null|
|         null|
|         null|
|         null|
+-------------+
only showing top 20 rows



In [86]:
df.select("released_date").show()

+-------------------+
|      released_date|
+-------------------+
|2020-02-05 00:00:00|
|2020-03-02 00:00:00|
|2020-02-19 00:00:00|
|2020-02-15 00:00:00|
|2020-02-24 00:00:00|
|2020-02-19 00:00:00|
|2020-02-10 00:00:00|
|2020-02-24 00:00:00|
|2020-02-21 00:00:00|
|2020-02-29 00:00:00|
|2020-02-29 00:00:00|
|2020-02-27 00:00:00|
|               null|
|2020-03-12 00:00:00|
|               null|
|2020-03-11 00:00:00|
|2020-03-01 00:00:00|
|               null|
|2020-03-08 00:00:00|
|               null|
+-------------------+
only showing top 20 rows



In [87]:
df=df.withColumn("deceased_date",coalesce(df.released_date,df.deceased_date))

In [88]:
df.printSchema()

root
 |-- patient_id: long (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: string (nullable = true)
 |-- country: string (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- infected_by: string (nullable = true)
 |-- contact_number: string (nullable = true)
 |-- symptom_onset_date: string (nullable = true)
 |-- confirmed_date: timestamp (nullable = true)
 |-- released_date: timestamp (nullable = true)
 |-- deceased_date: timestamp (nullable = true)
 |-- state: string (nullable = true)



### Add a column named no_days which is difference between the deceased_date and the confirmed_date then show the top 5 rows. Print the schema.
- <b> Hint: You need to typecast these columns as date first <b>

In [89]:

df.select("confirmed_date","deceased_date").show()

+-------------------+-------------------+
|     confirmed_date|      deceased_date|
+-------------------+-------------------+
|2020-01-23 00:00:00|2020-02-05 00:00:00|
|2020-01-30 00:00:00|2020-03-02 00:00:00|
|2020-01-30 00:00:00|2020-02-19 00:00:00|
|2020-01-30 00:00:00|2020-02-15 00:00:00|
|2020-01-31 00:00:00|2020-02-24 00:00:00|
|2020-01-31 00:00:00|2020-02-19 00:00:00|
|2020-01-31 00:00:00|2020-02-10 00:00:00|
|2020-02-02 00:00:00|2020-02-24 00:00:00|
|2020-02-05 00:00:00|2020-02-21 00:00:00|
|2020-02-05 00:00:00|2020-02-29 00:00:00|
|2020-02-06 00:00:00|2020-02-29 00:00:00|
|2020-02-07 00:00:00|2020-02-27 00:00:00|
|2020-02-16 00:00:00|               null|
|2020-02-16 00:00:00|2020-03-12 00:00:00|
|2020-02-19 00:00:00|               null|
|2020-02-19 00:00:00|2020-03-11 00:00:00|
|2020-02-20 00:00:00|2020-03-01 00:00:00|
|2020-02-20 00:00:00|               null|
|2020-02-20 00:00:00|2020-03-08 00:00:00|
|2020-02-20 00:00:00|               null|
+-------------------+-------------

In [90]:
#casting dates
df=df.withColumn("confirmed_date",to_timestamp(col("confirmed_date"),"mm/dd/yyyy"))
df=df.withColumn("deceased_date",to_timestamp(col("deceased_date"),"mm/dd/yyyy"))

In [91]:
#difference between confirmed_date,deceased_date
df=df.withColumn("no_days",datediff(col("deceased_date"),col("confirmed_date")))

In [92]:
df.printSchema()

root
 |-- patient_id: long (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: string (nullable = true)
 |-- country: string (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- infected_by: string (nullable = true)
 |-- contact_number: string (nullable = true)
 |-- symptom_onset_date: string (nullable = true)
 |-- confirmed_date: timestamp (nullable = true)
 |-- released_date: timestamp (nullable = true)
 |-- deceased_date: timestamp (nullable = true)
 |-- state: string (nullable = true)
 |-- no_days: integer (nullable = true)



### Add a is_male column if male then it should yield true, else then False

In [93]:
df.select("sex").show()

+------+
|   sex|
+------+
|  male|
|  male|
|  male|
|  male|
|female|
|female|
|  male|
|  male|
|  male|
|female|
|female|
|  male|
|  male|
|female|
|  male|
|  male|
|  male|
|  male|
|female|
|female|
+------+
only showing top 20 rows



In [94]:
df=df.withColumn("is_male",when(col("sex")=="male",1).otherwise(when(col("sex")=="female",0)))

In [95]:
df.select("is_male").show()

+-------+
|is_male|
+-------+
|      1|
|      1|
|      1|
|      1|
|      0|
|      0|
|      1|
|      1|
|      1|
|      0|
|      0|
|      1|
|      1|
|      0|
|      1|
|      1|
|      1|
|      1|
|      0|
|      0|
+-------+
only showing top 20 rows



### Add a is_dead column if patient state is not released then it should yield true, else then False

- Use <b>UDF</b> to perform this task. 
- However, UDF is not recommended there is no built in function can do the required operation.
- UDF is slower than built in functions.

In [96]:
df.select("state").show()

+--------+
|   state|
+--------+
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
|deceased|
|released|
|released|
|released|
|released|
|released|
|released|
|released|
+--------+
only showing top 20 rows



In [97]:
df=df.withColumn("is_dead",when(col("state")!="released",1).otherwise(0))
df.printSchema()

root
 |-- patient_id: long (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: string (nullable = true)
 |-- country: string (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- infected_by: string (nullable = true)
 |-- contact_number: string (nullable = true)
 |-- symptom_onset_date: string (nullable = true)
 |-- confirmed_date: timestamp (nullable = true)
 |-- released_date: timestamp (nullable = true)
 |-- deceased_date: timestamp (nullable = true)
 |-- state: string (nullable = true)
 |-- no_days: integer (nullable = true)
 |-- is_male: integer (nullable = true)
 |-- is_dead: integer (nullable = false)



In [98]:
df.select("is_dead","state").show()

+-------+--------+
|is_dead|   state|
+-------+--------+
|      0|released|
|      0|released|
|      0|released|
|      0|released|
|      0|released|
|      0|released|
|      0|released|
|      0|released|
|      0|released|
|      0|released|
|      0|released|
|      0|released|
|      1|deceased|
|      0|released|
|      0|released|
|      0|released|
|      0|released|
|      0|released|
|      0|released|
|      0|released|
+-------+--------+
only showing top 20 rows



### Change the ages to bins from 10s, 0s, 10s, 20s,.etc to 0,10, 20

In [99]:
# def removeChar(str):
#   if str==None:
#     return None
#   return int(str.replace("s",""))

In [100]:
# from pyspark.sql.types import IntegerType
# convertUDF = udf(removeChar,IntegerType())

In [101]:
df=df.withColumn("age2",translate(col("age"),"s",""))

In [102]:
df.printSchema()

root
 |-- patient_id: long (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: string (nullable = true)
 |-- country: string (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- infected_by: string (nullable = true)
 |-- contact_number: string (nullable = true)
 |-- symptom_onset_date: string (nullable = true)
 |-- confirmed_date: timestamp (nullable = true)
 |-- released_date: timestamp (nullable = true)
 |-- deceased_date: timestamp (nullable = true)
 |-- state: string (nullable = true)
 |-- no_days: integer (nullable = true)
 |-- is_male: integer (nullable = true)
 |-- is_dead: integer (nullable = false)
 |-- age2: string (nullable = true)



In [103]:
df.select("age2").show()

+----+
|age2|
+----+
|  50|
|  30|
|  50|
|  20|
|  20|
|  50|
|  20|
|  20|
|  30|
|  60|
|  50|
|  20|
|  80|
|  60|
|  70|
|  70|
|  70|
|  20|
|  70|
|  70|
+----+
only showing top 20 rows



### Change age, and no_days  to be typecasted as Double

In [104]:
df.select("age2","no_days").show()

+----+-------+
|age2|no_days|
+----+-------+
|  50|     13|
|  30|     32|
|  50|     20|
|  20|     16|
|  20|     24|
|  50|     19|
|  20|     10|
|  20|     22|
|  30|     16|
|  60|     24|
|  50|     23|
|  20|     20|
|  80|   null|
|  60|     25|
|  70|   null|
|  70|     21|
|  70|     10|
|  20|   null|
|  70|     17|
|  70|   null|
+----+-------+
only showing top 20 rows



In [105]:
df=df.withColumn("age2",fn.col('age2').cast(DoubleType()))
df=df.withColumn("no_days",fn.col('no_days').cast(DoubleType()))

In [106]:
df.show(5)

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+-------------------+-------------------+-------------------+--------+-------+-------+-------+----+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|     confirmed_date|      released_date|      deceased_date|   state|no_days|is_male|is_dead|age2|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+-------------------+-------------------+-------------------+--------+-------+-------+-------+----+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|2020-01-23 00:00:00|2020-02-05 00:00:00|2020-02-05 00:00:00|released|   13.0|      1|      0|50.0|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       null|            31|              null|2020-01-3

In [107]:
df.printSchema()

root
 |-- patient_id: long (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: string (nullable = true)
 |-- country: string (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- infected_by: string (nullable = true)
 |-- contact_number: string (nullable = true)
 |-- symptom_onset_date: string (nullable = true)
 |-- confirmed_date: timestamp (nullable = true)
 |-- released_date: timestamp (nullable = true)
 |-- deceased_date: timestamp (nullable = true)
 |-- state: string (nullable = true)
 |-- no_days: double (nullable = true)
 |-- is_male: integer (nullable = true)
 |-- is_dead: integer (nullable = false)
 |-- age2: double (nullable = true)



### Drop the columns
["patient_id","sex","infected_by","contact_number","released_date","state",
"symptom_onset_date","confirmed_date","deceased_date","country","no_days",
"city","infection_case"]

In [108]:
selected_col=["patient_id","sex","infected_by","contact_number","released_date","state", "symptom_onset_date","confirmed_date","deceased_date","country","no_days", "city","infection_case"]
df=df.drop(*selected_col)

In [109]:
df.printSchema()

root
 |-- age: string (nullable = true)
 |-- province: string (nullable = true)
 |-- is_male: integer (nullable = true)
 |-- is_dead: integer (nullable = false)
 |-- age2: double (nullable = true)



### Recount the number of nulls now

In [110]:
df.select([count(when(isnull(c), c)).alias(c) for c in df.columns]).show()

+----+--------+-------+-------+----+
| age|province|is_male|is_dead|age2|
+----+--------+-------+-------+----+
|1380|       0|   1122|      0|1380|
+----+--------+-------+-------+----+



In [111]:
df.select([count(when(isnull(c), c)).alias(c) for c in df.columns]).show()

+----+--------+-------+-------+----+
| age|province|is_male|is_dead|age2|
+----+--------+-------+-------+----+
|1380|       0|   1122|      0|1380|
+----+--------+-------+-------+----+



In [112]:
df=df.filter(col("is_male").isNotNull())

In [48]:
df=df.drop(col("age"))

In [49]:
df.printSchema()

root
 |-- province: string (nullable = true)
 |-- is_male: integer (nullable = true)
 |-- is_dead: integer (nullable = false)
 |-- age2: double (nullable = true)



## Now do the same but using SQL select statement

### From the original Patient DataFrame, Create a temporary view (table).

In [152]:
data=spark.read.csv("/content/drive/MyDrive/Spark_data/data/PatientInfo.csv",header=True,inferSchema=True)

### Use SELECT statement to select all columns from the dataframe and show the output.

In [156]:
data.createOrReplaceTempView("data_view")

### *Using SQL commands*, limit the output to only 5 rows 

In [53]:
spark.sql("select * from data_view limit(5)").show()

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+-------------------+-------------------+-------------+--------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|     confirmed_date|      released_date|deceased_date|   state|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+-------------------+-------------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|2020-01-23 00:00:00|2020-02-05 00:00:00|         null|released|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       null|            31|              null|2020-01-30 00:00:00|2020-03-02 00:00:00|         null|released|
|1000000003|  male|50s|  Korea|   Seoul|  Jongno-gu|contact with patient| 2002000001|

### Select the count of males and females in the dataset

In [120]:
spark.sql("select count(sex) as cnt_male  from data_view where  sex='male' ").show()

+--------+
|cnt_male|
+--------+
|    1825|
+--------+



In [119]:
spark.sql("select count(sex)  as cnt_female from data_view where  sex='female' ").show()

+----------+
|cnt_female|
+----------+
|      2218|
+----------+



### How many people did survive, and how many didn't?

In [124]:
spark.sql("select count(state) as survived  from data_view where state=='released' ").show()

+--------+
|survived|
+--------+
|    2929|
+--------+



In [128]:
spark.sql("select count(state) as dead  from data_view where (state=='isolated' or state== 'deceased') ").show()

+----+
|dead|
+----+
|2236|
+----+



### Now, let's perform some preprocessing using SQL:
1. Convert *age* column to double after removing the 's' at the end -- *hint: check SUBSTRING method*
2. Select only the following columns: `['sex', 'age', 'province', 'state']`
3. Store the result of the query in a new dataframe

In [166]:
df_sql=spark.sql(" select sex,CAST(substring(age,1,length(age)-1) as double) as age ,province,state  from data_view")

In [167]:
df_sql.show()

+------+----+--------+--------+
|   sex| age|province|   state|
+------+----+--------+--------+
|  male|50.0|   Seoul|released|
|  male|30.0|   Seoul|released|
|  male|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|female|20.0|   Seoul|released|
|female|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|30.0|   Seoul|released|
|female|60.0|   Seoul|released|
|female|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|80.0|   Seoul|deceased|
|female|60.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|female|70.0|   Seoul|released|
|female|70.0|   Seoul|released|
+------+----+--------+--------+
only showing top 20 rows



## Machine Learning 
### Create a pipeline model to predict is_dead and evaluate the performance.
- Use <b>StringIndexer</b> to transform <b>string</b> data type to indices.
- Use <b>OneHotEncoder</b> to deal with categorical values.
- Use <b>Imputer</b> to fill missing data with mean.

In [None]:
traindf,testdf=df.randomSplit([.8,.2],seed=42)

In [None]:
traindf.show()

+--------+-------+-------+----+
|province|is_male|is_dead|age2|
+--------+-------+-------+----+
|   Busan|      0|      0| 0.0|
|   Busan|      0|      0| 0.0|
|   Busan|      0|      0|10.0|
|   Busan|      0|      0|10.0|
|   Busan|      0|      0|20.0|
|   Busan|      0|      0|20.0|
|   Busan|      0|      0|20.0|
|   Busan|      0|      0|20.0|
|   Busan|      0|      0|20.0|
|   Busan|      0|      0|20.0|
|   Busan|      0|      0|20.0|
|   Busan|      0|      0|20.0|
|   Busan|      0|      0|20.0|
|   Busan|      0|      0|20.0|
|   Busan|      0|      0|20.0|
|   Busan|      0|      0|20.0|
|   Busan|      0|      0|20.0|
|   Busan|      0|      0|30.0|
|   Busan|      0|      0|30.0|
|   Busan|      0|      0|30.0|
+--------+-------+-------+----+
only showing top 20 rows



In [None]:
dtypes=df.dtypes


In [None]:
dtypes

[('province', 'string'),
 ('is_male', 'int'),
 ('is_dead', 'int'),
 ('age2', 'double')]

In [None]:
CATE_TYPE=[f for (f,d) in dtypes if d=="string"]
CATE_TYPE

['province']

In [None]:
NUM_TYPE=[f for (f,d) in dtypes if d !="string" and f !="is_dead"]
NUM_TYPE

['is_male', 'age2']

In [None]:
inx_out_cat_col=[f+"_index" for (f,d) in dtypes if d=="string"]

In [None]:
OHE_out_cat_col=[f+"OHE" for (f,d) in dtypes if d=="string"]

In [None]:
vector_col_input=inx_out_cat_col+OHE_out_cat_col

In [None]:
from pyspark.ml.feature import  StringIndexer,VectorAssembler,OneHotEncoder,Imputer
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.pipeline import Pipeline

In [None]:
SRITNG_INDEXER=StringIndexer(inputCols=CATE_TYPE,outputCols=inx_out_cat_col,handleInvalid="skip")
ONE_HOT_ENCODIN=OneHotEncoder(inputCols=inx_out_cat_col,outputCols=OHE_out_cat_col )
VECTOR_ASSEMBLER=VectorAssembler(inputCols=vector_col_input,outputCol="features")
imputer = Imputer(inputCols=['age2'], outputCols=['imputed_Age'],strategy="mean")

In [None]:
rf=RandomForestClassifier(numTrees=3, maxDepth=2, labelCol="is_dead", seed=42,predictionCol="prediction")

In [None]:
myStages=[SRITNG_INDEXER,ONE_HOT_ENCODIN,VECTOR_ASSEMBLER,imputer,rf]
pl=Pipeline(stages=myStages)
pl_model=pl.fit(traindf)

In [None]:
pred_df=pl_model.transform(testdf)

In [None]:
pred_df.printSchema()

root
 |-- province: string (nullable = true)
 |-- is_male: integer (nullable = true)
 |-- is_dead: integer (nullable = false)
 |-- age2: double (nullable = true)
 |-- province_index: double (nullable = false)
 |-- provinceOHE: vector (nullable = true)
 |-- features: vector (nullable = true)
 |-- imputed_Age: double (nullable = true)
 |-- rawPrediction: vector (nullable = true)
 |-- probability: vector (nullable = true)
 |-- prediction: double (nullable = false)



In [None]:
pred_df.select("age2","province","is_male","prediction","is_dead").show()

+----+--------+-------+----------+-------+
|age2|province|is_male|prediction|is_dead|
+----+--------+-------+----------+-------+
|10.0|   Busan|      0|       0.0|      0|
|20.0|   Busan|      0|       0.0|      0|
|20.0|   Busan|      0|       0.0|      0|
|20.0|   Busan|      0|       0.0|      0|
|20.0|   Busan|      0|       0.0|      0|
|30.0|   Busan|      0|       0.0|      0|
|30.0|   Busan|      0|       0.0|      0|
|40.0|   Busan|      0|       0.0|      0|
|50.0|   Busan|      0|       0.0|      0|
|50.0|   Busan|      0|       0.0|      0|
|60.0|   Busan|      0|       0.0|      0|
|60.0|   Busan|      0|       0.0|      0|
|60.0|   Busan|      0|       0.0|      0|
|70.0|   Busan|      0|       0.0|      0|
|30.0|   Busan|      0|       0.0|      1|
|80.0|   Busan|      0|       0.0|      1|
|10.0|   Busan|      1|       0.0|      0|
|20.0|   Busan|      1|       0.0|      0|
|50.0|   Busan|      1|       0.0|      0|
|60.0|   Busan|      1|       0.0|      0|
+----+-----

In [None]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator,BinaryClassificationEvaluator

In [None]:
evaluatorMulti = MulticlassClassificationEvaluator(labelCol="is_dead", predictionCol="prediction")
evaluator = BinaryClassificationEvaluator(labelCol="is_dead", rawPredictionCol="prediction", metricName='areaUnderROC')

In [None]:
predictionAndTarget =pred_df.select("prediction","is_dead")

In [None]:
acc = evaluatorMulti.evaluate(predictionAndTarget,{evaluatorMulti.metricName:"accuracy"})
f1 = evaluatorMulti.evaluate(predictionAndTarget,{evaluatorMulti.metricName:"f1"})

In [None]:
acc

0.8893333333333333

In [None]:
f1

0.8881771352190513