# Project 2: Customer Churn Prediction — Build Notebook

This workbook mirrors the exploratory script before the project was modularized. Each section pastes the working code from `src/` so the entire workflow can be run sequentially inside the notebook. Streamlit app (`app.py`) is intentionally excluded.


## Step 0 — Environment Imports & Paths


In [None]:

import os
from pathlib import Path

import numpy as np
import pandas as pd


## Step 1 — Data Preparation Code


In [None]:
# --- data_prep.py ---
# src/data_prep.py
from __future__ import annotations
import os
import io
import sys
import textwrap
import joblib
import numpy as np
import pandas as pd
from sklearn.model_selection import train_test_split

DATA_URL = "https://raw.githubusercontent.com/IBM/telco-customer-churn-on-icp4d/master/data/Telco-Customer-Churn.csv"
PROCESSED_DIR = "data/processed"

# ---- Helpers: teaching notes in comments ----
def load_raw_data(url: str = DATA_URL) -> pd.DataFrame:
    """
    Load the IBM Telco Customer Churn dataset directly from the public URL.
    Why direct URL? Keeps repo light and reproducible.
    """
    df = pd.read_csv(url)
    return df

def clean_total_charges(df: pd.DataFrame) -> pd.DataFrame:
    """
    The dataset sometimes has blank strings in TotalCharges (often when tenure==0).
    Approach (explain & justify):
      1) Coerce to numeric; blanks -> NaN
      2) If tenure == 0 and TotalCharges is NaN: set to 0 (no months billed yet)
      3) Else if tenure > 0 and TotalCharges is NaN: approximate as MonthlyCharges * tenure
         (reasonable imputation tied to business meaning)
    """
    df = df.copy()
    df["TotalCharges"] = pd.to_numeric(df["TotalCharges"], errors="coerce")

    mask_tenure0 = (df["tenure"] == 0) & (df["TotalCharges"].isna())
    df.loc[mask_tenure0, "TotalCharges"] = 0.0

    mask_tenure_pos = (df["tenure"] > 0) & (df["TotalCharges"].isna())
    df.loc[mask_tenure_pos, "TotalCharges"] = df.loc[mask_tenure_pos, "MonthlyCharges"] * df.loc[mask_tenure_pos, "tenure"]

    # safety: still any NaNs? fill with median as a last resort
    if df["TotalCharges"].isna().any():
        df["TotalCharges"] = df["TotalCharges"].fillna(df["TotalCharges"].median())

    return df

def engineer_features(df: pd.DataFrame) -> pd.DataFrame:
    """
    Create simple, explainable features requested by the brief:
      - tenure_bucket
      - services_count
      - monthly_to_total_ratio
      - internet_no_tech_support (flag)
      - Additional interaction/business signals to improve recall/precision
      - Map target y: Churn {No,Yes} -> {0,1}
    """
    df = df.copy()

    # tenure_bucket
    bins = [-0.1, 6, 12, 24, np.inf]
    labels = ["0-6m", "6-12m", "12-24m", "24m+"]
    df["tenure_bucket"] = pd.cut(df["tenure"], bins=bins, labels=labels)

    # services_count: count how many services are "on"
    # We'll treat 'Yes' as 1, and for InternetService (DSL/Fiber optic != 'No') as 1.
    service_cols_yesno = [
        "PhoneService", "MultipleLines", "OnlineSecurity", "OnlineBackup",
        "DeviceProtection", "TechSupport", "StreamingTV", "StreamingMovies"
    ]
    services_yes = df[service_cols_yesno].apply(
        lambda col: col.astype(str).str.strip().str.lower().eq("yes")
    )
    count_yes = services_yes.astype(int).sum(axis=1)
    internet_on = (df["InternetService"].str.lower() != "no").astype(int)
    df["services_count"] = count_yes + internet_on

    # monthly_to_total_ratio
    denom = (df["tenure"] * df["MonthlyCharges"]).replace(0, 1)  # avoid divide-by-zero
    df["monthly_to_total_ratio"] = df["TotalCharges"] / denom

    # flag: internet but no tech support
    df["internet_no_tech_support"] = (
        (df["InternetService"].str.lower() != "no") &
        (df["TechSupport"].str.lower() == "no")
    ).astype(int)

    # --- Additional engineered features to help the models ---
    # Ordinal encoding of tenure bucket for linear models
    tenure_ord = {"0-6m": 0, "6-12m": 1, "12-24m": 2, "24m+": 3}
    df["tenure_bucket_ord"] = df["tenure_bucket"].map(tenure_ord).fillna(0).astype(int)

    # Auto-pay indicator
    autopay_methods = {"bank transfer (automatic)", "credit card (automatic)"}
    df["is_auto_pay"] = df["PaymentMethod"].str.strip().str.lower().isin(autopay_methods).astype(int)

    # Long contract indicator (one or two year contracts)
    df["is_long_contract"] = df["Contract"].str.contains("year", case=False, na=False).astype(int)

    # Streaming / support service counts
    streaming_cols = ["StreamingTV", "StreamingMovies"]
    support_cols = ["OnlineSecurity", "OnlineBackup", "DeviceProtection", "TechSupport"]
    df["streaming_services"] = df[streaming_cols].apply(
        lambda col: col.astype(str).str.strip().str.lower().eq("yes")
    ).astype(int).sum(axis=1)
    df["support_services"] = df[support_cols].apply(
        lambda col: col.astype(str).str.strip().str.lower().eq("yes")
    ).astype(int).sum(axis=1)

    # Senior citizens on fiber optic internet (known high-risk segment)
    df["senior_fiber_optic"] = (
        (df["SeniorCitizen"] == 1) &
        (df["InternetService"].str.strip().str.lower() == "fiber optic")
    ).astype(int)

    # Charges per tenure month (helps capture early high spenders)
    df["charges_per_month_of_tenure"] = df["TotalCharges"] / df["tenure"].replace(0, 1)

    # Target mapping
    df["ChurnFlag"] = (df["Churn"].str.strip().str.lower() == "yes").astype(int)

    return df

