# Decision Trees in Apache Spark
![image](https://a248.e.akamai.net/secure.meetupstatic.com/photos/event/b/0/4/1/600_436125121.jpeg)
This notebook was put together by Joseph Kambourakis for the [Boston Apache Spark User Group](https://www.meetup.com/Boston-Apache-Spark-User-Group/).  If you have any questions please email me at <joseph.kambourakis@ibm.com> 

## Load Libraries and Initialize the Spark Session
There is a new Spark Session Context in 2.0 that we'll be using here.  Previously, Spark would use an SQL Context.  Labeled Point is a data format that is required for supervised learning techniques.

In [1]:
from pyspark.mllib.tree import DecisionTree, DecisionTreeModel
from pyspark.mllib.regression import LabeledPoint
from pyspark.ml.feature import OneHotEncoder, StringIndexer

SC = pyspark.sql.SparkSession(sc)

## Load Data
The dataset we are working with is the South African Heart Disease Dataset from [Elements of Statistical Learning](The dataset we are working with is the South African Heart Disease Dataset from [The Elements of Statistical Learning](http://statweb.stanford.edu/~tibs/ElemStatLearn/).  The data is in a csv hosted on github.  We'll be downloading it to the local directory.  Data Science Experience provides a local filestore and 5GB of space.  The ! operator lets you write bash commands within a notebook.  

In [2]:
!wget https://raw.githubusercontent.com/JosephKambourakisIBM/Meetup/master/SAheart.csv
rawdata = SC.read.csv('SAheart.csv', header = 'TRUE')

--2016-10-31 14:00:04--  https://raw.githubusercontent.com/JosephKambourakisIBM/Meetup/master/SAheart.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 151.101.48.133
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|151.101.48.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 21499 (21K) [text/plain]
Saving to: 'SAheart.csv.4'


2016-10-31 14:00:04 (37.4 MB/s) - 'SAheart.csv.4' saved [21499/21499]



## Examine the Data
![image of Pandas](http://pandas.pydata.org/_static/pandas_logo.png) <br>
We'll use the [toPandas](http://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.DataFrame.toPandas) function to print the datafame in a columnar format.  The default print for Spark is much messier.  To see it, try rawdata.collect()


In [3]:
rawdata.toPandas()

Unnamed: 0,sbp,tobacco,ldl,adiposity,famhist,typea,obesity,alcohol,age,chd
0,160,12,5.73,23.11,Present,49,25.3,97.2,52,1
1,144,0.01,4.41,28.61,Absent,55,28.87,2.06,63,1
2,118,0.08,3.48,32.28,Present,52,29.14,3.81,46,0
3,170,7.5,6.41,38.03,Present,51,31.99,24.26,58,1
4,134,13.6,3.5,27.78,Present,60,25.99,57.34,49,1
5,132,6.2,6.47,36.21,Present,62,30.77,14.14,45,0
6,142,4.05,3.38,16.2,Absent,59,20.81,2.62,38,0
7,114,4.08,4.59,14.6,Present,62,23.11,6.72,58,1
8,114,0,3.83,19.4,Present,49,24.86,2.49,29,0
9,132,0,5.8,30.96,Present,69,30.11,0,53,1


Let's examine the column names and schema of our data.  

In [4]:
print rawdata.columns
print '\n''\n'
rawdata.printSchema()

['sbp', 'tobacco', 'ldl', 'adiposity', 'famhist', 'typea', 'obesity', 'alcohol', 'age', 'chd']



root
 |-- sbp: string (nullable = true)
 |-- tobacco: string (nullable = true)
 |-- ldl: string (nullable = true)
 |-- adiposity: string (nullable = true)
 |-- famhist: string (nullable = true)
 |-- typea: string (nullable = true)
 |-- obesity: string (nullable = true)
 |-- alcohol: string (nullable = true)
 |-- age: string (nullable = true)
 |-- chd: string (nullable = true)



## Clean Data
We saw that the columns are all strings, so we'll need to convert them to the appropriate types.  We can use the [withColumn](http://spark.apache.org/docs/latest/api/python/pyspark.sql.html#pyspark.sql.DataFrame.withColumn) method to apply a cast method to each column.  We'll also need to convert the categorical variable family history, into an indicator variable famhistIndex.  

In [5]:
cleaneddata = rawdata.withColumn("sbp", rawdata["sbp"].cast('float'))\
.withColumn("tobacco", rawdata["tobacco"].cast('float'))\
.withColumn("ldl", rawdata["ldl"].cast('float'))\
.withColumn("adiposity", rawdata["adiposity"].cast('float'))\
.withColumn("typea", rawdata["typea"].cast('int'))\
.withColumn("obesity", rawdata["obesity"].cast('float'))\
.withColumn("alcohol", rawdata["alcohol"].cast('float'))\
.withColumn("age", rawdata["age"].cast('int'))\
.withColumn("chd", rawdata["chd"].cast('float')) #Our model requires the dependent variable to be a float
cleaneddata.printSchema()

#Clean the categorical variable family history
stringIndexer = StringIndexer(inputCol="famhist", outputCol="famhistIndex")
model = stringIndexer.fit(cleaneddata)
indexed = model.transform(cleaneddata)
indexed.show(5)
indexed.printSchema()

root
 |-- sbp: float (nullable = true)
 |-- tobacco: float (nullable = true)
 |-- ldl: float (nullable = true)
 |-- adiposity: float (nullable = true)
 |-- famhist: string (nullable = true)
 |-- typea: integer (nullable = true)
 |-- obesity: float (nullable = true)
 |-- alcohol: float (nullable = true)
 |-- age: integer (nullable = true)
 |-- chd: float (nullable = true)

+-----+-------+----+---------+-------+-----+-------+-------+---+---+------------+
|  sbp|tobacco| ldl|adiposity|famhist|typea|obesity|alcohol|age|chd|famhistIndex|
+-----+-------+----+---------+-------+-----+-------+-------+---+---+------------+
|160.0|   12.0|5.73|    23.11|Present|   49|   25.3|   97.2| 52|1.0|         1.0|
|144.0|   0.01|4.41|    28.61| Absent|   55|  28.87|   2.06| 63|1.0|         0.0|
|118.0|   0.08|3.48|    32.28|Present|   52|  29.14|   3.81| 46|0.0|         1.0|
|170.0|    7.5|6.41|    38.03|Present|   51|  31.99|  24.26| 58|1.0|         1.0|
|134.0|   13.6| 3.5|    27.78|Present|   60|  25.99

## More data cleaning
Our model function requires an RDD, so we'll change our dataframe to an RDD.  We also need to drop the famhist variable for our indicator variable famhistIndex.  

In [6]:
data2 = indexed.select('chd', "sbp", "tobacco", 'ldl', 'adiposity', 'typea', 'obesity', 'alcohol', 'age', 'famhistIndex').rdd

The Decision Tree classifier in Spark MLlib requires the data to be in LabeledPoint format.  

In [7]:
data3 = data2.map(lambda x: LabeledPoint(x[0], [x[1:]]))
data3.first()

LabeledPoint(1.0, [160.0,12.0,5.73000001907,23.1100006104,49.0,25.2999992371,97.1999969482,52.0,1.0])

## Build Model

In [8]:
# Train a DecisionTree model.
#  Empty categoricalFeaturesInfo indicates all features are continuous.
model = DecisionTree.trainClassifier(data3, numClasses=2, categoricalFeaturesInfo={}, impurity='gini', maxDepth=3, maxBins=5)
print(model.toDebugString())

DecisionTreeModel classifier of depth 3 with 15 nodes
  If (feature 7 <= 49.0)
   If (feature 7 <= 28.0)
    If (feature 6 <= 11.829999923706055)
     Predict: 0.0
    Else (feature 6 > 11.829999923706055)
     Predict: 0.0
   Else (feature 7 > 28.0)
    If (feature 2 <= 3.9200000762939453)
     Predict: 0.0
    Else (feature 2 > 3.9200000762939453)
     Predict: 0.0
  Else (feature 7 > 49.0)
   If (feature 8 <= 0.0)
    If (feature 1 <= 6.170000076293945)
     Predict: 0.0
    Else (feature 1 > 6.170000076293945)
     Predict: 1.0
   Else (feature 8 > 0.0)
    If (feature 2 <= 4.889999866485596)
     Predict: 1.0
    Else (feature 2 > 4.889999866485596)
     Predict: 1.0



Not exactly an easy output to understand.  Let's try to clean up our model output

In [9]:
printedmodel = str(model.toDebugString())
columns = indexed.columns
newprint = printedmodel.replace('Predict: 0.0', "Predict: Healthy")\
.replace('Predict: 1.0', "Predict: Heart Disease")\
.replace('feature 1', columns[0])\
.replace('feature 2', columns[1])\
.replace('feature 3', columns[2])\
.replace('feature 4', columns[3])\
.replace('feature 5', columns[4])\
.replace('feature 6', columns[5])\
.replace('feature 7', columns[6])\
.replace('feature 8', columns[7])
print(newprint)
#print columns

DecisionTreeModel classifier of depth 3 with 15 nodes
  If (obesity <= 49.0)
   If (obesity <= 28.0)
    If (typea <= 11.829999923706055)
     Predict: Healthy
    Else (typea > 11.829999923706055)
     Predict: Healthy
   Else (obesity > 28.0)
    If (tobacco <= 3.9200000762939453)
     Predict: Healthy
    Else (tobacco > 3.9200000762939453)
     Predict: Healthy
  Else (obesity > 49.0)
   If (alcohol <= 0.0)
    If (sbp <= 6.170000076293945)
     Predict: Healthy
    Else (sbp > 6.170000076293945)
     Predict: Heart Disease
   Else (alcohol > 0.0)
    If (tobacco <= 4.889999866485596)
     Predict: Heart Disease
    Else (tobacco > 4.889999866485596)
     Predict: Heart Disease



## Check model accuracy
Checking the model accuracy is a simple counting exercise.  We count the number of mistakes (prediction doesn't equal value) and divide by the total count.  

In [10]:
pred = model.predict(data3.map(lambda x: x.features))
labsandpred = data3.map(lambda lp: lp.label).zip(pred)
error = labsandpred.filter(lambda (v,p): v !=p).count()/float(data3.count())
print error

0.25974025974


Be careful of overfitting.  Here we make a tree with 15 bins and a depth of 13.  

In [11]:
model1 = DecisionTree.trainClassifier(data3, numClasses=2, categoricalFeaturesInfo={}, impurity='gini', maxDepth=13, maxBins=15)
pred1 = model1.predict(data3.map(lambda x: x.features))
labsandpred1 = data3.map(lambda lp: lp.label).zip(pred1)
error1 = labsandpred1.filter(lambda (v,p): v !=p).count()/float(data3.count())
print error1

0.0
