# **Labs 1 and 2 PySpark:**

In these labs we will be using the "[[NeurIPS 2020] Data Science for COVID-19 (DS4C)](https://www.kaggle.com/datasets/kimjihoo/coronavirusdataset?select=PatientInfo.csv)" dataset, retrieved from [Kaggle](https://www.kaggle.com/) on 1/6/2022, for educational non commercial purpose, License
[CC BY-NC-SA 4.0
](https://creativecommons.org/licenses/by-nc-sa/4.0/)


The csv file that we will be using in this lab is **PatientInfo**.

## PatientInfo.csv

**patient_id**
the ID of the patient

**sex**
the sex of the patient

**age**
the age of the patient

**country**
the country of the patient

**province**
the province of the patient

**city**
the city of the patient

**infection_case**
the case of infection

**infected_by**
the ID of who infected the patient


**contact_number**
the number of contacts with people

**symptom_onset_date**
the date of symptom onset

**confirmed_date**
the date of being confirmed

**released_date**
the date of being released

**deceased_date**
the date of being deceased

**state**
isolated / released / deceased

### Import the pyspark and check it's version

In [1]:
pip install pyspark py4j

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyspark
  Downloading pyspark-3.3.0.tar.gz (281.3 MB)
[K     |████████████████████████████████| 281.3 MB 49 kB/s 
[?25hCollecting py4j
  Downloading py4j-0.10.9.5-py2.py3-none-any.whl (199 kB)
[K     |████████████████████████████████| 199 kB 50.3 MB/s 
[?25hBuilding wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.3.0-py2.py3-none-any.whl size=281764026 sha256=f5178e339c478e6a9225b947ae391c103de9d2678604de093b2709296608bcd5
  Stored in directory: /root/.cache/pip/wheels/7a/8e/1b/f73a52650d2e5f337708d9f6a1750d451a7349a867f928b885
Successfully built pyspark
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9.5 pyspark-3.3.0


### Import and create SparkSession

In [2]:
import pyspark
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('SparkSQL').getOrCreate()

### Load the PatientInfo.csv file and show the first 5 rows

In [3]:
from IPython.display import display, HTML
display(HTML("<style>pre { white-space: pre !important; }</style>"))

In [5]:
df= spark.read.format('csv')\
.option("header", "true")\
.option('infereSchema', 'true')\
.load("PatientInfo - PatientInfo.csv")

### Display the schema of the dataset

In [6]:
df.printSchema()

root
 |-- patient_id: string (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: string (nullable = true)
 |-- country: string (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- infected_by: string (nullable = true)
 |-- contact_number: string (nullable = true)
 |-- symptom_onset_date: string (nullable = true)
 |-- confirmed_date: string (nullable = true)
 |-- released_date: string (nullable = true)
 |-- deceased_date: string (nullable = true)
 |-- state: string (nullable = true)



### Display the statistical summary

In [7]:
df.show(5)

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|         null|released|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|         null|released|
|1000000003|  male|50s|  Korea|   Seoul|  Jongno-gu|contact with patient| 2002000001|            17|              null|    2020-01-30|   202

In [8]:
df.summary().show()

+-------+--------------------+------+----+----------+--------+--------------+--------------------+--------------------+--------------------+------------------+--------------+-------------+-------------+--------+
|summary|          patient_id|   sex| age|   country|province|          city|      infection_case|         infected_by|      contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+-------+--------------------+------+----+----------+--------+--------------+--------------------+--------------------+--------------------+------------------+--------------+-------------+-------------+--------+
|  count|                5165|  4043|3785|      5165|    5165|          5071|                4246|                1346|                 791|               689|          5162|         1587|           66|    5165|
|   mean|2.8636345618679576E9|  null|null|      null|    null|          null|                null|2.2845944015643125E9|1.6772572523506988E7|            

### Using the state column.
### How many people survived (released), and how many didn't survive (isolated/deceased)?

In [14]:
from pyspark.sql.functions import *

In [15]:
df.select('state').where(df.state=='released').count()

2929

In [17]:
df.select('state').where((col('state')=='isolated')|(col('state')=='deceased')).count()

2236

### Display the number of null values in each column

In [18]:
df.select([count(when(isnan(c) | col(c).isNull(), c)).alias(c) for c in df.columns]
   ).show()

+----------+----+----+-------+--------+----+--------------+-----------+--------------+------------------+--------------+-------------+-------------+-----+
|patient_id| sex| age|country|province|city|infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|state|
+----------+----+----+-------+--------+----+--------------+-----------+--------------+------------------+--------------+-------------+-------------+-----+
|         0|1122|1380|      0|       0|  94|           919|       3819|          4374|              4476|             3|         3578|         5099|    0|
+----------+----+----+-------+--------+----+--------------+-----------+--------------+------------------+--------------+-------------+-------------+-----+



## Data preprocessing

### Fill the nulls in the deceased_date with the released_date. 
- You can use <b>coalesce</b> function

In [20]:
tmp_df = df.withColumn('deceased_date', coalesce(df['deceased_date'], df['released_date']))

### Add a column named no_days which is difference between the deceased_date and the confirmed_date then show the top 5 rows. Print the schema.
- <b> Hint: You need to typecast these columns as date first <b>

In [21]:
tmp_df.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-gu|contact with patient| 2002000001|            17|              null|    2020-01-30|

In [22]:
from  pyspark.sql.types import DateType
from  pyspark.sql.types import DoubleType
from  pyspark.sql.types import IntegerType


In [23]:
df2 =tmp_df.withColumn("deceased_date",col("deceased_date").cast(DateType()))\
                       .withColumn("confirmed_date",col("confirmed_date").cast(DateType()))


In [24]:
df3= df2.withColumn("day_off",col("deceased_date") -col("confirmed_date"))
df3.select("day_off").show(5)

+-----------------+
|          day_off|
+-----------------+
|INTERVAL '13' DAY|
|INTERVAL '32' DAY|
|INTERVAL '20' DAY|
|INTERVAL '16' DAY|
|INTERVAL '24' DAY|
+-----------------+
only showing top 5 rows



In [25]:
df3.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-----------------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|          day_off|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-----------------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|INTERVAL '13' DAY|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|INTERVAL '32' DAY|
|1000000003|  male|50s|  Korea|   Seoul|   Jo

In [26]:
df4= df2.withColumn("day_off",datediff(df2.deceased_date,df2.confirmed_date))


In [27]:
df4.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|day_off|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-gu|contact with patient| 2002000001|         

### Add a is_male column if male then it should yield true, else then False

In [28]:
df5=df4.withColumn("is_male",df4.sex=='male')

In [29]:
df5.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|day_off|is_male|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|   true|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|   true|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-gu|co

### Add a is_dead column if patient state is not released then it should yield true, else then False

- Use <b>UDF</b> to perform this task. 
- However, UDF is not recommended there is no built in function can do the required operation.
- UDF is slower than built in functions.

In [30]:
udf_fun=udf(lambda z: z!='released')
df6=df5.withColumn("is_dead",udf_fun(col('state')))

In [31]:
df6.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|day_off|is_male|is_dead|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|   true|  false|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|   true|  false|
|1000000003|  m

### Change the ages to bins from 10s, 0s, 10s, 20s,.etc to 0,10, 20

In [32]:

df7=df6.withColumn("age",split(df6['age'], 's').getItem(0))


In [33]:
df7.show(5)

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|day_off|is_male|is_dead|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|1000000001|  male| 50|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|   true|  false|
|1000000002|  male| 30|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|   true|  false|
|1000000003|  male| 

### Change age, and no_days  to be typecasted as Double

In [34]:
df8=df7.withColumn("age",col("age").cast(DoubleType()))\
.withColumn("day_off",col("day_off").cast(DoubleType()))

### Drop the columns
["patient_id","sex","infected_by","contact_number","released_date","state",
"symptom_onset_date","confirmed_date","deceased_date","country","no_days",
"city","infection_case"]

In [35]:
cols = ("patient_id","sex","infected_by","contact_number","released_date",
        "state", "symptom_onset_date","confirmed_date"
,"deceased_date","country","day_off", "city","infection_case")

df9=df8.drop(*cols)
  

In [36]:
df9.show()

+----+--------+-------+-------+
| age|province|is_male|is_dead|
+----+--------+-------+-------+
|50.0|   Seoul|   true|  false|
|30.0|   Seoul|   true|  false|
|50.0|   Seoul|   true|  false|
|20.0|   Seoul|   true|  false|
|20.0|   Seoul|  false|  false|
|50.0|   Seoul|  false|  false|
|20.0|   Seoul|   true|  false|
|20.0|   Seoul|   true|  false|
|30.0|   Seoul|   true|  false|
|60.0|   Seoul|  false|  false|
|50.0|   Seoul|  false|  false|
|20.0|   Seoul|   true|  false|
|80.0|   Seoul|   true|   true|
|60.0|   Seoul|  false|  false|
|70.0|   Seoul|   true|  false|
|70.0|   Seoul|   true|  false|
|70.0|   Seoul|   true|  false|
|20.0|   Seoul|   true|  false|
|70.0|   Seoul|  false|  false|
|70.0|   Seoul|  false|  false|
+----+--------+-------+-------+
only showing top 20 rows



### Recount the number of nulls now

In [38]:
df9.select([count(when(col(c).isNull(), c)).alias(c) for c in df9.columns]
   ).show()

+----+--------+-------+-------+
| age|province|is_male|is_dead|
+----+--------+-------+-------+
|1380|       0|   1122|      0|
+----+--------+-------+-------+



## Now do the same but using SQL select statement

### From the original Patient DataFrame, Create a temporary view (table).

In [46]:
df.createOrReplaceTempView('tableview')

### Use SELECT statement to select all columns from the dataframe and show the output.

In [47]:
spark.sql('select * from tableview').show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|         null|released|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|         null|released|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-gu|contact with patient| 2002000001|            17|              null|    2020-01-30|

### *Using SQL commands*, limit the output to only 5 rows 

In [48]:
spark.sql('select * from tableview limit 5').show()

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|         null|released|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|         null|released|
|1000000003|  male|50s|  Korea|   Seoul|  Jongno-gu|contact with patient| 2002000001|            17|              null|    2020-01-30|   202

### Select the count of males and females in the dataset

In [66]:
#count of males
spark.sql("select Count(sex) from tableview where sex='male'").show()

+----------+
|count(sex)|
+----------+
|      1825|
+----------+



In [67]:
#count of female
spark.sql("select Count(sex) from tableview where sex='female'").show()

+----------+
|count(sex)|
+----------+
|      2218|
+----------+



### How many people did survive, and how many didn't?

In [68]:
#number of survive
spark.sql("select Count(state) from tableview where state='released'").show()

+------------+
|count(state)|
+------------+
|        2929|
+------------+



In [69]:
#number of  Not survive
spark.sql("select Count(state) from tableview where state='isolated' or state='deceased' ").show()

+------------+
|count(state)|
+------------+
|        2236|
+------------+



### Now, let's perform some preprocessing using SQL:
1. Convert *age* column to double after removing the 's' at the end -- *hint: check SUBSTRING method*
2. Select only the following columns: `['sex', 'age', 'province', 'state']`
3. Store the result of the query in a new dataframe

In [94]:
df_sql=df.withColumn('age',substring('age',1,2))


In [95]:
df_sql.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male| 50|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       null|            75|        2020-01-22|    2020-01-23|   2020-02-05|         null|released|
|1000000002|  male| 30|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       null|            31|              null|    2020-01-30|   2020-03-02|         null|released|
|1000000003|  male| 50|  Korea|   Seoul|   Jongno-gu|contact with patient| 2002000001|            17|              null|    2020-01-30|

In [97]:
df_sql=df_sql.withColumn("age",df_sql.age.cast(DoubleType()))

In [98]:
df_sql.createOrReplaceTempView('tableview1')

In [103]:
df_new=spark.sql("select sex, age, province, state from tableview1")

In [104]:
df_new.show(5)

+------+----+--------+--------+
|   sex| age|province|   state|
+------+----+--------+--------+
|  male|50.0|   Seoul|released|
|  male|30.0|   Seoul|released|
|  male|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|female|20.0|   Seoul|released|
+------+----+--------+--------+
only showing top 5 rows



## Machine Learning 
### Create a pipeline model to predict is_dead and evaluate the performance.
- Use <b>StringIndexer</b> to transform <b>string</b> data type to indices.
- Use <b>OneHotEncoder</b> to deal with categorical values.
- Use <b>Imputer</b> to fill missing data with mean.

In [105]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder
from pyspark.ml.feature import VectorAssembler

In [106]:
df9.show()

+----+--------+-------+-------+
| age|province|is_male|is_dead|
+----+--------+-------+-------+
|50.0|   Seoul|   true|  false|
|30.0|   Seoul|   true|  false|
|50.0|   Seoul|   true|  false|
|20.0|   Seoul|   true|  false|
|20.0|   Seoul|  false|  false|
|50.0|   Seoul|  false|  false|
|20.0|   Seoul|   true|  false|
|20.0|   Seoul|   true|  false|
|30.0|   Seoul|   true|  false|
|60.0|   Seoul|  false|  false|
|50.0|   Seoul|  false|  false|
|20.0|   Seoul|   true|  false|
|80.0|   Seoul|   true|   true|
|60.0|   Seoul|  false|  false|
|70.0|   Seoul|   true|  false|
|70.0|   Seoul|   true|  false|
|70.0|   Seoul|   true|  false|
|20.0|   Seoul|   true|  false|
|70.0|   Seoul|  false|  false|
|70.0|   Seoul|  false|  false|
+----+--------+-------+-------+
only showing top 20 rows



convert is_dead to number labels

In [107]:
udf_con=udf(lambda z:   1 if z==True else 0,IntegerType()) 
df10=df9.withColumn("is_dead",udf_con(col("is_dead")))

In [108]:
df10.show()

+----+--------+-------+-------+
| age|province|is_male|is_dead|
+----+--------+-------+-------+
|50.0|   Seoul|   true|      0|
|30.0|   Seoul|   true|      0|
|50.0|   Seoul|   true|      0|
|20.0|   Seoul|   true|      0|
|20.0|   Seoul|  false|      0|
|50.0|   Seoul|  false|      0|
|20.0|   Seoul|   true|      0|
|20.0|   Seoul|   true|      0|
|30.0|   Seoul|   true|      0|
|60.0|   Seoul|  false|      0|
|50.0|   Seoul|  false|      0|
|20.0|   Seoul|   true|      0|
|80.0|   Seoul|   true|      1|
|60.0|   Seoul|  false|      0|
|70.0|   Seoul|   true|      0|
|70.0|   Seoul|   true|      0|
|70.0|   Seoul|   true|      0|
|20.0|   Seoul|   true|      0|
|70.0|   Seoul|  false|      0|
|70.0|   Seoul|  false|      0|
+----+--------+-------+-------+
only showing top 20 rows



drop all missing data

In [109]:
df11=df10.na.drop(how="any")

In [110]:
df12=df11.withColumn("is_male",udf_con(col("is_male")))

In [111]:
df12.show()

+----+--------+-------+-------+
| age|province|is_male|is_dead|
+----+--------+-------+-------+
|50.0|   Seoul|      1|      0|
|30.0|   Seoul|      1|      0|
|50.0|   Seoul|      1|      0|
|20.0|   Seoul|      1|      0|
|20.0|   Seoul|      0|      0|
|50.0|   Seoul|      0|      0|
|20.0|   Seoul|      1|      0|
|20.0|   Seoul|      1|      0|
|30.0|   Seoul|      1|      0|
|60.0|   Seoul|      0|      0|
|50.0|   Seoul|      0|      0|
|20.0|   Seoul|      1|      0|
|80.0|   Seoul|      1|      1|
|60.0|   Seoul|      0|      0|
|70.0|   Seoul|      1|      0|
|70.0|   Seoul|      1|      0|
|70.0|   Seoul|      1|      0|
|20.0|   Seoul|      1|      0|
|70.0|   Seoul|      0|      0|
|70.0|   Seoul|      0|      0|
+----+--------+-------+-------+
only showing top 20 rows



know missing data info

In [112]:
df11.select([count(when(col(c).isNull(), c)).alias(c) for c in df11.columns]
   ).show()

+---+--------+-------+-------+
|age|province|is_male|is_dead|
+---+--------+-------+-------+
|  0|       0|      0|      0|
+---+--------+-------+-------+



seperate categorical data

In [113]:
categoricalCols = [field for (field, dataType) in df9.dtypes
                  if ((dataType == 'string')&(field!='is_dead'))]
categoricalCols

['province']

In [114]:
indexOutputCols = [x + "_Index" for x in categoricalCols]
indexOutputCols

['province_Index']

In [115]:
oheOutputCols = [x + "_OHE" for x in categoricalCols]
oheOutputCols

['province_OHE']

StringIndexer 

In [116]:
stringIndexer = StringIndexer(inputCols=categoricalCols,
                              outputCols=indexOutputCols,
                             handleInvalid='skip')

oheEncoder = OneHotEncoder(inputCols=indexOutputCols,
                          outputCols=oheOutputCols)

In [117]:
assemblerInputs = oheOutputCols + ['is_male','age']
assemblerInputs

['province_OHE', 'is_male', 'age']

In [118]:
vecAssembler = VectorAssembler(inputCols=assemblerInputs,outputCol='features')

import libaries

In [119]:
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml import Pipeline

make model and pipline

In [120]:
rf = RandomForestClassifier(labelCol='is_dead',featuresCol='features', numTrees=10)

In [121]:
pipeline = Pipeline(stages=[stringIndexer, oheEncoder, vecAssembler, rf ] )

split data to train and test

In [122]:
(trainingData, testData) = df11.randomSplit([0.7, 0.3])

In [123]:
model = pipeline.fit(trainingData)

In [124]:
predicitions = model.transform(testData)

In [125]:
preds=predicitions.select('prediction')
y=testData.select('is_dead')

In [None]:
pip install pandas

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [None]:
pip install sklearn

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/


In [126]:
preds_pd=preds.toPandas()
y_pd=y.toPandas()

In [127]:
from sklearn.metrics import f1_score

accuracy

In [129]:
f1_score(y_pd,preds_pd)

0.8285714285714286