def compute_expected_tenure_and_clv(df: pd.DataFrame) -> pd.DataFrame:
    """
    Baseline ExpectedTenure assumption (documented):
      Month-to-month -> 6 months
      One year      -> 12 months
      Two year      -> 24 months
    Then CLV = MonthlyCharges * ExpectedTenure
    """
    df = df.copy()
    mapping = {
        "month-to-month": 6,
        "one year": 12,
        "two year": 24,
    }
    df["ExpectedTenure"] = df["Contract"].str.strip().str.lower().map(mapping)
    # fallback if any unexpected labels
    df["ExpectedTenure"] = df["ExpectedTenure"].fillna(6)
    df["CLV"] = df["MonthlyCharges"] * df["ExpectedTenure"]
    # Indicator for high monthly charges relative to CLV (above median ratio)
    clv_ratio = df["MonthlyCharges"] / df["CLV"].replace(0, 1)
    median_ratio = clv_ratio.median()
    df["high_charge_to_clv_ratio"] = (clv_ratio >= median_ratio).astype(int)
    return df

def select_model_columns(df: pd.DataFrame):
    """
    Separate features and target for later modeling.
    We keep raw categoricals for now; we'll one-hot encode in the model pipeline later.
    """
    target = "ChurnFlag"
    # keep core original columns + engineered ones:
    feature_cols = [
        # numerics
        "tenure", "MonthlyCharges", "TotalCharges", "CLV", "services_count",
        "monthly_to_total_ratio", "internet_no_tech_support",
        "tenure_bucket_ord", "is_auto_pay", "is_long_contract",
        "streaming_services", "support_services", "senior_fiber_optic",
        "charges_per_month_of_tenure", "high_charge_to_clv_ratio",
        # categoricals (kept as strings for now; encoder later)
        "gender", "SeniorCitizen", "Partner", "Dependents",
        "PhoneService", "MultipleLines", "InternetService",
        "OnlineSecurity", "OnlineBackup", "DeviceProtection",
        "TechSupport", "StreamingTV", "StreamingMovies",
        "Contract", "PaperlessBilling", "PaymentMethod",
        "tenure_bucket",
    ]
    # prune any missing columns defensively
    feature_cols = [c for c in feature_cols if c in df.columns]
    X = df[feature_cols].copy()
    y = df[target].copy()
    return X, y

def stratified_splits_and_save(X: pd.DataFrame, y: pd.Series, out_dir: str = PROCESSED_DIR, seed: int = 42):
    """
    60/20/20 stratified split. Save CSVs to data/processed/.
    """
    os.makedirs(out_dir, exist_ok=True)
    X_train, X_temp, y_train, y_temp = train_test_split(
        X, y, test_size=0.4, random_state=seed, stratify=y
    )
    X_val, X_test, y_val, y_test = train_test_split(
        X_temp, y_temp, test_size=0.5, random_state=seed, stratify=y_temp
    )

    def save_pair(Xd, yd, name):
        df_out = Xd.copy()
        df_out["ChurnFlag"] = yd.values
        df_out.to_csv(os.path.join(out_dir, f"{name}.csv"), index=False)

    save_pair(X_train, y_train, "train")
    save_pair(X_val, y_val, "val")
    save_pair(X_test, y_test, "test")

    print(f"Saved: {os.path.join(out_dir,'train.csv')}")
    print(f"       {os.path.join(out_dir,'val.csv')}")
    print(f"       {os.path.join(out_dir,'test.csv')}")

def main():
    print("Loading raw data…")
    df = load_raw_data()
    print(f"Rows: {len(df):,}")

    print("Cleaning TotalCharges…")
    df = clean_total_charges(df)

    print("Engineering features…")
    df = engineer_features(df)

    print("Computing ExpectedTenure + CLV…")
    df = compute_expected_tenure_and_clv(df)

    print("Preparing splits…")
    X, y = select_model_columns(df)
    stratified_splits_and_save(X, y)

    print(textwrap.dedent("""
    ✅ Data prep complete.
       - 60/20/20 splits saved to data/processed/
       - Features include: tenure_bucket, services_count, monthly_to_total_ratio, internet_no_tech_support, CLV
       Next: CLV quartiles + churn rate by quartile + charts.
    """))

