In [1]:
from pyspark.sql import SparkSession
from pyspark.ml import Pipeline
import pyspark.sql.functions as F
from pyspark.sql.types import DoubleType, StringType, StructType, StructField
from pyspark.ml.feature import StringIndexer, VectorAssembler, QuantileDiscretizer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark import SparkContext
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.linalg import Vectors, VectorUDT
import pandas as pd
import numpy as np
from sklearn import metrics 
import matplotlib.pyplot as plt
import seaborn as sns

In [2]:
spark = SparkSession.builder.appName("Titanic data ").getOrCreate()


In [3]:
data = spark.read.csv("/FileStore/tables/Titanic1.csv", header="true", inferSchema="true")
display(data)

last,first,gender,age,Pclass,fare,embarked,survived
Braund,Mr. Owen Harris,M,22.0,3,7.25,Southampton,no
Cumings,Mrs. John Bradley (Florence Briggs Thayer),F,38.0,1,71.2833,Cherbourg,yes
Heikkinen,Miss Laina,F,26.0,3,7.925,Southampton,yes
Futrelle,Mrs. Jacques Heath (Lily May Peel),F,35.0,1,53.1,Southampton,yes
Allen,Mr. William Henry,M,35.0,3,8.05,Southampton,no
Moran,Mr. James,M,,3,8.4583,Queenstown,no
McCarthy,Mr. Timothy J,M,54.0,1,51.8625,Southampton,no
Palsson,Master Gosta Leonard,M,2.0,3,21.075,Southampton,no
Johnson,Mrs. Oscar W (Elisabeth Vilhelmina Berg),F,27.0,3,11.1333,Southampton,yes
Nasser,Mrs. Nicholas (Adele Achem),F,14.0,2,30.0708,Cherbourg,yes


In [4]:
data.count()

In [5]:
data.createOrReplaceTempView("data")

In [6]:
empty_columns =[]
for col_name in data.columns:
    empty_values = data.where(F.col(col_name).isNull()).count()
    if(empty_values > 0):
        empty_columns.append((col_name, empty_values))
print(empty_columns)

In [7]:
round(spark.sql("SELECT AVG(age) FROM data").collect()[0][0])


In [8]:
data = data.fillna(30, subset=['age'])


In [9]:
labels = [StringIndexer(inputCol=column, outputCol=column+"_new").fit(data) for column in ["gender","embarked","survived"]]

pipeline = Pipeline(stages=labels)
data = pipeline.fit(data).transform(data)

data.show()

In [10]:
data = data.drop('gender','last','first','fare', 'embarked', 'survived')

In [11]:
# Survived with respect to gender
display(data)

age,Pclass,gender_new,embarked_new,survived_new
22.0,3,0.0,0.0,0.0
38.0,1,1.0,1.0,1.0
26.0,3,1.0,0.0,1.0
35.0,1,1.0,0.0,1.0
35.0,3,0.0,0.0,0.0
30.0,3,0.0,2.0,0.0
54.0,1,0.0,0.0,0.0
2.0,3,0.0,0.0,0.0
27.0,3,1.0,0.0,1.0
14.0,2,1.0,1.0,1.0


In [12]:
data.groupBy('survived_new','gender_new').count().show()

In [13]:
#Survive with respect to Pclass 
display(data)

age,Pclass,gender_new,embarked_new,survived_new
22.0,3,0.0,0.0,0.0
38.0,1,1.0,1.0,1.0
26.0,3,1.0,0.0,1.0
35.0,1,1.0,0.0,1.0
35.0,3,0.0,0.0,0.0
30.0,3,0.0,2.0,0.0
54.0,1,0.0,0.0,0.0
2.0,3,0.0,0.0,0.0
27.0,3,1.0,0.0,1.0
14.0,2,1.0,1.0,1.0


In [14]:
data.groupBy('Pclass', 'survived_new','gender_new').count().show()

In [15]:
# Survive with respect to age
display(data)

