# ML: Customer Churn Prediction

Predicts customers at risk of churning using LightGBM classification.

## Model Overview
- **Target**: Binary churn (no purchase in configurable days)
- **Algorithm**: LightGBM Classifier
- **Features**: Behavioral (purchase patterns) + Demographic
- **Validation**: 5-fold cross-validation
- **Target Performance**: AUC-ROC > 0.75, Precision > 0.6

## Data Flow
```
Silver (fact_receipts, dim_customers) --> Feature Engineering --> Model Training --> Gold (churn_predictions)
```

## Usage
Schedule this notebook to run **weekly** via Fabric pipeline to refresh predictions.

In [None]:
from pyspark.sql import functions as F
from pyspark.sql.window import Window
from pyspark.sql.utils import AnalysisException
from datetime import datetime, timedelta, timezone
import os
import pandas as pd
import numpy as np

# ML libraries
import lightgbm as lgb
from sklearn.model_selection import cross_val_score, StratifiedKFold
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score, confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

# Suppress warnings
import warnings
warnings.filterwarnings('ignore')

In [None]:
# =============================================================================
# PARAMETERS
# =============================================================================

def get_env(var_name, default=None):
    return os.environ.get(var_name, default)

SILVER_DB = get_env("SILVER_DB", default="ag")
GOLD_DB = get_env("GOLD_DB", default="au")

# Churn definition (days without purchase)
CHURN_WINDOW_DAYS = int(get_env("CHURN_WINDOW_DAYS", default="90"))

# Feature engineering window (days of historical data to analyze)
FEATURE_WINDOW_DAYS = int(get_env("FEATURE_WINDOW_DAYS", default="180"))

# Training window (exclude recent data for label stability)
LABEL_OFFSET_DAYS = int(get_env("LABEL_OFFSET_DAYS", default="7"))

# Model parameters
RANDOM_STATE = 42
CV_FOLDS = 5

print(f"Configuration:")
print(f"  SILVER_DB={SILVER_DB}")
print(f"  GOLD_DB={GOLD_DB}")
print(f"  CHURN_WINDOW_DAYS={CHURN_WINDOW_DAYS}")
print(f"  FEATURE_WINDOW_DAYS={FEATURE_WINDOW_DAYS}")
print(f"  LABEL_OFFSET_DAYS={LABEL_OFFSET_DAYS}")

In [None]:
# =============================================================================
# HELPER FUNCTIONS
# =============================================================================

def ensure_database(name):
    spark.sql(f"CREATE DATABASE IF NOT EXISTS {name}")

def read_silver(table_name):
    return spark.table(f"{SILVER_DB}.{table_name}")

def save_gold(df, table_name):
    full_name = f"{GOLD_DB}.{table_name}"
    df.write.format("delta").mode("overwrite").saveAsTable(full_name)
    print(f"  {full_name}: {df.count()} rows")

def silver_exists(table_name):
    try:
        spark.table(f"{SILVER_DB}.{table_name}")
        return True
    except AnalysisException:
        return False

ensure_database(GOLD_DB)

In [None]:
# =============================================================================
# DATA VALIDATION
# =============================================================================

print("="*60)
print("VALIDATING DATA SOURCES")
print("="*60)

required_tables = ["fact_receipts", "dim_customers"]
for table in required_tables:
    if not silver_exists(table):
        raise ValueError(f"Required table {SILVER_DB}.{table} not found!")
    print(f"  {table}: OK")

print("\nAll required tables exist.\n")

In [None]:
# =============================================================================
# DETERMINE ANALYSIS DATES
# =============================================================================

print("="*60)
print("DETERMINING ANALYSIS DATES")
print("="*60)

# Get latest transaction date from data
max_event_ts = read_silver("fact_receipts").agg(F.max("event_ts")).collect()[0][0]

if max_event_ts is None:
    raise ValueError("No transaction data found in fact_receipts!")

# Convert to date
latest_date = max_event_ts.date()

# Calculate key dates
snapshot_date = latest_date - timedelta(days=LABEL_OFFSET_DAYS)
feature_start_date = snapshot_date - timedelta(days=FEATURE_WINDOW_DAYS)
churn_cutoff_date = snapshot_date - timedelta(days=CHURN_WINDOW_DAYS)

print(f"  Latest transaction date: {latest_date}")
print(f"  Snapshot date (for predictions): {snapshot_date}")
print(f"  Feature window: {feature_start_date} to {snapshot_date}")
print(f"  Churn cutoff (last purchase before): {churn_cutoff_date}")
print()

In [None]:
# =============================================================================
# FEATURE ENGINEERING
# =============================================================================

print("="*60)
print("FEATURE ENGINEERING")
print("="*60)

