# 03 - Modeling Churn
Train churn models with a time-aware labeling scheme.

In [1]:
import pandas as pd
from sklearn.model_selection import train_test_split

from src import config
from src.cleaning import clean_transactions
from src.evaluate import (
    classification_metrics,
    plot_pr_curve,
    plot_roc_curve,
    permutation_feature_importance,
)
from src.features import build_customer_features, build_time_based_labels
from src.io import load_transactions_excel
from src.modeling import build_churn_models, select_feature_columns, predict_churn_probabilities
from src.utils import ensure_dirs, set_random_seed

set_random_seed(config.RANDOM_STATE)
ensure_dirs([config.FIGURES_DIR])

raw_df = load_transactions_excel()
clean_df = clean_transactions(raw_df)

snapshot_date = clean_df['InvoiceDate'].max()
cutoff_date = snapshot_date - pd.Timedelta(days=config.TIME_SPLIT.cutoff_days_before_snapshot)

train_df = clean_df[clean_df['InvoiceDate'] <= cutoff_date].copy()
train_features = build_customer_features(train_df)
labels = build_time_based_labels(clean_df, cutoff_date, snapshot_date)
train_features = train_features.merge(labels.rename('churned'), left_on='CustomerID', right_index=True)

numeric_features, categorical_features = select_feature_columns(train_features)
X = train_features[numeric_features + categorical_features]
y = train_features['churned']

X_train, X_test, y_train, y_test = train_test_split(
    X, y, test_size=0.2, random_state=config.RANDOM_STATE, stratify=y
)

models = build_churn_models(numeric_features, categorical_features)
model = models['hist_gradient_boosting']
model.fit(X_train, y_train)

y_proba = predict_churn_probabilities(model, X_test)
y_pred = (y_proba >= 0.5).astype(int)

classification_metrics(y_test, y_pred, y_proba)

{'roc_auc': 1.0, 'precision': 1.0, 'recall': 1.0, 'f1': 1.0}

In [2]:
plot_roc_curve(y_test, y_proba, str(config.FIGURES_DIR / 'roc_curve.png'))
plot_pr_curve(y_test, y_proba, str(config.FIGURES_DIR / 'pr_curve.png'))
permutation_feature_importance(
    model,
    X_test,
    y_test,
    str(config.FIGURES_DIR / 'feature_importance.png'),
)

Unnamed: 0,feature,importance_mean,importance_std
14,churned,0.505212,0.015348
0,recency_days,0.0,0.0
2,num_invoices,0.0,0.0
1,tenure_days,0.0,0.0
4,total_revenue,0.0,0.0
5,avg_order_value,0.0,0.0
6,median_order_value,0.0,0.0
3,frequency_per_month,0.0,0.0
7,revenue_per_month_active,0.0,0.0
8,unique_products,0.0,0.0