age,Pclass,gender_new,embarked_new,survived_new
22.0,3,0.0,0.0,0.0
38.0,1,1.0,1.0,1.0
26.0,3,1.0,0.0,1.0
35.0,1,1.0,0.0,1.0
35.0,3,0.0,0.0,0.0
30.0,3,0.0,2.0,0.0
54.0,1,0.0,0.0,0.0
2.0,3,0.0,0.0,0.0
27.0,3,1.0,0.0,1.0
14.0,2,1.0,1.0,1.0


In [16]:
# Survive with respect to embarked 
display(data)

age,Pclass,gender_new,embarked_new,survived_new
22.0,3,0.0,0.0,0.0
38.0,1,1.0,1.0,1.0
26.0,3,1.0,0.0,1.0
35.0,1,1.0,0.0,1.0
35.0,3,0.0,0.0,0.0
30.0,3,0.0,2.0,0.0
54.0,1,0.0,0.0,0.0
2.0,3,0.0,0.0,0.0
27.0,3,1.0,0.0,1.0
14.0,2,1.0,1.0,1.0


In [17]:
data.groupBy('gender_new').count().show()

In [18]:
data.groupBy('embarked_new','survived_new').count().show()

In [19]:
assembler = VectorAssembler(inputCols=["age","Pclass","gender_new","embarked_new"],outputCol="features")
X_assembler = assembler.transform(data)
X_assembler.show()

In [20]:
X_train, X_test = X_assembler.randomSplit([0.8, 0.2])

In [21]:
rf = RandomForestClassifier(featuresCol = 'features', labelCol = 'survived_new')
rfModel = rf.fit(X_train)
predictions = rfModel.transform(X_test)
predictions.select("prediction", "survived_new", "features").show()

In [22]:
predictions.count()

In [23]:
evaluator = MulticlassClassificationEvaluator(labelCol="survived_new", predictionCol="prediction", metricName="accuracy")
print("Accuracy : " + str(evaluator.evaluate(predictions)))

In [24]:
display(predictions)

age,Pclass,gender_new,embarked_new,survived_new,features,rawPrediction,probability,prediction
1.0,2,0.0,1.0,1.0,"List(1, 4, List(), List(1.0, 2.0, 0.0, 1.0))","List(1, 2, List(), List(3.586914174380941, 16.41308582561906))","List(1, 2, List(), List(0.17934570871904704, 0.820654291280953))",1.0
2.0,1,1.0,0.0,0.0,"List(1, 4, List(), List(2.0, 1.0, 1.0, 0.0))","List(1, 2, List(), List(1.8682934147907233, 18.131706585209276))","List(1, 2, List(), List(0.09341467073953616, 0.9065853292604638))",1.0
2.0,3,1.0,0.0,0.0,"List(1, 4, List(), List(2.0, 3.0, 1.0, 0.0))","List(1, 2, List(), List(9.584954616343209, 10.415045383656791))","List(1, 2, List(), List(0.47924773081716043, 0.5207522691828396))",1.0
4.0,1,0.0,0.0,1.0,"List(1, 4, List(), List(4.0, 1.0, 0.0, 0.0))","List(1, 2, List(), List(3.4381749190904434, 16.561825080909557))","List(1, 2, List(), List(0.17190874595452216, 0.8280912540454779))",1.0
4.0,3,0.0,0.0,0.0,"List(1, 4, List(), List(4.0, 3.0, 0.0, 0.0))","List(1, 2, List(), List(10.158222140745652, 9.841777859254348))","List(1, 2, List(), List(0.5079111070372826, 0.4920888929627174))",0.0
4.0,3,1.0,1.0,1.0,"List(1, 4, List(), List(4.0, 3.0, 1.0, 1.0))","List(1, 2, List(), List(5.688561440890413, 14.31143855910959))","List(1, 2, List(), List(0.2844280720445206, 0.7155719279554793))",1.0
5.0,3,1.0,1.0,1.0,"List(1, 4, List(), List(5.0, 3.0, 1.0, 1.0))","List(1, 2, List(), List(5.688561440890413, 14.31143855910959))","List(1, 2, List(), List(0.2844280720445206, 0.7155719279554793))",1.0
6.0,3,0.0,0.0,1.0,"List(1, 4, List(), List(6.0, 3.0, 0.0, 0.0))","List(1, 2, List(), List(10.158222140745652, 9.841777859254348))","List(1, 2, List(), List(0.5079111070372826, 0.4920888929627174))",0.0
7.0,2,1.0,0.0,1.0,"List(1, 4, List(), List(7.0, 2.0, 1.0, 0.0))","List(1, 2, List(), List(1.9016267481240567, 18.09837325187594))","List(1, 2, List(), List(0.09508133740620285, 0.9049186625937972))",1.0
9.0,3,0.0,0.0,1.0,"List(1, 4, List(), List(9.0, 3.0, 0.0, 0.0))","List(1, 2, List(), List(16.590359410857886, 3.409640589142113))","List(1, 2, List(), List(0.8295179705428943, 0.17048202945710564))",0.0