# if __name__ == '__main__':
#     main()


In [None]:
data_prep_main = main

## Step 2 — CLV Analysis Code


In [None]:
# --- clv_analysis.py ---
# src/clv_analysis.py
from __future__ import annotations
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

PROCESSED_DIR = "data/processed"
FIG_DIR = "figures"
os.makedirs(FIG_DIR, exist_ok=True)

# -----------------------------
# Helper: load the processed splits
# -----------------------------
def load_splits(processed_dir: str = PROCESSED_DIR):
    """
    We analyze on the TRAIN split to avoid peeking at validation/test.
    Why? It prevents subtle information leakage in your narrative & prevents 'tuning' to test.
    """
    train = pd.read_csv(os.path.join(processed_dir, "train.csv"))
    val = pd.read_csv(os.path.join(processed_dir, "val.csv"))
    test = pd.read_csv(os.path.join(processed_dir, "test.csv"))
    return train, val, test

# -----------------------------
# Helper: cut CLV into quartiles
# -----------------------------
def add_clv_quartile(df: pd.DataFrame, clv_col: str = "CLV") -> pd.DataFrame:
    """
    pd.qcut makes equal-sized bins by rank (25% each).
    We label them in ascending order of CLV.
    """
    out = df.copy()
    # qcut will fail if there are duplicate edges; 'duplicates=drop' handles ties gracefully.
    out["CLV_quartile"] = pd.qcut(out[clv_col], q=4, labels=["Low", "Medium", "High", "Premium"], duplicates="drop")
    return out

# -----------------------------
# Helper: churn rate by quartile
# -----------------------------
def churn_rate_by_quartile(df: pd.DataFrame, y_col: str = "ChurnFlag"):
    """
    Returns a small summary table with size and churn rate for each quartile.
    churn_rate = mean of ChurnFlag (since it's 0/1).
    """
    summary = (
        df.groupby("CLV_quartile")[y_col]
          .agg(size="count", churn_rate="mean")
          .reset_index()
          .sort_values("CLV_quartile")  # keeps the logical Low→Premium order
    )
    return summary

# -----------------------------
# Helper: bar plot
# -----------------------------
def plot_churn_rate(summary: pd.DataFrame, savepath: str):
    """
    Matplotlib only, single plot, no custom colors (keeps it clean & portable).
    """
    plt.figure(figsize=(6, 4))
    # We multiply by 100 to show percentages
    plt.bar(summary["CLV_quartile"].astype(str), summary["churn_rate"] * 100)
    plt.title("Churn Rate by CLV Quartile")
    plt.xlabel("CLV Quartile")
    plt.ylabel("Churn Rate (%)")
    # Annotate bars with values
    for x, y in zip(summary["CLV_quartile"].astype(str), summary["churn_rate"] * 100):
        plt.text(x, y + 0.5, f"{y:.1f}%", ha="center", va="bottom", fontsize=9)
    plt.tight_layout()
    plt.savefig(savepath, dpi=200)
    plt.close()

# -----------------------------
# Helper: print narrative insights
# -----------------------------
def print_insights(summary: pd.DataFrame):
    """
    Generates a few business-friendly bullets.
    We compute relative differences to highlight strongest contrasts.
    """
    # Convert to dicts for quick lookup
    sdict = {row["CLV_quartile"]: row for _, row in summary.iterrows()}
    # Guard: if labels got collapsed (rare with duplicates=drop), handle gracefully
    labels = [str(x) for x in summary["CLV_quartile"].tolist()]

    print("\n===== Business Insights (auto-generated) =====")
    # 1) Overall ordering insight (Low vs Premium if both exist)
    if "Low" in sdict and "Premium" in sdict:
        low = sdict["Low"]["churn_rate"] * 100
        prem = sdict["Premium"]["churn_rate"] * 100
        diff = prem - low
        trend = "higher" if diff > 0 else "lower"
        print(f"- Premium vs Low: Premium churn is {abs(diff):.1f} pp {trend} than Low "
              f"({prem:.1f}% vs {low:.1f}%).")
    # 2) Highest-risk quartile
    worst = summary.iloc[summary["churn_rate"].argmax()]
    print(f"- Highest churn segment: {worst['CLV_quartile']} at {worst['churn_rate']*100:.1f}% churn.")
    # 3) Prioritization hint (top two)
    top2 = summary.sort_values("churn_rate", ascending=False).head(2)
    segs = ", ".join([f"{r['CLV_quartile']} ({r['churn_rate']*100:.1f}%)" for _, r in top2.iterrows()])
    print(f"- Retention priority suggestion: focus on {segs}.")

    # 4) Size context
    total = summary["size"].sum()
    for _, r in summary.iterrows():
        pct = r["size"] / total * 100 if total else 0
        print(f"  • {r['CLV_quartile']}: {r['size']} customers ({pct:.1f}% of train), churn {r['churn_rate']*100:.1f}%")
    print("=============================================\n")

