In [10]:
import pyspark


In [11]:
from pyspark.sql import SparkSession

In [12]:
spark = SparkSession \
    .builder \
    .appName("Titanic Data") \
    .master("local[*]") \
    .config("spark.driver.host", "localhost") \
    .config("spark.driver.bindAddress", "127.0.0.1") \
    .config("spark.sql.warehouse.dir", "file:///tmp/spark-warehouse") \
    .config("spark.driver.extraJavaOptions", "-Djava.security.manager=allow") \
    .getOrCreate()

spark

25/03/10 10:25:10 WARN SparkSession: Using an existing Spark session; only runtime SQL configurations will take effect.


In [13]:
df = (spark.read
      .format("csv")
      .option("header", "true")
      .load("Data/titanic/train.csv")
      )

df.show(5)

+-----------+--------+------+--------------------+------+---+-----+-----+----------------+-------+-----+--------+
|PassengerId|Survived|Pclass|                Name|   Sex|Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|
+-----------+--------+------+--------------------+------+---+-----+-----+----------------+-------+-----+--------+
|          1|       0|     3|Braund, Mr. Owen ...|  male| 22|    1|    0|       A/5 21171|   7.25| NULL|       S|
|          2|       1|     1|Cumings, Mrs. Joh...|female| 38|    1|    0|        PC 17599|71.2833|  C85|       C|
|          3|       1|     3|Heikkinen, Miss. ...|female| 26|    0|    0|STON/O2. 3101282|  7.925| NULL|       S|
|          4|       1|     1|Futrelle, Mrs. Ja...|female| 35|    1|    0|          113803|   53.1| C123|       S|
|          5|       0|     3|Allen, Mr. Willia...|  male| 35|    0|    0|          373450|   8.05| NULL|       S|
+-----------+--------+------+--------------------+------+---+-----+-----+---------------

In [14]:
from pyspark.sql.functions import col

dataset = df.select(col("Survived").cast("float"),
                      col("Pclass").cast("float"),
                      col("Sex"),
                      col("Age").cast("float"),
                      col("Fare").cast("float"),
                      col("Embarked"))

dataset.show(4)

+--------+------+------+----+-------+--------+
|Survived|Pclass|   Sex| Age|   Fare|Embarked|
+--------+------+------+----+-------+--------+
|     0.0|   3.0|  male|22.0|   7.25|       S|
|     1.0|   1.0|female|38.0|71.2833|       C|
|     1.0|   3.0|female|26.0|  7.925|       S|
|     1.0|   1.0|female|35.0|   53.1|       S|
+--------+------+------+----+-------+--------+
only showing top 4 rows



In [15]:
from pyspark.sql.functions import isnull, when, count, col

dataset.select([count(when(isnull(c), c)).alias(c) for c in dataset.columns]).show()

+--------+------+---+---+----+--------+
|Survived|Pclass|Sex|Age|Fare|Embarked|
+--------+------+---+---+----+--------+
|       0|     0|  0|177|   0|       2|
+--------+------+---+---+----+--------+



In [16]:
dataset = dataset.replace('?', None).dropna(how='any')

In [17]:
dataset.select([count(when(isnull(c), c)).alias(c) for c in dataset.columns]).show()

+--------+------+---+---+----+--------+
|Survived|Pclass|Sex|Age|Fare|Embarked|
+--------+------+---+---+----+--------+
|       0|     0|  0|  0|   0|       0|
+--------+------+---+---+----+--------+



In [18]:
from pyspark.ml.feature import StringIndexer

dataset = StringIndexer(
    inputCol = 'Sex',
    outputCol = 'Gender',
    handleInvalid = 'keep').fit(dataset).transform(dataset)

In [19]:
dataset = StringIndexer(
    inputCol = 'Embarked',
    outputCol = 'Boarded',
    handleInvalid = 'keep').fit(dataset).transform(dataset)

In [20]:
dataset.show(4)

+--------+------+------+----+-------+--------+------+-------+
|Survived|Pclass|   Sex| Age|   Fare|Embarked|Gender|Boarded|
+--------+------+------+----+-------+--------+------+-------+
|     0.0|   3.0|  male|22.0|   7.25|       S|   0.0|    0.0|
|     1.0|   1.0|female|38.0|71.2833|       C|   1.0|    1.0|
|     1.0|   3.0|female|26.0|  7.925|       S|   1.0|    0.0|
|     1.0|   1.0|female|35.0|   53.1|       S|   1.0|    0.0|
+--------+------+------+----+-------+--------+------+-------+
only showing top 4 rows



In [21]:
dataset = dataset.drop('Sex')
dataset = dataset.drop('Embarked')

In [22]:
dataset.show(4)

+--------+------+----+-------+------+-------+
|Survived|Pclass| Age|   Fare|Gender|Boarded|
+--------+------+----+-------+------+-------+
|     0.0|   3.0|22.0|   7.25|   0.0|    0.0|
|     1.0|   1.0|38.0|71.2833|   1.0|    1.0|
|     1.0|   3.0|26.0|  7.925|   1.0|    0.0|
|     1.0|   1.0|35.0|   53.1|   1.0|    0.0|
+--------+------+----+-------+------+-------+
only showing top 4 rows



In [23]:
from pyspark.ml.feature import VectorAssembler

required_features = ['Pclass', 'Age', 'Fare', 'Gender', 'Boarded']
assembler = VectorAssembler(inputCols=required_features, outputCol='features')
transformed_data = assembler.transform(dataset)

In [24]:
transformed_data.show(4)

+--------+------+----+-------+------+-------+--------------------+
|Survived|Pclass| Age|   Fare|Gender|Boarded|            features|
+--------+------+----+-------+------+-------+--------------------+
|     0.0|   3.0|22.0|   7.25|   0.0|    0.0|[3.0,22.0,7.25,0....|
|     1.0|   1.0|38.0|71.2833|   1.0|    1.0|[1.0,38.0,71.2833...|
|     1.0|   3.0|26.0|  7.925|   1.0|    0.0|[3.0,26.0,7.92500...|
|     1.0|   1.0|35.0|   53.1|   1.0|    0.0|[1.0,35.0,53.0999...|
+--------+------+----+-------+------+-------+--------------------+
only showing top 4 rows



In [25]:
(training_data, test_data) = transformed_data.randomSplit([0.8, 0.2])
print("Number of train samples: " + str(training_data.count()))
print("Number of test samples: " + str(test_data.count()))

Number of train samples: 583
Number of test samples: 129


In [26]:
from pyspark.ml.classification import RandomForestClassifier

rf = RandomForestClassifier(labelCol='Survived',
                            featuresCol='features',
                            maxDepth=5)
model = rf.fit(training_data)
predictions = model.transform(test_data)

In [27]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

evaluator = MulticlassClassificationEvaluator(labelCol='Survived', predictionCol='prediction', metricName='accuracy')
accuracy = evaluator.evaluate(predictions)
print("Accuracy = %g" % accuracy)

Accuracy = 0.829457
