In [1]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.feature import StringIndexer, VectorIndexer
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
from pyspark.sql import SparkSession

In [2]:
spark=SparkSession \
.builder \
.appName("Employee Attrition using GradientBoostedTree") \
.config("spark.some.config.option","some-value") \
.getOrCreate()

In [3]:
df = spark.read.load('attrition-db.csv',format='csv', header = True,inferSchema = True)
df.show()

+---+---------+-----------------+---------+--------------------+----------------+---------+--------------+-------------+--------------+-----------------------+------+----------+--------------+--------+--------------------+---------------+-------------+-------------+-----------+------------------+------+--------+-----------------+-----------------+------------------------+-------------+----------------+-----------------+---------------------+---------------+--------------+------------------+-----------------------+--------------------+
|Age|Attrition|   BusinessTravel|DailyRate|          Department|DistanceFromHome|Education|EducationField|EmployeeCount|EmployeeNumber|EnvironmentSatisfaction|Gender|HourlyRate|JobInvolvement|JobLevel|             JobRole|JobSatisfaction|MaritalStatus|MonthlyIncome|MonthlyRate|NumCompaniesWorked|Over18|OverTime|PercentSalaryHike|PerformanceRating|RelationshipSatisfaction|StandardHours|StockOptionLevel|TotalWorkingYears|TrainingTimesLastYear|WorkLifeBalanc

In [4]:
df.printSchema()

root
 |-- Age: integer (nullable = true)
 |-- Attrition: string (nullable = true)
 |-- BusinessTravel: string (nullable = true)
 |-- DailyRate: integer (nullable = true)
 |-- Department: string (nullable = true)
 |-- DistanceFromHome: integer (nullable = true)
 |-- Education: integer (nullable = true)
 |-- EducationField: string (nullable = true)
 |-- EmployeeCount: integer (nullable = true)
 |-- EmployeeNumber: integer (nullable = true)
 |-- EnvironmentSatisfaction: integer (nullable = true)
 |-- Gender: string (nullable = true)
 |-- HourlyRate: integer (nullable = true)
 |-- JobInvolvement: integer (nullable = true)
 |-- JobLevel: integer (nullable = true)
 |-- JobRole: string (nullable = true)
 |-- JobSatisfaction: integer (nullable = true)
 |-- MaritalStatus: string (nullable = true)
 |-- MonthlyIncome: integer (nullable = true)
 |-- MonthlyRate: integer (nullable = true)
 |-- NumCompaniesWorked: integer (nullable = true)
 |-- Over18: string (nullable = true)
 |-- OverTime: string 

### Index labels, adding metadata to the label column
### Fit on whole dataset to include all labels in index

In [5]:
labelIndexer = StringIndexer(inputCol="BusinessTravel", outputCol="intBusinessTravel").fit(df)
Employee_df=labelIndexer.transform(df)


In [6]:
Employee_df.printSchema()

root
 |-- Age: integer (nullable = true)
 |-- Attrition: string (nullable = true)
 |-- BusinessTravel: string (nullable = true)
 |-- DailyRate: integer (nullable = true)
 |-- Department: string (nullable = true)
 |-- DistanceFromHome: integer (nullable = true)
 |-- Education: integer (nullable = true)
 |-- EducationField: string (nullable = true)
 |-- EmployeeCount: integer (nullable = true)
 |-- EmployeeNumber: integer (nullable = true)
 |-- EnvironmentSatisfaction: integer (nullable = true)
 |-- Gender: string (nullable = true)
 |-- HourlyRate: integer (nullable = true)
 |-- JobInvolvement: integer (nullable = true)
 |-- JobLevel: integer (nullable = true)
 |-- JobRole: string (nullable = true)
 |-- JobSatisfaction: integer (nullable = true)
 |-- MaritalStatus: string (nullable = true)
 |-- MonthlyIncome: integer (nullable = true)
 |-- MonthlyRate: integer (nullable = true)
 |-- NumCompaniesWorked: integer (nullable = true)
 |-- Over18: string (nullable = true)
 |-- OverTime: string 

