In [0]:
from pyspark.sql import SparkSession
from pyspark.ml.classification import RandomForestClassifier, GBTClassifier
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder
from pyspark.ml.evaluation import BinaryClassificationEvaluator
from pyspark.sql.types import IntegerType

In [0]:
#1st Step: make session and load the data
spark = SparkSession.builder.appName('churn').getOrCreate()

sparkDF = spark.read.csv('/FileStore/tables/cuschurn.csv' ,header=True, inferSchema=True)
sparkDF.na.drop().printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: double (nullable = true)
 |-- Churn: string (nullable = true)



In [0]:
sparkDF.groupBy('OnlineSecurity').count().show()

+-------------------+-----+
|     OnlineSecurity|count|
+-------------------+-----+
|                 No| 3498|
|                Yes| 2019|
|No internet service| 1526|
+-------------------+-----+



In [0]:
#2nd Step: indexing categories   
indexer = StringIndexer(inputCols=['gender','Partner','Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity' , 'OnlineBackup' , 'DeviceProtection', 'TechSupport' , 'StreamingTV', 'StreamingMovies', 'Contract' , 'PaperlessBilling', 'PaymentMethod' , 'Churn'] ,outputCols=['idxgender' , 'idxPartner', 'idxDependents' , 'idxPhoneService' , 'idxMultipleLines' , 'idxInternetService', 'idxOnlineSecurity' ,'idxOnlineBackup', 'idxDeviceProtection', 'idxTechSupport', 'idxStreamingTV', 'idxStreamingMovies' ,'idxContract', 'idxPaperlessBilling', 'idxPaymentMethod', 'idxChurn'])
indexed_data = indexer.fit(sparkDF).transform(sparkDF)

In [0]:
#convert (string) numbers to int numbers
indexed_data = indexed_data.withColumn('tenure', indexed_data['tenure'].cast(IntegerType()))
indexed_data = indexed_data.withColumn('MonthlyCharges', indexed_data['MonthlyCharges'].cast(IntegerType()))
indexed_data = indexed_data.withColumn('TotalCharges', indexed_data['TotalCharges'].cast(IntegerType()))

In [0]:
indexed_data.groupBy('idxPartner').count().show()

+----------+-----+
|idxPartner|count|
+----------+-----+
|       0.0| 3641|
|       1.0| 3402|
+----------+-----+



In [0]:
#just visualizing the data 
pre_data = indexed_data.select(['tenure', 'MonthlyCharges', 'TotalCharges', 'idxgender',
 'idxPartner',
 'idxDependents',
 'idxPhoneService',
 'idxMultipleLines',
 'idxInternetService',
 'idxOnlineSecurity',
 'idxOnlineBackup',
 'idxDeviceProtection',
 'idxTechSupport',
 'idxStreamingTV',
 'idxStreamingMovies',
 'idxContract',
 'idxPaperlessBilling',
 'idxPaymentMethod',
 'idxChurn'])

pre_data.display()

tenure,MonthlyCharges,TotalCharges,idxgender,idxPartner,idxDependents,idxPhoneService,idxMultipleLines,idxInternetService,idxOnlineSecurity,idxOnlineBackup,idxDeviceProtection,idxTechSupport,idxStreamingTV,idxStreamingMovies,idxContract,idxPaperlessBilling,idxPaymentMethod,idxChurn
1,29,29.0,1.0,1.0,0.0,1.0,2.0,1.0,0.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0
34,56,1889.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,1.0,0.0,0.0,0.0,2.0,1.0,1.0,0.0
2,53,108.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0
45,42,1840.0,0.0,0.0,0.0,1.0,2.0,1.0,1.0,0.0,1.0,1.0,0.0,0.0,2.0,1.0,2.0,0.0
2,70,151.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0
8,99,820.0,1.0,0.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,0.0,1.0,1.0,0.0,0.0,0.0,1.0
22,89,1949.0,0.0,0.0,1.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,3.0,0.0
10,29,301.0,1.0,0.0,0.0,1.0,2.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0
28,104,3046.0,1.0,1.0,0.0,0.0,1.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0,0.0,0.0,0.0,1.0
62,56,3487.0,0.0,0.0,1.0,0.0,0.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,2.0,1.0,2.0,0.0


In [0]:
#4th Step: prepare the data for model
assembler = VectorAssembler(outputCol= 'features' ,  handleInvalid="keep")