# -----------------------------
# Main
# -----------------------------
def main():
    print("Loading processed splits…")
    train, val, test = load_splits()

    # Basic sanity checks for required columns
    needed = {"CLV", "ChurnFlag"}
    missing = needed - set(train.columns)
    if missing:
        raise ValueError(f"Missing columns in train split: {missing}. "
                         f"Did you run src/data_prep.py successfully?")

    print("Adding CLV quartiles to TRAIN…")
    train_q = add_clv_quartile(train, clv_col="CLV")

    print("Computing churn rate by quartile…")
    summary = churn_rate_by_quartile(train_q, y_col="ChurnFlag")
    print(summary)

    # Save a version with the quartile label (could be useful for the app)
    out_csv = os.path.join(PROCESSED_DIR, "train_with_clv_quartile.csv")
    train_q.to_csv(out_csv, index=False)
    print(f"Saved: {out_csv}")

    # Plot & save
    fig_path = os.path.join(FIG_DIR, "churn_rate_by_clv_quartile.png")
    plot_churn_rate(summary, fig_path)
    print(f"Saved figure: {fig_path}")

    # Narrative insights
    print_insights(summary)

# if __name__ == '__main__':
#     main()


In [None]:
clv_main = main

## Step 3 — Model Training Code


In [None]:
# --- model_train.py ---
# src/model_train.py
# Train LogReg, RandomForest, and XGBoost with light tuning + imbalance care.
# Select a recall-first operating threshold on validation, then freeze for test.
# Saves:
#   models/logreg_pipeline.pkl              (best model = usually LOGREG here)
#   models/logreg_metrics.json              (chosen/best model metrics)
#   models/metrics_all.json                 (per-model val/test)
#   models/roc_curves_test.json             (for the ROC overlay in the app)

import inspect
import os
import json
import warnings
from dataclasses import dataclass
from typing import Dict, Tuple

import joblib
import numpy as np
import pandas as pd
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.pipeline import Pipeline
from sklearn.metrics import (
    precision_recall_fscore_support,
    roc_auc_score,
    roc_curve,
    confusion_matrix,
)
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier

# Optional: XGBoost (skip gracefully if not present)
try:
    from xgboost import XGBClassifier
    _HAS_XGB = True
except Exception as e:
    _HAS_XGB = False
    _XGB_ERR = str(e)

warnings.filterwarnings("ignore")

MODELS_DIR = "models"
DATA_DIR = "data/processed"
os.makedirs(MODELS_DIR, exist_ok=True)

VAL_RECALL_MIN = 0.60  # recall-first policy on validation
THRESH_GRID = np.round(np.linspace(0.30, 0.80, 51), 3)  # 0.30 → 0.80 (step ~0.01)
RANDOM_STATE = 42


# ------------------------- Data Loading -------------------------
def _load_splits() -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
    paths = [os.path.join(DATA_DIR, f) for f in ["train.csv", "val.csv", "test.csv"]]
    if not all(os.path.exists(p) for p in paths):
        raise FileNotFoundError(
            f"Processed splits not found in {DATA_DIR}. Run: python -m src.data_prep"
        )
    train = pd.read_csv(paths[0])
    val   = pd.read_csv(paths[1])
    test  = pd.read_csv(paths[2])
    return train, val, test


# ------------------------- Features -------------------------
def _feature_lists(df: pd.DataFrame) -> Tuple[list, list]:
    # Numeric engineered features you created in data_prep.py
    num_cols = [
        "tenure",
        "MonthlyCharges",
        "TotalCharges",
        "CLV",
        "services_count",
        "monthly_to_total_ratio",
        "internet_no_tech_support",
        "tenure_bucket_ord",
        "is_auto_pay",
        "is_long_contract",
        "streaming_services",
        "support_services",
        "senior_fiber_optic",
        "charges_per_month_of_tenure",
        "high_charge_to_clv_ratio",
    ]
    # Categorical features
    cat_cols = [
        "gender",
        "SeniorCitizen",           # treat as categorical (0/1)
        "Partner",
        "Dependents",
        "PhoneService",
        "MultipleLines",
        "InternetService",
        "OnlineSecurity",
        "OnlineBackup",
        "DeviceProtection",
        "TechSupport",
        "StreamingTV",
        "StreamingMovies",
        "Contract",
        "PaperlessBilling",
        "PaymentMethod",
        "tenure_bucket",
    ]
    # Keep only columns that actually exist
    num_cols = [c for c in num_cols if c in df.columns]
    cat_cols = [c for c in cat_cols if c in df.columns]
    return num_cols, cat_cols


