In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName('logregconsult').getOrCreate()

In [3]:
data = spark.read.csv('customer_churn.csv', inferSchema=True, header=True)

In [4]:
data.printSchema()

root
 |-- Names: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- Total_Purchase: double (nullable = true)
 |-- Account_Manager: integer (nullable = true)
 |-- Years: double (nullable = true)
 |-- Num_Sites: double (nullable = true)
 |-- Onboard_date: string (nullable = true)
 |-- Location: string (nullable = true)
 |-- Company: string (nullable = true)
 |-- Churn: integer (nullable = true)



In [5]:
data.describe().show()

+-------+-------------+-----------------+-----------------+------------------+-----------------+------------------+-------------------+--------------------+--------------------+-------------------+
|summary|        Names|              Age|   Total_Purchase|   Account_Manager|            Years|         Num_Sites|       Onboard_date|            Location|             Company|              Churn|
+-------+-------------+-----------------+-----------------+------------------+-----------------+------------------+-------------------+--------------------+--------------------+-------------------+
|  count|          900|              900|              900|               900|              900|               900|                900|                 900|                 900|                900|
|   mean|         null|41.81666666666667|10062.82403333334|0.4811111111111111| 5.27315555555555| 8.587777777777777|               null|                null|                null|0.16666666666666666|
| stddev| 

In [6]:
data.columns

['Names',
 'Age',
 'Total_Purchase',
 'Account_Manager',
 'Years',
 'Num_Sites',
 'Onboard_date',
 'Location',
 'Company',
 'Churn']

In [7]:
from pyspark.ml.feature import VectorAssembler

In [8]:
assembler = VectorAssembler(inputCols=['Age',
                             'Total_Purchase',
                             'Account_Manager',
                             'Years',
                             'Num_Sites'], outputCol='features')

In [9]:
output = assembler.transform(data)

In [10]:
final_data = output.select('features', 'churn')

In [11]:
train_churn, test_churn = final_data.randomSplit([0.70,0.30])

In [12]:
from pyspark.ml.classification import LogisticRegression

In [13]:
lr_churn = LogisticRegression(labelCol='churn')

In [14]:
fitted_churn_model = lr_churn.fit(train_churn)

In [15]:
training_sum = fitted_churn_model.summary

In [16]:
training_sum.predictions.describe().show()

+-------+-------------------+-------------------+
|summary|              churn|         prediction|
+-------+-------------------+-------------------+
|  count|                631|                631|
|   mean|0.16798732171156894|0.13153724247226625|
| stddev|0.37415162001472196| 0.3382551138172122|
|    min|                0.0|                0.0|
|    max|                1.0|                1.0|
+-------+-------------------+-------------------+



In [17]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [18]:
pred_and_label = fitted_churn_model.evaluate(test_churn)

In [19]:
pred_and_label.predictions.show()

+--------------------+-----+--------------------+--------------------+----------+
|            features|churn|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[22.0,11254.38,1....|    0|[4.60379748139259...|[0.99008554427503...|       0.0|
|[25.0,9672.03,0.0...|    0|[4.73842260480198...|[0.99132349920308...|       0.0|
|[28.0,8670.98,0.0...|    0|[7.83800781135404...|[0.99960570144730...|       0.0|
|[28.0,11128.95,1....|    0|[4.18465865542426...|[0.98500099307448...|       0.0|
|[28.0,11204.23,0....|    0|[1.67273203035972...|[0.84193973189491...|       0.0|
|[28.0,11245.38,0....|    0|[3.84389996335901...|[0.97903883691993...|       0.0|
|[29.0,5900.78,1.0...|    0|[4.27178547554675...|[0.98623527084568...|       0.0|
|[29.0,8688.17,1.0...|    1|[2.77874866474302...|[0.94151658010286...|       0.0|
|[29.0,11274.46,1....|    0|[4.45914962221052...|[0.98856018399085...|       0.0|
|[30.0,8403.78,1

In [20]:
churn_eval = BinaryClassificationEvaluator(rawPredictionCol='prediction',
                                          labelCol='churn')

In [21]:
auc = churn_eval.evaluate(pred_and_label.predictions)
auc

0.7208585858585859