# Load base data
print("Loading transaction data...")
receipts_df = (
    read_silver("fact_receipts")
    .filter(F.col("event_ts") <= F.lit(snapshot_date))
    .select("customer_id", "store_id", "event_ts", "total_amount", "payment_method")
)

print("Loading customer dimension...")
customers_df = (
    read_silver("dim_customers")
    .select("customer_id", "segment", "loyalty_status", "signup_date", "geography_id")
)

# Behavioral Features: Purchase patterns in feature window
print("\nEngineering behavioral features...")
behavioral_features = (
    receipts_df
    .filter(
        (F.col("event_ts") >= F.lit(feature_start_date)) &
        (F.col("event_ts") <= F.lit(snapshot_date))
    )
    .groupBy("customer_id")
    .agg(
        # Frequency metrics
        F.count("*").alias("purchase_count"),
        F.countDistinct("store_id").alias("unique_stores"),
        
        # Monetary metrics
        F.sum("total_amount").alias("total_spend"),
        F.avg("total_amount").alias("avg_basket_value"),
        F.stddev("total_amount").alias("basket_std"),
        F.max("total_amount").alias("max_basket"),
        F.min("total_amount").alias("min_basket"),
        
        # Recency metrics
        F.max("event_ts").alias("last_purchase_date"),
        F.min("event_ts").alias("first_purchase_in_window"),
        
        # Payment diversity
        F.countDistinct("payment_method").alias("payment_methods_used")
    )
    .withColumn(
        "days_since_last_purchase",
        F.datediff(F.lit(snapshot_date), F.to_date("last_purchase_date"))
    )
    .withColumn(
        "purchase_frequency",
        F.col("purchase_count") / F.lit(FEATURE_WINDOW_DAYS)
    )
    .withColumn(
        "basket_consistency",
        F.when(F.col("basket_std").isNull(), 0.0)
         .otherwise(F.col("basket_std") / F.col("avg_basket_value"))
    )
    .drop("first_purchase_in_window", "last_purchase_date")
)

print(f"  Behavioral features created for {behavioral_features.count()} customers")

# Demographic Features
print("\nEngineering demographic features...")
demographic_features = (
    customers_df
    .withColumn(
        "customer_tenure_days",
        F.datediff(F.lit(snapshot_date), F.col("signup_date"))
    )
    .withColumn(
        "is_premium_segment",
        F.when(F.col("segment").isin(["Premium", "VIP"]), 1).otherwise(0)
    )
    .withColumn(
        "is_loyal",
        F.when(F.col("loyalty_status") == "Gold", 2)
         .when(F.col("loyalty_status") == "Silver", 1)
         .otherwise(0)
    )
    .select(
        "customer_id",
        "customer_tenure_days",
        "is_premium_segment",
        "is_loyal",
        "geography_id"
    )
)

print(f"  Demographic features created for {demographic_features.count()} customers")

# Combine all features
print("\nCombining features...")
features_df = (
    demographic_features
    .join(behavioral_features, on="customer_id", how="left")
    # Fill nulls for customers with no purchases in window
    .fillna({
        "purchase_count": 0,
        "unique_stores": 0,
        "total_spend": 0.0,
        "avg_basket_value": 0.0,
        "basket_std": 0.0,
        "max_basket": 0.0,
        "min_basket": 0.0,
        "days_since_last_purchase": FEATURE_WINDOW_DAYS + 1,
        "payment_methods_used": 0,
        "purchase_frequency": 0.0,
        "basket_consistency": 0.0
    })
)

print(f"  Combined features for {features_df.count()} customers")
print()

In [None]:
# =============================================================================
# CREATE TARGET VARIABLE (CHURN LABEL)
# =============================================================================

print("="*60)
print("CREATING CHURN LABELS")
print("="*60)

# Define churn: customers who haven't purchased since churn_cutoff_date
churned_customers = (
    receipts_df
    .filter(F.col("event_ts") > F.lit(churn_cutoff_date))
    .select("customer_id")
    .distinct()
)

# Join to create binary label
dataset_df = (
    features_df
    .join(
        churned_customers.withColumn("is_active", F.lit(1)),
        on="customer_id",
        how="left"
    )
    .withColumn("is_churned", F.when(F.col("is_active").isNull(), 1).otherwise(0))
    .drop("is_active")
)

# Check class balance
churn_stats = dataset_df.groupBy("is_churned").count().collect()
total_customers = sum([row["count"] for row in churn_stats])
churned_count = [row["count"] for row in churn_stats if row["is_churned"] == 1][0]
churn_rate = churned_count / total_customers

print(f"  Total customers: {total_customers}")
print(f"  Churned: {churned_count} ({churn_rate:.1%})")
print(f"  Active: {total_customers - churned_count} ({1 - churn_rate:.1%})")
print()

