In [None]:
# To download dataset
import urllib.request
import os

url = "https://raw.githubusercontent.com/IBM/telco-customer-churn-on-icp4d/master/data/Telco-Customer-Churn.csv"
filename = "telco_customer_churn.csv"

try:
    urllib.request.urlretrieve(url, filename)
    print(f"Dataset downloaded successfully as {filename}")
except Exception as e:
    print(f"Download failed: {e}")

Dataset downloaded successfully as telco_customer_churn.csv


In [1]:
# Customer Churn Prediction using PySpark ML - FINAL VERSION
# This version handles missing dataset files automatically

# Import required libraries
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.ml.feature import StringIndexer, VectorAssembler, StandardScaler
from pyspark.ml.classification import LogisticRegression, RandomForestClassifier
from pyspark.ml.evaluation import BinaryClassificationEvaluator, MulticlassClassificationEvaluator
from pyspark.ml import Pipeline
from functools import reduce
import pandas as pd

# Initialize Spark Session
spark = SparkSession.builder \
    .appName("CustomerChurnPrediction") \
    .config("spark.sql.adaptive.enabled", "true") \
    .getOrCreate()

print("Spark session initialized successfully!")

# DATASET LOADING WITH MULTIPLE FALLBACK OPTIONS
print("Attempting to load dataset...")

try:
    # Method 1: Try to load from local file
    df = spark.read.csv("telco_customer_churn.csv", header=True, inferSchema=True)
    print("✓ Dataset loaded from local file")
except:
    print("Local file not found. Trying alternative methods...")

    try:
        # Method 2: Try to download from URL
        import urllib.request
        print("Downloading dataset from URL...")
        url = "https://raw.githubusercontent.com/IBM/telco-customer-churn-on-icp4d/master/data/Telco-Customer-Churn.csv"
        urllib.request.urlretrieve(url, "telco_customer_churn.csv")
        df = spark.read.csv("telco_customer_churn.csv", header=True, inferSchema=True)
        print("✓ Dataset downloaded and loaded successfully")
    except Exception as e:
        print("Download failed. Creating realistic sample dataset for demonstration...")

        # This method is only if the other two fail.
        # Method 3: Create sample dataset with realistic patterns
        from pyspark.sql.types import StructType, StructField, StringType, IntegerType, DoubleType
        import random

        # Define schema
        schema = StructType([
            StructField("customerID", StringType(), True),
            StructField("gender", StringType(), True),
            StructField("SeniorCitizen", IntegerType(), True),
            StructField("Partner", StringType(), True),
            StructField("Dependents", StringType(), True),
            StructField("tenure", IntegerType(), True),
            StructField("PhoneService", StringType(), True),
            StructField("MultipleLines", StringType(), True),
            StructField("InternetService", StringType(), True),
            StructField("OnlineSecurity", StringType(), True),
            StructField("OnlineBackup", StringType(), True),
            StructField("DeviceProtection", StringType(), True),
            StructField("TechSupport", StringType(), True),
            StructField("StreamingTV", StringType(), True),
            StructField("StreamingMovies", StringType(), True),
            StructField("Contract", StringType(), True),
            StructField("PaperlessBilling", StringType(), True),
            StructField("PaymentMethod", StringType(), True),
            StructField("MonthlyCharges", DoubleType(), True),
            StructField("TotalCharges", StringType(), True),
            StructField("Churn", StringType(), True)
        ])

        # Create realistic sample data
        print("Generating sample data with realistic churn patterns...")
        sample_data = []
        contracts = ["Month-to-month", "One year", "Two year"]
        payments = ["Electronic check", "Mailed check", "Bank transfer (automatic)", "Credit card (automatic)"]
        internet_services = ["DSL", "Fiber optic", "No"]
        yes_no = ["Yes", "No"]
        yes_no_no_service = ["Yes", "No", "No internet service"]

        random.seed(42)  # For reproducible results

        for i in range(7043):  # Match original dataset size
            tenure = random.randint(1, 72)
            monthly_charges = round(random.uniform(18.25, 118.75), 2)
            contract = random.choice(contracts)

            # Create realistic churn patterns based on contract type
            if contract == "Month-to-month":
                churn_prob = 0.427
            elif contract == "One year":
                churn_prob = 0.113
            else:  # Two year
                churn_prob = 0.028

            churn = "Yes" if random.random() < churn_prob else "No"

            # Generate total charges (some missing values like real dataset)
            if random.random() < 0.0015:  # 0.15% missing like real data
                total_charges = " "
            else:
                total_charges = str(round(tenure * monthly_charges + random.uniform(-100, 500), 2))

            row = (
                f"7590-VHVEG-{i:04d}",
                random.choice(["Male", "Female"]),
                random.choice([0, 1]) if random.random() < 0.16 else 0,  # 16% senior citizens
                random.choice(yes_no),
                random.choice(yes_no),
                tenure,
                random.choice(yes_no),
                random.choice(["No", "Yes", "No phone service"]),
                random.choice(internet_services),
                random.choice(yes_no_no_service),
                random.choice(yes_no_no_service),
                random.choice(yes_no_no_service),
                random.choice(yes_no_no_service),
                random.choice(yes_no_no_service),
                random.choice(yes_no_no_service),
                contract,
                random.choice(yes_no),
                random.choice(payments),
                monthly_charges,
                total_charges,
                churn
            )
            sample_data.append(row)

        df = spark.createDataFrame(sample_data, schema)
        print("✓ Sample dataset created successfully (7,043 records)")

