In [1]:
# 1. Mount Google Drive (optional for persistent storage)
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [13]:
!pip install pyspark




In [14]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.appName("ChurnPrediction").getOrCreate()


In [15]:
!wget -O Telco-Customer-Churn.csv https://raw.githubusercontent.com/IBM/telco-customer-churn-on-icp4d/master/data/Telco-Customer-Churn.csv


--2025-09-04 06:04:59--  https://raw.githubusercontent.com/IBM/telco-customer-churn-on-icp4d/master/data/Telco-Customer-Churn.csv
Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.108.133, 185.199.109.133, 185.199.110.133, ...
Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.108.133|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 970457 (948K) [text/plain]
Saving to: ‘Telco-Customer-Churn.csv’


2025-09-04 06:05:00 (15.7 MB/s) - ‘Telco-Customer-Churn.csv’ saved [970457/970457]



In [16]:
df = spark.read.csv("Telco-Customer-Churn.csv", header=True, inferSchema=True)
df.printSchema()
df.show(5)


root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: string (nullable = true)
 |-- PaymentMethod: string (nullable = true)
 |-- MonthlyCharges: double (nullable = true)
 |-- TotalCharges: string (nullable = true)
 |-- Churn: string (nullable = true)

+----------+------+-------------+-------+----------+------+------------+---------

In [17]:
from pyspark.sql.functions import col, when

# Drop customerID (not a useful feature)
df = df.drop("customerID")

# Handle blank strings in TotalCharges
df = df.withColumn("TotalCharges", when(col("TotalCharges") == " ", None).otherwise(col("TotalCharges")))

# Drop rows with nulls
df = df.dropna()

# Cast numeric columns
numeric_cols = ["tenure", "MonthlyCharges", "TotalCharges"]
for c in numeric_cols:
    df = df.withColumn(c, col(c).cast("float"))

# Convert Churn to label (Yes=1, No=0)
df = df.withColumn("label", (col("Churn") == "Yes").cast("integer"))


In [18]:
from pyspark.ml.feature import StringIndexer, OneHotEncoder, VectorAssembler

# Identify categorical columns (exclude Churn and label)
categorical_cols = [f for f, dt in df.dtypes if dt == "string" and f not in ["Churn"]]

# Index + OneHotEncode categorical columns
indexers = [StringIndexer(inputCol=col, outputCol=col+"_idx", handleInvalid="keep") for col in categorical_cols]
encoders = [OneHotEncoder(inputCol=col+"_idx", outputCol=col+"_vec") for col in categorical_cols]

# Assemble features
assembler_inputs = [c+"_vec" for c in categorical_cols] + numeric_cols
assembler = VectorAssembler(inputCols=assembler_inputs, outputCol="features", handleInvalid="skip")


In [19]:
train_df, test_df = df.randomSplit([0.7, 0.3], seed=42)


In [20]:
from pyspark.ml.classification import LogisticRegression
from pyspark.ml import Pipeline

lr = LogisticRegression(featuresCol="features", labelCol="label")
pipeline = Pipeline(stages=indexers + encoders + [assembler, lr])

model = pipeline.fit(train_df)


In [21]:
from pyspark.ml.evaluation import BinaryClassificationEvaluator

predictions = model.transform(test_df)
evaluator = BinaryClassificationEvaluator(labelCol="label", metricName="areaUnderROC")
auc = evaluator.evaluate(predictions)
print("AUC =", auc)

predictions.select("label", "prediction", "probability").show(10)


AUC = 0.8579851408873723
+-----+----------+--------------------+
|label|prediction|         probability|
+-----+----------+--------------------+
|    0|       1.0|[0.41547042607106...|
|    1|       1.0|[0.39826206653239...|
|    0|       0.0|[0.67573264111355...|
|    1|       1.0|[0.42671144425049...|
|    0|       0.0|[0.58610948461966...|
|    1|       0.0|[0.50694100806546...|
|    0|       0.0|[0.57458457230221...|
|    1|       0.0|[0.56777611017638...|
|    1|       0.0|[0.62692253030414...|
|    0|       0.0|[0.63673693152155...|
+-----+----------+--------------------+
only showing top 10 rows



In [22]:
model.save("/content/churn_model_spark")
