In [25]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import BinaryClassificationEvaluator

In [3]:
spark = SparkSession.builder.appName('logreg').getOrCreate()

In [4]:
data = spark.read.csv('customer_churn.csv', header = True, inferSchema=True)

In [5]:
data.printSchema()

root
 |-- Names: string (nullable = true)
 |-- Age: double (nullable = true)
 |-- Total_Purchase: double (nullable = true)
 |-- Account_Manager: integer (nullable = true)
 |-- Years: double (nullable = true)
 |-- Num_Sites: double (nullable = true)
 |-- Onboard_date: timestamp (nullable = true)
 |-- Location: string (nullable = true)
 |-- Company: string (nullable = true)
 |-- Churn: integer (nullable = true)



In [6]:
data.head()

Row(Names='Cameron Williams', Age=42.0, Total_Purchase=11066.8, Account_Manager=0, Years=7.22, Num_Sites=8.0, Onboard_date=datetime.datetime(2013, 8, 30, 7, 0, 40), Location='10265 Elizabeth Mission Barkerburgh, AK 89518', Company='Harvey LLC', Churn=1)

In [7]:
assembler = VectorAssembler(inputCols=['Age', 'Total_Purchase', 'Account_Manager', 'Years', 'Num_Sites'],outputCol='features')

In [8]:
output = assembler.transform(data)

In [9]:
output.head()

Row(Names='Cameron Williams', Age=42.0, Total_Purchase=11066.8, Account_Manager=0, Years=7.22, Num_Sites=8.0, Onboard_date=datetime.datetime(2013, 8, 30, 7, 0, 40), Location='10265 Elizabeth Mission Barkerburgh, AK 89518', Company='Harvey LLC', Churn=1, features=DenseVector([42.0, 11066.8, 0.0, 7.22, 8.0]))

In [10]:
train_test_data = output.select('features','churn')

In [12]:
XY_train, XY_test = train_test_data.randomSplit([0.7,0.3])

In [14]:
regr = LogisticRegression(labelCol='churn')

In [16]:
regr = regr.fit(XY_train)

In [17]:
history = regr.summary

In [22]:
history.predictions

DataFrame[features: vector, churn: double, rawPrediction: vector, probability: vector, prediction: double]

In [24]:
history.predictions.describe().show()

+-------+------------------+-------------------+
|summary|             churn|         prediction|
+-------+------------------+-------------------+
|  count|               616|                616|
|   mean|0.1672077922077922|0.13474025974025974|
| stddev| 0.373464547359324| 0.3417234141449658|
|    min|               0.0|                0.0|
|    max|               1.0|                1.0|
+-------+------------------+-------------------+



In [27]:
pred_and_labels = regr.evaluate(XY_test)

In [29]:
pred_and_labels.accuracy

0.8838028169014085