def _preprocessor(num_cols: list, cat_cols: list) -> ColumnTransformer:
    # scikit-learn >=1.4 renamed `sparse` to `sparse_output`; handle both.
    ohe_kwargs = {"handle_unknown": "ignore"}
    if "sparse_output" in OneHotEncoder.__init__.__code__.co_varnames:
        ohe_kwargs["sparse_output"] = True
    else:
        ohe_kwargs["sparse"] = True
    cat_encoder = OneHotEncoder(**ohe_kwargs)

    num_pipe = Pipeline([("scaler", StandardScaler())])
    cat_pipe = Pipeline([("ohe", cat_encoder)])
    pre = ColumnTransformer(
        transformers=[
            ("num", num_pipe, num_cols),
            ("cat", cat_pipe, cat_cols),
        ],
        remainder="drop",
        sparse_threshold=0.3,
    )
    return pre


# ------------------------- Models -------------------------
def make_logreg() -> LogisticRegression:
    # class_weight balances the minority (churn) class
    return LogisticRegression(
        penalty="l2",
        C=1.0,
        solver="lbfgs",
        max_iter=1000,
        class_weight="balanced",
        n_jobs=None,
        random_state=RANDOM_STATE,
    )


def make_rf() -> RandomForestClassifier:
    # Light tuning that generalizes well on Telco; balanced_subsample helps imbalance
    return RandomForestClassifier(
        n_estimators=600,
        max_depth=10,
        min_samples_leaf=20,
        class_weight="balanced_subsample",
        n_jobs=-1,
        random_state=RANDOM_STATE,
    )


def make_xgb(scale_pos_weight: float):
    # Common, robust defaults + early stopping (handled in fit block)
    return XGBClassifier(
        n_estimators=800,
        learning_rate=0.05,
        max_depth=5,
        min_child_weight=4,
        subsample=0.8,
        colsample_bytree=0.8,
        reg_lambda=2.0,
        gamma=0.0,
        objective="binary:logistic",
        eval_metric="auc",
        scale_pos_weight=scale_pos_weight,
        random_state=RANDOM_STATE,
        tree_method="hist",
        n_jobs=-1,
        verbosity=0,
    )


# ------------------------- Metrics & Thresholds -------------------------
@dataclass
class EvalResult:
    metrics_val: Dict
    metrics_test: Dict
    thr: float
    roc_test: Dict[str, list]


def _bin_metrics(y_true, y_prob, thr) -> Dict:
    y_pred = (y_prob >= thr).astype(int)
    p, r, f1, _ = precision_recall_fscore_support(
        y_true, y_pred, average="binary", zero_division=0
    )
    auc = roc_auc_score(y_true, y_prob)
    tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
    return {
        "precision": float(p),
        "recall": float(r),
        "f1": float(f1),
        "auc": float(auc),
        "tn": int(tn), "fp": int(fp), "fn": int(fn), "tp": int(tp),
        "used_threshold": float(thr),
    }


def _choose_threshold_recall_first(y_true, y_prob, recall_min=VAL_RECALL_MIN):
    best_thr = 0.50
    best_f1 = -1.0
    met = None
    for thr in THRESH_GRID:
        m = _bin_metrics(y_true, y_prob, thr)
        if m["recall"] >= recall_min and m["f1"] > best_f1:
            best_f1 = m["f1"]
            best_thr = thr
            met = m
    if met is None:
        # If no threshold reaches target recall, fall back to best f1 overall
        for thr in THRESH_GRID:
            m = _bin_metrics(y_true, y_prob, thr)
            if m["f1"] > best_f1:
                best_f1 = m["f1"]; best_thr = thr; met = m
    return best_thr, met


def _roc_points(y_true, y_prob) -> Tuple[list, list]:
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    return list(map(float, fpr)), list(map(float, tpr))