# Display basic information about the dataset
print("\n" + "="*60)
print("DATASET INFORMATION")
print("="*60)
print(f"Dataset Shape: {df.count()} rows, {len(df.columns)} columns")

print("\nDataset Schema:")
df.printSchema()

print("\nFirst 5 rows:")
df.show(5)

# Check for missing values
print("\nMissing Values Analysis:")
missing_counts = df.select([count(when(col(c).isNull() | (col(c) == " "), c)).alias(c) for c in df.columns])
missing_counts.show()

# Data Preprocessing
print("\n" + "="*60)
print("DATA PREPROCESSING")
print("="*60)

# Handle missing values in TotalCharges
print("Handling missing values in TotalCharges...")
df = df.withColumn("TotalCharges",
                   when(col("TotalCharges") == " ", 0.0)
                   .otherwise(col("TotalCharges").cast("double")))

# Convert Churn to binary
print("Converting Churn to binary format...")
df = df.withColumn("ChurnLabel", when(col("Churn") == "Yes", 1).otherwise(0))

# Feature Engineering
print("\n" + "="*60)
print("FEATURE ENGINEERING")
print("="*60)

# Create tenure groups
print("Creating tenure groups...")
df = df.withColumn("TenureGroup",
                   when(col("tenure") <= 12, "0-12 months")
                   .when(col("tenure") <= 24, "12-24 months")
                   .when(col("tenure") <= 48, "24-48 months")
                   .otherwise("48+ months"))

# Service columns for counting
service_cols = ["PhoneService", "MultipleLines", "InternetService",
                "OnlineSecurity", "OnlineBackup", "DeviceProtection",
                "TechSupport", "StreamingTV", "StreamingMovies"]

# Count active services using proper PySpark syntax
print("Calculating total services per customer...")
service_conditions = [when(col(c) == "Yes", 1).otherwise(0) for c in service_cols]
df = df.withColumn("TotalServices", reduce(lambda a, b: a + b, service_conditions))

# Create monthly charges per service ratio
print("Creating charges per service ratio...")
df = df.withColumn("ChargesPerService",
                   when(col("TotalServices") > 0, col("MonthlyCharges") / col("TotalServices"))
                   .otherwise(col("MonthlyCharges")))

# Exploratory Data Analysis
print("\n" + "="*60)
print("EXPLORATORY DATA ANALYSIS")
print("="*60)

print("Churn Distribution:")
churn_dist = df.groupBy("Churn").count().orderBy("Churn")
churn_dist.show()

print("Churn Rate by Contract Type:")
contract_analysis = df.groupBy("Contract") \
                     .agg(count("*").alias("total_customers"),
                          sum("ChurnLabel").alias("churned_customers"),
                          avg("ChurnLabel").alias("churn_rate")) \
                     .orderBy(desc("churn_rate"))
contract_analysis.show()

print("Churn Rate by Tenure Group:")
tenure_analysis = df.groupBy("TenureGroup") \
                   .agg(count("*").alias("total_customers"),
                        avg("ChurnLabel").alias("churn_rate")) \
                   .orderBy("churn_rate")
tenure_analysis.show()

# Prepare features for ML
print("\n" + "="*60)
print("MACHINE LEARNING PIPELINE")
print("="*60)

# Categorical columns to be indexed
categorical_cols = ["gender", "SeniorCitizen", "Partner", "Dependents",
                   "PhoneService", "MultipleLines", "InternetService",
                   "OnlineSecurity", "OnlineBackup", "DeviceProtection",
                   "TechSupport", "StreamingTV", "StreamingMovies",
                   "Contract", "PaperlessBilling", "PaymentMethod", "TenureGroup"]

# Numerical columns
numerical_cols = ["tenure", "MonthlyCharges", "TotalCharges",
                 "TotalServices", "ChargesPerService"]

print(f"Categorical features: {len(categorical_cols)}")
print(f"Numerical features: {len(numerical_cols)}")

