In [22]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import isnan
from pyspark.sql.types import StructType, StructField, IntegerType, StringType, DoubleType
from pyspark.ml.feature import OneHotEncoder, StringIndexer, VectorAssembler, StandardScaler
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator

In [2]:
spark = SparkSession.builder.getOrCreate()

In [3]:
path = 'WA_Fn-UseC_-Telco-Customer-Churn.csv'
schema = [
    ("customerID", StringType()),
    ("gender", StringType()),
    ("SeniorCitizen", StringType()), #
    ("Partner", StringType()), #
    ("Dependents", StringType()), #
    ("tenure", IntegerType()),
    ("PhoneService", StringType()), #
    ("MultipleLines", StringType()),
    ("InternetService", StringType()),
    ("OnlineSecurity", StringType()),
    ("OnlineBackup", StringType()),
    ("DeviceProtection", StringType()),
    ("TechSupport", StringType()),
    ("StreamingTV", StringType()),
    ("StreamingMovies", StringType()),
    ("Contract", StringType()),
    ("PaperlessBilling", StringType()), #
    ("PaymentMethod", StringType()),
    ("MonthlyCharges", DoubleType()),
    ("TotalCharges", DoubleType()),
    ("Churn", StringType()) #
]
schemaST = StructType([StructField(c, t, True) for c, t in schema])
df = spark.read.csv(path, header=True, schema=schemaST)

print(f"Shape of datasets {(df.count(), len(df.columns))}")
# df.head(3)

Shape of datasets (7043, 21)


In [4]:
df.printSchema()