# ------------------------- Train/Eval -------------------------
def train_and_eval():
    train, val, test = _load_splits()
    y_tr = train["ChurnFlag"].astype(int).values
    y_va = val["ChurnFlag"].astype(int).values
    y_te = test["ChurnFlag"].astype(int).values

    X_tr = train.drop(columns=["ChurnFlag"])
    X_va = val.drop(columns=["ChurnFlag"])
    X_te = test.drop(columns=["ChurnFlag"])

    num_cols, cat_cols = _feature_lists(train)
    pre = _preprocessor(num_cols, cat_cols)

    # Pos/neg counts for XGB scale_pos_weight
    pos = int(y_tr.sum())
    neg = int((y_tr == 0).sum())
    spw = neg / max(1, pos)

    results = {}
    roc_dict = {}

    # ---------- LOGREG ----------
    logreg = make_logreg()
    pipe_logreg = Pipeline(steps=[("pre", pre), ("model", logreg)])
    pipe_logreg.fit(X_tr, y_tr)
    joblib.dump(pipe_logreg, os.path.join(MODELS_DIR, "logreg_baseline_pipeline.pkl"))

    prob_va = pipe_logreg.predict_proba(X_va)[:, 1]
    thr, met_val = _choose_threshold_recall_first(y_va, prob_va, VAL_RECALL_MIN)
    prob_te = pipe_logreg.predict_proba(X_te)[:, 1]
    met_test = _bin_metrics(y_te, prob_te, thr)
    fpr, tpr = _roc_points(y_te, prob_te)
    roc_dict["logreg"] = {"fpr": fpr, "tpr": tpr}

    results["logreg"] = {"val": met_val, "test": met_test, "thr": thr}

    # ---------- RANDOM FOREST ----------
    rf = make_rf()
    pipe_rf = Pipeline(steps=[("pre", pre), ("model", rf)])
    pipe_rf.fit(X_tr, y_tr)

    prob_va = pipe_rf.predict_proba(X_va)[:, 1]
    thr_rf, met_val_rf = _choose_threshold_recall_first(y_va, prob_va, VAL_RECALL_MIN)
    prob_te = pipe_rf.predict_proba(X_te)[:, 1]
    met_test_rf = _bin_metrics(y_te, prob_te, thr_rf)
    fpr, tpr = _roc_points(y_te, prob_te)
    roc_dict["rf"] = {"fpr": fpr, "tpr": tpr}
    results["rf"] = {"val": met_val_rf, "test": met_test_rf, "thr": thr_rf}
    joblib.dump(pipe_rf, os.path.join(MODELS_DIR, "rf_pipeline.pkl"))

    # ---------- XGBOOST (optional) ----------
    if _HAS_XGB:
        xgb = make_xgb(spw)
        # Fit XGB with early stopping on validation
        # We pass the preprocessed arrays explicitly to leverage early stopping.
        X_tr_pre = pipe_logreg.named_steps["pre"].fit_transform(X_tr)  # reuse pre to get feature matrix
        X_va_pre = pipe_logreg.named_steps["pre"].transform(X_va)

        # XGBoost>=2.1 prefers callbacks for early stopping; fall back for older versions.
        fit_kwargs = {
            "X": X_tr_pre,
            "y": y_tr,
            "eval_set": [(X_va_pre, y_va)],
            "verbose": False,
        }
        fit_sig = inspect.signature(xgb.fit)
        try:
            if "callbacks" in fit_sig.parameters:
                try:
                    from xgboost.callback import EarlyStopping
                    callbacks = [EarlyStopping(rounds=50, save_best=True)]
                    xgb.fit(**fit_kwargs, callbacks=callbacks)
                except Exception:
                    xgb.fit(**fit_kwargs)
            elif "early_stopping_rounds" in fit_sig.parameters:
                xgb.fit(**fit_kwargs, early_stopping_rounds=50)
            else:
                xgb.fit(**fit_kwargs)
        except TypeError:
            xgb.fit(**fit_kwargs)
        # Build a pipeline wrapper so app can use a consistent interface
        pipe_xgb = Pipeline(steps=[("pre", pre), ("model", xgb)])

        prob_va = pipe_xgb.predict_proba(X_va)[:, 1]
        thr_xgb, met_val_xgb = _choose_threshold_recall_first(y_va, prob_va, VAL_RECALL_MIN)
        prob_te = pipe_xgb.predict_proba(X_te)[:, 1]
        met_test_xgb = _bin_metrics(y_te, prob_te, thr_xgb)
        fpr, tpr = _roc_points(y_te, prob_te)
        roc_dict["xgb"] = {"fpr": fpr, "tpr": tpr}
        results["xgb"] = {"val": met_val_xgb, "test": met_test_xgb, "thr": thr_xgb}
        joblib.dump(pipe_xgb, os.path.join(MODELS_DIR, "xgb_pipeline.pkl"))
    else:
        print(f"⚠️ XGBoost not available, skipping. Error: {_XGB_ERR}")

    # ---------------- Choose best model by validation F1 under recall constraint ----------------
    # (You can change selection logic if the brief asks for a different tie-breaker.)
    def key_fn(k):
        return results[k]["val"]["f1"]

    best_name = max(results.keys(), key=key_fn)
    best_thr = results[best_name]["thr"]

    # Save the *LogReg* pipeline as default (the app expects this path);
    # if another model won, we still save the winning pipeline under this path for simplicity.
    best_pipe = {
        "logreg": pipe_logreg,
        "rf": pipe_rf,
        "xgb": pipe_xgb if _HAS_XGB else pipe_logreg,
    }[best_name]

    # Persist the winning pipeline + its metrics
    joblib.dump(best_pipe, os.path.join(MODELS_DIR, "logreg_pipeline.pkl"))

    with open(os.path.join(MODELS_DIR, "logreg_metrics.json"), "w") as f:
        json.dump(
            {
                "val": results[best_name]["val"],
                "test": results[best_name]["test"],
                "used_threshold": best_thr,
                "model_name": best_name,
            },
            f,
            indent=2,
        )

    # Persist per-model metrics for the app comparison table
    with open(os.path.join(MODELS_DIR, "metrics_all.json"), "w") as f:
        json.dump(results, f, indent=2)

    # Persist ROC curves for the app overlay
    with open(os.path.join(MODELS_DIR, "roc_curves_test.json"), "w") as f:
        json.dump(roc_dict, f, indent=2)

    # Convenience prints
    print(f"\n=== BEST MODEL: {best_name.upper()} (test at thr={best_thr:.2f}) ===")
    print(results[best_name]["test"])


if __name__ == "__main__":
    train_and_eval()


## Step 4 — Global Interpretability Code


