# **Предсказание оттока клиентов (PySpark, MlLib)**

In [61]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, trim, when, count, isnan, isnull
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler, StandardScaler
from pyspark.ml import Pipeline
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier, GBTClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator


In [2]:
spark = SparkSession.builder.appName('ChurnPrediction').getOrCreate()

In [10]:
df = spark.read.csv('WA_Fn-UseC_-Telco-Customer-Churn.csv', header=True, inferSchema=True)

In [11]:
df.show()

+----------+------+-------------+-------+----------+------+------------+----------------+---------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------+----------------+--------------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|   MultipleLines|InternetService|     OnlineSecurity|       OnlineBackup|   DeviceProtection|        TechSupport|        StreamingTV|    StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+----------------+---------------+-------------------+-------------------+-------------------+-------------------+-------------------+-------------------+--------------+----------------+--------------------+--------------+------------+-----+
|7590-VHVEG|Female|            0|    Yes|        No|     1|  

In [45]:
TARGET = "Churn"
ID_COL = "customerID"

In [12]:
df.printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: string (nullable = true)
 |-- Churn: string (nullable = true)



In [22]:
df.describe().show()

+-------+----------+------+------------------+-------+----------+------------------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+------------------+------------------+-----+
|summary|customerID|gender|     SeniorCitizen|Partner|Dependents|            tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|      Contract|PaperlessBilling|       PaymentMethod|    MonthlyCharges|      TotalCharges|Churn|
+-------+----------+------+------------------+-------+----------+------------------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------------+----------------+--------------------+------------------+------------------+-----+
|  count|      7043|  7043|              7043|   7043|      7043|     

In [27]:
print('rows: ', df.count(), "| columns: ", len(df.columns))

rows:  7043 | columns:  21


In [28]:
df.groupBy("Churn").count().show()

+-----+-----+
|Churn|count|
+-----+-----+
|   No| 5174|
|  Yes| 1869|
+-----+-----+



In [30]:
nulls = df.select([count(when(isnan(c) | isnull(c), c)).alias(c) for c in df.columns])
nulls.show()

+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+
|customerID|gender|SeniorCitizen|Partner|Dependents|tenure|PhoneService|MultipleLines|InternetService|OnlineSecurity|OnlineBackup|DeviceProtection|TechSupport|StreamingTV|StreamingMovies|Contract|PaperlessBilling|PaymentMethod|MonthlyCharges|TotalCharges|Churn|
+----------+------+-------------+-------+----------+------+------------+-------------+---------------+--------------+------------+----------------+-----------+-----------+---------------+--------+----------------+-------------+--------------+------------+-----+
|         0|     0|            0|      0|         0|     0|           0|            0|              0|             0|           0|               0|          0|          0|              0|       0|               0| 

In [31]:
df.createOrReplaceTempView("telco")
spark.sql("""
SELECT Contract, InternetService, Churn, COUNT(*) as n
FROM telco
GROUP BY Contract, InternetService, Churn
ORDER BY n DESC
""").show(20, truncate=False)

+--------------+---------------+-----+----+
|Contract      |InternetService|Churn|n   |
+--------------+---------------+-----+----+
|Month-to-month|Fiber optic    |Yes  |1162|
|Month-to-month|Fiber optic    |No   |966 |
|Month-to-month|DSL            |No   |829 |
|Two year      |No             |No   |633 |
|Two year      |DSL            |No   |616 |
|One year      |DSL            |No   |517 |
|One year      |Fiber optic    |No   |435 |
|Month-to-month|No             |No   |425 |
|Two year      |Fiber optic    |No   |398 |
|Month-to-month|DSL            |Yes  |394 |
|One year      |No             |No   |355 |
|One year      |Fiber optic    |Yes  |104 |
|Month-to-month|No             |Yes  |99  |
|One year      |DSL            |Yes  |53  |
|Two year      |Fiber optic    |Yes  |31  |
|Two year      |DSL            |Yes  |12  |
|One year      |No             |Yes  |9   |
|Two year      |No             |Yes  |5   |
+--------------+---------------+-----+----+



In [39]:
df.dtypes