In [None]:
# =============================================================================
# PREPARE TRAINING DATA
# =============================================================================

print("="*60)
print("PREPARING TRAINING DATA")
print("="*60)

# Convert to Pandas for sklearn/lightgbm
print("Converting to Pandas DataFrame...")
dataset_pd = dataset_df.toPandas()

# Define feature columns (exclude ID and target)
feature_cols = [
    "customer_tenure_days",
    "is_premium_segment",
    "is_loyal",
    "geography_id",
    "purchase_count",
    "unique_stores",
    "total_spend",
    "avg_basket_value",
    "basket_std",
    "max_basket",
    "min_basket",
    "days_since_last_purchase",
    "payment_methods_used",
    "purchase_frequency",
    "basket_consistency"
]

X = dataset_pd[feature_cols]
y = dataset_pd["is_churned"]

print(f"  Training set shape: {X.shape}")
print(f"  Features: {len(feature_cols)}")
print(f"  Target distribution: {y.value_counts().to_dict()}")
print()

In [None]:
# =============================================================================
# TRAIN LIGHTGBM MODEL WITH CROSS-VALIDATION
# =============================================================================

print("="*60)
print("TRAINING LIGHTGBM MODEL")
print("="*60)

# LightGBM classifier with tuned hyperparameters
lgb_model = lgb.LGBMClassifier(
    objective='binary',
    metric='auc',
    boosting_type='gbdt',
    num_leaves=31,
    learning_rate=0.05,
    n_estimators=200,
    max_depth=6,
    min_child_samples=20,
    subsample=0.8,
    colsample_bytree=0.8,
    reg_alpha=0.1,
    reg_lambda=0.1,
    random_state=RANDOM_STATE,
    n_jobs=-1,
    verbose=-1
)

# Cross-validation with stratified folds
print(f"Running {CV_FOLDS}-fold cross-validation...")
cv = StratifiedKFold(n_splits=CV_FOLDS, shuffle=True, random_state=RANDOM_STATE)
cv_scores = cross_val_score(lgb_model, X, y, cv=cv, scoring='roc_auc', n_jobs=-1)

print(f"\nCross-Validation AUC-ROC Scores:")
for i, score in enumerate(cv_scores, 1):
    print(f"  Fold {i}: {score:.4f}")
print(f"  Mean: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")

# Train final model on all data
print("\nTraining final model on full dataset...")
lgb_model.fit(X, y)
print("  Model training complete.")
print()

In [None]:
# =============================================================================
# MODEL EVALUATION
# =============================================================================

print("="*60)
print("MODEL EVALUATION")
print("="*60)

# Generate predictions
y_pred_proba = lgb_model.predict_proba(X)[:, 1]
y_pred = (y_pred_proba >= 0.5).astype(int)

# Calculate metrics
auc_roc = roc_auc_score(y, y_pred_proba)
precision = precision_score(y, y_pred)
recall = recall_score(y, y_pred)
f1 = f1_score(y, y_pred)

print(f"Performance Metrics:")
print(f"  AUC-ROC: {auc_roc:.4f}")
print(f"  Precision (at 0.5 threshold): {precision:.4f}")
print(f"  Recall: {recall:.4f}")
print(f"  F1-Score: {f1:.4f}")

# Check acceptance criteria
print(f"\nAcceptance Criteria:")
print(f"  AUC-ROC > 0.75: {'PASS' if auc_roc > 0.75 else 'FAIL'} ({auc_roc:.4f})")
print(f"  Precision > 0.6: {'PASS' if precision > 0.6 else 'FAIL'} ({precision:.4f})")

# Confusion matrix
cm = confusion_matrix(y, y_pred)
print(f"\nConfusion Matrix:")
print(f"  True Negatives: {cm[0, 0]}")
print(f"  False Positives: {cm[0, 1]}")
print(f"  False Negatives: {cm[1, 0]}")
print(f"  True Positives: {cm[1, 1]}")

# Top 20% risk analysis
top_20_pct_threshold = np.percentile(y_pred_proba, 80)
top_20_pct_mask = y_pred_proba >= top_20_pct_threshold
actual_churners_captured = y[top_20_pct_mask].sum()
total_churners = y.sum()
capture_rate = actual_churners_captured / total_churners

print(f"\nTop 20% High-Risk Customers:")
print(f"  Threshold: {top_20_pct_threshold:.4f}")
print(f"  Captured {actual_churners_captured} of {total_churners} churners ({capture_rate:.1%})")
print(f"  Target: 60%+ - {'PASS' if capture_rate > 0.6 else 'FAIL'}")
print()

In [None]:
# =============================================================================
# FEATURE IMPORTANCE ANALYSIS
# =============================================================================

