In [1]:
from pyspark.sql import SparkSession

In [2]:
spark = SparkSession.builder.appName('logregconsult').getOrCreate()

In [3]:
data = spark.read.csv('../data/customer_churn.csv', inferSchema=True, header=True)

In [4]:
data.printSchema()

root
 |-- Names: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- Total_Purchase: double (nullable = true)
 |-- Account_Manager: integer (nullable = true)
 |-- Years: double (nullable = true)
 |-- Num_Sites: double (nullable = true)
 |-- Onboard_date: timestamp (nullable = true)
 |-- Location: string (nullable = true)
 |-- Company: string (nullable = true)
 |-- Churn: integer (nullable = true)



In [5]:
data.show()

+-------------------+----+--------------+---------------+-----+---------+-------------------+--------------------+--------------------+-----+
|              Names| Age|Total_Purchase|Account_Manager|Years|Num_Sites|       Onboard_date|            Location|             Company|Churn|
+-------------------+----+--------------+---------------+-----+---------+-------------------+--------------------+--------------------+-----+
|   Cameron Williams|42.0|       11066.8|              0| 7.22|      8.0|2013-08-30 07:00:40|10265 Elizabeth M...|          Harvey LLC|    1|
|      Kevin Mueller|41.0|      11916.22|              0|  6.5|     11.0|2013-08-13 00:38:46|6157 Frank Garden...|          Wilson PLC|    1|
|        Eric Lozano|38.0|      12884.75|              0| 6.67|     12.0|2016-06-29 06:20:07|1331 Keith Court ...|Miller, Johnson a...|    1|
|      Phillip White|42.0|       8010.76|              0| 6.71|     10.0|2014-04-22 12:43:12|13120 Daniel Moun...|           Smith Inc|    1|
|     

In [6]:
data.describe().show()

+-------+-------------+-----------------+-----------------+------------------+-----------------+------------------+--------------------+--------------------+-------------------+
|summary|        Names|              Age|   Total_Purchase|   Account_Manager|            Years|         Num_Sites|            Location|             Company|              Churn|
+-------+-------------+-----------------+-----------------+------------------+-----------------+------------------+--------------------+--------------------+-------------------+
|  count|          900|              900|              900|               900|              900|               900|                 900|                 900|                900|
|   mean|         null|41.81666666666667|10062.82403333334|0.4811111111111111| 5.27315555555555| 8.587777777777777|                null|                null|0.16666666666666666|
| stddev|         null|6.127560416916251|2408.644531858096|0.4999208935073339|1.274449013194616|1.764835592035

In [7]:
data.columns

['Names',
 'Age',
 'Total_Purchase',
 'Account_Manager',
 'Years',
 'Num_Sites',
 'Onboard_date',
 'Location',
 'Company',
 'Churn']

In [9]:
from pyspark.ml.feature import VectorAssembler

In [10]:
assembler = VectorAssembler(inputCols=['Age',
 'Total_Purchase',
 'Account_Manager',
 'Years',
 'Num_Sites'],outputCol='features')

In [11]:
output = assembler.transform(data)

In [12]:
final_data = output.select('features', 'churn')

In [13]:
train_churn, test_churn = final_data.randomSplit([0.7, 0.3])

In [14]:
from pyspark.ml.classification import LogisticRegression

In [15]:
lr_churn = LogisticRegression(labelCol='churn')

In [16]:
fitted_churn_model = lr_churn.fit(train_churn)

In [18]:
training_sum = fitted_churn_model.summary

In [21]:
training_sum.predictions.describe().show()

+-------+------------------+------------------+
|summary|             churn|        prediction|
+-------+------------------+------------------+
|  count|               652|               652|
|   mean|0.1656441717791411|0.1334355828220859|
| stddev| 0.372046339068832| 0.340305962213103|
|    min|               0.0|               0.0|
|    max|               1.0|               1.0|
+-------+------------------+------------------+



In [22]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [23]:
pred_and_labels = fitted_churn_model.evaluate(test_churn)

In [24]:
pred_and_labels.predictions.show()

+--------------------+-----+--------------------+--------------------+----------+
|            features|churn|       rawPrediction|         probability|prediction|
+--------------------+-----+--------------------+--------------------+----------+
|[22.0,11254.38,1....|    0|[4.74723607980660...|[0.99139897849878...|       0.0|
|[27.0,8628.8,1.0,...|    0|[5.56039046930846...|[0.99616747092700...|       0.0|
|[28.0,9090.43,1.0...|    0|[1.35967929982531...|[0.79570757059096...|       0.0|
|[28.0,11128.95,1....|    0|[4.28797177882733...|[0.98645328354712...|       0.0|
|[29.0,11274.46,1....|    0|[4.62079003081958...|[0.99025096428144...|       0.0|
|[29.0,13255.05,1....|    0|[4.34165019001336...|[0.98715218158802...|       0.0|
|[30.0,6744.87,0.0...|    0|[3.70440428050473...|[0.97597645991122...|       0.0|
|[30.0,8874.83,0.0...|    0|[3.44701291109573...|[0.96914193430094...|       0.0|
|[30.0,10183.98,1....|    0|[2.86656498934921...|[0.94616865868120...|       0.0|
|[31.0,8688.21,0