In [7]:
li = StringIndexer(inputCol="Department", outputCol="intDepartment").fit(Employee_df)
Employee_df = li.transform(Employee_df)


In [8]:
Employee_df.printSchema()

root
 |-- Age: integer (nullable = true)
 |-- Attrition: string (nullable = true)
 |-- BusinessTravel: string (nullable = true)
 |-- DailyRate: integer (nullable = true)
 |-- Department: string (nullable = true)
 |-- DistanceFromHome: integer (nullable = true)
 |-- Education: integer (nullable = true)
 |-- EducationField: string (nullable = true)
 |-- EmployeeCount: integer (nullable = true)
 |-- EmployeeNumber: integer (nullable = true)
 |-- EnvironmentSatisfaction: integer (nullable = true)
 |-- Gender: string (nullable = true)
 |-- HourlyRate: integer (nullable = true)
 |-- JobInvolvement: integer (nullable = true)
 |-- JobLevel: integer (nullable = true)
 |-- JobRole: string (nullable = true)
 |-- JobSatisfaction: integer (nullable = true)
 |-- MaritalStatus: string (nullable = true)
 |-- MonthlyIncome: integer (nullable = true)
 |-- MonthlyRate: integer (nullable = true)
 |-- NumCompaniesWorked: integer (nullable = true)
 |-- Over18: string (nullable = true)
 |-- OverTime: string 

### Converting String columns to Int columns for vectorizing

In [9]:
li1 = StringIndexer(inputCol="EducationField", outputCol="intEducationField").fit(Employee_df)
Employee_df = li1.transform(Employee_df)


In [10]:
li2 = StringIndexer(inputCol="Gender", outputCol="intGender").fit(Employee_df)
Employee_df = li2.transform(Employee_df)
li3= StringIndexer(inputCol="JobRole", outputCol="intJobRole").fit(Employee_df)
Employee_df = li3.transform(Employee_df)
li4 = StringIndexer(inputCol="MaritalStatus", outputCol="intMaritalStatus").fit(Employee_df)
Employee_df = li4.transform(Employee_df)
li5 = StringIndexer(inputCol="Over18", outputCol="intOver18").fit(Employee_df)
Employee_df = li5.transform(Employee_df)
li6 = StringIndexer(inputCol="OverTime", outputCol="intOverTime").fit(Employee_df)
Employee_df = li6.transform(Employee_df)
Employee_df.printSchema()

root
 |-- Age: integer (nullable = true)
 |-- Attrition: string (nullable = true)
 |-- BusinessTravel: string (nullable = true)
 |-- DailyRate: integer (nullable = true)
 |-- Department: string (nullable = true)
 |-- DistanceFromHome: integer (nullable = true)
 |-- Education: integer (nullable = true)
 |-- EducationField: string (nullable = true)
 |-- EmployeeCount: integer (nullable = true)
 |-- EmployeeNumber: integer (nullable = true)
 |-- EnvironmentSatisfaction: integer (nullable = true)
 |-- Gender: string (nullable = true)
 |-- HourlyRate: integer (nullable = true)
 |-- JobInvolvement: integer (nullable = true)
 |-- JobLevel: integer (nullable = true)
 |-- JobRole: string (nullable = true)
 |-- JobSatisfaction: integer (nullable = true)
 |-- MaritalStatus: string (nullable = true)
 |-- MonthlyIncome: integer (nullable = true)
 |-- MonthlyRate: integer (nullable = true)
 |-- NumCompaniesWorked: integer (nullable = true)
 |-- Over18: string (nullable = true)
 |-- OverTime: string 

### Vectorizing the column values

In [12]:
assembler = VectorAssembler(
    inputCols=["Age", "intBusinessTravel", "DailyRate","intDepartment","DistanceFromHome","Education","intEducationField","EmployeeCount","EmployeeNumber","EnvironmentSatisfaction","intGender","HourlyRate","JobInvolvement","JobLevel","intJobRole","JobSatisfaction","intMaritalStatus","MonthlyIncome","MonthlyRate","NumCompaniesWorked","intOver18","intOverTime","PercentSalaryHike","PerformanceRating","RelationshipSatisfaction","StandardHours","StockOptionLevel","TotalWorkingYears","TrainingTimesLastYear","WorkLifeBalance","YearsAtCompany","YearsInCurrentRole","YearsSinceLastPromotion","YearsWithCurrManager"],
    outputCol="features")