root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: string (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: double (nullable = true)
 |-- Churn: string (nullable = true)



In [5]:
print("Null values:")
for c in df.columns:
    print(f"{c} {df.filter(df[c].isNull() | isnan(c)).count()}")


Null values:
customerID 0
gender 0
SeniorCitizen 0
Partner 0
Dependents 0
tenure 0
PhoneService 0
MultipleLines 0
InternetService 0
OnlineSecurity 0
OnlineBackup 0
DeviceProtection 0
TechSupport 0
StreamingTV 0
StreamingMovies 0
Contract 0
PaperlessBilling 0
PaymentMethod 0
MonthlyCharges 0
TotalCharges 11
Churn 0


In [6]:
df = df.filter(~(df.TotalCharges.isNull() | isnan(df.TotalCharges)))

In [7]:
# for seeing unique value
for c, t in schema:
    if t == StringType() and c != 'customerID':
        df.select(c).distinct().show()

+------+
|gender|
+------+
|Female|
|  Male|
+------+

+-------------+
|SeniorCitizen|
+-------------+
|            0|
|            1|
+-------------+

+-------+
|Partner|
+-------+
|     No|
|    Yes|
+-------+

+----------+
|Dependents|
+----------+
|        No|
|       Yes|
+----------+

+------------+
|PhoneService|
+------------+
|          No|
|         Yes|
+------------+

+----------------+
|   MultipleLines|
+----------------+
|No phone service|
|              No|
|             Yes|
+----------------+

+---------------+
|InternetService|
+---------------+
|    Fiber optic|
|             No|
|            DSL|
+---------------+

+-------------------+
|     OnlineSecurity|
+-------------------+
|                 No|
|                Yes|
|No internet service|
+-------------------+

+-------------------+
|       OnlineBackup|
+-------------------+
|                 No|
|                Yes|
|No internet service|
+-------------------+

+-------------------+
|   DeviceProtection|
+-

In [8]:
print(f"Shape of datasets {(df.count(), len(df.columns))}")

Shape of datasets (7032, 21)


In [9]:
df.groupBy('Churn').count().show()

+-----+-----+
|Churn|count|
+-----+-----+
|   No| 5163|
|  Yes| 1869|
+-----+-----+



In [10]:
numericalCols = []
categoricalCols = []
for c, t in schema:
    if t == StringType():
        categoricalCols.append(c)
        continue
    numericalCols.append(c)

categoricalCols.remove('customerID')
categoricalCols.remove('Churn')

print(numericalCols)
print(categoricalCols)

['tenure', 'MonthlyCharges', 'TotalCharges']
['gender', 'SeniorCitizen', 'Partner', 'Dependents', 'PhoneService', 'MultipleLines', 'InternetService', 'OnlineSecurity', 'OnlineBackup', 'DeviceProtection', 'TechSupport', 'StreamingTV', 'StreamingMovies', 'Contract', 'PaperlessBilling', 'PaymentMethod']


In [11]:
label_indexer = StringIndexer(inputCol="Churn", outputCol="label")
# df = label_indexer.fit(df).transform(df)

In [12]:
numAssembler = VectorAssembler(inputCols=numericalCols, outputCol="numFeat")
# assembled_df = numAssembler.transform(df)

scaler = StandardScaler(inputCol="numFeat", outputCol="scnumFeat")
# scalerModel = scaler.fit(assembled_df)
# scaledData = scalerModel.transform(assembled_df)

# scaledData.select("customerID", "numFeat", "scnumFeat").show(3, truncate=False)

In [13]:
indexers = [StringIndexer(inputCol=col, outputCol=col + "_index") for col in categoricalCols]
encoders = [OneHotEncoder(inputCol=col + "_index", outputCol=col + "_vec") for col in categoricalCols]

# for indexer, encoder in zip(indexers, encoders):
#     scaledData = indexer.fit(scaledData).transform(scaledData)
#     scaledData = encoder.fit(scaledData).transform(scaledData)

# scaledData.select(["customerID"] + [col + "_vec" for col in categoricalCols]).show(3, truncate=False)

In [14]:
final_features = ["scnumFeat"] + [col + "_vec" for col in categoricalCols]
assembler_final = VectorAssembler(inputCols=final_features, outputCol="final_features")
# final_df = assembler_final.transform(scaledData)

# final_df.select("customerID", "final_features", "label").show(3, truncate=False)

In [15]:
lr = LogisticRegression(featuresCol="final_features", labelCol="label")

In [16]:
pipeline = Pipeline(stages= 
    [label_indexer] + indexers + encoders + [
    numAssembler, scaler,
    assembler_final,
    lr
])

In [17]:
train, test = df.randomSplit([0.8, 0.2], seed=42)

In [18]:
model = pipeline.fit(train)

In [19]:
predictions = model.transform(test)

In [20]:
predictions.select("Churn", "prediction", "probability").show()

+-----+----------+--------------------+
|Churn|prediction|         probability|
+-----+----------+--------------------+
|  Yes|       1.0|[0.34707951978117...|
|   No|       0.0|[0.94576585173164...|
|   No|       0.0|[0.56621492187733...|
|   No|       0.0|[0.95308953097470...|
|  Yes|       1.0|[0.27346013042051...|
|   No|       0.0|[0.80495400505907...|
|   No|       0.0|[0.99881424688212...|
|   No|       0.0|[0.98741282954411...|
|   No|       0.0|[0.97969104456531...|
|   No|       0.0|[0.86758251195468...|
|   No|       0.0|[0.67452785740181...|
|   No|       0.0|[0.52186406854224...|
|   No|       0.0|[0.64041143478133...|
|   No|       0.0|[0.54274220977215...|
|   No|       0.0|[0.91376407295167...|
|   No|       1.0|[0.46815109925052...|
|  Yes|       1.0|[0.41505049582362...|
|   No|       0.0|[0.80513579429792...|
|  Yes|       1.0|[0.20421503319520...|
|  Yes|       0.0|[0.94962400269657...|
+-----+----------+--------------------+
only showing top 20 rows



In [23]:
evaluator = BinaryClassificationEvaluator(labelCol="label", rawPredictionCol="prediction", metricName="areaUnderROC")
roc_auc = evaluator.evaluate(predictions)
print(f"Area Under ROC: {roc_auc:.4f}")

Area Under ROC: 0.7210


In [24]:
accuracy_evaluator = MulticlassClassificationEvaluator(labelCol="label", predictionCol="prediction", metricName="accuracy")
accuracy = accuracy_evaluator.evaluate(predictions)
print(f"Accuracy: {accuracy:.4f}")

Accuracy: 0.8040