In [25]:
display(predictions)

age,Pclass,gender_new,embarked_new,survived_new,features,rawPrediction,probability,prediction
1.0,2,0.0,1.0,1.0,"List(1, 4, List(), List(1.0, 2.0, 0.0, 1.0))","List(1, 2, List(), List(3.586914174380941, 16.41308582561906))","List(1, 2, List(), List(0.17934570871904704, 0.820654291280953))",1.0
2.0,1,1.0,0.0,0.0,"List(1, 4, List(), List(2.0, 1.0, 1.0, 0.0))","List(1, 2, List(), List(1.8682934147907233, 18.131706585209276))","List(1, 2, List(), List(0.09341467073953616, 0.9065853292604638))",1.0
2.0,3,1.0,0.0,0.0,"List(1, 4, List(), List(2.0, 3.0, 1.0, 0.0))","List(1, 2, List(), List(9.584954616343209, 10.415045383656791))","List(1, 2, List(), List(0.47924773081716043, 0.5207522691828396))",1.0
4.0,1,0.0,0.0,1.0,"List(1, 4, List(), List(4.0, 1.0, 0.0, 0.0))","List(1, 2, List(), List(3.4381749190904434, 16.561825080909557))","List(1, 2, List(), List(0.17190874595452216, 0.8280912540454779))",1.0
4.0,3,0.0,0.0,0.0,"List(1, 4, List(), List(4.0, 3.0, 0.0, 0.0))","List(1, 2, List(), List(10.158222140745652, 9.841777859254348))","List(1, 2, List(), List(0.5079111070372826, 0.4920888929627174))",0.0
4.0,3,1.0,1.0,1.0,"List(1, 4, List(), List(4.0, 3.0, 1.0, 1.0))","List(1, 2, List(), List(5.688561440890413, 14.31143855910959))","List(1, 2, List(), List(0.2844280720445206, 0.7155719279554793))",1.0
5.0,3,1.0,1.0,1.0,"List(1, 4, List(), List(5.0, 3.0, 1.0, 1.0))","List(1, 2, List(), List(5.688561440890413, 14.31143855910959))","List(1, 2, List(), List(0.2844280720445206, 0.7155719279554793))",1.0
6.0,3,0.0,0.0,1.0,"List(1, 4, List(), List(6.0, 3.0, 0.0, 0.0))","List(1, 2, List(), List(10.158222140745652, 9.841777859254348))","List(1, 2, List(), List(0.5079111070372826, 0.4920888929627174))",0.0
7.0,2,1.0,0.0,1.0,"List(1, 4, List(), List(7.0, 2.0, 1.0, 0.0))","List(1, 2, List(), List(1.9016267481240567, 18.09837325187594))","List(1, 2, List(), List(0.09508133740620285, 0.9049186625937972))",1.0
9.0,3,0.0,0.0,1.0,"List(1, 4, List(), List(9.0, 3.0, 0.0, 0.0))","List(1, 2, List(), List(16.590359410857886, 3.409640589142113))","List(1, 2, List(), List(0.8295179705428943, 0.17048202945710564))",0.0


