In [None]:
!apt-get install openjdk-11-jdk-headless -qq
!wget -q https://dlcdn.apache.org/spark/spark-3.2.1/spark-3.2.1-bin-hadoop3.2.tgz
!tar -xf spark-3.2.1-bin-hadoop3.2.tgz
!pip install -q findspark

import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-11-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.2.1-bin-hadoop3.2"

import findspark
findspark.init()
from pyspark.sql import SparkSession
spark = SparkSession.builder.master("local[*]").getOrCreate()

In [None]:
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.regression import LinearRegression

data = spark.read.csv('titanic.csv',inferSchema=True, header =True)
data.printSchema()

root
 |-- PassengerId: integer (nullable = true)
 |-- Survived: integer (nullable = true)
 |-- Pclass: integer (nullable = true)
 |-- Name: string (nullable = true)
 |-- Gender: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- SibSp: integer (nullable = true)
 |-- Parch: integer (nullable = true)
 |-- Ticket: string (nullable = true)
 |-- Fare: double (nullable = true)
 |-- Cabin: string (nullable = true)
 |-- Embarked: string (nullable = true)



In [None]:
data = data.select(['Survived', 'Pclass', 'Gender', 'Age', 'SibSp', 'Parch', 'Fare'])

In [None]:
from pyspark.ml.feature import Imputer
imputer = Imputer(strategy='mean', inputCols=['Age'], outputCols=['AgeImputed'])
imputer_model = imputer.fit(data)
data = imputer_model.transform(data)

In [None]:
from pyspark.ml.feature import StringIndexer
gender_indexer = StringIndexer(inputCol='Gender', outputCol='GenderIndexed')
gender_indexer_model = gender_indexer.fit(data)
data = gender_indexer_model.transform(data)

In [None]:
data.printSchema()

root
 |-- Survived: integer (nullable = true)
 |-- Pclass: integer (nullable = true)
 |-- Gender: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- SibSp: integer (nullable = true)
 |-- Parch: integer (nullable = true)
 |-- Fare: double (nullable = true)
 |-- AgeImputed: double (nullable = true)
 |-- GenderIndexed: double (nullable = false)



In [None]:
from pyspark.ml.feature import VectorAssembler
assembler = VectorAssembler(inputCols=['Pclass', 'SibSp', 'Parch', 'Fare', 'AgeImputed', 'GenderIndexed'], outputCol='features')
output = assembler.transform(data)
finalized_data = output.select("features", "Survived")
finalized_data.show()


+--------------------+--------+
|            features|Survived|
+--------------------+--------+
|[3.0,1.0,0.0,7.25...|       0|
|[1.0,1.0,0.0,71.2...|       1|
|[3.0,0.0,0.0,7.92...|       1|
|[1.0,1.0,0.0,53.1...|       1|
|[3.0,0.0,0.0,8.05...|       0|
|[3.0,0.0,0.0,8.45...|       0|
|[1.0,0.0,0.0,51.8...|       0|
|[3.0,3.0,1.0,21.0...|       0|
|[3.0,0.0,2.0,11.1...|       1|
|[2.0,1.0,0.0,30.0...|       1|
|[3.0,1.0,1.0,16.7...|       1|
|[1.0,0.0,0.0,26.5...|       1|
|[3.0,0.0,0.0,8.05...|       0|
|[3.0,1.0,5.0,31.2...|       0|
|[3.0,0.0,0.0,7.85...|       0|
|[2.0,0.0,0.0,16.0...|       1|
|[3.0,4.0,1.0,29.1...|       0|
|[2.0,0.0,0.0,13.0...|       1|
|[3.0,1.0,0.0,18.0...|       0|
|[3.0,0.0,0.0,7.22...|       1|
+--------------------+--------+
only showing top 20 rows



In [None]:
train_data,test_data = finalized_data.randomSplit([0.8,0.2])

from pyspark.ml.classification import RandomForestClassifier
algo = RandomForestClassifier(featuresCol='features', labelCol='Survived')
model = algo.fit(train_data)
pred = model.evaluate(test_data)
pred.predictions.show()



+--------------------+--------+--------------------+--------------------+----------+
|            features|Survived|       rawPrediction|         probability|prediction|
+--------------------+--------+--------------------+--------------------+----------+
|(6,[0,4],[2.0,29....|       0|[18.1999737252064...|[0.90999868626032...|       0.0|
|(6,[0,4],[2.0,29....|       0|[18.1999737252064...|[0.90999868626032...|       0.0|
|[1.0,0.0,0.0,25.9...|       1|[2.69790881746876...|[0.13489544087343...|       1.0|
|[1.0,0.0,0.0,26.0...|       0|[15.7385889970113...|[0.78692944985056...|       0.0|
|[1.0,0.0,0.0,26.5...|       1|[10.7747876544857...|[0.53873938272428...|       0.0|
|[1.0,0.0,0.0,26.5...|       1|[10.7747876544857...|[0.53873938272428...|       0.0|
|[1.0,0.0,0.0,26.5...|       1|[11.2698237204520...|[0.56349118602260...|       0.0|
|[1.0,0.0,0.0,27.7...|       0|[10.7162843939333...|[0.53581421969666...|       0.0|
|[1.0,0.0,0.0,27.7...|       0|[11.0153192159475...|[0.5507659607