# String Indexing for categorical variables
print("Creating string indexers for categorical variables...")
indexers = [StringIndexer(inputCol=col, outputCol=col+"_indexed",
                         handleInvalid="skip") for col in categorical_cols]

# Create feature vector
feature_cols = [col+"_indexed" for col in categorical_cols] + numerical_cols
assembler = VectorAssembler(inputCols=feature_cols, outputCol="features")

# Feature scaling
scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures")

# Split data
print("Splitting data into training and testing sets...")
train_df, test_df = df.randomSplit([0.8, 0.2], seed=42)

print(f"Training set size: {train_df.count()}")
print(f"Test set size: {test_df.count()}")

# Model Development
print("\n" + "="*60)
print("MODEL TRAINING")
print("="*60)

# Model 1: Logistic Regression
print("Setting up Logistic Regression model...")
lr = LogisticRegression(featuresCol="scaledFeatures", labelCol="ChurnLabel",
                       maxIter=100, regParam=0.01)

# Model 2: Random Forest
print("Setting up Random Forest model...")
rf = RandomForestClassifier(featuresCol="scaledFeatures", labelCol="ChurnLabel",
                           numTrees=50, maxDepth=10, seed=42)

# Create pipelines
lr_pipeline = Pipeline(stages=indexers + [assembler, scaler, lr])
rf_pipeline = Pipeline(stages=indexers + [assembler, scaler, rf])

# Train models
print("Training Logistic Regression model...")
lr_model = lr_pipeline.fit(train_df)

print("Training Random Forest model...")
rf_model = rf_pipeline.fit(train_df)

# Make predictions
print("Making predictions on test set...")
lr_predictions = lr_model.transform(test_df)
rf_predictions = rf_model.transform(test_df)

# Model Evaluation
print("\n" + "="*60)
print("MODEL EVALUATION")
print("="*60)

# Initialize evaluators
binary_evaluator = BinaryClassificationEvaluator(labelCol="ChurnLabel",
                                                rawPredictionCol="rawPrediction")
multi_evaluator = MulticlassClassificationEvaluator(labelCol="ChurnLabel",
                                                   predictionCol="prediction")

# Logistic Regression Results
print("Evaluating Logistic Regression...")
lr_auc = binary_evaluator.evaluate(lr_predictions, {binary_evaluator.metricName: "areaUnderROC"})
lr_accuracy = multi_evaluator.evaluate(lr_predictions, {multi_evaluator.metricName: "accuracy"})
lr_precision = multi_evaluator.evaluate(lr_predictions, {multi_evaluator.metricName: "weightedPrecision"})
lr_recall = multi_evaluator.evaluate(lr_predictions, {multi_evaluator.metricName: "weightedRecall"})
lr_f1 = multi_evaluator.evaluate(lr_predictions, {multi_evaluator.metricName: "f1"})

# Random Forest Results
print("Evaluating Random Forest...")
rf_auc = binary_evaluator.evaluate(rf_predictions, {binary_evaluator.metricName: "areaUnderROC"})
rf_accuracy = multi_evaluator.evaluate(rf_predictions, {multi_evaluator.metricName: "accuracy"})
rf_precision = multi_evaluator.evaluate(rf_predictions, {multi_evaluator.metricName: "weightedPrecision"})
rf_recall = multi_evaluator.evaluate(rf_predictions, {multi_evaluator.metricName: "weightedRecall"})
rf_f1 = multi_evaluator.evaluate(rf_predictions, {multi_evaluator.metricName: "f1"})

# Print Results
print("\n" + "="*70)
print("MODEL PERFORMANCE COMPARISON")
print("="*70)
print(f"{'Metric':<20} {'Logistic Regression':<25} {'Random Forest':<20}")
print("-"*70)
print(f"{'AUC-ROC':<20} {lr_auc:<25.4f} {rf_auc:<20.4f}")
print(f"{'Accuracy':<20} {lr_accuracy:<25.4f} {rf_accuracy:<20.4f}")
print(f"{'Precision':<20} {lr_precision:<25.4f} {rf_precision:<20.4f}")
print(f"{'Recall':<20} {lr_recall:<25.4f} {rf_recall:<20.4f}")
print(f"{'F1-Score':<20} {lr_f1:<25.4f} {rf_f1:<20.4f}")

# Feature Importance Analysis
print("\n" + "="*60)
print("FEATURE IMPORTANCE ANALYSIS")
print("="*60)

# Get feature importance from Random Forest
feature_importance = rf_model.stages[-1].featureImportances
feature_names = feature_cols

# Create feature importance analysis
importance_list = [(feature_names[i], float(feature_importance[i]))
                  for i in range(len(feature_names))]
importance_sorted = sorted(importance_list, key=lambda x: x[1], reverse=True)

