# Customer Churn Prediction (Subscription SaaS)

End-to-end churn model + thresholding + drivers.


In [None]:
import os
from pathlib import Path

def find_project_root(start: Path, marker: str = "01_customer_churn_prediction_saas") -> Path:
    p = start.resolve()
    for parent in [p] + list(p.parents):
        if parent.name == marker:
            return parent
    return start.resolve()

ROOT = find_project_root(Path.cwd())
os.chdir(ROOT)
print("Project root:", ROOT)


In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path
from sklearn.model_selection import train_test_split
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import roc_auc_score, precision_recall_curve, classification_report, confusion_matrix
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.inspection import permutation_importance
from sklearn.calibration import CalibratedClassifierCV
import joblib


In [None]:
data_path = Path("data/churn_customers.csv")
if not data_path.exists():
    from data.make_dataset import main as make_data
    make_data(out_path=str(data_path))

df = pd.read_csv(data_path)
df.head()


In [None]:
target = "churned"
X = df.drop(columns=[target, "customer_id"])
y = df[target]

cat_cols = ["contract", "payment_method"]
num_cols = [c for c in X.columns if c not in cat_cols]

preprocess = ColumnTransformer(
    transformers=[
        ("num", Pipeline([("scaler", StandardScaler())]), num_cols),
        ("cat", OneHotEncoder(handle_unknown="ignore"), cat_cols),
    ]
)

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=42, stratify=y
)

logit = Pipeline([("prep", preprocess),
                  ("clf", LogisticRegression(max_iter=2000, class_weight="balanced"))])

rf = Pipeline([("prep", preprocess),
               ("clf", RandomForestClassifier(
                   n_estimators=400, random_state=42, n_jobs=-1,
                   min_samples_leaf=6, class_weight="balanced_subsample"
               ))])

logit.fit(X_train, y_train)
rf.fit(X_train, y_train)

def auc(model):
    p = model.predict_proba(X_test)[:,1]
    return roc_auc_score(y_test, p)

print("AUC - Logistic:", round(auc(logit), 4))
print("AUC - RandomForest:", round(auc(rf), 4))


In [None]:
# Calibrate probabilities and choose a threshold using F1 as a starting point
base_model = rf
cal = CalibratedClassifierCV(base_model, method="isotonic", cv=3)
cal.fit(X_train, y_train)

p_test = cal.predict_proba(X_test)[:,1]
print("Calibrated AUC:", round(roc_auc_score(y_test, p_test), 4))

precision, recall, thr = precision_recall_curve(y_test, p_test)
f1 = (2*precision*recall) / (precision+recall+1e-9)
best_i = int(np.nanargmax(f1))
best_thr = float(thr[max(best_i-1, 0)])
print("Best F1 threshold:", round(best_thr, 4), "| F1:", round(float(np.nanmax(f1)), 4))

y_pred = (p_test >= best_thr).astype(int)
print(classification_report(y_test, y_pred, digits=3))
confusion_matrix(y_test, y_pred)


In [None]:
# Driver insights (permutation importance on RF base estimator)
prep = base_model.named_steps["prep"]
clf = base_model.named_steps["clf"]
X_test_pre = prep.transform(X_test)

res = permutation_importance(clf, X_test_pre, y_test, n_repeats=7, random_state=42, n_jobs=-1, scoring="roc_auc")

ohe = prep.named_transformers_["cat"]
cat_names = list(ohe.get_feature_names_out(["contract","payment_method"]))
feature_names = num_cols + cat_names

imp = pd.DataFrame({
    "feature": feature_names,
    "importance_mean": res.importances_mean,
    "importance_std": res.importances_std
}).sort_values("importance_mean", ascending=False)

top = imp.head(12).iloc[::-1]
plt.figure(figsize=(8, 5))
plt.barh(top["feature"], top["importance_mean"])
plt.title("Top Permutation Importances (AUC impact)")
plt.xlabel("Mean importance")
plt.tight_layout()

Path("reports").mkdir(exist_ok=True)
plt.savefig("reports/feature_importance.png", dpi=200, bbox_inches="tight")
plt.show()

imp.head(12)


In [None]:
# Export artifacts
Path("models").mkdir(exist_ok=True)
Path("reports").mkdir(exist_ok=True)

joblib.dump({"model": cal, "threshold": best_thr, "features": list(X.columns)}, "models/churn_model.joblib")

metrics = {
    "auc_calibrated": float(roc_auc_score(y_test, p_test)),
    "best_f1_threshold": float(best_thr),
}
Path("reports/metrics.json").write_text(pd.Series(metrics).to_json(), encoding="utf-8")

print("Saved models/churn_model.joblib and reports/*")


## Recommendations
- Score weekly and prioritize top-risk customers (capacity-based threshold).
- Map actions to drivers:
  - High support tickets → proactive support escalation
  - Low adoption / incomplete onboarding → guided onboarding + in-app nudges
  - Late payments → billing outreach + payment method incentives
  - Month-to-month + high charges → annual plan offer / A/B discount