In [26]:
display(predictions)

age,Pclass,gender_new,embarked_new,survived_new,features,rawPrediction,probability,prediction
1.0,2,0.0,1.0,1.0,"List(1, 4, List(), List(1.0, 2.0, 0.0, 1.0))","List(1, 2, List(), List(3.586914174380941, 16.41308582561906))","List(1, 2, List(), List(0.17934570871904704, 0.820654291280953))",1.0
2.0,1,1.0,0.0,0.0,"List(1, 4, List(), List(2.0, 1.0, 1.0, 0.0))","List(1, 2, List(), List(1.8682934147907233, 18.131706585209276))","List(1, 2, List(), List(0.09341467073953616, 0.9065853292604638))",1.0
2.0,3,1.0,0.0,0.0,"List(1, 4, List(), List(2.0, 3.0, 1.0, 0.0))","List(1, 2, List(), List(9.584954616343209, 10.415045383656791))","List(1, 2, List(), List(0.47924773081716043, 0.5207522691828396))",1.0
4.0,1,0.0,0.0,1.0,"List(1, 4, List(), List(4.0, 1.0, 0.0, 0.0))","List(1, 2, List(), List(3.4381749190904434, 16.561825080909557))","List(1, 2, List(), List(0.17190874595452216, 0.8280912540454779))",1.0
4.0,3,0.0,0.0,0.0,"List(1, 4, List(), List(4.0, 3.0, 0.0, 0.0))","List(1, 2, List(), List(10.158222140745652, 9.841777859254348))","List(1, 2, List(), List(0.5079111070372826, 0.4920888929627174))",0.0
4.0,3,1.0,1.0,1.0,"List(1, 4, List(), List(4.0, 3.0, 1.0, 1.0))","List(1, 2, List(), List(5.688561440890413, 14.31143855910959))","List(1, 2, List(), List(0.2844280720445206, 0.7155719279554793))",1.0
5.0,3,1.0,1.0,1.0,"List(1, 4, List(), List(5.0, 3.0, 1.0, 1.0))","List(1, 2, List(), List(5.688561440890413, 14.31143855910959))","List(1, 2, List(), List(0.2844280720445206, 0.7155719279554793))",1.0
6.0,3,0.0,0.0,1.0,"List(1, 4, List(), List(6.0, 3.0, 0.0, 0.0))","List(1, 2, List(), List(10.158222140745652, 9.841777859254348))","List(1, 2, List(), List(0.5079111070372826, 0.4920888929627174))",0.0
7.0,2,1.0,0.0,1.0,"List(1, 4, List(), List(7.0, 2.0, 1.0, 0.0))","List(1, 2, List(), List(1.9016267481240567, 18.09837325187594))","List(1, 2, List(), List(0.09508133740620285, 0.9049186625937972))",1.0
9.0,3,0.0,0.0,1.0,"List(1, 4, List(), List(9.0, 3.0, 0.0, 0.0))","List(1, 2, List(), List(16.590359410857886, 3.409640589142113))","List(1, 2, List(), List(0.8295179705428943, 0.17048202945710564))",0.0


In [27]:
display(predictions)

