# PySpark RDDs 

On Databricks, on creating a cluster, the SparkSession is created internally. SparkConf, SparkContext or SQLContext, don't have to be called explicitly as they are encapsulated inside SparkSession, accessible through a variabled called spark. 


In [None]:
spark

If this is run through Jupyter notebook or as python wheel file, it needs to be initialised, conventionally like this. Here on databricks, this is not required. 

In [None]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName('Intro').getOrCreate() # app name can be changed, ofc.  


## File handling

### Json

In [None]:
df = spark.read.json("/FileStore/tables/oscar.json",multiLine=True)

In [None]:
df.printSchema()

root
 |-- awards: array (nullable = true)
 |    |-- element: string (containsNull = true)
 |-- director: string (nullable = true)
 |-- genre: string (nullable = true)
 |-- title: string (nullable = true)
 |-- year: long (nullable = true)



In [None]:
df.show()

+--------------------+--------------------+--------------+------------------+----+
|              awards|            director|         genre|             title|year|
+--------------------+--------------------+--------------+------------------+----+
|[Best Picture, Be...|        Bong Joon-ho|Drama/Thriller|          Parasite|2019|
|[Best Picture, Be...|  Guillermo del Toro| Drama/Fantasy|The Shape of Water|2017|
|[Best Picture, Be...|       Barry Jenkins|         Drama|         Moonlight|2016|
|[Best Picture, Be...|        Tom McCarthy|         Drama|         Spotlight|2015|
|[Best Picture, Be...|Alejandro G. Iñár...|  Drama/Comedy|           Birdman|2014|
+--------------------+--------------------+--------------+------------------+----+



In [None]:
df.columns

Out[13]: ['awards', 'director', 'genre', 'title', 'year']

In [None]:
df.describe().show()

+-------+--------------------+--------------+------------------+------------------+
|summary|            director|         genre|             title|              year|
+-------+--------------------+--------------+------------------+------------------+
|  count|                   5|             5|                 5|                 5|
|   mean|                null|          null|              null|            2016.2|
| stddev|                null|          null|              null|1.9235384061671172|
|    min|Alejandro G. Iñár...|         Drama|           Birdman|              2014|
|    max|        Tom McCarthy|Drama/Thriller|The Shape of Water|              2019|
+-------+--------------------+--------------+------------------+------------------+



### Understanding Spark Schemas

In [None]:
from pyspark.sql.types import StructField, StructType, StringType, IntegerType, ArrayType

In [None]:
df_schema = StructType([
    StructField("title", StringType(), True),
    StructField("genre", StringType(), True),
    StructField("director", StringType(), True),
    StructField("year", IntegerType(), True),
    StructField("awards", ArrayType(StringType(), True), True)
])

In [None]:
df = spark.read.json("/FileStore/tables/oscar.json",multiLine=True, schema=df_schema)
df.printSchema()

root
 |-- title: string (nullable = true)
 |-- genre: string (nullable = true)
 |-- director: string (nullable = true)
 |-- year: integer (nullable = true)
 |-- awards: array (nullable = true)
 |    |-- element: string (containsNull = true)



### CSV File handling

In [None]:
df = spark.read.csv("/FileStore/tables/titanic.csv", inferSchema=True,  header=True)


In [None]:
df.printSchema()

root
 |-- PassengerId: integer (nullable = true)
 |-- Survived: integer (nullable = true)
 |-- Pclass: integer (nullable = true)
 |-- Name: string (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- SibSp: integer (nullable = true)
 |-- Parch: integer (nullable = true)
 |-- Ticket: string (nullable = true)
 |-- Fare: double (nullable = true)
 |-- Cabin: string (nullable = true)
 |-- Embarked: string (nullable = true)



## Exploratory analysis on RDD

In [None]:
df.head(5)

Out[8]: [Row(PassengerId=1, Survived=0, Pclass=3, Name='Braund, Mr. Owen Harris', Sex='male', Age=22.0, SibSp=1, Parch=0, Ticket='A/5 21171', Fare=7.25, Cabin=None, Embarked='S'),
 Row(PassengerId=2, Survived=1, Pclass=1, Name='Cumings, Mrs. John Bradley (Florence Briggs Thayer)', Sex='female', Age=38.0, SibSp=1, Parch=0, Ticket='PC 17599', Fare=71.2833, Cabin='C85', Embarked='C'),
 Row(PassengerId=3, Survived=1, Pclass=3, Name='Heikkinen, Miss. Laina', Sex='female', Age=26.0, SibSp=0, Parch=0, Ticket='STON/O2. 3101282', Fare=7.925, Cabin=None, Embarked='S'),
 Row(PassengerId=4, Survived=1, Pclass=1, Name='Futrelle, Mrs. Jacques Heath (Lily May Peel)', Sex='female', Age=35.0, SibSp=1, Parch=0, Ticket='113803', Fare=53.1, Cabin='C123', Embarked='S'),
 Row(PassengerId=5, Survived=0, Pclass=3, Name='Allen, Mr. William Henry', Sex='male', Age=35.0, SibSp=0, Parch=0, Ticket='373450', Fare=8.05, Cabin=None, Embarked='S')]

In [None]:
print((df.count(), len(df.columns)))

(891, 13)


In [None]:
df.select("Age").show(3)

+----+
| Age|
+----+
|22.0|
|38.0|
|26.0|
+----+
only showing top 3 rows



In [None]:
df.createOrReplaceTempView('titanicdb')

In [None]:
spark.sql("select * from titanicdb limit 5").show()

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| null|       S|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| null|       S|
+-----------+--------+------+--------------------+------+----+-----+-----+------

In [None]:
%sql
select * from titanicdb limit 5

PassengerId,Survived,Pclass,Name,Sex,Age,SibSp,Parch,Ticket,Fare,Cabin,Embarked
1,0,3,"Braund, Mr. Owen Harris",male,22.0,1,0,A/5 21171,7.25,,S
2,1,1,"Cumings, Mrs. John Bradley (Florence Briggs Thayer)",female,38.0,1,0,PC 17599,71.2833,C85,C
3,1,3,"Heikkinen, Miss. Laina",female,26.0,0,0,STON/O2. 3101282,7.925,,S
4,1,1,"Futrelle, Mrs. Jacques Heath (Lily May Peel)",female,35.0,1,0,113803,53.1,C123,S
5,0,3,"Allen, Mr. William Henry",male,35.0,0,0,373450,8.05,,S


In [None]:
df.describe().show() # df.describe('Age').show() 

+-------+-----------------+-------------------+------------------+--------------------+------+------------------+------------------+-------------------+------------------+-----------------+-----+--------+
|summary|      PassengerId|           Survived|            Pclass|                Name|   Sex|               Age|             SibSp|              Parch|            Ticket|             Fare|Cabin|Embarked|
+-------+-----------------+-------------------+------------------+--------------------+------+------------------+------------------+-------------------+------------------+-----------------+-----+--------+
|  count|              891|                891|               891|                 891|   891|               714|               891|                891|               891|              891|  204|     889|
|   mean|            446.0| 0.3838383838383838| 2.308641975308642|                null|  null| 29.69911764705882|0.5230078563411896|0.38159371492704824|260318.54916792738| 32.20420

In [None]:
df.filter("Survived==1").select(["Name","Sex", "Age","Pclass","Fare"]).show()

+--------------------+------+----+------+--------+
|                Name|   Sex| Age|Pclass|    Fare|
+--------------------+------+----+------+--------+
|Cumings, Mrs. Joh...|female|38.0|     1| 71.2833|
|Heikkinen, Miss. ...|female|26.0|     3|   7.925|
|Futrelle, Mrs. Ja...|female|35.0|     1|    53.1|
|Johnson, Mrs. Osc...|female|27.0|     3| 11.1333|
|Nasser, Mrs. Nich...|female|14.0|     2| 30.0708|
|Sandstrom, Miss. ...|female| 4.0|     3|    16.7|
|Bonnell, Miss. El...|female|58.0|     1|   26.55|
|Hewlett, Mrs. (Ma...|female|55.0|     2|    16.0|
|Williams, Mr. Cha...|  male|null|     2|    13.0|
|Masselmani, Mrs. ...|female|null|     3|   7.225|
|Beesley, Mr. Lawr...|  male|34.0|     2|    13.0|
|"McGowan, Miss. A...|female|15.0|     3|  8.0292|
|Sloper, Mr. Willi...|  male|28.0|     1|    35.5|
|Asplund, Mrs. Car...|female|38.0|     3| 31.3875|
|"O'Dwyer, Miss. E...|female|null|     3|  7.8792|
|Spencer, Mrs. Wil...|female|null|     1|146.5208|
|Glynn, Miss. Mary...|female|nu

In [None]:
df.filter((df['Age']>50) & (df["SibSp"]==0) & (df['Parch']==0)).count()

Out[41]: 43

In [None]:
df.withColumn("family",(df['SibSp']+df['Parch'])).show()

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|family|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|     1|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|     1|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| null|       S|     0|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|     1|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| null|       S|     0|
|          6|       0|  

In [None]:
df.columns

Out[50]: ['PassengerId',
 'Survived',
 'Pclass',
 'Name',
 'Sex',
 'Age',
 'SibSp',
 'Parch',
 'Ticket',
 'Fare',
 'Cabin',
 'Embarked']

In [None]:
df = df.withColumn("family",(df['SibSp']+df['Parch']))
df.columns

Out[171]: ['PassengerId',
 'Survived',
 'Pclass',
 'Name',
 'Sex',
 'Age',
 'SibSp',
 'Parch',
 'Ticket',
 'Fare',
 'Cabin',
 'Embarked',
 'family']

In [None]:
df=df.withColumnRenamed('family','Family')
df.columns

Out[172]: ['PassengerId',
 'Survived',
 'Pclass',
 'Name',
 'Sex',
 'Age',
 'SibSp',
 'Parch',
 'Ticket',
 'Fare',
 'Cabin',
 'Embarked',
 'Family']

In [None]:
df.groupBy('Survived').sum().show()

+--------+----------------+-------------+-----------+--------+----------+----------+------------------+-----------+
|Survived|sum(PassengerId)|sum(Survived)|sum(Pclass)|sum(Age)|sum(SibSp)|sum(Parch)|         sum(Fare)|sum(Family)|
+--------+----------------+-------------+-----------+--------+----------+----------+------------------+-----------+
|       1|          151974|          342|        667| 8219.67|       162|       159|16551.229399999997|        321|
|       0|          245412|            0|       1390| 12985.5|       304|       181|12142.719899999987|        485|
+--------+----------------+-------------+-----------+--------+----------+----------+------------------+-----------+



In [None]:
df.groupBy('Survived').sum().select(['Survived', 'sum(Age)']).show()

+--------+--------+
|Survived|sum(Age)|
+--------+--------+
|       1| 8219.67|
|       0| 12985.5|
+--------+--------+



In [None]:
df.agg({'Survived':'sum'}).show()

+-------------+
|sum(Survived)|
+-------------+
|          342|
+-------------+



In [None]:
df.agg({'Survived':'sum'}).collect()[0][0]


Out[70]: 342

In [None]:
df.groupBy('Survived').agg({'Pclass':'mean','Survived':'count'}).show()

+--------+------------------+---------------+
|Survived|       avg(Pclass)|count(Survived)|
+--------+------------------+---------------+
|       1|1.9502923976608186|            342|
|       0|2.5318761384335153|            549|
+--------+------------------+---------------+



In [None]:
df.sort(df.Age.desc()).show(10)

+-----------+--------+------+--------------------+----+----+-----+-----+----------+-------+-----+--------+------+
|PassengerId|Survived|Pclass|                Name| Sex| Age|SibSp|Parch|    Ticket|   Fare|Cabin|Embarked|Family|
+-----------+--------+------+--------------------+----+----+-----+-----+----------+-------+-----+--------+------+
|        631|       1|     1|Barkworth, Mr. Al...|male|80.0|    0|    0|     27042|   30.0|  A23|       S|     0|
|        852|       0|     3| Svensson, Mr. Johan|male|74.0|    0|    0|    347060|  7.775| null|       S|     0|
|         97|       0|     1|Goldschmidt, Mr. ...|male|71.0|    0|    0|  PC 17754|34.6542|   A5|       C|     0|
|        494|       0|     1|Artagaveytia, Mr....|male|71.0|    0|    0|  PC 17609|49.5042| null|       C|     0|
|        117|       0|     3|Connors, Mr. Patrick|male|70.5|    0|    0|    370369|   7.75| null|       Q|     0|
|        673|       0|     2|Mitchell, Mr. Hen...|male|70.0|    0|    0|C.A. 24580|   10

### Joins

In [None]:
df2 = spark.read.csv("/FileStore/tables/titanic.csv", inferSchema=True,  header=True)


In [None]:
df2=df2.limit(10)

In [None]:
print((df2.count(), len(df2.columns)))

(10, 12)


In [None]:
df2.join(df,df.Name==df2.Name, how='inner').show() # left_outer,right_outer count 

+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|Family|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|     1|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.

### Missing value imputation

In [None]:
df.count()

Out[107]: 891

In [None]:
from pyspark.sql.functions import when, count, col

In [None]:
df.select([col(c) for c in df.columns]).show()


+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+------+
|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|Family|
+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+------+
|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25| null|       S|     1|
|          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|  C85|       C|     1|
|          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925| null|       S|     0|
|          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1| C123|       S|     1|
|          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05| null|       S|     0|
|          6|       0|  

In [None]:
df.select([when(col(c).isNull(), col(c)) for c in df.columns]).show()


+----------------------------------------------------+----------------------------------------------+------------------------------------------+--------------------------------------+------------------------------------+------------------------------------+----------------------------------------+----------------------------------------+------------------------------------------+--------------------------------------+----------------------------------------+----------------------------------------------+------------------------------------------+
|CASE WHEN (PassengerId IS NULL) THEN PassengerId END|CASE WHEN (Survived IS NULL) THEN Survived END|CASE WHEN (Pclass IS NULL) THEN Pclass END|CASE WHEN (Name IS NULL) THEN Name END|CASE WHEN (Sex IS NULL) THEN Sex END|CASE WHEN (Age IS NULL) THEN Age END|CASE WHEN (SibSp IS NULL) THEN SibSp END|CASE WHEN (Parch IS NULL) THEN Parch END|CASE WHEN (Ticket IS NULL) THEN Ticket END|CASE WHEN (Fare IS NULL) THEN Fare END|CASE WHEN (Cabin IS NULL) 

In [None]:

df.select([count(when(col(c).isNull(), c)).alias(c) for c in df.columns]).show()


+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+------+
|PassengerId|Survived|Pclass|Name|Sex|Age|SibSp|Parch|Ticket|Fare|Cabin|Embarked|Family|
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+------+
|          0|       0|     0|   0|  0|177|    0|    0|     0|   0|  687|       2|     0|
+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+------+



In [None]:
df = df.drop('PassengerId','Name','Ticket','Cabin','Family')

In [None]:
df.columns

Out[139]: ['Survived', 'Pclass', 'Sex', 'Age', 'SibSp', 'Parch', 'Fare', 'Embarked']

In [None]:
from pyspark.ml.feature import Imputer
imputer = Imputer(
    inputCols=['Age'], #specifying the input column names
    outputCols=['Age'], #specifying the output column names
    strategy="mean"                  # or "median" if you want to use the median value
)
df = imputer.fit(df).transform(df)

In [None]:
df.where(df.Age.isNull()).count()

Out[153]: 0

In [None]:
df.groupBy('Embarked').count().show()


+--------+-----+
|Embarked|count|
+--------+-----+
|       Q|   77|
|    null|    2|
|       C|  168|
|       S|  644|
+--------+-----+



In [None]:
df = df.fillna(value='S', subset='Embarked')

In [None]:
df.groupBy('Embarked').count().show()


+--------+-----+
|Embarked|count|
+--------+-----+
|       Q|   77|
|       C|  168|
|       S|  646|
+--------+-----+



### Change datatype 

In [None]:
df.printSchema()

root
 |-- Survived: integer (nullable = true)
 |-- Pclass: integer (nullable = true)
 |-- Sex: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- SibSp: integer (nullable = true)
 |-- Parch: integer (nullable = true)
 |-- Fare: double (nullable = true)
 |-- Embarked: string (nullable = false)



In [None]:
from pyspark.sql.types import IntegerType
df = df.withColumn('Age', col('Age').cast(IntegerType())) # 'int'
df.show(2)

+--------+------+------+---+-----+-----+-------+--------+
|Survived|Pclass|   Sex|Age|SibSp|Parch|   Fare|Embarked|
+--------+------+------+---+-----+-----+-------+--------+
|       0|     3|  male| 22|    1|    0|   7.25|       S|
|       1|     1|female| 38|    1|    0|71.2833|       C|
+--------+------+------+---+-----+-----+-------+--------+
only showing top 2 rows



## Model Building 

In [None]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder

### Label Encoding

In [None]:
s_indexer = StringIndexer(inputCols=['Sex','Embarked'],outputCols=['Sex_index','Embarked_index'])
df = s_indexer.fit(df).transform(df)
df = df.drop('Sex','Embarked')

In [None]:
df.show()

+--------+------+---+-----+-----+-------+---------+--------------+
|Survived|Pclass|Age|SibSp|Parch|   Fare|Sex_index|Embarked_index|
+--------+------+---+-----+-----+-------+---------+--------------+
|       0|     3| 22|    1|    0|   7.25|      0.0|           0.0|
|       1|     1| 38|    1|    0|71.2833|      1.0|           1.0|
|       1|     3| 26|    0|    0|  7.925|      1.0|           0.0|
|       1|     1| 35|    1|    0|   53.1|      1.0|           0.0|
|       0|     3| 35|    0|    0|   8.05|      0.0|           0.0|
|       0|     3| 29|    0|    0| 8.4583|      0.0|           2.0|
|       0|     1| 54|    0|    0|51.8625|      0.0|           0.0|
|       0|     3|  2|    3|    1| 21.075|      0.0|           0.0|
|       1|     3| 27|    0|    2|11.1333|      1.0|           0.0|
|       1|     2| 14|    1|    0|30.0708|      1.0|           1.0|
|       1|     3|  4|    1|    1|   16.7|      1.0|           0.0|
|       1|     1| 58|    0|    0|  26.55|      1.0|           

### OneHotEncoding

In [None]:
oh_encoder = OneHotEncoder(inputCol='Embarked_index',outputCol='EmbarkedVec')
df = oh_encoder.fit(df).transform(df)

Out[204]: DataFrame[Survived: int, Pclass: int, Age: int, SibSp: int, Parch: int, Fare: double, Sex_index: double, Embarked: vector]

In [None]:
df = df.drop('Embarked_index')

In [None]:
df.show()

+--------+------+---+-----+-----+-------+---------+-------------+
|Survived|Pclass|Age|SibSp|Parch|   Fare|Sex_index|  EmbarkedVec|
+--------+------+---+-----+-----+-------+---------+-------------+
|       0|     3| 22|    1|    0|   7.25|      0.0|(2,[0],[1.0])|
|       1|     1| 38|    1|    0|71.2833|      1.0|(2,[1],[1.0])|
|       1|     3| 26|    0|    0|  7.925|      1.0|(2,[0],[1.0])|
|       1|     1| 35|    1|    0|   53.1|      1.0|(2,[0],[1.0])|
|       0|     3| 35|    0|    0|   8.05|      0.0|(2,[0],[1.0])|
|       0|     3| 29|    0|    0| 8.4583|      0.0|    (2,[],[])|
|       0|     1| 54|    0|    0|51.8625|      0.0|(2,[0],[1.0])|
|       0|     3|  2|    3|    1| 21.075|      0.0|(2,[0],[1.0])|
|       1|     3| 27|    0|    2|11.1333|      1.0|(2,[0],[1.0])|
|       1|     2| 14|    1|    0|30.0708|      1.0|(2,[1],[1.0])|
|       1|     3|  4|    1|    1|   16.7|      1.0|(2,[0],[1.0])|
|       1|     1| 58|    0|    0|  26.55|      1.0|(2,[0],[1.0])|
|       0|

### Vector creation

Create a single vector of all independant variables

In [None]:
df.columns

Out[209]: ['Survived',
 'Pclass',
 'Age',
 'SibSp',
 'Parch',
 'Fare',
 'Sex_index',
 'EmbarkedVec']

In [None]:
from pyspark.ml.feature import VectorAssembler
df_assembler = VectorAssembler(inputCols=df.columns[1:], outputCol='features')
df_vec = df_assembler.transform(df)
df_vec.printSchema()

root
 |-- Survived: integer (nullable = true)
 |-- Pclass: integer (nullable = true)
 |-- Age: integer (nullable = true)
 |-- SibSp: integer (nullable = true)
 |-- Parch: integer (nullable = true)
 |-- Fare: double (nullable = true)
 |-- Sex_index: double (nullable = false)
 |-- EmbarkedVec: vector (nullable = true)
 |-- features: vector (nullable = true)



In [None]:
df_vec = df_vec.select('features','Survived')
df_vec.show()

+--------------------+--------+
|            features|Survived|
+--------------------+--------+
|[3.0,22.0,1.0,0.0...|       0|
|[1.0,38.0,1.0,0.0...|       1|
|[3.0,26.0,0.0,0.0...|       1|
|[1.0,35.0,1.0,0.0...|       1|
|(8,[0,1,4,6],[3.0...|       0|
|(8,[0,1,4],[3.0,2...|       0|
|(8,[0,1,4,6],[1.0...|       0|
|[3.0,2.0,3.0,1.0,...|       0|
|[3.0,27.0,0.0,2.0...|       1|
|[2.0,14.0,1.0,0.0...|       1|
|[3.0,4.0,1.0,1.0,...|       1|
|[1.0,58.0,0.0,0.0...|       1|
|(8,[0,1,4,6],[3.0...|       0|
|[3.0,39.0,1.0,5.0...|       0|
|[3.0,14.0,0.0,0.0...|       0|
|[2.0,55.0,0.0,0.0...|       1|
|[3.0,2.0,4.0,1.0,...|       0|
|(8,[0,1,4,6],[2.0...|       1|
|[3.0,31.0,1.0,0.0...|       0|
|[3.0,29.0,0.0,0.0...|       1|
+--------------------+--------+
only showing top 20 rows



### Train Test split 

In [None]:
train_df,test_df = df_vec.randomSplit([0.75,0.25])

In [None]:
train_df.groupby('survived').per().show()

+--------+-----+
|survived|count|
+--------+-----+
|       1|  259|
|       0|  399|
+--------+-----+



In [None]:
test_df.groupby('survived').count().show()


+--------+-----+
|survived|count|
+--------+-----+
|       1|   83|
|       0|  150|
+--------+-----+



### Scaling

Computing summary statistics, scale data to have unit standard deviation and/or zero mean. Can be applied on Vector column. 

In [None]:
from pyspark.ml.feature import StandardScaler

sscaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", withStd=True, withMean=False)

scalerModel = sscaler.fit(train_df)

train_df = scalerModel.transform(train_df)
train_df = train_df.drop('features').withColumnRenamed('scaledFeatures','features')
train_df.show()

+--------------------+--------+
|      scaledFeatures|Survived|
+--------------------+--------+
|(8,[0,1,2,4],[1.1...|       0|
|(8,[0,1,2,4],[3.5...|       0|
|(8,[0,1,2,4],[3.5...|       1|
|(8,[0,1,4],[2.383...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       1|
+--------------------+--------+
only showing top 20 rows



In [None]:
train_df.show()

+--------------------+--------+
|      scaledFeatures|Survived|
+--------------------+--------+
|(8,[0,1,2,4],[1.1...|       0|
|(8,[0,1,2,4],[3.5...|       0|
|(8,[0,1,2,4],[3.5...|       1|
|(8,[0,1,4],[2.383...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       1|
+--------------------+--------+
only showing top 20 rows



In [None]:
test_df = scalerModel.transform(test_df)
test_df = test_df.drop('features').withColumnRenamed('scaledFeatures','features')
test_df.show()

+--------------------+--------+
|      scaledFeatures|Survived|
+--------------------+--------+
|(8,[0,1,2,4],[3.5...|       0|
|(8,[0,1,2,4],[3.5...|       0|
|(8,[0,1,2,4],[3.5...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       1|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4],[3.574...|       0|
|(8,[0,1,4,5],[3.5...|       1|
|(8,[0,1,4,5],[3.5...|       1|
|(8,[0,1,4,5],[3.5...|       0|
|(8,[0,1,4,5],[3.5...|       1|
|(8,[0,1,4,5],[3.5...|       1|
|(8,[0,1,4,5],[3.5...|       0|
|(8,[0,1,4,6],[1.1...|       0|
|(8,[0,1,4,6],[1.1...|       0|
|(8,[0,1,4,6],[1.1...|       1|
|(8,[0,1,4,6],[1.1...|       0|
+--------------------+--------+
only showing top 20 rows



## Model 

### Training

In [None]:
from pyspark.ml.classification import LogisticRegression

In [None]:
lr_model = LogisticRegression(labelCol="Survived")
lr_model = lr_model.fit(train_df)

In [None]:
lr_model_summary = lr_model.summary
lr_model_summary.predictions.show()

+--------------------+--------+--------------------+--------------------+----------+
|            features|Survived|       rawPrediction|         probability|prediction|
+--------------------+--------+--------------------+--------------------+----------+
|(8,[0,1,2,4],[1.1...|     0.0|[0.57556710846881...|[0.64004676147655...|       0.0|
|(8,[0,1,2,4],[3.5...|     0.0|[2.07185293333972...|[0.88813718193298...|       0.0|
|(8,[0,1,2,4],[3.5...|     1.0|[2.32983919982898...|[0.91131834218173...|       0.0|
|(8,[0,1,4],[2.383...|     0.0|[1.70685383440628...|[0.84642776667122...|       0.0|
|(8,[0,1,4],[3.574...|     0.0|[1.47716648728009...|[0.81414421381632...|       0.0|
|(8,[0,1,4],[3.574...|     0.0|[1.54247647780634...|[0.82382444387736...|       0.0|
|(8,[0,1,4],[3.574...|     0.0|[1.67817144274552...|[0.84266224814588...|       0.0|
|(8,[0,1,4],[3.574...|     0.0|[1.81617764792071...|[0.86010684267618...|       0.0|
|(8,[0,1,4],[3.574...|     0.0|[1.81593999291999...|[0.8600782448

### Inferencing

In [None]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator,MulticlassClassificationEvaluator

In [None]:
lr_model_summary.predictions.describe().show()

+-------+-------------------+-------------------+
|summary|           Survived|         prediction|
+-------+-------------------+-------------------+
|  count|                658|                658|
|   mean|0.39361702127659576|0.36322188449848025|
| stddev|0.48892325937786185|0.48129387005540825|
|    min|                0.0|                0.0|
|    max|                1.0|                1.0|
+-------+-------------------+-------------------+



In [None]:
lr_preds=lr_model.transform(test_df)
lr_preds.show()

+--------------------+--------+--------------------+--------------------+----------+
|            features|Survived|       rawPrediction|         probability|prediction|
+--------------------+--------+--------------------+--------------------+----------+
|(8,[0,1,2,4],[3.5...|       0|[2.09193827963831...|[0.89011714947554...|       0.0|
|(8,[0,1,2,4],[3.5...|       0|[2.07185293333972...|[0.88813718193298...|       0.0|
|(8,[0,1,2,4],[3.5...|       0|[2.04943509521291...|[0.88589052578873...|       0.0|
|(8,[0,1,4],[3.574...|       0|[1.81386666685046...|[0.85982854688398...|       0.0|
|(8,[0,1,4],[3.574...|       0|[1.81386666685046...|[0.85982854688398...|       0.0|
|(8,[0,1,4],[3.574...|       0|[1.81386666685046...|[0.85982854688398...|       0.0|
|(8,[0,1,4],[3.574...|       1|[1.81386666685046...|[0.85982854688398...|       0.0|
|(8,[0,1,4],[3.574...|       0|[1.91565421799764...|[0.87165303831932...|       0.0|
|(8,[0,1,4],[3.574...|       0|[2.18708768772347...|[0.8990839735

In [None]:
eval = BinaryClassificationEvaluator(rawPredictionCol = "prediction", labelCol = "Survived")
auc = eval.evaluate(lr_preds)
print(auc)

0.8261847389558233


In [None]:
# Accuracy, Precision, and Recall
multi_evaluator = MulticlassClassificationEvaluator(labelCol="Survived", predictionCol="prediction")
accuracy = multi_evaluator.evaluate(lr_preds, {multi_evaluator.metricName: "accuracy"})
precision = multi_evaluator.evaluate(lr_preds, {multi_evaluator.metricName: "weightedPrecision"})
recall = multi_evaluator.evaluate(lr_preds, {multi_evaluator.metricName: "weightedRecall"})

print(f"AUC-ROC: {auc:.4f}")
print(f"Accuracy: {accuracy:.4f}")
print(f"Precision: {precision:.4f}")
print(f"Recall: {recall:.4f}")

AUC-ROC: 0.8262
Accuracy: 0.8455
Precision: 0.8442
Recall: 0.8455
