# OTT Subscription Churn (Synthetic + SHAP)

Generate a realistic OTT churn dataset and train models with explainability.

**Contact:** Pablo Monteros — [GitHub](https://github.com/Pmonteros8) • [LinkedIn](https://www.linkedin.com/in/pmonteros/) • [Email](mailto:Pablo.monterosj@gmail.com)

In [None]:
import pandas as pd, numpy as np
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import OneHotEncoder
from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.metrics import classification_report, roc_auc_score
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
import shap

from projects.generate_ott_churn import generate_ott_churn
df = generate_ott_churn(8000)

X = df.drop(columns=["churned"])
y = df["churned"]

cat_cols = ["plan"]
num_cols = [c for c in X.columns if c not in cat_cols]
pre = ColumnTransformer([("cat", OneHotEncoder(handle_unknown="ignore"), cat_cols),
                         ("num", "passthrough", num_cols)])

models = {"logreg": LogisticRegression(max_iter=1000),
          "rf": RandomForestClassifier(n_estimators=300, random_state=42)}

for name, model in models.items():
    clf = Pipeline([("pre", pre), ("clf", model)])
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42, stratify=y)
    clf.fit(X_train, y_train)
    preds = clf.predict(X_test)
    proba = clf.predict_proba(X_test)[:,1] if hasattr(clf, "predict_proba") else preds
    print(f"== {name.upper()} ==")
    print(classification_report(y_test, preds, digits=3))
    try:
        print("ROC AUC:", roc_auc_score(y_test, proba))
    except Exception as e:
        print("AUC unavailable:", e)

# SHAP on RF
rf_clf = Pipeline([("pre", pre), ("clf", RandomForestClassifier(n_estimators=300, random_state=42))])
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42, stratify=y)
rf_clf.fit(X_train, y_train)
X_trans = rf_clf.named_steps["pre"].transform(X_test)
explainer = shap.TreeExplainer(rf_clf.named_steps["clf"])
shap_values = explainer.shap_values(X_trans)
try:
    shap.summary_plot(shap_values[1], X_trans, show=True)
except Exception as e:
    print("SHAP summary plot note:", e)