In [10]:
from pyspark import SparkContext, SparkConf
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.tree import DecisionTree
import numpy as np
import pandas as pd

In [21]:
filepath = "../../datasets/MLCourse/PastHires.csv"
df = pd.read_csv(filepath)
df.head()

Unnamed: 0,Years Experience,Employed?,Previous employers,Level of Education,Top-tier school,Interned,Hired
0,10,Y,4,BS,N,N,Y
1,0,N,0,BS,Y,Y,Y
2,7,N,6,BS,N,N,N
3,2,Y,1,MS,Y,N,Y
4,20,N,2,PhD,Y,N,N


In [None]:
conf = SparkConf().setMaster("local").setAppName("SparkDecisionTree")
sc = SparkContext(conf=conf)

In [20]:
def binary(YN):
    if YN == "Y":
        return 1
    else:
        return 0

def map_education(degree):
    if degree == "BS":
        return 1
    elif degree == "MS":
        return 2
    elif degree == "PhD":
        return 3
    else:
        return 0
    
def create_labeled_point(fields):
    years_experience = int(fields[0])
    employed = binary(fields[1])
    previous_employer = int(fields[2])
    level_of_education = map_education(fields[3])
    top_tier_school = binary(fields[4])
    interned = binary(fields[5])
    hired = binary(fields[6])
    
    return LabeledPoint(hired, np.array([years_experience, employed, previous_employer,
                                        level_of_education, top_tier_school, interned]))

In [23]:
raw_data = sc.textFile(filepath)
header = raw_data.first()
raw_data = raw_data.filter(lambda x: x != header)

In [27]:
csv_data = raw_data.map(lambda x: x.split(","))
training_data = csv_data.map(create_labeled_point)

In [28]:
test_candidate = [np.array([10, 1, 3, 1, 0, 0])]
test_data = sc.parallelize(test_candidate)

In [29]:
model = DecisionTree.trainClassifier(training_data, numClasses=2, 
                                    categoricalFeaturesInfo={1:2, 3:4, 4:2, 5:2},
                                    impurity="gini", maxDepth=5, maxBins=32)

In [31]:
prediction = model.predict(test_data)

In [32]:
print("hire prediction:")
results = prediction.collect()
for result in results:
    print(result)

hire prediction:
1.0


In [33]:
print("learned classification tree model")
print(model.toDebugString())

learned classification tree model
DecisionTreeModel classifier of depth 4 with 9 nodes
  If (feature 1 in {0.0})
   If (feature 5 in {0.0})
    If (feature 0 <= 0.5)
     If (feature 3 in {1.0})
      Predict: 0.0
     Else (feature 3 not in {1.0})
      Predict: 1.0
    Else (feature 0 > 0.5)
     Predict: 0.0
   Else (feature 5 not in {0.0})
    Predict: 1.0
  Else (feature 1 not in {0.0})
   Predict: 1.0

