In [54]:
from pyspark.sql import SparkSession
from pyspark.conf import SparkConf
from pyspark import SparkContext
from pyspark.sql import SQLContext
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.mllib.classification import LogisticRegressionWithLBFGS
from pyspark.mllib.evaluation import BinaryClassificationMetrics
from pyspark.mllib.evaluation import MulticlassMetrics
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.mllib.util import MLUtils
from pyspark.ml.feature import VectorAssembler
from sklearn.metrics import confusion_matrix
from pyspark.sql.functions import when
import pandas as pd 

In [55]:
config = SparkConf()
spark = SparkSession.builder.master("local").appName("test").config(conf=config).getOrCreate()

In [56]:
# Preparing the train
train = spark.read.option("header", "true").csv(r"data\train.csv")
# filtre et supprime les donn√©es "NA"
train = train.filter(train.Age != "NA")
train = train.withColumn("Gender", when(train.Sex == "male","1").when(train.Sex == "female","2"))
# Supprime les colonnes inutiles
cols = ('SibSp', 'Parch', 'Fare', 'Ticket' ,'Cabin', 'Embarked', 'Name', 'Sex')
train = train.drop(*cols)

train.show(10)

+-----------+--------+------+---+------+
|PassengerId|Survived|Pclass|Age|Gender|
+-----------+--------+------+---+------+
|          1|       0|     3| 22|     1|
|          2|       1|     1| 38|     2|
|          3|       1|     3| 26|     2|
|          4|       1|     1| 35|     2|
|          5|       0|     3| 35|     1|
|          7|       0|     1| 54|     1|
|          8|       0|     3|  2|     1|
|          9|       1|     3| 27|     2|
|         10|       1|     2| 14|     2|
|         11|       1|     3|  4|     2|
+-----------+--------+------+---+------+
only showing top 10 rows



In [57]:
from pyspark.sql.types import IntegerType
from pyspark.sql.types import FloatType

train = train.withColumn("PassengerId", train["PassengerId"].cast('float'))
train = train.withColumn("Survived" ,train["Survived"].cast('float'))
train = train.withColumn("Pclass" ,train["Pclass"].cast('float'))
train = train.withColumn("Age" ,train["Age"].cast(('float')))
train = train.withColumn("Gender" ,train["Gender"].cast(('float')))
train.printSchema()
features = ['Pclass','Age', 'Gender']
va = VectorAssembler(inputCols = features, outputCol='features')
va_df = va.transform(train)
va_df.show(3)

root
 |-- PassengerId: float (nullable = true)
 |-- Survived: float (nullable = true)
 |-- Pclass: float (nullable = true)
 |-- Age: float (nullable = true)
 |-- Gender: float (nullable = true)

+-----------+--------+------+----+------+--------------+
|PassengerId|Survived|Pclass| Age|Gender|      features|
+-----------+--------+------+----+------+--------------+
|        1.0|     0.0|   3.0|22.0|   1.0|[3.0,22.0,1.0]|
|        2.0|     1.0|   1.0|38.0|   2.0|[1.0,38.0,2.0]|
|        3.0|     1.0|   3.0|26.0|   2.0|[3.0,26.0,2.0]|
+-----------+--------+------+----+------+--------------+
only showing top 3 rows



In [58]:
(train_, test) = va_df.randomSplit([0.8, 0.2])

In [59]:
from pyspark.ml.feature import StringIndexer

dtc = DecisionTreeClassifier(featuresCol="features", labelCol="Survived")

#indexer = StringIndexer().setInputCol("Survived").setOutputCol("label_idx").fit(train_)

prediction = dtc.fit(train_) #dtc.setLabelCol("label_idx").fit(dtc)

prediction

DecisionTreeClassificationModel: uid=DecisionTreeClassifier_1a82c1af751b, depth=5, numNodes=23, numClasses=2, numFeatures=3

In [60]:
pred = prediction.transform(test)
pred = pred.withColumnRenamed('Survived', 'label')
pred.show(3)

+-----------+-----+------+----+------+--------------+-------------+--------------------+----------+
|PassengerId|label|Pclass| Age|Gender|      features|rawPrediction|         probability|prediction|
+-----------+-----+------+----+------+--------------+-------------+--------------------+----------+
|       11.0|  1.0|   3.0| 4.0|   2.0| [3.0,4.0,2.0]|    [4.0,9.0]|[0.30769230769230...|       1.0|
|       12.0|  1.0|   1.0|58.0|   2.0|[1.0,58.0,2.0]|  [8.0,122.0]|[0.06153846153846...|       1.0|
|       17.0|  0.0|   3.0| 2.0|   1.0| [3.0,2.0,1.0]|    [2.0,2.0]|           [0.5,0.5]|       0.0|
+-----------+-----+------+----+------+--------------+-------------+--------------------+----------+
only showing top 3 rows



In [61]:
evaluator=MulticlassClassificationEvaluator(predictionCol="prediction")
acc = evaluator.evaluate(pred)

print("Prediction Accuracy: ", acc*100)

y_pred=pred.select("prediction").collect()
y_orig=pred.select("label").collect()

confusion_M = confusion_matrix(y_orig, y_pred)
print("Confusion Matrix:")
print(confusion_M)

Prediction Accuracy:  85.86322536304587
Confusion Matrix:
[[82  6]
 [13 36]]


In [62]:
from sklearn.metrics import classification_report, confusion_matrix
print(classification_report(y_orig, y_pred))

              precision    recall  f1-score   support

         0.0       0.86      0.93      0.90        88
         1.0       0.86      0.73      0.79        49

    accuracy                           0.86       137
   macro avg       0.86      0.83      0.84       137
weighted avg       0.86      0.86      0.86       137



In [63]:
#from dtreeviz import trees
#trees.dtreeviz(prediction, fancy=True)

In [64]:
#import graphviz
#print(graphviz.__version__)

#graphviz_tree = tree.export_graphviz(iowa_model, out_file=None, feature_names=features, filled=True)
#graphviz.Source(graphviz_tree, format="png") 

In [65]:
from io import StringIO
from IPython.display import Image  
import pydotplus
from sklearn.tree import export_graphviz

feature_cols = ['PassengerId', 'Survived','Pclass', 'Age', 'Gender']
train.show(1)

dot_data = StringIO()
export_graphviz(dtc.fit(train), out_file=dot_data, filled=True, rounded=True, special_characters=True, feature_names = feature_cols, class_names=['0','1'])
graph = pydotplus.graph_from_dot_data(dot_data.getvalue())  
graph.write_png('test-spark.png')
Image(graph.create_png())

+-----------+--------+------+----+------+
|PassengerId|Survived|Pclass| Age|Gender|
+-----------+--------+------+----+------+
|        1.0|     0.0|   3.0|22.0|   1.0|
+-----------+--------+------+----+------+
only showing top 1 row



IllegalArgumentException: features does not exist. Available: PassengerId, Survived, Pclass, Age, Gender

In [None]:
spark.stop()