In [12]:
import pyspark  
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.classification import LogisticRegressionWithSGD
from pyspark.mllib.tree import DecisionTree

In [13]:
sc.stop()
sc = pyspark.SparkContext()

In [14]:
raw_rdd = sc.textFile("/Users/Amardeep/Documents/Semester_2/Dic/Lab_5/Titanic.csv")

In [15]:
raw_rdd.count()

1317

In [16]:
raw_rdd.take(5)

['"","class","age","sex","survived"',
 '"1","1st class","adults","man","yes"',
 '"2","1st class","adults","man","yes"',
 '"3","1st class","adults","man","yes"',
 '"4","1st class","adults","man","yes"']

In [17]:
header = raw_rdd.first()
data_rdd = raw_rdd.filter(lambda line: line != header)

In [18]:
data_rdd.takeSample(False, 5, 0)

['"159","1st class","adults","man","no"',
 '"256","1st class","adults","women","yes"',
 '"1204","3rd class","adults","women","no"',
 '"758","3rd class","adults","man","no"',
 '"730","3rd class","adults","man","no"']

In [25]:
def row_to_labeled_point(line):
    passenger_id, convert_to_int, age, sex, survived = [values.strip('"') for values in line.split(',')]
    convert_to_int = int(convert_to_int[0]) - 1
    
    if (age not in ['adults', 'child'] or 
        sex not in ['man', 'women'] or
        survived not in ['yes', 'no']):
        raise RuntimeError('unknown value')
    
    features = [
        convert_to_int,
        (1 if age == 'adults' else 0),
        (1 if sex == 'women' else 0)
    ]
    return LabeledPoint(1 if survived == 'yes' else 0, features)

In [26]:
labeled_points_rdd = data_rdd.map(row_to_labeled_point)

In [27]:
labeled_points_rdd.takeSample(False, 5, 0)

[LabeledPoint(0.0, [0.0,1.0,0.0]),
 LabeledPoint(1.0, [0.0,1.0,1.0]),
 LabeledPoint(0.0, [2.0,1.0,1.0]),
 LabeledPoint(0.0, [2.0,1.0,0.0]),
 LabeledPoint(0.0, [2.0,1.0,0.0])]

In [45]:
training_rdd, test_rdd = labeled_points_rdd.randomSplit([0.75, 0.25], seed = 0)

In [46]:
training_count = training_rdd.count()
test_count = test_rdd.count()

In [47]:
training_count, test_count

(978, 338)

In [48]:
model = DecisionTree.trainClassifier(training_rdd, 
                                     numClasses=2, 
                                     categoricalFeaturesInfo={
                                        0: 3,
                                        1: 2,
                                        2: 2
                                     })

In [49]:
predictions_rdd = model.predict(test_rdd.map(lambda x: x.features))

In [50]:
truth_and_predictions_rdd = test_rdd.map(lambda lp: lp.label).zip(predictions_rdd)

In [51]:
accuracy = truth_and_predictions_rdd.filter(lambda v_p: v_p[0] == v_p[1]).count() / float(test_count)
print('Accuracy =', accuracy)
print(model.toDebugString())

Accuracy = 0.8047337278106509
DecisionTreeModel classifier of depth 4 with 21 nodes
  If (feature 2 in {0.0})
   If (feature 1 in {0.0})
    If (feature 0 in {0.0,1.0})
     Predict: 1.0
    Else (feature 0 not in {0.0,1.0})
     Predict: 0.0
   Else (feature 1 not in {0.0})
    If (feature 0 in {1.0})
     Predict: 0.0
    Else (feature 0 not in {1.0})
     If (feature 0 in {0.0})
      Predict: 0.0
     Else (feature 0 not in {0.0})
      Predict: 0.0
  Else (feature 2 not in {0.0})
   If (feature 0 in {2.0})
    If (feature 1 in {0.0})
     Predict: 0.0
    Else (feature 1 not in {0.0})
     Predict: 0.0
   Else (feature 0 not in {2.0})
    If (feature 0 in {1.0})
     If (feature 1 in {0.0})
      Predict: 1.0
     Else (feature 1 not in {0.0})
      Predict: 1.0
    Else (feature 0 not in {1.0})
     If (feature 1 in {0.0})
      Predict: 1.0
     Else (feature 1 not in {0.0})
      Predict: 1.0



In [52]:
model = LogisticRegressionWithSGD.train(training_rdd)

  "Deprecated in 2.0.0. Use ml.classification.LogisticRegression or "


In [53]:
predictions_rdd = model.predict(test_rdd.map(lambda x: x.features))

In [54]:
labels_and_predictions_rdd = test_rdd.map(lambda lp: lp.label).zip(predictions_rdd)

In [55]:
accuracy = labels_and_predictions_rdd.filter(lambda v_p: v_p[0] == v_p[1]).count() / float(test_count)
print('Accuracy =', accuracy)

Accuracy = 0.7928994082840237