In [None]:
# --- interpretability.py ---
# src/interpretability.py
import os, json, joblib, numpy as np, pandas as pd, matplotlib.pyplot as plt
from pathlib import Path

FIG_DIR = Path("figures"); FIG_DIR.mkdir(parents=True, exist_ok=True)
MODELS_DIR = Path("models")
DATA_DIR = Path("data/processed")

def _load_splits():
    train = pd.read_csv(DATA_DIR/"train.csv")
    val   = pd.read_csv(DATA_DIR/"val.csv")
    test  = pd.read_csv(DATA_DIR/"test.csv")
    return train, val, test

def _safe_load(path):
    return joblib.load(path) if path.exists() else None

def logreg_importance():
    pipe = _safe_load(MODELS_DIR/"logreg_pipeline.pkl")
    if pipe is None:
        print("LogReg pipeline missing; skip."); return
    pre, model = pipe.named_steps["pre"], pipe.named_steps["model"]

    # compute feature std on training preprocessed data (for standardized importances)
    train, _, _ = _load_splits()
    Xtr = train.drop(columns=["ChurnFlag"])
    Xtr_pre = pre.transform(Xtr)
    Xtr_pre = Xtr_pre.toarray() if hasattr(Xtr_pre, "toarray") else np.asarray(Xtr_pre)
    feat_names = pre.get_feature_names_out()

    if hasattr(model, "coef_"):
        std = Xtr_pre.std(axis=0)
        coefs = model.coef_.ravel()
        imp = np.abs(coefs * std)
        title = "LogReg global importance (|coef × std|)"
    elif hasattr(model, "feature_importances_"):
        imp = model.feature_importances_
        title = f"{model.__class__.__name__} feature importances"
    else:
        print(f"{model.__class__.__name__} lacks coefficients/feature_importances; skipping global plot.")
        return

    top_idx = np.argsort(imp)[::-1][:20]
    top_feats = feat_names[top_idx]
    top_vals  = imp[top_idx]

    plt.figure()
    plt.barh(range(len(top_vals))[::-1], top_vals[::-1])
    plt.yticks(range(len(top_vals))[::-1], top_feats[::-1])
    plt.title(title)
    plt.tight_layout()
    out = FIG_DIR/"logreg_global_importance.png"
    plt.savefig(out, dpi=200); plt.close()
    print(f"Saved {out}")

def shap_trees(which="rf"):
    try:
        import shap
    except Exception as e:
        print("SHAP not available; skip trees.", e); return
    model_path = MODELS_DIR/(f"{which}_pipeline.pkl")
    pipe = _safe_load(model_path)
    if pipe is None:
        print(f"{which.upper()} pipeline missing; skip."); return

    pre, model = pipe.named_steps["pre"], pipe.named_steps["model"]
    _, _, test = _load_splits()
    X = test.drop(columns=["ChurnFlag"])
    # sample to keep it snappy
    Xs = X.sample(min(300, len(X)), random_state=42)
    Xs_pre = pre.transform(Xs)
    if hasattr(Xs_pre, "toarray"):
        Xs_pre = Xs_pre.toarray()
    else:
        Xs_pre = np.asarray(Xs_pre)
    explainer = shap.TreeExplainer(model)
    shap_values = explainer.shap_values(Xs_pre)
    if isinstance(shap_values, list) and len(shap_values) > 1:
        shap_values_use = shap_values[-1]
    else:
        shap_values_use = shap_values
    # Global bar plot
    feat_names = pre.get_feature_names_out()
    plt.figure()
    shap.summary_plot(shap_values_use, features=Xs_pre, feature_names=feat_names, plot_type="bar", show=False)
    out = FIG_DIR/f"{which}_shap_global_bar.png"
    plt.tight_layout(); plt.savefig(out, dpi=200); plt.close()
    print(f"Saved {out}")

if __name__ == "__main__":
    logreg_importance()
    shap_trees("rf")
    shap_trees("xgb")


## Step 5 — Local Interpretability Code


In [None]:
# --- interpret.py ---
# src/interpret.py
from __future__ import annotations
import os
import math
import joblib
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from typing import List

MODELS_DIR = "models"
FIG_DIR = "figures"
os.makedirs(FIG_DIR, exist_ok=True)

def get_feature_names(preprocessor) -> List[str]:
    """
    Recover the final feature names after ColumnTransformer + OneHotEncoder.
    Works in sklearn>=1.0 where get_feature_names_out is supported.
    """
    try:
        names = preprocessor.get_feature_names_out()
        return names.tolist()
    except Exception:
        # Fallback: build names manually for older versions (unlikely with your env)
        names = []
        for name, trans, cols in preprocessor.transformers_:
            if name == "remainder" and trans == "drop":
                continue
            if hasattr(trans, "named_steps") and "ohe" in trans.named_steps:
                # categorical pipeline
                ohe = trans.named_steps["ohe"]
                imputer = trans.named_steps.get("imputer", None)
                base_cols = cols
                # Handle case where imputer may change dtype but not names
                cat_names = ohe.get_feature_names_out(base_cols)
                names.extend(cat_names.tolist())
            else:
                # numeric pipeline: names are just cols
                names.extend(cols if isinstance(cols, list) else list(cols))
        return names