assembler.setInputCols(['tenure', 'MonthlyCharges', 'idxgender','idxPartner',
 'idxPhoneService',
 'idxMultipleLines',
 'idxInternetService',
 'idxDeviceProtection',
 'idxContract',
 'idxStreamingMovies',
 'idxContract'])

output = assembler.transform(pre_data)
output.show()

+------+--------------+------------+---------+----------+-------------+---------------+----------------+------------------+-----------------+---------------+-------------------+--------------+--------------+------------------+-----------+-------------------+----------------+--------+--------------------+
|tenure|MonthlyCharges|TotalCharges|idxgender|idxPartner|idxDependents|idxPhoneService|idxMultipleLines|idxInternetService|idxOnlineSecurity|idxOnlineBackup|idxDeviceProtection|idxTechSupport|idxStreamingTV|idxStreamingMovies|idxContract|idxPaperlessBilling|idxPaymentMethod|idxChurn|            features|
+------+--------------+------------+---------+----------+-------------+---------------+----------------+------------------+-----------------+---------------+-------------------+--------------+--------------+------------------+-----------+-------------------+----------------+--------+--------------------+
|     1|            29|          29|      1.0|       1.0|          0.0|           

In [0]:
data = output.select('features','idxChurn')
data.show()


+--------------------+--------+
|            features|idxChurn|
+--------------------+--------+
|[1.0,29.0,1.0,1.0...|     0.0|
|(11,[0,1,6,7,8,10...|     0.0|
|(11,[0,1,6],[2.0,...|     1.0|
|[45.0,42.0,0.0,0....|     0.0|
|(11,[0,1,2],[2.0,...|     1.0|
|(11,[0,1,2,5,7,9]...|     1.0|
|(11,[0,1,5],[22.0...|     0.0|
|(11,[0,1,2,4,5,6]...|     0.0|
|[28.0,104.0,1.0,1...|     1.0|
|(11,[0,1,6,8,10],...|     0.0|
|(11,[0,1,3,6],[13...|     0.0|
|[16.0,18.0,0.0,0....|     0.0|
|[58.0,100.0,0.0,1...|     0.0|
|(11,[0,1,5,7,9],[...|     1.0|
|(11,[0,1,7,9],[25...|     0.0|
|[69.0,113.0,1.0,1...|     0.0|
|[52.0,20.0,1.0,0....|     0.0|
|[71.0,106.0,0.0,0...|     0.0|
|(11,[0,1,2,3,6,7]...|     1.0|
|(11,[0,1,2,7,9],[...|     0.0|
+--------------------+--------+
only showing top 20 rows



In [0]:
#5th Step: split data
train, test = data.randomSplit([0.7,0.3])

In [0]:
#6th Step: make model object
rfc = RandomForestClassifier(numTrees=20, maxDepth=7, labelCol='idxChurn')
gbt = GBTClassifier(labelCol='idxChurn',  maxDepth=7)

In [0]:
#7th Step fit the rfc model 
rfc_model = rfc.fit(train)
train_summary = rfc_model.summary
train_summary.predictions.describe().show()

+-------+-------------------+-------------------+
|summary|           idxChurn|         prediction|
+-------+-------------------+-------------------+
|  count|               4982|               4982|
|   mean| 0.2611401043757527|0.17402649538338016|
| stddev|0.43930022382273104| 0.3791703208360239|
|    min|                0.0|                0.0|
|    max|                1.0|                1.0|
+-------+-------------------+-------------------+



In [0]:
#8th Step: make prediction
predictions = rfc_model.transform(test)
predictions.select('idxChurn' , 'prediction').display()

idxChurn,prediction
0.0,1.0
1.0,1.0
1.0,1.0
1.0,1.0
1.0,1.0
1.0,1.0
1.0,1.0
1.0,1.0
1.0,1.0
1.0,1.0


In [0]:
#9th Step: evaluate results
evaluator = BinaryClassificationEvaluator(rawPredictionCol='prediction', labelCol='idxChurn')

auc = evaluator.evaluate(predictions)
auc

Out[40]: 0.687250596681226

In [0]:
#just to compare result of gbt and random Forest
gbt_model = gbt.fit(train)
gbt_prediction = gbt_model.transform(test)
gbt_prediction.select('idxChurn' , 'prediction').display()

idxChurn,prediction
0.0,1.0
1.0,1.0
1.0,1.0
1.0,1.0
1.0,1.0
1.0,1.0
1.0,1.0
1.0,1.0
1.0,1.0
1.0,1.0


In [0]:
auc = evaluator.evaluate(gbt_preiction)
auc

Out[42]: 0.7013250804222522