In [13]:
output = assembler.transform(Employee_df)
output.select("features").show(truncate=False)

+-----------------------------------------------------------------------------------------------------------------------------------------------------------+
|features                                                                                                                                                   |
+-----------------------------------------------------------------------------------------------------------------------------------------------------------+
|[41.0,0.0,1102.0,1.0,1.0,2.0,0.0,1.0,1.0,2.0,1.0,94.0,3.0,2.0,0.0,4.0,1.0,5993.0,19479.0,8.0,0.0,1.0,11.0,3.0,1.0,80.0,0.0,8.0,0.0,1.0,6.0,4.0,0.0,5.0]    |
|[49.0,1.0,279.0,0.0,8.0,1.0,0.0,1.0,2.0,3.0,0.0,61.0,2.0,2.0,1.0,2.0,0.0,5130.0,24907.0,1.0,0.0,0.0,23.0,4.0,4.0,80.0,1.0,10.0,3.0,3.0,10.0,7.0,1.0,7.0]   |
|[37.0,0.0,1373.0,0.0,2.0,2.0,4.0,1.0,4.0,4.0,0.0,92.0,2.0,1.0,2.0,3.0,1.0,2090.0,2396.0,6.0,0.0,1.0,15.0,3.0,2.0,80.0,0.0,7.0,3.0,3.0,0.0,0.0,0.0,0.0]     |
|[33.0,1.0,1392.0,0.0,3.0,4.0,0.0,1.0,5.0,4.0,1.0,56

### Index labels, adding metadata to the label column
### Fit on whole dataset to include all labels in index

In [14]:
labelIndexer = StringIndexer(inputCol="Attrition", outputCol="indexedAttrition").fit(output)

### Automatically identify categorical features, and index them
### Set maxCategories so features with > 4 distinct values are treated as continuous

In [15]:
featureIndexer =\
    VectorIndexer(inputCol="features", outputCol="indexedFeatures", maxCategories=4).fit(output)


### Split the data into training and test sets (30% held out for testing)

In [16]:
(trainingData, testData) = output.randomSplit([0.7, 0.3])

### Train a GBT model

In [17]:

gbt = GBTClassifier(labelCol="indexedAttrition", featuresCol="indexedFeatures", maxIter=10)

### Chain indexers and GBT in a Pipeline

In [18]:

pipeline = Pipeline(stages=[labelIndexer, featureIndexer, gbt])

### Train model

In [19]:

model = pipeline.fit(trainingData)

### Make predictions

In [20]:

predictions = model.transform(testData)

### Select example rows to display

In [21]:

predictions.select("prediction", "indexedAttrition", "features").show(5)


+----------+----------------+--------------------+
|prediction|indexedAttrition|            features|
+----------+----------------+--------------------+
|       0.0|             0.0|[18.0,2.0,287.0,0...|
|       1.0|             0.0|[18.0,2.0,1124.0,...|
|       0.0|             0.0|[19.0,0.0,265.0,0...|
|       0.0|             0.0|[19.0,0.0,645.0,0...|
|       1.0|             1.0|[19.0,1.0,602.0,1...|
+----------+----------------+--------------------+
only showing top 5 rows



### Select (prediction, true label) and compute test error

In [24]:

evaluator = MulticlassClassificationEvaluator(
labelCol="indexedAttrition", predictionCol="prediction", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Test Error = %g" % (1.0 - accuracy))
print("Accuracy=%g" %accuracy)

Test Error = 0.161435
Accuracy=0.838565


In [25]:
gbtModel = model.stages[2]
print(gbtModel)  # summary only

GBTClassificationModel (uid=GBTClassifier_4295b380c8f679608dc3) with 10 trees