def main():
    # Load best pipeline (we saved logreg as best)
    model_path = os.path.join(MODELS_DIR, "logreg_pipeline.pkl")
    pipe = joblib.load(model_path)

    pre = pipe.named_steps["pre"]
    model = pipe.named_steps["model"]

    # We handle linear models via coefficients; otherwise fall back to feature_importances.
    use_coeffs = hasattr(model, "coef_")
    use_importance = hasattr(model, "feature_importances_")

    if not use_coeffs and not use_importance:
        raise ValueError(
            f"{model.__class__.__name__} lacks coefficients or feature_importances_. "
            "Cannot generate interpretability artifacts."
        )

    # Get expanded feature names after preprocessing
    feat_names = get_feature_names(pre)

    # Coefficients (shape: [1, n_features])
    if use_coeffs:
        values = model.coef_.ravel()
    else:
        values = model.feature_importances_.ravel()
    if len(values) != len(feat_names):
        raise ValueError(f"Mismatch: {len(values)} weights vs {len(feat_names)} feature names")

    # Build a tidy table
    df_imp = pd.DataFrame({
        "feature": feat_names,
        "value": values,
    })
    if use_coeffs:
        df_imp["odds_ratio"] = np.exp(df_imp["value"])
    else:
        df_imp["odds_ratio"] = np.nan  # not meaningful for tree models
    df_imp["abs_value"] = df_imp["value"].abs()

    # Sort for top drivers
    sort_col = "value"
    top_positive = df_imp.sort_values(sort_col, ascending=False).head(20)
    top_negative = df_imp.sort_values(sort_col, ascending=True).head(20)

    # Save CSVs
    out_csv_all = os.path.join(MODELS_DIR, "logreg_feature_importance_full.csv")
    out_csv_pos = os.path.join(MODELS_DIR, "logreg_top_positive.csv")
    out_csv_neg = os.path.join(MODELS_DIR, "logreg_top_negative.csv")
    df_imp.to_csv(out_csv_all, index=False)
    top_positive.to_csv(out_csv_pos, index=False)
    top_negative.to_csv(out_csv_neg, index=False)

    print(f"Saved: {out_csv_all}")
    print(f"Saved: {out_csv_pos}")
    print(f"Saved: {out_csv_neg}")

    # --- PLOTS (horizontal bar charts) ---
    def plot_barh(df, title, savepath):
        plt.figure(figsize=(8, 6))
        df_plot = df.copy()
        df_plot = df_plot.sort_values(sort_col, ascending=True)
        labels = df_plot["feature"].str.replace("cat__ohe__", "", regex=False)
        plt.barh(labels, df_plot[sort_col])
        plt.xlabel("Coefficient" if use_coeffs else "Feature Importance")
        plt.title(title)
        if use_coeffs:
            for y, (_, row) in enumerate(df_plot.iterrows()):
                plt.text(row[sort_col] + 0.01, y, f"{row['odds_ratio']:.2f}", va="center")
        plt.tight_layout()
        plt.savefig(savepath, dpi=200)
        plt.close()

    pos_title = "Top Positive Drivers of Churn (higher → more likely to churn)" if use_coeffs else "Top Features Increasing Predicted Risk"
    neg_title = "Top Negative Drivers of Churn (higher → less likely to churn)" if use_coeffs else "Top Features Decreasing Predicted Risk"
    plot_barh(top_positive, pos_title, os.path.join(FIG_DIR, "logreg_top_positive_odds.png"))
    plot_barh(top_negative, neg_title, os.path.join(FIG_DIR, "logreg_top_negative_odds.png"))

    print("Saved figures:")
    print(" - figures/logreg_top_positive_odds.png")
    print(" - figures/logreg_top_negative_odds.png")

# if __name__ == '__main__':
#     main()


In [None]:
interpret_main = main

### Execute Step 1 — Data Preparation


In [None]:

raw_df = load_raw_data()
clean_df = clean_total_charges(raw_df)
feature_df = engineer_features(clean_df)
feature_df = compute_expected_tenure_and_clv(feature_df)
X, y = select_model_columns(feature_df)
stratified_splits_and_save(X, y)
print(feature_df.head())


### Execute Step 2 — CLV Analysis


In [None]:

train_df = pd.read_csv(PROCESSED_DIR / "train.csv")
val_df = pd.read_csv(PROCESSED_DIR / "val.csv")
test_df = pd.read_csv(PROCESSED_DIR / "test.csv")

train_q = add_clv_quartile(train_df)
clv_summary = churn_rate_by_quartile(train_q)
print(clv_summary)


### Execute Step 3 — Model Training


In [None]:

train_and_eval()


### Execute Step 4 — Global Interpretability


In [None]:

logreg_importance()
shap_trees('rf')
shap_trees('xgb')


### Execute Step 5 — Local Interpretability


In [None]:

interpret_main()
