# **Labs 1 PySpark:**

In these labs we will be using the "[[NeurIPS 2020] Data Science for COVID-19 (DS4C)](https://www.kaggle.com/datasets/kimjihoo/coronavirusdataset?select=PatientInfo.csv)" dataset, retrieved from [Kaggle](https://www.kaggle.com/) on 1/6/2022, for educational non commercial purpose, License
[CC BY-NC-SA 4.0
](https://creativecommons.org/licenses/by-nc-sa/4.0/)


The csv file that we will be using in this lab is **PatientInfo**.

## PatientInfo.csv

**patient_id**
the ID of the patient

**sex**
the sex of the patient

**age**
the age of the patient

**country**
the country of the patient

**province**
the province of the patient

**city**
the city of the patient

**infection_case**
the case of infection

**infected_by**
the ID of who infected the patient


**contact_number**
the number of contacts with people

**symptom_onset_date**
the date of symptom onset

**confirmed_date**
the date of being confirmed

**released_date**
the date of being released

**deceased_date**
the date of being deceased

**state**
isolated / released / deceased

### Import the pyspark and check it's version

In [1]:
import findspark
import os

os.environ["SPARK_HOME"] = "C:\\spark\\spark-3.5.5-bin-hadoop3"
findspark.init()

In [2]:
# version of pyspark    
from pyspark import __version__
print(__version__)

3.5.5


In [3]:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
from pyspark.sql.types import *

# Create a local Spark session
spark = SparkSession.builder \
    .appName("lab1") \
    .master("local[*]") \
    .getOrCreate()

### Load the PatientInfo.csv file and show the first 5 rows

In [4]:
from IPython.display import display, HTML
display(HTML("<style>pre { white-space: pre !important; }</style>"))

In [5]:
df=spark.read.csv("..\Data\PatientInfo.csv",header=True,inferSchema=True)
df.show(10)

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       NULL|            75|        2020-01-22|    2020-01-23|   2020-02-05|         NULL|released|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       NULL|            31|              NULL|    2020-01-30|   2020-03-02|         NULL|released|
|1000000003|  male|50s|  Korea|   Seoul|  Jongno-gu|contact with patient| 2002000001|            17|              NULL|    2020-01-30|   202

In [6]:
print(df.describe().show())
print("-"*50)
print(df.printSchema())

+-------+--------------------+------+----+----------+--------+--------------+--------------------+--------------------+--------------------+------------------+--------+
|summary|          patient_id|   sex| age|   country|province|          city|      infection_case|         infected_by|      contact_number|symptom_onset_date|   state|
+-------+--------------------+------+----+----------+--------+--------------+--------------------+--------------------+--------------------+------------------+--------+
|  count|                5165|  4043|3785|      5165|    5165|          5071|                4246|                1346|                 791|               690|    5165|
|   mean|2.8636345618679576E9|  NULL|NULL|      NULL|    NULL|          NULL|                NULL|2.2845944015643125E9|1.6772572523506988E7|              NULL|    NULL|
| stddev| 2.074210725277473E9|  NULL|NULL|      NULL|    NULL|          NULL|                NULL|1.5265072953383324E9| 3.093097580985502E8|              N

In [7]:
df.describe()

DataFrame[summary: string, patient_id: string, sex: string, age: string, country: string, province: string, city: string, infection_case: string, infected_by: string, contact_number: string, symptom_onset_date: string, state: string]

In [8]:
df.groupBy("state").count().show()

+--------+-----+
|   state|count|
+--------+-----+
|isolated| 2158|
|released| 2929|
|deceased|   78|
+--------+-----+



In [9]:
for c in df.columns:
  print(c,df.filter(df[c].isNull()).count())
  print("-"*50)

patient_id 0
--------------------------------------------------
sex 1122
--------------------------------------------------
age 1380
--------------------------------------------------
country 0
--------------------------------------------------
province 0
--------------------------------------------------
city 94
--------------------------------------------------
infection_case 919
--------------------------------------------------
infected_by 3819
--------------------------------------------------
contact_number 4374
--------------------------------------------------
symptom_onset_date 4475
--------------------------------------------------
confirmed_date 3
--------------------------------------------------
released_date 3578
--------------------------------------------------
deceased_date 5099
--------------------------------------------------
state 0
--------------------------------------------------


## Data preprocessing

In [10]:
df_deceased= df.withColumn("deceased_date",F.coalesce(df['deceased_date'],df['released_date']))
df_deceased.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       NULL|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       NULL|            31|              NULL|    2020-01-30|   2020-03-02|   2020-03-02|released|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-gu|contact with patient| 2002000001|            17|              NULL|    2020-01-30|

### Add a column named no_days which is difference between the deceased_date and the confirmed_date then show the top 5 rows. Print the schema.
- <b> Hint: You need to typecast these columns as date first <b>

In [11]:
df_deceased.printSchema()

root
 |-- patient_id: long (nullable = true)
 |-- sex: string (nullable = true)
 |-- age: string (nullable = true)
 |-- country: string (nullable = true)
 |-- province: string (nullable = true)
 |-- city: string (nullable = true)
 |-- infection_case: string (nullable = true)
 |-- infected_by: string (nullable = true)
 |-- contact_number: string (nullable = true)
 |-- symptom_onset_date: string (nullable = true)
 |-- confirmed_date: date (nullable = true)
 |-- released_date: date (nullable = true)
 |-- deceased_date: date (nullable = true)
 |-- state: string (nullable = true)



In [12]:
df_with_difference=df_deceased.withColumn("no_days",F.datediff(df_deceased['deceased_date'],
                                                                                   df_deceased['confirmed_date']
                                                                                   ))
df_with_difference.show(5)

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       NULL|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       NULL|            31|              NULL|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|
|1000000003|  male|50s|  Korea|   Seoul|  Jongno-gu|contact with patient| 2002000001|            17|

### Remove null values of sex column.
### Add a is_male column if male then it should yield true, else (Female) then False

In [13]:
df4=df_with_difference.filter(df_with_difference['sex'].isNotNull())
df4.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       NULL|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       NULL|            31|              NULL|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-gu|contact with patient| 2002000001|         

In [14]:
df_male=df4.withColumn("is_male",F.when(df4['sex']=='male','True').otherwise('Female'))
df_male.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|is_male|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       NULL|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|   True|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       NULL|            31|              NULL|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|   True|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-gu|co

### Add a is_dead column if patient state is not released then it should yield true, else then False

- Use <b>UDF</b> to perform this task.
- However, UDF is not recommended there is no built in function can do the required operation.
- UDF is slower than built in functions.

In [15]:
import pyspark.sql.functions as F

In [16]:
def is_state(state):
  if state!='released':
    return True
  else:
    return False

In [17]:
df_dead=df_male.withColumn("is_dead",F.udf(is_state)(df_male['state']))
df_dead.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|is_male|is_dead|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       NULL|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|   True|  false|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       NULL|            31|              NULL|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|   True|  false|
|1000000003|  m

### Change the ages to bins from 10s, 0s, 10s, 20s,.etc to 0,10, 20

In [18]:
df_with_bins=df_dead.withColumn("age",F.regexp_replace(df_dead['age'],'s',''))
df_with_bins.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|is_male|is_dead|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|1000000001|  male| 50|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       NULL|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|   True|  false|
|1000000002|  male| 30|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       NULL|            31|              NULL|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|   True|  false|
|1000000003|  m

### Change age, and no_days  to be typecasted as Double

In [19]:
df_change=df_with_bins.withColumn("age",df_with_bins['age'].cast("double"))
df_change.show()

+----------+------+----+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|patient_id|   sex| age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|is_male|is_dead|
+----------+------+----+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|1000000001|  male|50.0|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       NULL|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|     13|   True|  false|
|1000000002|  male|30.0|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       NULL|            31|              NULL|    2020-01-30|   2020-03-02|   2020-03-02|released|     32|   True|  false|
|100000000

In [20]:
df_change_days=df_change.withColumn("no_days",df_with_bins['no_days'].cast("double"))
df_change_days.show()

+----------+------+----+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|patient_id|   sex| age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|no_days|is_male|is_dead|
+----------+------+----+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+-------+-------+-------+
|1000000001|  male|50.0|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       NULL|            75|        2020-01-22|    2020-01-23|   2020-02-05|   2020-02-05|released|   13.0|   True|  false|
|1000000002|  male|30.0|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       NULL|            31|              NULL|    2020-01-30|   2020-03-02|   2020-03-02|released|   32.0|   True|  false|
|100000000

### Drop the columns
["patient_id","sex","infected_by","contact_number","released_date","state",
"symptom_onset_date","confirmed_date","deceased_date","country","no_days",
"city","infection_case"]

In [21]:
drop_columns=["patient_id","sex","infected_by","contact_number","released_date","state",
"symptom_onset_date","confirmed_date","deceased_date","country","no_days",
"city","infection_case"]

In [22]:
df_dropped_columns=df_change_days.drop(*drop_columns)
df_dropped_columns.show()

+----+--------+-------+-------+
| age|province|is_male|is_dead|
+----+--------+-------+-------+
|50.0|   Seoul|   True|  false|
|30.0|   Seoul|   True|  false|
|50.0|   Seoul|   True|  false|
|20.0|   Seoul|   True|  false|
|20.0|   Seoul| Female|  false|
|50.0|   Seoul| Female|  false|
|20.0|   Seoul|   True|  false|
|20.0|   Seoul|   True|  false|
|30.0|   Seoul|   True|  false|
|60.0|   Seoul| Female|  false|
|50.0|   Seoul| Female|  false|
|20.0|   Seoul|   True|  false|
|80.0|   Seoul|   True|   true|
|60.0|   Seoul| Female|  false|
|70.0|   Seoul|   True|  false|
|70.0|   Seoul|   True|  false|
|70.0|   Seoul|   True|  false|
|20.0|   Seoul|   True|  false|
|70.0|   Seoul| Female|  false|
|70.0|   Seoul| Female|  false|
+----+--------+-------+-------+
only showing top 20 rows



### Recount the number of nulls now

In [23]:
for col in df_dropped_columns.columns:
  print(col,df_dropped_columns.filter(df_dropped_columns[col].isNull()).count())

age 261
province 0
is_male 0
is_dead 0


## Now do the same but using SQL select statement

### From the original Patient DataFrame, Create a temporary view (table).

In [24]:
veiw_name="patient"
df.createOrReplaceTempView('patient')

### Use SELECT statement to select all columns from the dataframe and show the output.

In [25]:
data_selected=df.select("*")
data_selected.show()

+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|        city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+------------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul|  Gangseo-gu|     overseas inflow|       NULL|            75|        2020-01-22|    2020-01-23|   2020-02-05|         NULL|released|
|1000000002|  male|30s|  Korea|   Seoul| Jungnang-gu|     overseas inflow|       NULL|            31|              NULL|    2020-01-30|   2020-03-02|         NULL|released|
|1000000003|  male|50s|  Korea|   Seoul|   Jongno-gu|contact with patient| 2002000001|            17|              NULL|    2020-01-30|

### *Using SQL commands*, limit the output to only 5 rows

In [26]:
data_selected.limit(5).show()

+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|patient_id|   sex|age|country|province|       city|      infection_case|infected_by|contact_number|symptom_onset_date|confirmed_date|released_date|deceased_date|   state|
+----------+------+---+-------+--------+-----------+--------------------+-----------+--------------+------------------+--------------+-------------+-------------+--------+
|1000000001|  male|50s|  Korea|   Seoul| Gangseo-gu|     overseas inflow|       NULL|            75|        2020-01-22|    2020-01-23|   2020-02-05|         NULL|released|
|1000000002|  male|30s|  Korea|   Seoul|Jungnang-gu|     overseas inflow|       NULL|            31|              NULL|    2020-01-30|   2020-03-02|         NULL|released|
|1000000003|  male|50s|  Korea|   Seoul|  Jongno-gu|contact with patient| 2002000001|            17|              NULL|    2020-01-30|   202

### Select the count of males and females in the dataset

In [27]:
data_count_male_and_female=df.select("sex").groupBy("sex").count()
data_count_male_and_female.show()

+------+-----+
|   sex|count|
+------+-----+
|  NULL| 1122|
|female| 2218|
|  male| 1825|
+------+-----+



### How many people did survive, and how many didn't?

In [28]:
data_survived_1= spark.sql("select state,count(state) as count from patient group by state")
data_survived_1.show()

+--------+-----+
|   state|count|
+--------+-----+
|isolated| 2158|
|released| 2929|
|deceased|   78|
+--------+-----+



### Now, let's perform some preprocessing using SQL:
1. Convert *age* column to double after removing the 's' at the end -- *hint: check SUBSTRING method*
2. Select only the following columns: `['sex', 'age', 'province', 'state']`
3. Store the result of the query in a new dataframe

In [29]:
age_remove_s=spark.sql("select sex ,cast(substring(age,0,2) as double) as age, province, state from patient")
age_remove_s.show()

+------+----+--------+--------+
|   sex| age|province|   state|
+------+----+--------+--------+
|  male|50.0|   Seoul|released|
|  male|30.0|   Seoul|released|
|  male|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|female|20.0|   Seoul|released|
|female|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|30.0|   Seoul|released|
|female|60.0|   Seoul|released|
|female|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|80.0|   Seoul|deceased|
|female|60.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|female|70.0|   Seoul|released|
|female|70.0|   Seoul|released|
+------+----+--------+--------+
only showing top 20 rows



In [30]:
data_clean = age_remove_s.select("sex","age","province","state")
data_clean.show()

+------+----+--------+--------+
|   sex| age|province|   state|
+------+----+--------+--------+
|  male|50.0|   Seoul|released|
|  male|30.0|   Seoul|released|
|  male|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|female|20.0|   Seoul|released|
|female|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|30.0|   Seoul|released|
|female|60.0|   Seoul|released|
|female|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|80.0|   Seoul|deceased|
|female|60.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|female|70.0|   Seoul|released|
|female|70.0|   Seoul|released|
+------+----+--------+--------+
only showing top 20 rows



In [39]:
# Convert to Pandas and save
pandas_df = data_clean.toPandas()
pandas_df.to_csv(r"C:\Users\aliay\OneDrive\Desktop\iTi\31 - Big Data\Mariem\Practical-PySpark\Data\cleaned_data.csv", index=False)

In [41]:
data_clean.distinct().show()

+------+----+-----------------+--------+
|   sex| age|         province|   state|
+------+----+-----------------+--------+
|female|40.0|     Jeollabuk-do|released|
|  male|50.0|            Ulsan|released|
|  male|30.0|      Gyeonggi-do|released|
|female|10.0|      Gyeonggi-do|isolated|
|female|50.0|     Jeollanam-do|released|
|  male|10.0|            Ulsan|released|
|  male|30.0|Chungcheongbuk-do|isolated|
|  male|30.0|     Jeollanam-do|isolated|
|female|10.0| Gyeongsangbuk-do|released|
|female|50.0| Gyeongsangbuk-do|isolated|
|female|30.0|          Jeju-do|released|
|  male|50.0|          Incheon|isolated|
|  male|20.0|          Incheon|released|
|female|60.0|          Daejeon|isolated|
|  male|80.0| Gyeongsangbuk-do|released|
|female|80.0|            Busan|deceased|
|female|70.0|            Daegu|isolated|
|  male|40.0|            Busan|released|
|female|80.0|Chungcheongnam-do|isolated|
|female|20.0|Chungcheongbuk-do|released|
+------+----+-----------------+--------+
only showing top

<B>Machine Learning</B>

In [42]:
# load the file with spark 

df = spark.read.csv("..\Data\cleaned_data.csv",header=True,inferSchema=True)
df.printSchema()    

root
 |-- sex: string (nullable = true)
 |-- age: double (nullable = true)
 |-- province: string (nullable = true)
 |-- state: string (nullable = true)



In [43]:
df.show()

+------+----+--------+--------+
|   sex| age|province|   state|
+------+----+--------+--------+
|  male|50.0|   Seoul|released|
|  male|30.0|   Seoul|released|
|  male|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|female|20.0|   Seoul|released|
|female|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|30.0|   Seoul|released|
|female|60.0|   Seoul|released|
|female|50.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|  male|80.0|   Seoul|deceased|
|female|60.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|70.0|   Seoul|released|
|  male|20.0|   Seoul|released|
|female|70.0|   Seoul|released|
|female|70.0|   Seoul|released|
+------+----+--------+--------+
only showing top 20 rows



In [44]:
df.dtypes

[('sex', 'string'),
 ('age', 'double'),
 ('province', 'string'),
 ('state', 'string')]

In [45]:
df.describe().show()    

+-------+------+-----------------+--------+--------+
|summary|   sex|              age|province|   state|
+-------+------+-----------------+--------+--------+
|  count|  4043|             3719|    5165|    5165|
|   mean|  NULL|41.05942457649906|    NULL|    NULL|
| stddev|  NULL|19.62275925520363|    NULL|    NULL|
|    min|female|             10.0|   Busan|deceased|
|    max|  male|             90.0|   Ulsan|released|
+-------+------+-----------------+--------+--------+



In [46]:
str_columns = [f[0] for f in df.dtypes if f[1] == 'string'] 
str_columns 

['sex', 'province', 'state']

In [47]:
categorical_columns = [str_column + "_Index" for str_column in str_columns  ]
categorical_columns

['sex_Index', 'province_Index', 'state_Index']

In [48]:
ONEHOT_COLUMNS = [str_column + "_OHE" for str_column in str_columns]
ONEHOT_COLUMNS

['sex_OHE', 'province_OHE', 'state_OHE']

In [49]:
from pyspark.ml.feature import StringIndexer
indexer = StringIndexer(inputCols=str_columns, outputCols=categorical_columns, handleInvalid="keep")  

In [50]:
from pyspark.ml.feature import OneHotEncoder
encoder = OneHotEncoder(inputCols=categorical_columns, outputCols=ONEHOT_COLUMNS)

In [51]:
numeric_columns = [f[0] for f in df.dtypes if f[1] != 'string'] 
numeric_columns

['age']

In [52]:
# Impute numeric columns
from pyspark.ml.feature import Imputer
imputer = Imputer(
    inputCols=numeric_columns,
    outputCols=numeric_columns,
)

In [53]:
features = numeric_columns + ONEHOT_COLUMNS 
features

['age', 'sex_OHE', 'province_OHE', 'state_OHE']

In [54]:
# removing The Crew column
features = [f for f in features if f != 'state_OHE']
features

['age', 'sex_OHE', 'province_OHE']

In [55]:
from pyspark.ml.feature import VectorAssembler

vecAssmbler = VectorAssembler(inputCols=features, outputCol="features")
vecAssmbler

VectorAssembler_555130b66db6

In [56]:
# train test split  

train, test = df.randomSplit([0.8, 0.2], seed=42)
print(f"train: {train.count()}, test: {test.count()}")

train: 4166, test: 999


In [63]:
from pyspark.ml.regression import LinearRegression  

LR = LinearRegression(featuresCol="features", labelCol="state_Index", predictionCol="crew_prediction")  

In [64]:
from pyspark.ml import Pipeline 
pip = Pipeline(stages=[indexer, encoder, imputer, vecAssmbler, LR])

In [65]:
pipe_model = pip.fit(train) 

In [68]:
pred_train = pipe_model.transform(train)    
pred_train.show()

+----+-----------------+-----------+--------+---------+--------------+-----------+---------+--------------+-------------+--------------------+------------------+
| sex|              age|   province|   state|sex_Index|province_Index|state_Index|  sex_OHE|  province_OHE|    state_OHE|            features|   crew_prediction|
+----+-----------------+-----------+--------+---------+--------------+-----------+---------+--------------+-------------+--------------------+------------------+
|NULL|40.83025210084033|      Busan|isolated|      2.0|           5.0|        1.0|(2,[],[])|(17,[5],[1.0])|(3,[1],[1.0])|(20,[0,8],[40.830...| 0.591367287433862|
|NULL|40.83025210084033|      Busan|isolated|      2.0|           5.0|        1.0|(2,[],[])|(17,[5],[1.0])|(3,[1],[1.0])|(20,[0,8],[40.830...| 0.591367287433862|
|NULL|40.83025210084033|Gyeonggi-do|isolated|      2.0|           2.0|        1.0|(2,[],[])|(17,[2],[1.0])|(3,[1],[1.0])|(20,[0,5],[40.830...|1.1315206897413983|
|NULL|40.83025210084033|Gyeo

In [70]:
# Evaluation (RMSE and R2)
from pyspark.ml.evaluation import RegressionEvaluator
evaluator = RegressionEvaluator(labelCol="state_Index", predictionCol="crew_prediction", metricName="rmse")
rmse = evaluator.evaluate(pred_train)
print(f"RMSE: {rmse}")

RMSE: 0.38271045710291673


In [71]:
evaluator = RegressionEvaluator(labelCol="state_Index", predictionCol="crew_prediction", metricName="r2")
r2 = evaluator.evaluate(pred_train)
print(f"R2: {r2}")  

R2: 0.47198131993434334


In [72]:
pred_test = pipe_model.transform(test)    
pred_test.show()

+----+-----------------+-----------+--------+---------+--------------+-----------+---------+--------------+-------------+--------------------+------------------+
| sex|              age|   province|   state|sex_Index|province_Index|state_Index|  sex_OHE|  province_OHE|    state_OHE|            features|   crew_prediction|
+----+-----------------+-----------+--------+---------+--------------+-----------+---------+--------------+-------------+--------------------+------------------+
|NULL|40.83025210084033|Gyeonggi-do|isolated|      2.0|           2.0|        1.0|(2,[],[])|(17,[2],[1.0])|(3,[1],[1.0])|(20,[0,5],[40.830...|1.1315206897413983|
|NULL|40.83025210084033|Gyeonggi-do|isolated|      2.0|           2.0|        1.0|(2,[],[])|(17,[2],[1.0])|(3,[1],[1.0])|(20,[0,5],[40.830...|1.1315206897413983|
|NULL|40.83025210084033|Gyeonggi-do|isolated|      2.0|           2.0|        1.0|(2,[],[])|(17,[2],[1.0])|(3,[1],[1.0])|(20,[0,5],[40.830...|1.1315206897413983|
|NULL|40.83025210084033|Gyeo

In [73]:
evaluator = RegressionEvaluator(labelCol="state_Index", predictionCol="crew_prediction", metricName="rmse")
rmse = evaluator.evaluate(pred_test)
print(f"RMSE: {rmse}")


RMSE: 0.37663675162519783


In [74]:
evaluator = RegressionEvaluator(labelCol="state_Index", predictionCol="crew_prediction", metricName="r2")
r2 = evaluator.evaluate(pred_test)
print(f"R2: {r2}")

R2: 0.48923913858721135