print("="*60)
print("FEATURE IMPORTANCE")
print("="*60)

# Get feature importance
feature_importance = pd.DataFrame({
    'feature': feature_cols,
    'importance': lgb_model.feature_importances_
}).sort_values('importance', ascending=False)

print("\nTop 10 Most Important Features:")
for idx, row in feature_importance.head(10).iterrows():
    print(f"  {row['feature']:<30} {row['importance']:>10.1f}")

# Visualize feature importance
plt.figure(figsize=(10, 6))
plt.barh(feature_importance['feature'].head(10), feature_importance['importance'].head(10))
plt.xlabel('Feature Importance')
plt.title('Top 10 Features for Churn Prediction')
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()

print()

In [None]:
# =============================================================================
# CHURN RISK DISTRIBUTION
# =============================================================================

print("="*60)
print("CHURN RISK DISTRIBUTION")
print("="*60)

# Analyze risk score distribution
risk_bins = [0, 0.2, 0.4, 0.6, 0.8, 1.0]
risk_labels = ['Very Low', 'Low', 'Medium', 'High', 'Very High']
dataset_pd['risk_category'] = pd.cut(y_pred_proba, bins=risk_bins, labels=risk_labels, include_lowest=True)

risk_distribution = dataset_pd.groupby('risk_category', observed=True).size()
print("\nRisk Category Distribution:")
for category, count in risk_distribution.items():
    pct = count / len(dataset_pd) * 100
    print(f"  {category:<12} {count:>6} ({pct:>5.1f}%)")

# Visualize distribution
plt.figure(figsize=(10, 6))
plt.hist(y_pred_proba, bins=50, alpha=0.7, edgecolor='black')
plt.axvline(x=0.5, color='r', linestyle='--', label='Threshold (0.5)')
plt.axvline(x=top_20_pct_threshold, color='orange', linestyle='--', label=f'Top 20% ({top_20_pct_threshold:.2f})')
plt.xlabel('Churn Probability')
plt.ylabel('Number of Customers')
plt.title('Distribution of Churn Risk Scores')
plt.legend()
plt.tight_layout()
plt.show()

print()

In [None]:
# =============================================================================
# SAVE PREDICTIONS TO GOLD LAYER
# =============================================================================

print("="*60)
print("SAVING PREDICTIONS TO GOLD LAYER")
print("="*60)

# Create predictions dataframe
predictions_pd = pd.DataFrame({
    'customer_id': dataset_pd['customer_id'],
    'churn_probability': y_pred_proba,
    'churn_prediction': y_pred,
    'risk_category': dataset_pd['risk_category'],
    'is_churned_actual': y,
    'prediction_date': pd.Timestamp(snapshot_date),
    'model_version': '1.0',
    'churn_window_days': CHURN_WINDOW_DAYS
})

# Convert to Spark DataFrame
predictions_spark = spark.createDataFrame(predictions_pd)

# Save to Gold layer
save_gold(predictions_spark, "gold_churn_predictions")

print("\nPredictions saved successfully.")
print(f"  Table: {GOLD_DB}.gold_churn_predictions")
print(f"  Rows: {len(predictions_pd)}")
print()

In [None]:
# =============================================================================
# SUMMARY STATISTICS FOR BUSINESS USERS
# =============================================================================

print("="*60)
print("BUSINESS INSIGHTS SUMMARY")
print("="*60)

# High-risk customers requiring immediate attention
high_risk_customers = predictions_pd[predictions_pd['churn_probability'] >= 0.7]
very_high_risk_customers = predictions_pd[predictions_pd['churn_probability'] >= 0.8]

print(f"\nActionable Insights:")
print(f"  Total customers analyzed: {len(predictions_pd):,}")
print(f"  High risk (>70%): {len(high_risk_customers):,}")
print(f"  Very high risk (>80%): {len(very_high_risk_customers):,}")
print(f"  Recommended for retention campaigns: {len(high_risk_customers):,}")

# Calculate potential revenue at risk
avg_customer_value = dataset_pd['total_spend'].mean()
revenue_at_risk = len(high_risk_customers) * avg_customer_value

print(f"\nRevenue Impact:")
print(f"  Avg customer value (last {FEATURE_WINDOW_DAYS} days): ${avg_customer_value:,.2f}")
print(f"  Estimated revenue at risk: ${revenue_at_risk:,.2f}")

print(f"\nNext Steps:")
print(f"  1. Target high-risk customers with retention offers")
print(f"  2. Analyze top churn drivers from feature importance")
print(f"  3. Monitor effectiveness weekly")
print(f"  4. Refresh predictions regularly (recommended: weekly)")

print("\n" + "="*60)
print("CHURN PREDICTION MODEL COMPLETE")
print("="*60)