age,Pclass,gender_new,embarked_new,survived_new,features,rawPrediction,probability,prediction
1.0,2,0.0,1.0,1.0,"List(1, 4, List(), List(1.0, 2.0, 0.0, 1.0))","List(1, 2, List(), List(3.586914174380941, 16.41308582561906))","List(1, 2, List(), List(0.17934570871904704, 0.820654291280953))",1.0
2.0,1,1.0,0.0,0.0,"List(1, 4, List(), List(2.0, 1.0, 1.0, 0.0))","List(1, 2, List(), List(1.8682934147907233, 18.131706585209276))","List(1, 2, List(), List(0.09341467073953616, 0.9065853292604638))",1.0
2.0,3,1.0,0.0,0.0,"List(1, 4, List(), List(2.0, 3.0, 1.0, 0.0))","List(1, 2, List(), List(9.584954616343209, 10.415045383656791))","List(1, 2, List(), List(0.47924773081716043, 0.5207522691828396))",1.0
4.0,1,0.0,0.0,1.0,"List(1, 4, List(), List(4.0, 1.0, 0.0, 0.0))","List(1, 2, List(), List(3.4381749190904434, 16.561825080909557))","List(1, 2, List(), List(0.17190874595452216, 0.8280912540454779))",1.0
4.0,3,0.0,0.0,0.0,"List(1, 4, List(), List(4.0, 3.0, 0.0, 0.0))","List(1, 2, List(), List(10.158222140745652, 9.841777859254348))","List(1, 2, List(), List(0.5079111070372826, 0.4920888929627174))",0.0
4.0,3,1.0,1.0,1.0,"List(1, 4, List(), List(4.0, 3.0, 1.0, 1.0))","List(1, 2, List(), List(5.688561440890413, 14.31143855910959))","List(1, 2, List(), List(0.2844280720445206, 0.7155719279554793))",1.0
5.0,3,1.0,1.0,1.0,"List(1, 4, List(), List(5.0, 3.0, 1.0, 1.0))","List(1, 2, List(), List(5.688561440890413, 14.31143855910959))","List(1, 2, List(), List(0.2844280720445206, 0.7155719279554793))",1.0
6.0,3,0.0,0.0,1.0,"List(1, 4, List(), List(6.0, 3.0, 0.0, 0.0))","List(1, 2, List(), List(10.158222140745652, 9.841777859254348))","List(1, 2, List(), List(0.5079111070372826, 0.4920888929627174))",0.0
7.0,2,1.0,0.0,1.0,"List(1, 4, List(), List(7.0, 2.0, 1.0, 0.0))","List(1, 2, List(), List(1.9016267481240567, 18.09837325187594))","List(1, 2, List(), List(0.09508133740620285, 0.9049186625937972))",1.0
9.0,3,0.0,0.0,1.0,"List(1, 4, List(), List(9.0, 3.0, 0.0, 0.0))","List(1, 2, List(), List(16.590359410857886, 3.409640589142113))","List(1, 2, List(), List(0.8295179705428943, 0.17048202945710564))",0.0


In [28]:
from pyspark.ml.classification import GBTClassifier
gbt = GBTClassifier(labelCol="survived_new", featuresCol="features",maxIter=10)
gbt_model = gbt.fit(X_train)
gbt_predictions = gbt_model.transform(X_test)
gbt_predictions.select("prediction", "survived_new", "features").show()


In [29]:
evaluator1 = MulticlassClassificationEvaluator(labelCol="survived_new", predictionCol="prediction", metricName="accuracy")
print("Accuracy : " + str(evaluator1.evaluate(gbt_predictions)))

In [30]:
from pyspark.ml.classification import NaiveBayes
nb = NaiveBayes(labelCol="survived_new", featuresCol="features")
nb_model = nb.fit(X_train)
nb_predictions = nb_model.transform(X_test)
nb_predictions.select("prediction", "survived_new", "features").show()


In [31]:
evaluator2 = MulticlassClassificationEvaluator(labelCol="survived_new", predictionCol="prediction", metricName="accuracy")
print("Accuracy : " + str(evaluator2.evaluate(nb_predictions)))

In [32]:
from pyspark.ml.classification import LinearSVC
svm = LinearSVC(maxIter = 10 ,labelCol="survived_new", featuresCol="features")
svm_model = nb.fit(X_train)
svm_predictions = nb_model.transform(X_test)
svm_predictions.select("prediction", "survived_new", "features").show()

In [33]:
evaluator3 = MulticlassClassificationEvaluator(labelCol="survived_new", predictionCol="prediction", metricName="accuracy")
print("Accuracy : " + str(evaluator3.evaluate(svm_predictions)))