print("Top 10 Most Important Features (Random Forest):")
print(f"{'Feature':<25} {'Importance':<15}")
print("-"*40)
for feature, importance in importance_sorted[:10]:
    print(f"{feature:<25} {importance:<15.4f}")

# Business Impact Analysis
print("\n" + "="*60)
print("BUSINESS IMPACT ANALYSIS")
print("="*60)

# Analyze high-risk segments
print("High-Risk Customer Segments:")
high_risk_segments = df.groupBy("Contract", "PaymentMethod") \
                      .agg(count("*").alias("customers"),
                           avg("ChurnLabel").alias("churn_rate"),
                           avg("MonthlyCharges").alias("avg_monthly_charges")) \
                      .filter(col("customers") > 50) \
                      .orderBy(desc("churn_rate"))

high_risk_segments.show(10)

# Revenue analysis
churned_customers = df.filter(col("ChurnLabel") == 1)
total_customers = df.count()
churned_count = churned_customers.count()

revenue_metrics = churned_customers.agg(
    sum("MonthlyCharges").alias("monthly_revenue_lost"),
    avg("MonthlyCharges").alias("avg_customer_value"),
    sum("TotalCharges").alias("total_revenue_lost")
).collect()[0]

print(f"\nRevenue Impact Summary:")
print(f"Total Customers: {total_customers:,}")
print(f"Churned Customers: {churned_count:,}")
print(f"Churn Rate: {(churned_count/total_customers)*100:.1f}%")
print(f"Monthly Revenue Lost: ${revenue_metrics['monthly_revenue_lost']:,.2f}")
print(f"Annual Revenue at Risk: ${revenue_metrics['monthly_revenue_lost']*12:,.2f}")
print(f"Average Churned Customer Value: ${revenue_metrics['avg_customer_value']:.2f}")

# Model Performance in Business Terms
test_size = test_df.count()
rf_predictions_pd = rf_predictions.select("ChurnLabel", "prediction").toPandas()

# Confusion matrix values
true_positives = len(rf_predictions_pd[(rf_predictions_pd['ChurnLabel'] == 1) & (rf_predictions_pd['prediction'] == 1)])
false_negatives = len(rf_predictions_pd[(rf_predictions_pd['ChurnLabel'] == 1) & (rf_predictions_pd['prediction'] == 0)])
true_negatives = len(rf_predictions_pd[(rf_predictions_pd['ChurnLabel'] == 0) & (rf_predictions_pd['prediction'] == 0)])
false_positives = len(rf_predictions_pd[(rf_predictions_pd['ChurnLabel'] == 0) & (rf_predictions_pd['prediction'] == 1)])

print(f"\nModel Performance in Business Terms:")
print(f"Correctly Identified Churners: {true_positives}")
print(f"Missed Churners: {false_negatives}")
print(f"False Alarms: {false_positives}")
print(f"Correctly Identified Loyal Customers: {true_negatives}")

# Calculate potential savings
avg_monthly_value = revenue_metrics['avg_customer_value']
if true_positives > 0:
    potential_monthly_savings = true_positives * avg_monthly_value * 0.5  # Assume 50% retention success
    potential_annual_savings = potential_monthly_savings * 12
    print(f"\nPotential Business Impact (50% retention success rate):")
    print(f"Monthly Revenue Savings: ${potential_monthly_savings:,.2f}")
    print(f"Annual Revenue Savings: ${potential_annual_savings:,.2f}")

print("\n" + "="*60)
print("ANALYSIS COMPLETED SUCCESSFULLY!")
print("="*60)
print("✓ Models trained and evaluated")
print("✓ Feature importance analyzed")
print("✓ Business impact calculated")
print("✓ Results ready for reporting")

# Clean up
spark.stop()
print("\nSpark session stopped.")

Spark session initialized successfully!
Attempting to load dataset...
Local file not found. Trying alternative methods...
Downloading dataset from URL...
✓ Dataset downloaded and loaded successfully

DATASET INFORMATION
Dataset Shape: 7043 rows, 21 columns

Dataset Schema:
root
 |-- customerID: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- SeniorCitizen: integer (nullable = true)
 |-- Partner: string (nullable = true)
 |-- Dependents: string (nullable = true)
 |-- tenure: integer (nullable = true)
 |-- PhoneService: string (nullable = true)
 |-- MultipleLines: string (nullable = true)
 |-- InternetService: string (nullable = true)
 |-- OnlineSecurity: string (nullable = true)
 |-- OnlineBackup: string (nullable = true)
 |-- DeviceProtection: string (nullable = true)
 |-- TechSupport: string (nullable = true)
 |-- StreamingTV: string (nullable = true)
 |-- StreamingMovies: string (nullable = true)
 |-- Contract: string (nullable = true)
 |-- PaperlessBilling: strin