In [2]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.feature import StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import VectorAssembler
from pyspark.sql import SparkSession


In [4]:
spark=SparkSession \
.builder \
.appName("Employee Attrition using classification decision trees") \
.config("spark.some.config.option","some-value") \
.getOrCreate()

In [5]:
df=spark.read.load('attrition-db.csv',format='csv',header=True,inferSchema=True)
df.show()

+---+---------+-----------------+---------+--------------------+----------------+---------+--------------+-------------+--------------+-----------------------+------+----------+--------------+--------+--------------------+---------------+-------------+-------------+-----------+------------------+------+--------+-----------------+-----------------+------------------------+-------------+----------------+-----------------+---------------------+---------------+--------------+------------------+-----------------------+--------------------+
|Age|Attrition|   BusinessTravel|DailyRate|          Department|DistanceFromHome|Education|EducationField|EmployeeCount|EmployeeNumber|EnvironmentSatisfaction|Gender|HourlyRate|JobInvolvement|JobLevel|             JobRole|JobSatisfaction|MaritalStatus|MonthlyIncome|MonthlyRate|NumCompaniesWorked|Over18|OverTime|PercentSalaryHike|PerformanceRating|RelationshipSatisfaction|StandardHours|StockOptionLevel|TotalWorkingYears|TrainingTimesLastYear|WorkLifeBalanc

In [24]:
df.printSchema()

root
 |-- Age: integer (nullable = true)
 |-- Attrition: string (nullable = true)
 |-- BusinessTravel: string (nullable = true)
 |-- DailyRate: integer (nullable = true)
 |-- Department: string (nullable = true)
 |-- DistanceFromHome: integer (nullable = true)
 |-- Education: integer (nullable = true)
 |-- EducationField: string (nullable = true)
 |-- EmployeeCount: integer (nullable = true)
 |-- EmployeeNumber: integer (nullable = true)
 |-- EnvironmentSatisfaction: integer (nullable = true)
 |-- Gender: string (nullable = true)
 |-- HourlyRate: integer (nullable = true)
 |-- JobInvolvement: integer (nullable = true)
 |-- JobLevel: integer (nullable = true)
 |-- JobRole: string (nullable = true)
 |-- JobSatisfaction: integer (nullable = true)
 |-- MaritalStatus: string (nullable = true)
 |-- MonthlyIncome: integer (nullable = true)
 |-- MonthlyRate: integer (nullable = true)
 |-- NumCompaniesWorked: integer (nullable = true)
 |-- Over18: string (nullable = true)
 |-- OverTime: string 

### Vectorize the columns

In [12]:
assembler = VectorAssembler(
    inputCols=["Age","DailyRate","DistanceFromHome","Education","EmployeeCount","EmployeeNumber","EnvironmentSatisfaction","HourlyRate","JobInvolvement","JobLevel","JobSatisfaction","MonthlyIncome","MonthlyRate","NumCompaniesWorked","PercentSalaryHike","PerformanceRating","RelationshipSatisfaction","StandardHours","StockOptionLevel","TotalWorkingYears","TrainingTimesLastYear","WorkLifeBalance","YearsAtCompany","YearsInCurrentRole","YearsSinceLastPromotion","YearsWithCurrManager"],
    outputCol="features")


In [13]:
output = assembler.transform(df)
output.select("features").show(truncate=False)

+---------------------------------------------------------------------------------------------------------------------------+
|features                                                                                                                   |
+---------------------------------------------------------------------------------------------------------------------------+
|[41.0,1102.0,1.0,2.0,1.0,1.0,2.0,94.0,3.0,2.0,4.0,5993.0,19479.0,8.0,11.0,3.0,1.0,80.0,0.0,8.0,0.0,1.0,6.0,4.0,0.0,5.0]    |
|[49.0,279.0,8.0,1.0,1.0,2.0,3.0,61.0,2.0,2.0,2.0,5130.0,24907.0,1.0,23.0,4.0,4.0,80.0,1.0,10.0,3.0,3.0,10.0,7.0,1.0,7.0]   |
|[37.0,1373.0,2.0,2.0,1.0,4.0,4.0,92.0,2.0,1.0,3.0,2090.0,2396.0,6.0,15.0,3.0,2.0,80.0,0.0,7.0,3.0,3.0,0.0,0.0,0.0,0.0]     |
|[33.0,1392.0,3.0,4.0,1.0,5.0,4.0,56.0,3.0,1.0,3.0,2909.0,23159.0,1.0,11.0,3.0,3.0,80.0,0.0,8.0,3.0,3.0,8.0,7.0,3.0,0.0]    |
|[27.0,591.0,2.0,1.0,1.0,7.0,1.0,40.0,3.0,1.0,2.0,3468.0,16632.0,9.0,12.0,3.0,4.0,80.0,1.0,6.0,3.0,3.0,2.0,2.0,2.0,2.0

### Index labels, adding metadata to the label colum
### Fit on whole dataset to include all labels in index

In [14]:
labelIndexer = StringIndexer(inputCol="Attrition", outputCol="indexedAttrition").fit(output)

### Automatically identify categorical features, and index them
### Set maxCategories so features with > 4 distinct values are treated as continuous

In [15]:
featureIndexer =\
    VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(output)


### Split the data into training and test sets (30% held out for testing)

In [16]:
(trainingData, testData) = output.randomSplit([0.7, 0.3])


### Train a DecisionTree model

In [17]:
dt = DecisionTreeClassifier(labelCol="indexedAttrition", featuresCol="indexedFeatures")

### Chain indexers and tree in a Pipeline

In [18]:

pipeline = Pipeline(stages=[labelIndexer, featureIndexer, dt])

### Train model

In [19]:

model = pipeline.fit(trainingData)

### Make predictions

In [20]:

predictions = model.transform(testData)

### Select example rows to display

In [21]:

predictions.select("prediction", "Attrition", "features").show(5)

+----------+---------+--------------------+
|prediction|Attrition|            features|
+----------+---------+--------------------+
|       1.0|       No|[18.0,287.0,5.0,2...|
|       1.0|       No|[18.0,1431.0,14.0...|
|       1.0|       No|[18.0,812.0,10.0,...|
|       1.0|       No|[19.0,1181.0,3.0,...|
|       1.0|      Yes|[19.0,419.0,21.0,...|
+----------+---------+--------------------+
only showing top 5 rows



### Select (prediction, true label) and compute test error

In [25]:

evaluator = MulticlassClassificationEvaluator(
labelCol="indexedAttrition", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)

In [30]:
print("Test Error = %g " % (1.0 - accuracy))
print("Accuracy = %g " %accuracy)



Test Error = 0.185022 
Accuracy = 0.814978 


In [28]:


treeModel = model.stages[2]
# summary only
print(treeModel)


DecisionTreeClassificationModel (uid=DecisionTreeClassifier_45cd947c5f861d645bde) of depth 5 with 49 nodes