[('customerID', 'string'),
 ('gender', 'string'),
 ('SeniorCitizen', 'int'),
 ('Partner', 'string'),
 ('Dependents', 'string'),
 ('tenure', 'int'),
 ('PhoneService', 'string'),
 ('MultipleLines', 'string'),
 ('InternetService', 'string'),
 ('OnlineSecurity', 'string'),
 ('OnlineBackup', 'string'),
 ('DeviceProtection', 'string'),
 ('TechSupport', 'string'),
 ('StreamingTV', 'string'),
 ('StreamingMovies', 'string'),
 ('Contract', 'string'),
 ('PaperlessBilling', 'string'),
 ('PaymentMethod', 'string'),
 ('MonthlyCharges', 'double'),
 ('TotalCharges', 'string'),
 ('Churn', 'string')]

In [35]:
drop_cols = ["customerID", "Churn"]

In [40]:
num_cols = [c for c, t in df.dtypes if c not in drop_cols and t in ('int', 'bigint', 'float', 'double')]
cat_cols = [c for c, t in df.dtypes if c not in drop_cols and t == 'string']

In [43]:
print("num_cols:", num_cols)
print("cat_cols:", cat_cols)

num_cols: ['SeniorCitizen', 'tenure', 'MonthlyCharges']
cat_cols: ['gender', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod', 'TotalCharges']


In [46]:
label_indexer = StringIndexer(inputCol=TARGET, outputCol='label', handleInvalid='keep')

In [52]:
cat_indexers =[StringIndexer(inputCol=c, outputCol=f"{c}_idx", handleInvalid='keep') for c in cat_cols]

In [53]:
encoder = OneHotEncoder(
    inputCols=[f"{c}_idx" for c in cat_cols],
    outputCols=[f"{c}_ohe" for c in cat_cols]
)

In [54]:
assembler = VectorAssembler(
    inputCols=num_cols + [f"{c}_ohe" for c in cat_cols],
    outputCol="features",
    handleInvalid="skip"
)

In [55]:
scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures")

In [56]:
base_stages = [label_indexer] + cat_indexers + [encoder, assembler]

In [57]:
train, test = df.randomSplit([0.8, 0.2])

In [60]:
print("train:", train.count(), " test:", test.count())

train: 5592  test: 1451


In [64]:
lr = LogisticRegression(featuresCol='scaledFeatures', labelCol='label', maxIter=50)
lr_pipeline = Pipeline(stages=base_stages + [scaler, lr])

In [66]:
rf = RandomForestClassifier(featuresCol='features', labelCol='label', maxDepth=8, numTrees=200, seed=42)
rf_pipeline = Pipeline(stages=base_stages + [rf])

In [68]:
gbt = GBTClassifier(featuresCol="features", labelCol="label", maxIter=80, maxDepth=5, seed=42)
gbt_pipeline = Pipeline(stages=base_stages + [gbt])

In [71]:
models = {"LogisticRegression": lr_pipeline.fit(train),
         "RandomForest": rf_pipeline.fit(train),
         "GBT": gbt_pipeline.fit(train)}

In [72]:
bin_eval = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="rawPrediction", metricName="areaUnderROC")
multi_eval_acc = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
multi_eval_f1 = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="f1")
multi_eval_prec = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="weightedPrecision")
multi_eval_rec = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="weightedRecall")


In [89]:
metrics = []
for name, model in models.items():
  pred = model.transform(test)
  # auc = bin_eval.evaluate(pred)
  acc = multi_eval_acc.evaluate(pred)
  f1  = multi_eval_f1.evaluate(pred)
  pr  = multi_eval_prec.evaluate(pred)
  rc  = multi_eval_rec.evaluate(pred)
  metrics.append((name, acc, f1, pr, rc))
  print(f"{name:18s} | ACC={acc:.4f} | F1={f1:.4f} | P={pr:.4f} | R={rc:.4f}")


LogisticRegression | ACC=0.7574 | F1=0.7561 | P=0.7548 | R=0.7574
RandomForest       | ACC=0.7464 | F1=0.6380 | P=0.5571 | R=0.7464
GBT                | ACC=0.7981 | F1=0.7917 | P=0.7887 | R=0.7981


In [88]:
best_name, *_ = sorted(metrics, key=lambda x: x[1], reverse=True)[0]
best_pred = models[best_name].transform(test)
print("\nBest model:", best_name)
best_pred.groupBy("label", "prediction").count().orderBy("label", "prediction").show()


Best model: GBT
+-----+----------+-----+
|label|prediction|count|
+-----+----------+-----+
|  0.0|       0.0|  967|
|  0.0|       1.0|  116|
|  1.0|       0.0|  177|
|  1.0|       1.0|  191|
+-----+----------+-----+



In [90]:
df.write.mode("overwrite").parquet("data/cleaned_telco.parquet")