# Uplift Modeling for Churn Prediction

This notebook contains the end-to-end uplift modeling workflow: exploratory data analysis (EDA), feature insights, and preparation for uplift modeling.

---

## **1. Setup**
Imports and project paths. 


In [None]:
# Core libraries
from pathlib import Path

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import chi2_contingency, ttest_ind

# Embedding & ML (used in section 4 – feature engineering pipeline)
from sentence_transformers import SentenceTransformer
from sklearn.metrics.pairwise import cosine_similarity

# Uplift modeling imports (used in model-selection section)
from sklearn.model_selection import StratifiedKFold

try:
    from causalml.inference.meta import BaseSRegressor, BaseTRegressor, BaseXRegressor
    from causalml.metrics import qini_auc_score, auuc_score
except Exception:  # pragma: no cover - handled at runtime
    BaseSRegressor = BaseTRegressor = BaseXRegressor = None
    qini_auc_score = auuc_score = None

try:
    from lightgbm import LGBMRegressor
    from xgboost import XGBRegressor
except Exception:  # pragma: no cover - handled at runtime
    LGBMRegressor = XGBRegressor = None

pd.set_option("display.max_columns", 200)

def set_axes_clear(ax, x_axis_at_zero=False):
    """Y-axis at x=0 (left spine). X-axis: at y=0 if x_axis_at_zero else bottom spine."""
    for spine in ("top", "right"):
        ax.spines[spine].set_visible(False)
    if x_axis_at_zero:
        ax.spines["bottom"].set_visible(False)
    else:
        ax.spines["bottom"].set_visible(True)
        ax.spines["bottom"].set_color("black")
        ax.spines["bottom"].set_linewidth(1.2)
    ax.spines["left"].set_visible(True)
    ax.spines["left"].set_color("black")
    ax.spines["left"].set_linewidth(1.2)

# Paths (notebook lives in project root)
# Project root (parent of src/); files/ is at project root, not inside src
BASE_DIR = Path(".").resolve().parent
FILE_DIR = BASE_DIR / "files"  # all data files live under files/
TRAIN_DIR = FILE_DIR / "train"
TEST_DIR = FILE_DIR / "test"


## **2. Load data**
Load training and test CSVs, restrict train event data to the observation window (July 1–15, 2025), and print shapes.


In [None]:
# Training data
churn_labels = pd.read_csv(TRAIN_DIR / "churn_labels.csv", parse_dates=["signup_date"])
app_usage = pd.read_csv(TRAIN_DIR / "app_usage.csv", parse_dates=["timestamp"])
web_visits = pd.read_csv(TRAIN_DIR / "web_visits.csv", parse_dates=["timestamp"])
claims = pd.read_csv(TRAIN_DIR / "claims.csv", parse_dates=["diagnosis_date"])

# Test data
test_members = pd.read_csv(TEST_DIR / "test_members.csv", parse_dates=["signup_date"])
test_app_usage = pd.read_csv(TEST_DIR / "test_app_usage.csv", parse_dates=["timestamp"])
test_web_visits = pd.read_csv(TEST_DIR / "test_web_visits.csv", parse_dates=["timestamp"])
test_claims = pd.read_csv(TEST_DIR / "test_claims.csv", parse_dates=["diagnosis_date"])

# Observation window: July 1 - July 15, 2025 (pre-outreach). Outreach = July 15; churn measured after.
# Restrict train event data only; test data is not filtered (outreach has not occurred for test).
OBS_START = pd.Timestamp("2025-07-01")
OBS_END   = pd.Timestamp("2025-07-15")  # exclusive: keep events strictly before outreach

web_visits = web_visits[(web_visits["timestamp"] >= OBS_START) & (web_visits["timestamp"] < OBS_END)]
app_usage  = app_usage[(app_usage["timestamp"] >= OBS_START) & (app_usage["timestamp"] < OBS_END)]
claims     = claims[(claims["diagnosis_date"] >= OBS_START) & (claims["diagnosis_date"] < OBS_END)]

# Quick sanity check
for name, df in {
    "churn_labels": churn_labels,
    "app_usage": app_usage,
    "web_visits": web_visits,
    "claims": claims,
    "test_members": test_members,
    "test_app_usage": test_app_usage,
    "test_web_visits": test_web_visits,
    "test_claims": test_claims,
}.items():
    print(f"{name}: {df.shape}")

**What it means:** The printed shapes show how many rows and columns each table has after loading and (for train) after restricting to the observation window. Train event tables (app_usage, web_visits, claims) are limited to events before outreach (July 15, 2025).

**What it says about further analysis:** We have 10,000 train members and 10,000 test members; event volumes are large enough for aggregation. Next we explore structure and missingness, then build features from these tables.

## **3. EDA**

Exploratory data analysis: table structure, missingness, treatment balance, leakage checks, and uplift by engagement/claims/recency.

---

### **3.1 Raw data overview**
Summarize structure, dtypes, and sample rows for all 8 tables.


In [None]:
# ----------
# 3.1 Raw data overview: All datasets
# ----------

def print_table_overview(name, df):
    """Print structure, dtypes, numeric describe, date ranges, and head for a single table.

    Parameters
    ----------
    name : str
        Display name of the table.
    df : pandas.DataFrame
        The table to summarize.
    """
    print(f"\n{'='*60}")
    print(f"  {name}")
    print(f"{'='*60}")
    print(f"\n--- dtypes ---")
    print(df.dtypes.to_string())
    
    # Describe only meaningful numeric columns (exclude IDs and dates)
    numeric_cols = df.select_dtypes(include=[np.number]).columns.tolist()
    # Remove member_id and other ID columns
    numeric_cols = [col for col in numeric_cols if col not in ['member_id']]
    
    if len(numeric_cols) > 0:
        print(f"\n--- describe (numeric columns only) ---")
        print(df[numeric_cols].describe().to_string())
    
    # Show date ranges separately
    date_cols = df.select_dtypes(include=['datetime64']).columns.tolist()
    if len(date_cols) > 0:
        print(f"\n--- date ranges ---")
        for col in date_cols:
            print(f"  {col}: {df[col].min()} to {df[col].max()}")
    
    # Show object column summaries
    object_cols = df.select_dtypes(include=['object']).columns.tolist()
    if len(object_cols) > 0:
        print(f"\n--- object columns (unique counts) ---")
        for col in object_cols:
            print(f"  {col}: {df[col].nunique()} unique values")
    
    print(f"\n--- head(2) ---")
    display(df.head(2))

# Apply to all tables
all_tables = {
    "churn_labels": churn_labels,
    "app_usage": app_usage,
    "web_visits": web_visits,
    "claims": claims,
    "test_members": test_members,
    "test_app_usage": test_app_usage,
    "test_web_visits": test_web_visits,
    "test_claims": test_claims,
}

for name, df in all_tables.items():
    print_table_overview(name, df)


**What it means:** For each of the 8 tables we see column types, numeric summaries (where applicable), date ranges, unique counts for object columns, and a 2-row sample. This confirms churn_labels has churn/outreach, event tables have timestamps, and web_visits has url/title/description.

**What it says about further analysis:** We have a clear picture of schema and value ranges. Use this to design aggregates (e.g. counts per member from event tables) and to handle dates consistently. Next: column-specific checks (event_type, url, title, icd_code).

### **3.1 Column-specific checks**
Check `event_type`, `url`, `title`, and `icd_code` to guide feature engineering (e.g. drop constant columns, use title/description for content features).

In [None]:
# ----------
# 3.1 Column-specific checks
# Purpose: Check special columns (event_type, url, title, icd_code) for feature engineering decisions.
# What we test: value_counts for event_type, url, title, icd_code.
# What we do with this info:
#   - If event_type is constant -> drop it.
#   - URL/title variety -> potential content-categorization features.
#   - ICD distribution -> guides focus-ICD flag design.
# ----------

print("="*60)
print("  Column-specific checks")
print("="*60)
print("\napp_usage event_type value_counts:")
print(app_usage["event_type"].value_counts().to_string())
print(f"  -> {'CONSTANT — can drop' if app_usage['event_type'].nunique() == 1 else 'MULTIPLE VALUES — keep'}")

# url and title in web_visits: content variety
print(f"\nweb_visits url: {web_visits['url'].nunique()} unique values")
print("  Top-5 URLs:")
print(web_visits["url"].value_counts().head(5).to_string())
print(f"\nweb_visits title: {web_visits['title'].nunique()} unique values")
print("  Top-5 titles:")
print(web_visits["title"].value_counts().head(5).to_string())

# icd_code in claims
print(f"\nclaims icd_code: {claims['icd_code'].nunique()} unique values")
print("  Top-10 ICD codes:")
print(claims["icd_code"].value_counts().head(10).to_string())

**What it means:** `event_type` is constant ("session") so it can be dropped. URLs are very diverse; titles/descriptions have limited unique values and are good candidates for embedding-based relevance. ICD codes are well distributed for focus-ICD flags.

**What it says about further analysis:** Drop `event_type` in feature engineering. Use `title` and `description` (not URL) for health-related web classification via embeddings. Build focus-ICD binary/count features from the ICD distribution.

### **3.2 Missing values and member coverage**
Check column-level nulls and how many members have no rows in each activity source (web, app, claims).

In [None]:
# ----------
# 3.2 Missing values and member coverage
# Purpose: Find all missing data — both nulls within tables and members absent
#   from activity sources. This determines our imputation strategy.
# What we test:
#   Part A: .isnull().sum() per column for all 8 tables.
#   Part B: For train and test base sets, count members with zero rows in each
#           source (web, app, claims). Cross-source pattern.
# What we do with this info:
#   - If column nulls exist -> decide imputation method.
#   - Members absent from a source -> their aggregated features will be NaN
#     after join. Must zero-fill or handle explicitly.
#   - Cross-source patterns -> if absence is correlated, a single "inactive"
#     flag may suffice.
# Plots: 1 bar chart (% absent per source with annotations).
# ----------

# --- Part A: Column-level null check ---
print("="*60)
print("  Part A — Column-level null check")
print("="*60)

null_rows = []
for name, df in all_tables.items():
    nulls = df.isnull().sum()
    for col, cnt in nulls.items():
        if cnt > 0:
            null_rows.append({"table": name, "column": col, "null_count": cnt})

if null_rows:
    null_df = pd.DataFrame(null_rows)
    print(null_df.to_string(index=False))
else:
    print("\n✓ No null values found in any column of any table.\n")

# --- Part B: Member coverage across activity sources ---
print("="*60)
print("  Part B — Member coverage across sources")
print("="*60)

for split_name, base_df, src_tables in [
    ("TRAIN", churn_labels, {"web_visits": web_visits, "app_usage": app_usage, "claims": claims}),
    ("TEST",  test_members,  {"web_visits": test_web_visits, "app_usage": test_app_usage, "claims": test_claims}),
]:
    base_ids = set(base_df["member_id"].unique())
    n_base = len(base_ids)
    print(f"\n--- {split_name} (base members: {n_base}) ---")

    source_sets = {}
    coverage_rows = []
    for src_name, src_df in src_tables.items():
        present = set(src_df["member_id"].unique())
        source_sets[src_name] = present
        missing = len(base_ids - present)
        coverage_rows.append({
            "source": src_name,
            "members_present": len(present & base_ids),
            "members_absent": missing,
            "absent_pct": missing / n_base * 100,
        })

    cov = pd.DataFrame(coverage_rows)
    print(cov.to_string(index=False))

    # Cross-source pattern (train only)
    if split_name == "TRAIN":
        has_web = source_sets["web_visits"] & base_ids
        has_app = source_sets["app_usage"] & base_ids
        has_claims = source_sets["claims"] & base_ids
        no_web = base_ids - has_web
        no_app = base_ids - has_app
        no_claims = base_ids - has_claims

        patterns = {
            "missing web only":     len(no_web - no_app - no_claims),
            "missing app only":     len(no_app - no_web - no_claims),
            "missing claims only":  len(no_claims - no_web - no_app),
            "missing web+app":      len(no_web & no_app - no_claims),
            "missing web+claims":   len(no_web & no_claims - no_app),
            "missing app+claims":   len(no_app & no_claims - no_web),
            "missing all 3":        len(no_web & no_app & no_claims),
            "present in all":       len(has_web & has_app & has_claims),
        }
        print("\nCross-source missingness patterns (train):")
        for pat, cnt in patterns.items():
            print(f"  {pat}: {cnt} ({cnt/n_base*100:.2f}%)")

        # --- Bar chart: % absent per source (train) ---
        fig, ax = plt.subplots(figsize=(7, 4))
        bars = ax.bar(cov["source"], cov["absent_pct"], color=sns.color_palette()[:3])
        ax.set_ylabel("% with zero activity", fontsize=11)
        ax.set_xlabel("Source table")
        ax.set_title("Members absent from each activity source (train)")
        for bar, row in zip(bars, cov.itertuples()):
            ax.annotate(f"{row.members_absent}\n({row.absent_pct:.2f}%)",
                        xy=(bar.get_x() + bar.get_width()/2, bar.get_height()),
                        ha="center", va="bottom", fontsize=9)
        set_axes_clear(ax, x_axis_at_zero=False)
        fig.subplots_adjust(left=0.22, right=0.96, top=0.92, bottom=0.12)
        plt.show()


**What it means:** No column-level nulls in any table. A small share of members have zero rows in each source (e.g. ~0.25% missing web, ~0.02% missing app). Almost all members appear in all three sources; absence is mostly per-source, not overlapping.

**What it says about further analysis:** Zero-fill aggregated features for members missing from a source. Optional: add binary flags like `has_web_visits` if we find that absence is predictive of churn or uplift. Next: test whether missingness is related to churn/outreach (3.3).

### **3.3 Missingness mechanism**
Test whether absence from each activity source is associated with churn or outreach (Chi-square). If yes, missingness is informative and we should consider has_X flags.

In [None]:
# ----------
# 3.3 Missingness mechanism (MCAR / MAR / MNAR)
# Purpose: Understand *why* some members are absent from activity sources.
#   If absence correlates with churn/outreach, the missingness itself
#   is informative (MAR/MNAR) and we should create has_X flags as features.
#   If not (MCAR), simple zero-fill is fine.
# What we test:
#   1. Churn rate for present vs absent members, per source (Chi-square).
#   2. Outreach rate for present vs absent, per source (Chi-square).
#   3. Cross-source contingency table.
# Plots: 1 grouped bar chart — churn rate for "has activity" vs "no activity"
#   per source. Key visual to judge whether absence is informative.
# What we do with this info:
#   - If churn differs significantly by presence → create has_X binary flag.
#   - If not → zero-fill engagement features is sufficient.
# ----------

# Build has_X flags on train base
train_ids = churn_labels[["member_id", "churn", "outreach"]].copy()

web_ids = set(web_visits["member_id"].unique())
app_ids = set(app_usage["member_id"].unique())
claims_ids = set(claims["member_id"].unique())

train_ids["has_web"] = train_ids["member_id"].isin(web_ids).astype(int)
train_ids["has_app"] = train_ids["member_id"].isin(app_ids).astype(int)
train_ids["has_claims"] = train_ids["member_id"].isin(claims_ids).astype(int)

# --- 1 & 2: Chi-square tests ---
results = []
for source in ["has_web", "has_app", "has_claims"]:
    for target in ["churn", "outreach"]:
        ct = pd.crosstab(train_ids[source], train_ids[target])
        chi2, p, dof, expected = chi2_contingency(ct)
        rate_absent = train_ids.loc[train_ids[source] == 0, target].mean()
        rate_present = train_ids.loc[train_ids[source] == 1, target].mean()
        results.append({
            "source_flag": source,
            "target": target,
            "rate_absent (0)": round(rate_absent, 4) if pd.notna(rate_absent) else "N/A",
            "rate_present (1)": round(rate_present, 4),
            "chi2": round(chi2, 2),
            "p_value": f"{p:.4g}",
        })

results_df = pd.DataFrame(results)
print("="*70)
print("  Chi-square tests: churn/outreach rate by presence/absence")
print("="*70)
print(results_df.to_string(index=False))

# --- 3: Cross-source contingency ---
print("\n" + "="*70)
print("  Cross-source contingency (train)")
print("="*70)
cross = train_ids.groupby(["has_web", "has_app", "has_claims"]).size().reset_index(name="count")
print(cross.to_string(index=False))

# --- Grouped bar chart: churn rate present vs absent per source ---
chart_data = []
for source in ["has_web", "has_app", "has_claims"]:
    for val, label in [(1, "present"), (0, "absent")]:
        subset = train_ids[train_ids[source] == val]
        if len(subset) > 0:
            chart_data.append({
                "source": source.replace("has_", ""),
                "group": label,
                "churn_rate": subset["churn"].mean(),
                "n": len(subset),
            })

chart_df = pd.DataFrame(chart_data)
if len(chart_df) > 0:
    fig, ax = plt.subplots(figsize=(8, 5))
    sources = chart_df["source"].unique()
    x = np.arange(len(sources))
    width = 0.35

    present = chart_df[chart_df["group"] == "present"].set_index("source")
    absent = chart_df[chart_df["group"] == "absent"].set_index("source")

    bars1 = ax.bar(x - width/2,
                   [present.loc[s, "churn_rate"] if s in present.index else 0 for s in sources],
                   width, label="Has activity", color=sns.color_palette()[0])
    bars2 = ax.bar(x + width/2,
                   [absent.loc[s, "churn_rate"] if s in absent.index else 0 for s in sources],
                   width, label="No activity", color=sns.color_palette()[3])

    ax.set_ylabel("Churn rate")
    ax.set_xlabel("Activity source")
    ax.set_title("Churn rate: members with vs without activity per source")
    ax.set_xticks(x)
    ax.set_xticklabels(sources)
    ax.legend()
    set_axes_clear(ax, x_axis_at_zero=False)

    for bars in [bars1, bars2]:
        for bar in bars:
            h = bar.get_height()
            if h > 0:
                ax.annotate(f"{h:.3f}",
                            xy=(bar.get_x() + bar.get_width()/2, h),
                            ha="center", va="bottom", fontsize=9)
    plt.tight_layout()
    plt.show()


**What it means:** Chi-square p-values and the bar chart show whether churn (or outreach) rate differs between members who have activity in a source vs those who do not. The cross-tab shows how many members are missing from each combination of sources.

**What it says about further analysis:**  p-values are large (e.g. > 0.05), missingness is not strongly related to churn/outreach → zero-fill is enough.

### **3.4 Labels and treatment balance**
Check churn rate, outreach rate, and group sizes. Affects metric choice and model design.

In [None]:
# ----------
# 3.4 Labels and treatment balance
# Purpose: Assess churn class imbalance and outreach/control balance.
#   These directly affect metric choice and model design.
# What we test:
#   - Overall churn rate, outreach rate, group sizes.
#   - Outreach × churn cross-tabulation.
# Plots:
#   1. sns.countplot for outreach (0/1).
#   2. sns.countplot for churn (0/1).
#   3. sns.barplot churn rate by outreach group.
# What we do with this info:
#   - ~20% churn → class imbalance → consider stratified sampling / proper metric.
#   - ~40% outreach → decent treatment group size for uplift estimation.
# ----------

churn_rate = churn_labels["churn"].mean()
outreach_rate = churn_labels["outreach"].mean()

summary_labels = churn_labels.groupby("outreach")["churn"].agg([
    ("members", "count"),
    ("churn_rate", "mean"),
])

print(f"Overall churn rate: {churn_rate:.3f}")
print(f"Outreach rate: {outreach_rate:.3f}")
print("\nOutreach x Churn cross-tabulation:")
cross_tab = pd.crosstab(churn_labels["outreach"], churn_labels["churn"],
                        margins=True, margins_name="Total")
print(cross_tab.to_string())
print("\nChurn rates by group:")
print(summary_labels.to_string())

plt.figure(figsize=(8, 5))
sns.countplot(data=churn_labels, x="outreach")
plt.title("Outreach vs. control counts", fontsize=14, fontweight='bold')
plt.xlabel("Outreach", fontsize=12)
plt.ylabel("Count", fontsize=12)
plt.tick_params(labelsize=11)
set_axes_clear(plt.gca(), x_axis_at_zero=False)
plt.tight_layout()
plt.show()

plt.figure(figsize=(8, 5))
sns.countplot(data=churn_labels, x="churn")
plt.title("Churn label counts", fontsize=14, fontweight='bold')
plt.xlabel("Churn", fontsize=12)
plt.ylabel("Count", fontsize=12)
plt.tick_params(labelsize=11)
set_axes_clear(plt.gca(), x_axis_at_zero=False)
plt.tight_layout()
plt.show()

plt.figure(figsize=(8, 5))
sns.barplot(data=summary_labels.reset_index(), x="outreach", y="churn_rate")
plt.title("Churn rate by outreach group", fontsize=14, fontweight='bold')
plt.xlabel("Outreach", fontsize=12)
plt.ylabel("Churn rate", fontsize=12)
plt.tick_params(labelsize=11)
set_axes_clear(plt.gca(), x_axis_at_zero=False)
plt.tight_layout()
plt.show()


**What it means:** We see overall churn rate (~20%), outreach rate (~40%), and how churn varies between treated vs control. Plots show balance and raw churn by group.

**What it says about further analysis:** Class imbalance in churn suggests using stratified sampling or appropriate metrics (e.g. AUUC for uplift). Treatment group size is sufficient for estimating uplift. Next: validate time windows and check for leakage (3.5).

### **3.5 Leakage & time-window validation**
Confirm event timestamps fall within the observation window and do not precede signup (no leakage).

In [None]:
# ----------
# 3.5 Leakage & time-window validation
# Purpose: Validate event timestamps vs signup dates and summarize observation windows.
# Event data is already restricted to [OBS_START, OBS_END) in Section 2 (Load data).
# ----------

# Global observation end (max event date across sources)
obs_end = max(
    web_visits["timestamp"].max(),
    app_usage["timestamp"].max(),
    claims["diagnosis_date"].max(),
)

# Summary of min/max timestamps per table
window_summary = pd.DataFrame([
    {
        "table": "web_visits",
        "min_ts": web_visits["timestamp"].min(),
        "max_ts": web_visits["timestamp"].max(),
        "rows": len(web_visits),
    },
    {
        "table": "app_usage",
        "min_ts": app_usage["timestamp"].min(),
        "max_ts": app_usage["timestamp"].max(),
        "rows": len(app_usage),
    },
    {
        "table": "claims",
        "min_ts": claims["diagnosis_date"].min(),
        "max_ts": claims["diagnosis_date"].max(),
        "rows": len(claims),
    },
])

# Count events that occur before signup_date (possible leakage or data issues)

def count_events_before_signup(events_df, date_col, labels_df):
    """Count event rows where event date is before the member's signup_date (potential leakage).

    Parameters
    ----------
    events_df : pandas.DataFrame
        Event table with member_id and date_col.
    date_col : str
        Name of the date column (e.g. 'timestamp' or 'diagnosis_date').
    labels_df : pandas.DataFrame
        Table with member_id and signup_date.

    Returns
    -------
    int
        Number of event rows where event date < signup_date.
    """
    merged = events_df[["member_id", date_col]].merge(
        labels_df[["member_id", "signup_date"]], on="member_id", how="left"
    )
    return (merged[date_col] < merged["signup_date"]).sum()

leakage_checks = pd.DataFrame([
    {
        "table": "web_visits",
        "events_before_signup": count_events_before_signup(web_visits, "timestamp", churn_labels),
    },
    {
        "table": "app_usage",
        "events_before_signup": count_events_before_signup(app_usage, "timestamp", churn_labels),
    },
    {
        "table": "claims",
        "events_before_signup": count_events_before_signup(claims, "diagnosis_date", churn_labels),
    },
])

window_summary, leakage_checks


**What it means:** The first table shows min/max timestamps per event table (all within July 1–14, 2025). The second table shows zero events before signup for web, app, and claims — no leakage.

**What it says about further analysis:** Observation window and signup logic are consistent. We can safely use these events for feature engineering. Next: temporal and engagement uplift (3.6, 3.7).

### **3.6 Temporal features as uplift moderators**

Uplift = P(churn=1 | outreach=1, bin) − P(churn=1 | outreach=0, bin).

Each bar shows uplift among members who had **at least one event** in that bin. The same member may appear in multiple bins.


In [None]:
# 3.6 – Data preparation: extract temporal features from web + app events.

web_ev = web_visits[["member_id", "timestamp"]].copy()
web_ev["hour"] = web_ev["timestamp"].dt.hour
web_ev["dow"] = web_ev["timestamp"].dt.dayofweek  # 0=Mon … 6=Sun

app_ev = app_usage[["member_id", "timestamp"]].copy()
app_ev["hour"] = app_ev["timestamp"].dt.hour
app_ev["dow"] = app_ev["timestamp"].dt.dayofweek

events = pd.concat([
    web_ev[["member_id", "hour", "dow"]],
    app_ev[["member_id", "hour", "dow"]],
], ignore_index=True)

def time_bin(h):
    """Map hour (0-23) to a time-of-day label for aggregation.

    Parameters
    ----------
    h : int
        Hour of day (0-23).

    Returns
    -------
    str
        One of 'Early Morning', 'Morning', 'Afternoon', 'Evening'.
    """
    if h < 6:  return "Early Morning"
    if h < 12: return "Morning"
    if h < 18: return "Afternoon"
    return "Evening"

events["time_of_day"] = events["hour"].apply(time_bin)

DOW_NAMES = {0: "Mon", 1: "Tue", 2: "Wed", 3: "Thu", 4: "Fri", 5: "Sat", 6: "Sun"}
events["dow_name"] = events["dow"].map(DOW_NAMES)
events["is_weekend"] = events["dow"].isin([5, 6])

labels = churn_labels[["member_id", "churn", "outreach"]]

print(f"Events prepared: {len(events):,} rows from {events['member_id'].nunique():,} members")


In [None]:
# Reusable uplift helper – used by all uplift bar plots in 3.6 / 3.7 / 3.8.

def compute_uplift(member_ids):
    """Return (uplift, n_treated, n_control) for a set of member IDs."""
    df = labels[labels["member_id"].isin(member_ids)]
    tr = df[df["outreach"] == 1]["churn"]
    co = df[df["outreach"] == 0]["churn"]
    uplift = tr.mean() - co.mean() if len(tr) > 0 and len(co) > 0 else np.nan
    return uplift, len(tr), len(co)


def plot_uplift_bars(bin_names, uplifts, title, xlabel):
    """Simple bar plot: one bar per bin, y = uplift, horizontal zero line."""
    fig, ax = plt.subplots(figsize=(8, 4))
    ax.bar(range(len(bin_names)), uplifts, color="steelblue", edgecolor="black", alpha=0.85)
    ax.axhline(0, color="black", linestyle="-", linewidth=1.2)
    ax.set_xticks(range(len(bin_names)))
    ax.set_xticklabels(bin_names, rotation=20, ha="right")
    ax.set_xlabel(xlabel, fontsize=12)
    ax.set_ylabel("Uplift (churn-rate difference)", fontsize=12)
    ax.set_title(title, fontsize=14, fontweight="bold")
    ax.grid(axis="y", alpha=0.3)
    set_axes_clear(ax, x_axis_at_zero=True)
    plt.tight_layout()
    plt.show()


In [None]:
# 3.6a – Uplift by time of day

tod_order = ["Early Morning", "Morning", "Afternoon", "Evening"]
tod_uplift, tod_nt, tod_nc = [], [], []

for tod in tod_order:
    ids = events.loc[events["time_of_day"] == tod, "member_id"].unique()
    u, nt, nc = compute_uplift(ids)
    tod_uplift.append(u); tod_nt.append(nt); tod_nc.append(nc)

plot_uplift_bars(tod_order, tod_uplift,
                 title="Uplift by time of day",
                 xlabel="Time of day")

print(f"{'Group':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(tod_order, tod_uplift, tod_nt, tod_nc):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")


In [None]:
# 3.6b – Uplift by day of week

dow_order = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
dow_uplift, dow_nt, dow_nc = [], [], []

for day in dow_order:
    ids = events.loc[events["dow_name"] == day, "member_id"].unique()
    u, nt, nc = compute_uplift(ids)
    dow_uplift.append(u); dow_nt.append(nt); dow_nc.append(nc)

plot_uplift_bars(dow_order, dow_uplift,
                 title="Uplift by day of week",
                 xlabel="Day of week")

print(f"{'Group':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(dow_order, dow_uplift, dow_nt, dow_nc):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")


In [None]:
# 3.6c – Uplift by weekend vs weekday

wk_labels = ["Weekday", "Weekend"]
wk_uplift, wk_nt, wk_nc = [], [], []

for is_wknd in [False, True]:
    ids = events.loc[events["is_weekend"] == is_wknd, "member_id"].unique()
    u, nt, nc = compute_uplift(ids)
    wk_uplift.append(u); wk_nt.append(nt); wk_nc.append(nc)

plot_uplift_bars(wk_labels, wk_uplift,
                 title="Uplift by weekend vs weekday",
                 xlabel="Day type")

print(f"{'Group':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(wk_labels, wk_uplift, wk_nt, wk_nc):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")


**What it means:** Uplift by time of day and by weekday vs weekend is similar across bins (all slightly negative). Outreach reduces churn a bit regardless of when members are active.

**What it says about further analysis:** Temporal features (time of day, day of week, weekend) do not strongly moderate uplift. Next: engagement and claims (3.7, 3.8).

### **3.7 Engagement features as uplift moderators**

**(a)** Distribution sanity checks — log-scaled histograms and quantile summaries.  
**(b)** Uplift by engagement quartile for each feature.


In [None]:
# 3.7a – Distribution sanity checks (engagement)

web_per = web_visits.groupby("member_id").size().rename("web_visits_count").reset_index()
app_per = app_usage.groupby("member_id").size().rename("app_sessions_count").reset_index()
url_div = web_visits.groupby("member_id")["url"].nunique().rename("url_nunique").reset_index()

eng = churn_labels[["member_id", "churn", "outreach"]].merge(
    web_per, on="member_id", how="left"
).merge(
    app_per, on="member_id", how="left"
).merge(
    url_div, on="member_id", how="left"
)
for col in ["web_visits_count", "app_sessions_count", "url_nunique"]:
    eng[col] = eng[col].fillna(0)

# Quantile summaries
for feat in ["web_visits_count", "app_sessions_count", "url_nunique"]:
    print(f"\n{feat}:")
    print(eng[feat].quantile([0, .25, .5, .75, .9, .95, .99, 1]).to_string())

# Log-scaled histograms
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
for i, feat in enumerate(["web_visits_count", "app_sessions_count", "url_nunique"]):
    axes[i].hist(np.log1p(eng[feat]), bins=50, edgecolor="black", alpha=0.7)
    axes[i].set_xlabel(f"log(1 + {feat})", fontsize=11)
    axes[i].set_ylabel("Number of members", fontsize=11)
    axes[i].set_title(feat, fontsize=12, fontweight="bold")
    axes[i].grid(alpha=0.3)
    set_axes_clear(axes[i], x_axis_at_zero=False)
plt.tight_layout()
plt.show()


In [None]:
# 3.7b – Uplift by web_visits_count quartile

eng["web_q"] = pd.qcut(eng["web_visits_count"], q=4, duplicates="drop")
bin_names, uplifts, nts, ncs = [], [], [], []
for b in sorted(eng["web_q"].dropna().unique()):
    ids = eng.loc[eng["web_q"] == b, "member_id"]
    u, nt, nc = compute_uplift(ids)
    bin_names.append(str(b)); uplifts.append(u); nts.append(nt); ncs.append(nc)

plot_uplift_bars(bin_names, uplifts,
                 title="Uplift by web visits quartile",
                 xlabel="Web visits (quartile)")

print(f"{'Bin':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(bin_names, uplifts, nts, ncs):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")


In [None]:
# 3.7c – Uplift by app_sessions_count quartile

eng["app_q"] = pd.qcut(eng["app_sessions_count"], q=4, duplicates="drop")
bin_names, uplifts, nts, ncs = [], [], [], []
for b in sorted(eng["app_q"].dropna().unique()):
    ids = eng.loc[eng["app_q"] == b, "member_id"]
    u, nt, nc = compute_uplift(ids)
    bin_names.append(str(b)); uplifts.append(u); nts.append(nt); ncs.append(nc)

plot_uplift_bars(bin_names, uplifts,
                 title="Uplift by app sessions quartile",
                 xlabel="App sessions (quartile)")

print(f"{'Bin':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(bin_names, uplifts, nts, ncs):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")


In [None]:
# 3.7d – Uplift by URL diversity quartile

eng["url_q"] = pd.qcut(eng["url_nunique"], q=4, duplicates="drop")
bin_names, uplifts, nts, ncs = [], [], [], []
for b in sorted(eng["url_q"].dropna().unique()):
    ids = eng.loc[eng["url_q"] == b, "member_id"]
    u, nt, nc = compute_uplift(ids)
    bin_names.append(str(b)); uplifts.append(u); nts.append(nt); ncs.append(nc)

plot_uplift_bars(bin_names, uplifts,
                 title="Uplift by URL diversity quartile",
                 xlabel="Unique URLs (quartile)")

print(f"{'Bin':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(bin_names, uplifts, nts, ncs):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")


**What it means:** Engagement (event counts, sessions, URL diversity) shows uplift varying by quartile; some bins have near-zero or slightly positive uplift.

**What it says about further analysis:** Engagement level can moderate uplift — useful as features for the model. Next: claims (3.8) and recency/tenure (3.9).

### **3.8 Claims features as uplift moderators**

**(a)** Distribution sanity checks — log-scaled histograms, quantile summaries, focus-ICD prevalence.  
**(b)** Uplift by claims strata: quartile bins for counts, binary for has-focus-ICD, and 0/1/2/3 for count of focus ICDs.


In [None]:
# 3.8a – Distribution sanity checks (claims)

claims_per = claims.groupby("member_id").size().rename("claims_count").reset_index()
icd_nun = claims.groupby("member_id")["icd_code"].nunique().rename("icd_nunique").reset_index()

icd_focus_codes = ["E11.9", "I10", "Z71.3"]
claims_f = claims.copy()
claims_f["is_focus"] = claims_f["icd_code"].isin(icd_focus_codes)

# Binary flag: has any focus ICD
focus_any = claims_f.groupby("member_id")["is_focus"].any().rename("has_focus_icd").reset_index()

# Count of distinct focus ICDs per member (0, 1, 2, or 3)
focus_count = (
    claims_f[claims_f["is_focus"]]
    .groupby("member_id")["icd_code"]
    .nunique()
    .rename("focus_icd_count")
    .reset_index()
)

cl = churn_labels[["member_id", "churn", "outreach"]].merge(
    claims_per, on="member_id", how="left"
).merge(
    icd_nun, on="member_id", how="left"
).merge(
    focus_any, on="member_id", how="left"
).merge(
    focus_count, on="member_id", how="left"
)
cl["claims_count"] = cl["claims_count"].fillna(0)
cl["icd_nunique"] = cl["icd_nunique"].fillna(0)
cl["has_focus_icd"] = cl["has_focus_icd"].fillna(False).astype(int)
cl["focus_icd_count"] = cl["focus_icd_count"].fillna(0).astype(int)

# Quantile summaries
for feat in ["claims_count", "icd_nunique"]:
    print(f"\n{feat}:")
    print(cl[feat].quantile([0, .25, .5, .75, .9, .95, .99, 1]).to_string())

print(f"\nFocus-ICD prevalence: {cl['has_focus_icd'].mean():.3f}")
print(f"\nFocus-ICD count distribution:")
print(cl["focus_icd_count"].value_counts().sort_index().to_string())

# Log-scaled histograms
fig, axes = plt.subplots(1, 2, figsize=(12, 4))
for i, feat in enumerate(["claims_count", "icd_nunique"]):
    axes[i].hist(np.log1p(cl[feat]), bins=50, edgecolor="black", alpha=0.7, color="green")
    axes[i].set_xlabel(f"log(1 + {feat})", fontsize=11)
    axes[i].set_ylabel("Number of members", fontsize=11)
    axes[i].set_title(feat, fontsize=12, fontweight="bold")
    axes[i].grid(alpha=0.3)
    set_axes_clear(axes[i], x_axis_at_zero=False)
plt.tight_layout()
plt.show()


In [None]:
# 3.8b – Uplift by claims_count quartile

cl["claims_q"] = pd.qcut(cl["claims_count"], q=4, duplicates="drop")
bin_names, uplifts, nts, ncs = [], [], [], []
for b in sorted(cl["claims_q"].dropna().unique()):
    ids = cl.loc[cl["claims_q"] == b, "member_id"]
    u, nt, nc = compute_uplift(ids)
    bin_names.append(str(b)); uplifts.append(u); nts.append(nt); ncs.append(nc)

plot_uplift_bars(bin_names, uplifts,
                 title="Uplift by claims count quartile",
                 xlabel="Claims count (quartile)")

print(f"{'Bin':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(bin_names, uplifts, nts, ncs):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")


In [None]:
# 3.8c – Uplift by icd_nunique quartile

cl["icd_q"] = pd.qcut(cl["icd_nunique"], q=4, duplicates="drop")
bin_names, uplifts, nts, ncs = [], [], [], []
for b in sorted(cl["icd_q"].dropna().unique()):
    ids = cl.loc[cl["icd_q"] == b, "member_id"]
    u, nt, nc = compute_uplift(ids)
    bin_names.append(str(b)); uplifts.append(u); nts.append(nt); ncs.append(nc)

plot_uplift_bars(bin_names, uplifts,
                 title="Uplift by distinct ICD codes quartile",
                 xlabel="Distinct ICD codes (quartile)")

print(f"{'Bin':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(bin_names, uplifts, nts, ncs):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")


In [None]:
# 3.8d – Uplift by has focus ICD (binary: No / Yes)

bin_names, uplifts, nts, ncs = [], [], [], []
for val, name in [(0, "No focus ICD"), (1, "Has focus ICD")]:
    ids = cl.loc[cl["has_focus_icd"] == val, "member_id"]
    u, nt, nc = compute_uplift(ids)
    bin_names.append(name); uplifts.append(u); nts.append(nt); ncs.append(nc)

plot_uplift_bars(bin_names, uplifts,
                 title="Uplift by has focus ICD",
                 xlabel="Focus ICD status")

print(f"{'Group':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(bin_names, uplifts, nts, ncs):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")


In [None]:
# 3.8e – Uplift by count of focus ICDs (0, 1, 2, 3)

bin_names, uplifts, nts, ncs = [], [], [], []
for cnt in [0, 1, 2, 3]:
    ids = cl.loc[cl["focus_icd_count"] == cnt, "member_id"]
    u, nt, nc = compute_uplift(ids)
    bin_names.append(f"{cnt} focus ICD"); uplifts.append(u); nts.append(nt); ncs.append(nc)

plot_uplift_bars(bin_names, uplifts,
                 title="Uplift by count of focus ICDs",
                 xlabel="Number of distinct focus ICD codes")

print(f"{'Group':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(bin_names, uplifts, nts, ncs):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")


**What it means:** Uplift by claims (count, focus ICD, quartiles) varies across groups; some segments show stronger or weaker outreach effects.

**What it says about further analysis:** Claims-based features are useful for targeting. Next: recency and tenure (3.9).

### **3.9 Recency and tenure features as uplift moderators**

Derived features: days since last web/app/claims/any activity, tenure in days, and a binary "recent activity within 7 days" flag.
All computed from a single reusable function so the same logic applies to train and test.


In [None]:
# 3.9 – Build recency & tenure features (reusable for train and test)

def build_recency_tenure(members_df, web_df, app_df, claims_df):
    """
    Build member-level recency and tenure features.
    
    Parameters
    ----------
    members_df : DataFrame with at least member_id and signup_date columns.
    web_df     : DataFrame with member_id and timestamp.
    app_df     : DataFrame with member_id and timestamp.
    claims_df  : DataFrame with member_id and diagnosis_date.
    
    Returns
    -------
    DataFrame indexed by member_id with columns:
        days_since_last_web, days_since_last_app, days_since_last_claim,
        days_since_last_activity, tenure_days, recent_any_7d.
    """
    # Global reference date = max observed timestamp across all sources
    ref_date = max(
        pd.to_datetime(web_df["timestamp"], errors="coerce").dropna().max(),
        pd.to_datetime(app_df["timestamp"], errors="coerce").dropna().max(),
        pd.to_datetime(claims_df["diagnosis_date"], errors="coerce").dropna().max(),
    )

    # Per-member last timestamps (intermediates only)
    last_web = web_df.groupby("member_id")["timestamp"].max()
    last_app = app_df.groupby("member_id")["timestamp"].max()
    last_claim = claims_df.groupby("member_id")["diagnosis_date"].max()

    out = members_df[["member_id", "signup_date"]].copy()

    out["days_since_last_web"] = out["member_id"].map(last_web).pipe(lambda s: (ref_date - s).dt.days)
    out["days_since_last_app"] = out["member_id"].map(last_app).pipe(lambda s: (ref_date - s).dt.days)
    out["days_since_last_claim"] = out["member_id"].map(last_claim).pipe(lambda s: (ref_date - s).dt.days)

    # Last activity across all three sources
    last_any = (
        pd.concat([last_web.rename("ts"), last_app.rename("ts"), last_claim.rename("ts")])
        .groupby(level=0).max()
    )
    out["days_since_last_activity"] = out["member_id"].map(last_any).pipe(lambda s: (ref_date - s).dt.days)

    out["tenure_days"] = (ref_date - pd.to_datetime(out["signup_date"], errors="coerce")).dt.days

    out["recent_any_7d"] = (out["days_since_last_activity"].notna() & (out["days_since_last_activity"] <= 7)).astype(int)

    out = out.set_index("member_id").drop(columns=["signup_date"])
    return out, ref_date


recency = build_recency_tenure(churn_labels, web_visits, app_usage, claims)
recency_df, ref_date = recency

# Merge with labels for uplift computation
rec = churn_labels[["member_id", "churn", "outreach"]].merge(recency_df, left_on="member_id", right_index=True)

print(f"Reference date: {ref_date}")
print(f"Members: {len(rec)}")
print(rec[["days_since_last_web", "days_since_last_app", "days_since_last_claim",
           "days_since_last_activity", "tenure_days", "recent_any_7d"]].describe().round(1).to_string())


In [None]:
# 3.9a – Uplift by days_since_last_web quartile

feat = "days_since_last_web"
valid = rec.dropna(subset=[feat]).copy()
excluded = len(rec) - len(valid)

valid["_bin"] = pd.qcut(valid[feat], q=4, duplicates="drop")
if valid["_bin"].nunique() < 2:
    valid["_bin"] = pd.cut(valid[feat], bins=min(4, valid[feat].nunique()))

bin_names, uplifts, nts, ncs = [], [], [], []
for b in sorted(valid["_bin"].dropna().unique()):
    ids = valid.loc[valid["_bin"] == b, "member_id"]
    u, nt, nc = compute_uplift(ids)
    bin_names.append(str(b)); uplifts.append(u); nts.append(nt); ncs.append(nc)

plot_uplift_bars(bin_names, uplifts,
                 title="Uplift by days since last web activity",
                 xlabel="Days-since-last-web bin")

print(f"Excluded from plot (no web activity): {excluded} members (n too small for stable uplift).")
print(f"{'Bin':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(bin_names, uplifts, nts, ncs):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")


In [None]:
# 3.9b – Uplift by days_since_last_app quartile

feat = "days_since_last_app"
valid = rec.dropna(subset=[feat]).copy()
excluded = len(rec) - len(valid)

valid["_bin"] = pd.qcut(valid[feat], q=4, duplicates="drop")
if valid["_bin"].nunique() < 2:
    valid["_bin"] = pd.cut(valid[feat], bins=min(4, valid[feat].nunique()))

bin_names, uplifts, nts, ncs = [], [], [], []
for b in sorted(valid["_bin"].dropna().unique()):
    ids = valid.loc[valid["_bin"] == b, "member_id"]
    u, nt, nc = compute_uplift(ids)
    bin_names.append(str(b)); uplifts.append(u); nts.append(nt); ncs.append(nc)

plot_uplift_bars(bin_names, uplifts,
                 title="Uplift by days since last app activity",
                 xlabel="Days-since-last-app bin")

print(f"Excluded from plot (no app activity): {excluded} members (n too small for stable uplift).")
print(f"{'Bin':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(bin_names, uplifts, nts, ncs):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")


In [None]:
# 3.9c – Uplift by days_since_last_claim quartile

feat = "days_since_last_claim"
valid = rec.dropna(subset=[feat]).copy()
excluded = len(rec) - len(valid)

valid["_bin"] = pd.qcut(valid[feat], q=4, duplicates="drop")
if valid["_bin"].nunique() < 2:
    valid["_bin"] = pd.cut(valid[feat], bins=min(4, valid[feat].nunique()))

bin_names, uplifts, nts, ncs = [], [], [], []
for b in sorted(valid["_bin"].dropna().unique()):
    ids = valid.loc[valid["_bin"] == b, "member_id"]
    u, nt, nc = compute_uplift(ids)
    bin_names.append(str(b)); uplifts.append(u); nts.append(nt); ncs.append(nc)

plot_uplift_bars(bin_names, uplifts,
                 title="Uplift by days since last claim",
                 xlabel="Days-since-last-claim bin")

print(f"Excluded from plot (no claims): {excluded} members (n too small for stable uplift).")
print(f"{'Bin':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(bin_names, uplifts, nts, ncs):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")


In [None]:
# 3.9d – Uplift by days_since_last_activity quartile

feat = "days_since_last_activity"
valid = rec.dropna(subset=[feat]).copy()
excluded = len(rec) - len(valid)

valid["_bin"] = pd.qcut(valid[feat], q=4, duplicates="drop")
if valid["_bin"].nunique() < 2:
    valid["_bin"] = pd.cut(valid[feat], bins=min(4, valid[feat].nunique()))

bin_names, uplifts, nts, ncs = [], [], [], []
for b in sorted(valid["_bin"].dropna().unique()):
    ids = valid.loc[valid["_bin"] == b, "member_id"]
    u, nt, nc = compute_uplift(ids)
    bin_names.append(str(b)); uplifts.append(u); nts.append(nt); ncs.append(nc)

plot_uplift_bars(bin_names, uplifts,
                 title="Uplift by days since last activity (any source)",
                 xlabel="Days-since-last-activity bin")

print(f"Excluded from plot (no activity at all): {excluded} members (n too small for stable uplift).")
print(f"{'Bin':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(bin_names, uplifts, nts, ncs):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")


In [None]:
# 3.9e – Uplift by tenure_days quartile

feat = "tenure_days"
valid = rec.dropna(subset=[feat]).copy()
excluded = len(rec) - len(valid)

valid["_bin"] = pd.qcut(valid[feat], q=4, duplicates="drop")
if valid["_bin"].nunique() < 2:
    valid["_bin"] = pd.cut(valid[feat], bins=min(4, valid[feat].nunique()))

bin_names, uplifts, nts, ncs = [], [], [], []
for b in sorted(valid["_bin"].dropna().unique()):
    ids = valid.loc[valid["_bin"] == b, "member_id"]
    u, nt, nc = compute_uplift(ids)
    bin_names.append(str(b)); uplifts.append(u); nts.append(nt); ncs.append(nc)

plot_uplift_bars(bin_names, uplifts,
                 title="Uplift by tenure (days since signup)",
                 xlabel="Tenure bin (days)")

print(f"Excluded from plot (missing signup_date): {excluded} members (n too small for stable uplift).")
print(f"{'Bin':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(bin_names, uplifts, nts, ncs):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")


**What it means:** Uplift by recency and tenure bins shows how outreach effect varies with how recently members were active and how long they have been members.

**What it says about further analysis:** Recency and tenure are strong candidates for the uplift model. EDA is complete; next steps are feature matrix construction and model training.

---

## **4. Feature Engineering Pipeline**

A modular, reference-date-based pipeline that converts raw event tables into a member-level feature matrix.  
The same functions work for **batch training** (with `ref_date = max activity in the data`) and **real-time scoring** (with `ref_date = scoring_date` or max in batch). **OBS_END** is used only to filter which events are loaded for train (Section 2), not as the feature reference date.

### **4.1 Configuration and helpers**
Constants, the WellCo brief loader, and a helper that derives `ref_date` from event tables to prevent cross-dataset leakage.

In [None]:
# ── 4.1  Configuration and helpers ──────────────────────────────────────────

# --- Pipeline constants (override these to reconfigure) ---------------------
WELLCO_BRIEF_PATH: Path = FILE_DIR / "wellco_client_brief.txt"
SIMILARITY_THRESHOLD: float = 0.2          # cosine-similarity cutoff for web relevance
# Chosen based on the clear gap between irrelevant scores (max ≈ 0.13) and
# relevant scores (min ≈ 0.26).  A threshold of 0.2 sits safely in that gap,
# retaining all WellCo-related content while excluding unrelated visits.
EMBED_MODEL_NAME: str = "all-MiniLM-L6-v2" # sentence-transformers model; swap as needed
FOCUS_ICD_CODES: list[str] = ["E11.9", "I10", "Z71.3"]  # WellCo clinical focus codes


def load_wellco_brief(path: Path | str = WELLCO_BRIEF_PATH) -> str:
    """Read the WellCo client brief from disk and return it as a single string.

    Parameters
    ----------
    path : Path or str
        File path to the brief text file (default: ``WELLCO_BRIEF_PATH``).

    Returns
    -------
    str
        Full text content of the brief.

    Assumptions
    -----------
    * The file exists and is UTF-8 encoded plain text.
    * The text is used as the reference document for semantic-similarity
      filtering of web visits (see 4.2).
    """
    return Path(path).read_text(encoding="utf-8")


def ref_date_from_tables(*dfs: pd.DataFrame) -> pd.Timestamp:
    """Derive a per-dataset reference date from the latest timestamp in event tables.

    This prevents cross-dataset leakage: train and test each compute their own
    ``ref_date`` from their own event tables, rather than sharing a global constant.

    Parameters
    ----------
    *dfs : pd.DataFrame
        One or more event DataFrames.  Each must contain either a ``timestamp``
        or a ``diagnosis_date`` column (or both).

    Returns
    -------
    pd.Timestamp
        The maximum observed timestamp across all supplied tables.

    Assumptions
    -----------
    * At least one DataFrame is provided.
    * Date columns are already parsed as ``datetime64``.
    """
    max_dates: list[pd.Timestamp] = []
    for df in dfs:
        if "timestamp" in df.columns:
            max_dates.append(df["timestamp"].max())
        if "diagnosis_date" in df.columns:
            max_dates.append(df["diagnosis_date"].max())
    if not max_dates:
        raise ValueError("None of the supplied DataFrames contain 'timestamp' or 'diagnosis_date'.")
    return max(max_dates)


def embed_wellco_brief(
    brief_text: str,
    model: SentenceTransformer,
) -> np.ndarray:
    """Embed the WellCo client brief into a single vector.

    Call this **once** at startup; reuse the returned embedding across all
    datasets and scoring calls.

    Parameters
    ----------
    brief_text : str
        Full text of the WellCo client brief.
    model : SentenceTransformer
        Pre-loaded sentence-transformers model.

    Returns
    -------
    np.ndarray
        Shape ``(1, embedding_dim)`` — the brief's embedding vector.
    """
    return model.encode([brief_text])


def embed_visit_texts(
    web_df: pd.DataFrame,
    model: SentenceTransformer,
) -> np.ndarray:
    """Embed the concatenated title + description of each web visit.

    **De-duplication optimisation:** Web-visit tables typically contain many
    rows that share the same (title, description) pair — e.g. 100k+ rows
    but only ~26 unique texts.  Rather than running the neural model on every
    row, this function:

    1. Builds one ``_text`` string per row:  ``(title + " " + description)``.
    2. Uses ``pd.factorize`` to map each row to an integer index
       (0, 1, 2, …) that identifies its *unique* text.
       - ``codes``: array of length ``n_rows``; ``codes[i]`` = which unique
         text row ``i`` belongs to.
       - ``uniques``: array of length ``n_unique``; the actual unique strings.
    3. Embeds only the ``n_unique`` texts with the model (e.g. 26 instead
       of 100k+).
    4. Uses NumPy array indexing (``unique_embeddings[codes]``) to expand
       back to one embedding per row — no Python loop, one vectorised step.

    The result is identical to embedding every row individually, but orders
    of magnitude faster when duplicates dominate.

    Reusable for both batch processing and real-time scoring of individual
    or small batches of visits.

    Parameters
    ----------
    web_df : pd.DataFrame
        Must contain ``title`` and ``description`` columns (NaN allowed).
    model : SentenceTransformer
        Pre-loaded sentence-transformers model (same one used for the brief).

    Returns
    -------
    np.ndarray
        Shape ``(len(web_df), embedding_dim)`` — one embedding per visit.

    Assumptions
    -----------
    * ``title`` and ``description`` may be NaN; they are filled with "".
    * The caller is responsible for any time-filtering before calling this.
    """
    # 1. Build one text string per row
    texts = (
        web_df["title"].fillna("") + " " + web_df["description"].fillna("")
    ).str.strip()

    # 2. Factorize: assign each row an integer id pointing to its unique text
    #    codes  – shape (n_rows,)   : codes[i] = index of row i's unique text
    #    uniques – shape (n_unique,) : the actual unique strings
    codes, uniques = pd.factorize(texts)

    # 3. Embed only the unique texts (e.g. 26 instead of 100k+)
    unique_embeddings = model.encode(uniques.tolist())  # shape (n_unique, dim)

    print(
        f"  embed_visit_texts: {len(uniques):,} unique texts embedded "
        f"(from {len(texts):,} rows)"
    )

    # 4. Map back: one array-indexing step gives every row its embedding
    #    visit_embeddings[i] = unique_embeddings[codes[i]]
    return unique_embeddings[codes]


print(f"Similarity threshold – {SIMILARITY_THRESHOLD}")
print(f"Embedding model      – {EMBED_MODEL_NAME}")
print(f"Focus ICD codes      – {FOCUS_ICD_CODES}")

In [None]:
# ── RUN ONCE: load embedding model & embed WellCo brief ────────────────────
# This cell is intentionally isolated so it runs exactly once per session.
# All downstream cells reuse `embed_model` and `wellco_embedding`.

brief_text = load_wellco_brief()
embed_model = SentenceTransformer(EMBED_MODEL_NAME)
wellco_embedding = embed_wellco_brief(brief_text, embed_model)  # shape (1, dim)

print(f"WellCo brief loaded     – {len(brief_text):,} characters")
print(f"Embedding model loaded  – {EMBED_MODEL_NAME}")
print(f"WellCo embedding shape  – {wellco_embedding.shape}")

### **4.2 Web relevance filtering (embeddings)**
Using the pre-computed `wellco_embedding` and `embed_model` (loaded once in the cell above), embed each web-visit's text and retain only visits whose cosine similarity exceeds the configurable threshold. **Embeddings are used solely for relevance filtering and are never fed into the model as features.**

**Performance optimisation:** `embed_visit_texts` uses `pd.factorize` to identify unique (title, description) pairs and embeds only those (e.g. ~26 unique texts instead of 100k+ rows). Embeddings are then mapped back to all rows via a single NumPy array-indexing step, making the process orders of magnitude faster when duplicates dominate.

In [None]:
# ── 4.2  Web relevance filtering ────────────────────────────────────────────

def filter_wellco_relevant_visits(
    web_df: pd.DataFrame,
    wellco_embedding: np.ndarray,
    embed_model: SentenceTransformer,
    similarity_threshold: float = SIMILARITY_THRESHOLD,
    ref_date: pd.Timestamp | None = None,
) -> pd.DataFrame:
    """Return only web visits that are semantically relevant to the WellCo brief.

    The WellCo brief embedding and the embedding model are created once
    (in the "RUN ONCE" cell) and passed in here — no redundant loading.

    ``embed_visit_texts`` internally de-duplicates texts via ``pd.factorize``
    so the neural model only runs on the *unique* (title, description) pairs
    (e.g. ~26) rather than every row (e.g. 100k+), then maps embeddings back
    to all rows with one NumPy array-indexing step.

    Steps
    -----
    1. (Optional) Filter to ``timestamp <= ref_date`` if a ref_date is supplied.
    2. Embed visit texts via ``embed_visit_texts`` (title + description,
       with internal de-duplication for speed).
    3. Compute cosine similarity of each visit embedding to ``wellco_embedding``.
    4. Keep rows where similarity >= ``similarity_threshold``.

    Parameters
    ----------
    web_df : pd.DataFrame
        Raw web-visits table with columns ``member_id``, ``timestamp``,
        ``title``, ``description``, ``url``.
    wellco_embedding : np.ndarray
        Pre-computed embedding of the WellCo brief, shape ``(1, dim)``.
    embed_model : SentenceTransformer
        Pre-loaded sentence-transformers model (same one used to create
        ``wellco_embedding``).
    similarity_threshold : float
        Minimum cosine similarity to retain a visit (default from config).
    ref_date : pd.Timestamp or None
        If provided, only visits with ``timestamp <= ref_date`` are considered.

    Returns
    -------
    pd.DataFrame
        Subset of ``web_df`` (same schema) containing only WellCo-relevant visits.

    Assumptions
    -----------
    * ``title`` and ``description`` may contain NaN; they are filled with "".
    * Embeddings are used *only* for filtering — no embedding dimensions are
      added to the downstream feature matrix.
    """
    df = web_df.copy()

    # 1. Time filter
    if ref_date is not None:
        df = df[df["timestamp"] <= ref_date]

    if df.empty:
        return df

    # 2. Embed visit texts
    visit_embeddings = embed_visit_texts(df, embed_model)  # shape (n, dim)

    # 3. Cosine similarity → 1-D array
    similarities = cosine_similarity(visit_embeddings, wellco_embedding).flatten()

    # 4. Filter
    mask = similarities >= similarity_threshold
    relevant = df[mask].copy()

    print(
        f"Web relevance filter: {mask.sum():,} / {len(mask):,} visits retained "
        f"(threshold={similarity_threshold})"
    )
    return relevant

### **4.3 Per-source aggregation functions**
Each function takes the relevant event DataFrame(s) and an explicit `ref_date`, filters events to `<= ref_date`, and returns a **member-level** DataFrame keyed by `member_id`.

| Source | Features produced |
|---|---|
| Web (relevant visits) | `wellco_web_visits_count`, `days_since_last_wellco_web` |
| App | `app_sessions_count` |
| Claims | `icd_distinct_count`, `has_focus_icd`, `days_since_last_claim` |
| Lifecycle | `tenure_days` |

In [None]:
# ── 4.3a  Web aggregation ───────────────────────────────────────────────────

def agg_web_features(
    web_relevant_df: pd.DataFrame,
    members_df: pd.DataFrame,
    ref_date: pd.Timestamp,
) -> pd.DataFrame:
    """Aggregate WellCo-relevant web visits into member-level features.

    Parameters
    ----------
    web_relevant_df : pd.DataFrame
        Output of ``filter_wellco_relevant_visits`` — already filtered by
        relevance *and* ``timestamp <= ref_date``.
    members_df : pd.DataFrame
        Member roster (must contain ``member_id``); used as the left table so
        every member gets a row even if they have zero relevant visits.
    ref_date : pd.Timestamp
        Decision / reference date for computing recency.

    Returns
    -------
    pd.DataFrame
        Columns: ``member_id``, ``wellco_web_visits_count``,
        ``wellco_web_unique_urls``, ``days_since_last_wellco_web``.

    Assumptions
    -----------
    * ``web_relevant_df`` has columns ``member_id``, ``timestamp``, ``url``.
    * Members with no relevant visits get count = 0, unique_urls = 0, and
      ``days_since_last_wellco_web = NaN`` (no visit to measure from).
    """
    # Redundant safety filter (already done in filter step)
    wdf = web_relevant_df[web_relevant_df["timestamp"] <= ref_date].copy()

    if wdf.empty:
        out = members_df[["member_id"]].copy()
        out["wellco_web_visits_count"] = 0
        # TOGGLE: uncomment next line to include URL feature in the matrix
        # out["wellco_web_unique_urls"] = 0
        out["days_since_last_wellco_web"] = np.nan
        return out

    agg = (
        wdf.groupby("member_id")
        .agg(
            wellco_web_visits_count=("timestamp", "count"),
            # TOGGLE: uncomment next line to include URL feature in the matrix
            # wellco_web_unique_urls=("url", "nunique"),
            _last_visit=("timestamp", "max"),
        )
        .reset_index()
    )
    agg["days_since_last_wellco_web"] = (ref_date - agg["_last_visit"]).dt.days
    agg.drop(columns="_last_visit", inplace=True)

    # Left join so every member gets a row
    out = members_df[["member_id"]].merge(agg, on="member_id", how="left")
    out["wellco_web_visits_count"] = out["wellco_web_visits_count"].fillna(0).astype(int)
    # TOGGLE: uncomment next line to include URL feature in the matrix
    # out["wellco_web_unique_urls"] = out["wellco_web_unique_urls"].fillna(0).astype(int)
    # days_since_last_wellco_web stays NaN for members with no relevant visits
    return out

In [None]:
# ── 4.3b  App aggregation ───────────────────────────────────────────────────

def agg_app_features(
    app_df: pd.DataFrame,
    members_df: pd.DataFrame,
    ref_date: pd.Timestamp,
) -> pd.DataFrame:
    """Count app sessions per member up to the reference date.

    Parameters
    ----------
    app_df : pd.DataFrame
        Raw app-usage table with columns ``member_id``, ``timestamp``.
    members_df : pd.DataFrame
        Member roster (must contain ``member_id``).
    ref_date : pd.Timestamp
        Decision / reference date; only sessions with ``timestamp <= ref_date``
        are counted.

    Returns
    -------
    pd.DataFrame
        Columns: ``member_id``, ``app_sessions_count``.

    Assumptions
    -----------
    * Each row in ``app_df`` represents one session.
    * Members with no sessions get ``app_sessions_count = 0``.
    """
    adf = app_df[app_df["timestamp"] <= ref_date].copy()

    counts = (
        adf.groupby("member_id")
        .size()
        .rename("app_sessions_count")
        .reset_index()
    )

    out = members_df[["member_id"]].merge(counts, on="member_id", how="left")
    out["app_sessions_count"] = out["app_sessions_count"].fillna(0).astype(int)
    return out

In [None]:
# ── 4.3c  Claims aggregation ────────────────────────────────────────────────

def agg_claims_features(
    claims_df: pd.DataFrame,
    members_df: pd.DataFrame,
    ref_date: pd.Timestamp,
    focus_icd_codes: list[str] = FOCUS_ICD_CODES,
) -> pd.DataFrame:
    """Aggregate claims into member-level diagnostic features.

    Parameters
    ----------
    claims_df : pd.DataFrame
        Raw claims table with columns ``member_id``, ``diagnosis_date``,
        ``icd_code``.
    members_df : pd.DataFrame
        Member roster (must contain ``member_id``).
    ref_date : pd.Timestamp
        Decision / reference date; only claims with
        ``diagnosis_date <= ref_date`` are included.
    focus_icd_codes : list[str]
        ICD-10 codes that define the WellCo clinical focus (default from config).

    Returns
    -------
    pd.DataFrame
        Columns: ``member_id``, ``icd_distinct_count``, ``has_focus_icd``,
        ``days_since_last_claim``.

    Assumptions
    -----------
    * ``icd_code`` is a string column.
    * ``has_focus_icd`` is binary (1 if the member has >= 1 claim with a focus
      ICD code, else 0).
    * Members with no claims get counts = 0, ``has_focus_icd = 0``, and
      ``days_since_last_claim = NaN``.
    """
    cdf = claims_df[claims_df["diagnosis_date"] <= ref_date].copy()

    if cdf.empty:
        out = members_df[["member_id"]].copy()
        out["icd_distinct_count"] = 0
        out["has_focus_icd"] = 0
        out["days_since_last_claim"] = np.nan
        return out

    # Flag focus ICD rows
    cdf["_is_focus"] = cdf["icd_code"].isin(focus_icd_codes)

    agg = (
        cdf.groupby("member_id")
        .agg(
            icd_distinct_count=("icd_code", "nunique"),
            has_focus_icd=("_is_focus", "any"),
            _last_claim=("diagnosis_date", "max"),
        )
        .reset_index()
    )
    agg["has_focus_icd"] = agg["has_focus_icd"].astype(int)
    agg["days_since_last_claim"] = (ref_date - agg["_last_claim"]).dt.days
    agg.drop(columns="_last_claim", inplace=True)

    out = members_df[["member_id"]].merge(agg, on="member_id", how="left")
    out["icd_distinct_count"] = out["icd_distinct_count"].fillna(0).astype(int)
    out["has_focus_icd"] = out["has_focus_icd"].fillna(0).astype(int)
    # days_since_last_claim stays NaN for members with no claims
    return out

In [None]:
# ── 4.3d  Lifecycle / tenure ────────────────────────────────────────────────

def agg_lifecycle_tenure(
    members_df: pd.DataFrame,
    ref_date: pd.Timestamp,
) -> pd.DataFrame:
    """Compute membership tenure in days as of the reference date.

    Parameters
    ----------
    members_df : pd.DataFrame
        Member roster with columns ``member_id`` and ``signup_date``.
    ref_date : pd.Timestamp
        Decision / reference date.

    Returns
    -------
    pd.DataFrame
        Columns: ``member_id``, ``tenure_days``.

    Assumptions
    -----------
    * ``signup_date`` is already parsed as ``datetime64``.
    * If ``signup_date`` is NaT the resulting ``tenure_days`` will be NaN.
    """
    out = members_df[["member_id"]].copy()
    out["tenure_days"] = (ref_date - members_df["signup_date"]).dt.days
    return out

### **4.4 Feature assembly**
`build_feature_matrix` orchestrates the full pipeline: relevance filtering, all four aggregations, and a single merge into one member-level table. Pass `include_labels=True` to attach `outreach` and `churn` columns for training.

In [None]:
# ── 4.4  Feature assembly ───────────────────────────────────────────────────

def build_feature_matrix(
    members_df: pd.DataFrame,
    web_df: pd.DataFrame,
    app_df: pd.DataFrame,
    claims_df: pd.DataFrame,
    ref_date: pd.Timestamp,
    *,
    wellco_embedding: np.ndarray,
    embed_model: SentenceTransformer,
    similarity_threshold: float = SIMILARITY_THRESHOLD,
    focus_icd_codes: list[str] = FOCUS_ICD_CODES,
    include_labels: bool = False,
) -> pd.DataFrame:
    """Build the full member-level feature matrix from raw event tables.

    This is the single entry-point for both batch training and real-time
    scoring.  It calls the relevance filter and all four aggregation functions,
    then merges everything on ``member_id``.

    Parameters
    ----------
    members_df : pd.DataFrame
        Member roster.  Must contain ``member_id`` and ``signup_date``.
        For training data this is ``churn_labels`` (which also has ``outreach``
        and ``churn``).
    web_df, app_df, claims_df : pd.DataFrame
        Raw event tables (web visits, app sessions, claims).
    ref_date : pd.Timestamp
        Decision / reference date.  All event filters use ``<= ref_date``.
    wellco_embedding : np.ndarray
        Pre-computed WellCo brief embedding, shape ``(1, dim)``.
        Created once in the "RUN ONCE" cell.
    embed_model : SentenceTransformer
        Pre-loaded sentence-transformers model (same one used to create
        ``wellco_embedding``).
    similarity_threshold : float
        Cosine-similarity cutoff for web relevance (default from config).
    focus_icd_codes : list[str]
        ICD-10 focus codes (default from config).
    include_labels : bool
        If True and ``members_df`` contains ``outreach`` / ``churn``, append
        those columns to the output (for training only).

    Returns
    -------
    pd.DataFrame
        One row per member.  Feature columns: ``wellco_web_visits_count``,
        ``days_since_last_wellco_web``, ``app_sessions_count``,
        ``icd_distinct_count``, ``has_focus_icd``, ``days_since_last_claim``,
        ``tenure_days`` (7 features).
        Optionally ``outreach`` and ``churn`` if ``include_labels=True``.

    Assumptions
    -----------
    * ``members_df`` is the source of truth for the member list.
    * Event tables may have extra members (ignored) or fewer (filled with
      0 / NaN as documented in each aggregation function).
    """
    # 1. Relevance filter on web visits
    web_relevant = filter_wellco_relevant_visits(
        web_df,
        wellco_embedding=wellco_embedding,
        embed_model=embed_model,
        similarity_threshold=similarity_threshold,
        ref_date=ref_date,
    )

    # 2. Per-source aggregations
    feat_web = agg_web_features(web_relevant, members_df, ref_date)
    feat_app = agg_app_features(app_df, members_df, ref_date)
    feat_claims = agg_claims_features(claims_df, members_df, ref_date, focus_icd_codes)
    feat_life = agg_lifecycle_tenure(members_df, ref_date)

    # 3. Merge all on member_id (left from members so every member has a row)
    feature_matrix = (
        members_df[["member_id"]]
        .merge(feat_web, on="member_id", how="left")
        .merge(feat_app, on="member_id", how="left")
        .merge(feat_claims, on="member_id", how="left")
        .merge(feat_life, on="member_id", how="left")
    )

    # 4. Attach labels if requested
    if include_labels:
        label_cols = [c for c in ("outreach", "churn") if c in members_df.columns]
        if label_cols:
            feature_matrix = feature_matrix.merge(
                members_df[["member_id"] + label_cols], on="member_id", how="left"
            )

    return feature_matrix

### **4.5 Build train and test feature matrices**
**Train events** are filtered to those that occurred **until OBS_END** (Section 2: `web_visits`, `app_usage`, `claims` use `timestamp` / `diagnosis_date` < OBS_END). So the train feature matrix is built from events up to that date only.

**Ref date:** For both train and test, `ref_date` is the **max timestamp** in that dataset's event tables (`ref_date_from_tables(...)`), so features are "as of" the latest activity in the data. This keeps the pipeline reproducible at any future decision time. OBS_END is not passed as the feature reference date; it only defines which events are included in the train tables.

In [None]:
# ── 4.5  Build train & test feature matrices ────────────────────────────────

# --- Reference dates: max activity in each dataset (reproducible at any decision time) ---
# Train tables (web_visits, app_usage, claims) contain only events that occurred until OBS_END (filtered in Section 2).
# ref_date is the max timestamp in those tables, not OBS_END itself.
ref_date_train = ref_date_from_tables(web_visits, app_usage, claims)
ref_date_test = ref_date_from_tables(test_web_visits, test_app_usage, test_claims)

print(f"ref_date_train = {ref_date_train}  (max activity in train tables)")
print(f"ref_date_test  = {ref_date_test}  (max activity in test tables)")

# --- Train ------------------------------------------------------------------
print("\n── Building TRAIN feature matrix ──")
train_features = build_feature_matrix(
    members_df=churn_labels,
    web_df=web_visits,
    app_df=app_usage,
    claims_df=claims,
    ref_date=ref_date_train,
    wellco_embedding=wellco_embedding,
    embed_model=embed_model,
    include_labels=True,
)

# --- Test -------------------------------------------------------------------
print("\n── Building TEST feature matrix ──")
test_features = build_feature_matrix(
    members_df=test_members,
    web_df=test_web_visits,
    app_df=test_app_usage,
    claims_df=test_claims,
    ref_date=ref_date_test,
    wellco_embedding=wellco_embedding,
    embed_model=embed_model,
    include_labels=False,
)

# --- Quick summary ----------------------------------------------------------
print(f"\nTrain features shape: {train_features.shape}")
print(f"Test  features shape: {test_features.shape}")
print(f"\nTrain columns: {list(train_features.columns)}")
print(f"Test  columns: {list(test_features.columns)}")
print(f"\nTrain head:\n{train_features.head()}")

### **4.6 Feature diagnostics (informational only)**
Inspect distributions and multicollinearity of the engineered features **on the training set**. This section is for review only — no features are automatically dropped or transformed.

In [None]:
# ── 4.6a  Per-feature distribution diagnostics ──────────────────────────────

# Feature columns only (exclude member_id and labels)
FEATURE_COLS = [
    "wellco_web_visits_count",
    # TOGGLE: uncomment next line to include URL feature in diagnostics
    # "wellco_web_unique_urls",
    "days_since_last_wellco_web",
    "app_sessions_count",
    "icd_distinct_count",
    "has_focus_icd",
    "days_since_last_claim",
    "tenure_days",
]

# Summary statistics
print("=" * 70)
print("FEATURE SUMMARY STATISTICS (train)")
print("=" * 70)
print(train_features[FEATURE_COLS].describe().T.to_string())

# Percentage of zeros and missing
print("\n" + "=" * 70)
print("ZEROS AND MISSING VALUES (train)")
print("=" * 70)
n = len(train_features)
for col in FEATURE_COLS:
    pct_zero = (train_features[col] == 0).sum() / n * 100
    pct_miss = train_features[col].isna().sum() / n * 100
    print(f"  {col:<35s}  zeros: {pct_zero:6.2f}%   missing: {pct_miss:6.2f}%")

In [None]:
# ── 4.6b  Histograms ────────────────────────────────────────────────────────
# Each plot: x = feature value (e.g. number of visits), y = how many members have that value.
FEATURE_XLABELS = {
    "wellco_web_visits_count": "Relevant web visits per member",
    "wellco_web_unique_urls": "Unique URLs (WellCo-relevant visits)",
    "days_since_last_wellco_web": "Days since last relevant web visit",
    "app_sessions_count": "App sessions per member",
    "icd_distinct_count": "Distinct ICD codes per member",
    "has_focus_icd": "Has focus ICD (0 = no, 1 = yes)",
    "days_since_last_claim": "Days since last claim",
    "tenure_days": "Tenure (days since signup)",
}

fig, axes = plt.subplots(2, 4, figsize=(18, 8))
axes = axes.flatten()

for i, col in enumerate(FEATURE_COLS):
    ax = axes[i]
    data = train_features[col].dropna()
    ax.hist(data, bins=40, edgecolor="white", alpha=0.8)
    ax.set_title(col, fontsize=10, fontweight="bold")
    ax.set_xlabel(FEATURE_XLABELS.get(col, col), fontsize=10)
    ax.set_ylabel("Number of members", fontsize=10)
    # Annotate with basic stats
    ax.axvline(data.median(), color="red", linestyle="--", linewidth=1, label=f"median={data.median():.1f}")
    ax.legend(fontsize=7)
    set_axes_clear(ax, x_axis_at_zero=False)

# TOGGLE: comment next line when URL is in FEATURE_COLS (8 features fill the 2x4 grid)
axes[7].set_visible(False)  # 7 features → hide empty 8th subplot

plt.suptitle("Feature distributions: how many members have each value (train set)", fontsize=14, fontweight="bold", y=1.02)
plt.tight_layout()
plt.show()


**What we see in these distribution plots (train set):** Each histogram shows how many members have each value for one feature. **wellco_web_visits_count:** Strong right skew; most members in the 0–5 or 5–14 range, long tail up to 62 visits — a few heavy engagers. **wellco_web_unique_urls:** Almost identical shape to visit count (same data, one row per URL), so redundant; we drop it from the matrix. **days_since_last_wellco_web:** Right-skewed; many at 0–1 days, tail out to 13; ~2% missing (no relevant visit). **app_sessions_count:** More symmetric, roughly bell-shaped; median ~10, less skew than web/claims. **icd_distinct_count:** Multi-modal (several peaks at 4, 5, 6 codes). **has_focus_icd:** Almost all 1 (92%+); near-binary. **days_since_last_claim:** Right-skewed; peak at 1–2 days. **tenure_days:** Right-skewed; spread from 45 to 561 days. For tree-based uplift models we use these as-is; for linear models we'd log1p skewed counts and keep binary 0/1.

In [None]:
# ── 4.6c  Multicollinearity diagnostic ──────────────────────────────────────

corr = train_features[FEATURE_COLS].corr()

plt.figure(figsize=(10, 8))
sns.heatmap(
    corr,
    annot=True,
    fmt=".2f",
    cmap="coolwarm",
    vmin=-1,
    vmax=1,
    square=True,
    linewidths=0.5,
)
plt.title("Feature Correlation Matrix (train set)", fontsize=14, fontweight="bold")
plt.xlabel("Feature", fontsize=12)
plt.ylabel("Feature", fontsize=12)
set_axes_clear(plt.gca(), x_axis_at_zero=False)
plt.tight_layout()
plt.show()

# Flag highly correlated pairs (|r| >= 0.8)
HIGH_CORR_THRESHOLD = 0.8
print(f"\nPairs with |correlation| >= {HIGH_CORR_THRESHOLD}:")
flagged = []
for i in range(len(FEATURE_COLS)):
    for j in range(i + 1, len(FEATURE_COLS)):
        r = corr.iloc[i, j]
        if abs(r) >= HIGH_CORR_THRESHOLD:
            flagged.append((FEATURE_COLS[i], FEATURE_COLS[j], r))
            print(f"  {FEATURE_COLS[i]}  ↔  {FEATURE_COLS[j]}  :  r = {r:.3f}")
if not flagged:
    print("  (none)")


**What we see in this heatmap:** The plot above is the correlation matrix of the 8 feature columns (including `wellco_web_unique_urls`). The only pair with |r| ≥ 0.8 is **wellco_web_visits_count** and **wellco_web_unique_urls** — they are **perfectly correlated** (r = 1.0). That is because both were computed from the same WellCo-relevant rows: in that subset each visit is one row and one URL, so the two counts are identical. All other pairs in this matrix are moderate or weak (e.g. tenure vs. recency, app vs. web); no other near-perfect correlation appears. **Decision:** We drop `wellco_web_unique_urls` from the feature matrix for modeling (7 features) to avoid multicollinearity; the diagnostics above still show the 8-feature version so it is clear we tried the URL feature and why we dropped it.

#### **4.6d Uplift by WellCo-relevant web features**
In the EDA (Section 3.7) we checked uplift by **raw** web activity (all visits). Now that the pipeline filters visits to only WellCo-relevant content, we verify that **these filtered features** carry uplift signal. We plot uplift by **three** web-related quantities: `wellco_web_visits_count`, `wellco_web_unique_urls` (included here for diagnostics only; we tried it but removed it from the feature matrix — see 4.6c), and `days_since_last_wellco_web`. Plot outputs below were generated when the matrix still had the URL column; the feature matrix used for modeling now has 7 features (no URL).

In [None]:
# ── 4.6d  Uplift by WellCo-relevant web features ───────────────────────────
# Uses the existing `compute_uplift` and `plot_uplift_bars` from Section 3.6,
# and the `labels` DataFrame (= churn_labels[["member_id","churn","outreach"]]).
# `train_features` already contains the three web features plus outreach/churn.

# Helper: merge train_features with labels so compute_uplift can look up IDs
_web_uplift_df = train_features[["member_id",
                                  "wellco_web_visits_count",
                                  "wellco_web_unique_urls",
                                  "days_since_last_wellco_web"]].copy()

# ── 1. Uplift by wellco_web_visits_count (quartiles) ───────────────────────
_web_uplift_df["_bin"] = pd.qcut(
    _web_uplift_df["wellco_web_visits_count"], q=4, duplicates="drop"
)
bin_names, uplifts, nts, ncs = [], [], [], []
for b in sorted(_web_uplift_df["_bin"].dropna().unique()):
    ids = _web_uplift_df.loc[_web_uplift_df["_bin"] == b, "member_id"]
    u, nt, nc = compute_uplift(ids)
    bin_names.append(str(b)); uplifts.append(u); nts.append(nt); ncs.append(nc)

plot_uplift_bars(bin_names, uplifts,
                 title="Uplift by WellCo-relevant web visits (count)",
                 xlabel="wellco_web_visits_count (quartile)")

print(f"{'Bin':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(bin_names, uplifts, nts, ncs):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")

# ── 2. Uplift by wellco_web_unique_urls (quartiles) ────────────────────────
_web_uplift_df["_bin"] = pd.qcut(
    _web_uplift_df["wellco_web_unique_urls"], q=4, duplicates="drop"
)
bin_names, uplifts, nts, ncs = [], [], [], []
for b in sorted(_web_uplift_df["_bin"].dropna().unique()):
    ids = _web_uplift_df.loc[_web_uplift_df["_bin"] == b, "member_id"]
    u, nt, nc = compute_uplift(ids)
    bin_names.append(str(b)); uplifts.append(u); nts.append(nt); ncs.append(nc)

plot_uplift_bars(bin_names, uplifts,
                 title="Uplift by WellCo-relevant unique URLs",
                 xlabel="wellco_web_unique_urls (quartile)")

print(f"{'Bin':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(bin_names, uplifts, nts, ncs):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")

# ── 3. Uplift by days_since_last_wellco_web (quartiles) ────────────────────
# Members with NaN (no relevant visits at all) are excluded from binning
# but reported separately.
_valid = _web_uplift_df.dropna(subset=["days_since_last_wellco_web"])
_excluded = len(_web_uplift_df) - len(_valid)
_valid = _valid.copy()
_valid["_bin"] = pd.qcut(
    _valid["days_since_last_wellco_web"], q=4, duplicates="drop"
)
bin_names, uplifts, nts, ncs = [], [], [], []
for b in sorted(_valid["_bin"].dropna().unique()):
    ids = _valid.loc[_valid["_bin"] == b, "member_id"]
    u, nt, nc = compute_uplift(ids)
    bin_names.append(str(b)); uplifts.append(u); nts.append(nt); ncs.append(nc)

plot_uplift_bars(bin_names, uplifts,
                 title="Uplift by days since last WellCo-relevant web visit",
                 xlabel="days_since_last_wellco_web (quartile)")

print(f"Excluded from plot (no relevant web visits): {_excluded} members.")
print(f"{'Bin':<25} {'Uplift':>8} {'n_treated':>10} {'n_control':>10}")
for name, u, nt, nc in zip(bin_names, uplifts, nts, ncs):
    print(f"{name:<25} {u:>8.4f} {nt:>10} {nc:>10}")

**What we see in these uplift plots:** Each bar chart bins members by quartiles of one WellCo-relevant web feature and shows the **uplift** (churn rate difference: control − treated) in that bin. **Plot 1 (wellco_web_visits_count):** Uplift is small and slightly negative across quartiles; more visits do not show a clearly stronger or weaker outreach effect here. **Plot 2 (wellco_web_unique_urls):** Mirrors the visit-count plot (same redundancy as in the correlation heatmap); we kept this plot to show we tried the URL feature. **Plot 3 (days_since_last_wellco_web):** Members with no relevant visit are excluded; across quartiles uplift is again modest. Overall, these filtered web features show some variation in uplift by segment but no single strong moderator; they remain useful as inputs to the uplift model rather than as standalone targeting rules.

### **4.7 Relevance filter sanity test**
The 26 unique (title, description) pairs in the web data split into two groups based on the WellCo brief (nutrition, exercise, sleep, stress, diabetes, hypertension, cardiometabolic health). This test uses a **new fixture** (different members 10–12 and a different mix of titles) so you can verify that the current `SIMILARITY_THRESHOLD` generalizes — expected counts are from ground truth (`WELLCO_RELEVANT_TITLES` / `NOT_RELEVANT_TITLES`). If the test passes, the threshold is separating relevant from non-relevant correctly on these unseen examples.

In [None]:
# ── 4.7  Relevance filter sanity test ───────────────────────────────────────
#
# Ground-truth grouping of the 26 unique (title, description) pairs in the
# web-visits data, classified by whether the content aligns with the WellCo
# brief (nutrition, exercise, sleep, stress, diabetes, hypertension,
# cardiometabolic health).

WELLCO_RELEVANT_TITLES: set[str] = {
    "Diabetes management",
    "Hypertension basics",
    "Stress reduction",
    "Restorative sleep tips",
    "Healthy eating guide",
    "Aerobic exercise",
    "HbA1c targets",
    "Strength training basics",
    "Lowering blood pressure",
    "Sleep hygiene",
    "Mediterranean diet",
    "Cardio workouts",
    "Exercise routines",
    "Meditation guide",
    "Cardiometabolic health",
    "High-fiber meals",
    "Cholesterol friendly foods",
    "Weight management",
}  # 18 titles

NOT_RELEVANT_TITLES: set[str] = {
    "Gadget roundup",
    "Game reviews",
    "New releases",
    "Dog training",
    "Electric vehicles",
    "Budget planning",
    "Match highlights",
    "Top destinations",
}  # 8 titles

assert len(WELLCO_RELEVANT_TITLES) + len(NOT_RELEVANT_TITLES) == 26, \
    "Expected 26 unique titles total"

# ── Build a small fixture DataFrame (new examples to retest the threshold) ───
# Different members and a different mix of titles than before, so we can check
# that the current SIMILARITY_THRESHOLD generalizes (e.g. 0.2 works on unseen examples).
# Ground truth: expected counts come from WELLCO_RELEVANT_TITLES / NOT_RELEVANT_TITLES.

_test_rows = [
    # Member 10: 2 relevant, 1 not
    (10, "https://x.com/1", "Stress reduction",        "Mindfulness and wellness",                    "2025-08-01 10:00:00"),
    (10, "https://x.com/2", "Healthy eating guide",    "Nutrition and balanced diet",                 "2025-08-02 11:00:00"),
    (10, "https://x.com/3", "Gadget roundup",          "Smartphones and laptops news",               "2025-08-03 12:00:00"),
    # Member 11: 3 relevant, 1 not
    (11, "https://y.com/1", "Cardio workouts",         "Exercise and recovery",                       "2025-08-04 09:00:00"),
    (11, "https://y.com/2", "Meditation guide",        "Mindfulness and relaxation",                 "2025-08-05 10:00:00"),
    (11, "https://y.com/3", "Aerobic exercise",        "Cardio and endurance",                       "2025-08-06 11:00:00"),
    (11, "https://y.com/4", "New releases",            "Box office and trailers",                    "2025-08-07 14:00:00"),
    # Member 12: 0 relevant, 2 not
    (12, "https://z.com/1", "Match highlights",        "League standings and transfers",             "2025-08-08 08:00:00"),
    (12, "https://z.com/2", "Top destinations",       "City guides and itineraries",                "2025-08-09 16:00:00"),
]

test_web = pd.DataFrame(_test_rows, columns=["member_id", "url", "title", "description", "timestamp"])
test_web["timestamp"] = pd.to_datetime(test_web["timestamp"])

test_ref = pd.Timestamp("2025-08-15")

# ── Run the filter ──────────────────────────────────────────────────────────
filtered = filter_wellco_relevant_visits(
    test_web,
    wellco_embedding=wellco_embedding,
    embed_model=embed_model,
    similarity_threshold=SIMILARITY_THRESHOLD,
    ref_date=test_ref,
)

counts = filtered.groupby("member_id").size()
unique_urls = filtered.groupby("member_id")["url"].nunique()

# Expected by ground truth: Member 10 → 2 relevant, 2 unique URLs; 11 → 3, 3; 12 → 0
assert counts.get(10, 0) == 2, f"Member 10: expected 2 relevant visits, got {counts.get(10, 0)}"
assert counts.get(11, 0) == 3, f"Member 11: expected 3 relevant visits, got {counts.get(11, 0)}"
assert counts.get(12, 0) == 0, f"Member 12: expected 0 relevant visits, got {counts.get(12, 0)}"
assert unique_urls.get(10, 0) == 2, f"Member 10: expected 2 unique URLs, got {unique_urls.get(10, 0)}"
assert unique_urls.get(11, 0) == 3, f"Member 11: expected 3 unique URLs, got {unique_urls.get(11, 0)}"

# Verify no irrelevant titles leaked through
assert filtered["title"].isin(NOT_RELEVANT_TITLES).sum() == 0, \
    "Filter let through visits with non-relevant titles!"

# Verify all retained titles are from the relevant set
assert filtered["title"].isin(WELLCO_RELEVANT_TITLES).all(), \
    "Filter retained titles outside the expected relevant set!"

print("✓ All relevance-filter sanity checks passed.")
print(f"  Member 10: {counts.get(10, 0)} visits, {unique_urls.get(10, 0)} unique URLs")
print(f"  Member 11: {counts.get(11, 0)} visits, {unique_urls.get(11, 0)} unique URLs")
print(f"  Member 12: {counts.get(12, 0)} visits (correctly excluded)")

---

## **5. Model Selection — Uplift-Only Cross-Validation**

Evaluate 6 candidate uplift models (S-, T-, X-learner × LightGBM, XGBoost) using **stratified K-fold CV** that preserves treatment × churn balance. Models are compared with **uplift-only metrics** (AUUC, Qini, uplift@10 %, uplift@20 %) — no AUC-ROC, accuracy, or other predictive metrics. Segment stability across folds is also reported.

**NaN handling:** LightGBM and XGBoost handle NaN natively. CausalML meta-learners pass the feature table through to the base learner. **No imputation** is applied for fit / predict. Imputation is reserved for SHAP only (later section).

### **5.1 Configuration**
Feature list, cross-validation parameters, and base-learner factories. All hyperparameters are set here for reproducibility.

In [None]:
# ── 5.1  Configuration ──────────────────────────────────────────────────────

# Feature columns (reused from Section 4.6 — excludes member_id and labels)
FEATURE_COLS = [
    "wellco_web_visits_count",
    "days_since_last_wellco_web",
    "app_sessions_count",
    "icd_distinct_count",
    "has_focus_icd",
    "days_since_last_claim",
    "tenure_days",
]

# Cross-validation
N_SPLITS = 5
RANDOM_STATE = 42

# Evaluation cut-offs (fraction of population ranked by predicted uplift)
TOP_K_LIST = [0.10, 0.20]

# Number of points for cumulative uplift curve
N_CURVE_POINTS = 100


def make_lgbm():
    """Create a LightGBM regressor configured for shallow, regularised trees.

    ``class_weight='balanced'`` up-weights the minority class (churn = 1)
    so the base learner is aware of the ~20 % churn imbalance.
    """
    return LGBMRegressor(
        n_estimators=300,
        learning_rate=0.05,
        max_depth=4,
        num_leaves=31,
        min_child_samples=100,
        subsample=0.9,
        colsample_bytree=0.9,
        reg_alpha=0.1,
        reg_lambda=1.0,
        class_weight="balanced",
        random_state=RANDOM_STATE,
        verbose=-1,
    )


def make_xgb(scale_pos_weight: float = 1.0):
    """Create an XGBoost regressor configured for shallow, regularised trees.

    Parameters
    ----------
    scale_pos_weight : float
        Ratio of negative to positive examples — compensates for churn
        class imbalance.  Computed from the training fold at runtime.
    """
    return XGBRegressor(
        n_estimators=300,
        learning_rate=0.05,
        max_depth=4,
        min_child_weight=5,
        subsample=0.9,
        colsample_bytree=0.9,
        reg_alpha=0.1,
        reg_lambda=1.0,
        scale_pos_weight=scale_pos_weight,
        objective="binary:logistic",
        eval_metric="logloss",
        random_state=RANDOM_STATE,
        verbosity=0,
        n_jobs=-1,
    )


# Candidate definitions: (display_name, meta-learner_key, base-learner_key)
CANDIDATE_DEFS = [
    ("S + LGBM", "S", "LGBM"),
    ("S + XGB",  "S", "XGB"),
    ("T + LGBM", "T", "LGBM"),
    ("T + XGB",  "T", "XGB"),
    ("X + LGBM", "X", "LGBM"),
    ("X + XGB",  "X", "XGB"),
]

print(f"Features ({len(FEATURE_COLS)}): {FEATURE_COLS}")
print(f"CV: {N_SPLITS}-fold, random_state={RANDOM_STATE}")
print(f"Candidates: {[c[0] for c in CANDIDATE_DEFS]}")

### **5.2 Data preparation**
Build the modelling arrays from `train_features`. `member_id` is **excluded** (asserted). The combined stratification variable `2 * treatment + churn` creates four cells so that `StratifiedKFold` preserves treatment × churn balance in every fold.

In [None]:
# ── 5.2  Data preparation ───────────────────────────────────────────────────

X = train_features[FEATURE_COLS].copy()
y = train_features["churn"].astype(int).values
treatment = train_features["outreach"].astype(int).values

assert "member_id" not in X.columns, "member_id must NOT be a feature!"
assert X.shape[1] == len(FEATURE_COLS)

# Combined stratification: 4 groups (control-nochurn, control-churn,
#                                     treated-nochurn, treated-churn)
stratify_col = 2 * treatment + y

# Class-imbalance ratio for XGBoost (negative / positive)
_pos = y.sum()
_neg = len(y) - _pos
SCALE_POS_WEIGHT = _neg / max(_pos, 1)

print(f"X shape:           {X.shape}")
print(f"Churn rate:        {y.mean():.3f}")
print(f"Treatment rate:    {treatment.mean():.3f}")
print(f"scale_pos_weight:  {SCALE_POS_WEIGHT:.2f}")
print(f"\nStratify groups:   {dict(zip(*np.unique(stratify_col, return_counts=True)))}")
print(f"\nNaN counts:\n{X.isna().sum().to_string()}")

**What it means:** The feature matrix has 7 columns and 10,000 rows. Churn rate is ~20 %, treatment rate ~40 %. Two features (`days_since_last_wellco_web`, `days_since_last_claim`) have NaN for members with no activity in that source — these are handled natively by the tree-based base learners.

**What it says about further analysis:** The four stratification groups are large enough for 5-fold CV. The class imbalance ratio (~4:1) is passed to XGBoost via `scale_pos_weight` and to LGBM via `class_weight="balanced"`.

### **5.3 Metric helper functions**
Custom metric functions used inside the CV loop. All metrics follow the convention **positive uplift = outreach reduces churn** (i.e. `churn_control − churn_treated`).

| Function | What it computes |
|---|---|
| `uplift_at_k` | Realised uplift in the top-*k* fraction ranked by predicted uplift |
| `uplift_curve` | Cumulative uplift at evenly spaced fractions of the population |
| `approx_auuc` | Area under the uplift curve (trapezoidal rule) |
| `assign_segments` | Four uplift segments based on predicted-uplift quartiles |

In [None]:
# ── 5.3  Metric helper functions ────────────────────────────────────────────


def uplift_at_k(
    y_true: np.ndarray,
    t_true: np.ndarray,
    uplift_scores: np.ndarray,
    k: float,
) -> float:
    """Realised uplift in the top-*k* fraction of the population.

    Parameters
    ----------
    y_true : array of int
        Observed churn labels (0/1).
    t_true : array of int
        Treatment indicator (1 = outreach, 0 = control).
    uplift_scores : array of float
        Predicted uplift (higher = more benefit from treatment).
    k : float in (0, 1]
        Fraction of the population to consider (e.g. 0.10 for top 10 %).

    Returns
    -------
    float
        ``churn_rate_control - churn_rate_treated`` in the top-k segment.
        Positive means outreach reduces churn in that segment.
    """
    n = max(1, int(len(uplift_scores) * k))
    idx = np.argsort(-uplift_scores)[:n]
    y_sub, t_sub = y_true[idx], t_true[idx]
    treated = t_sub == 1
    control = t_sub == 0
    if treated.sum() == 0 or control.sum() == 0:
        return np.nan
    return float(y_sub[control].mean() - y_sub[treated].mean())


def uplift_curve(
    y_true: np.ndarray,
    t_true: np.ndarray,
    uplift_scores: np.ndarray,
    n_points: int = 100,
):
    """Cumulative uplift curve evaluated at ``n_points`` fractions.

    Parameters
    ----------
    y_true, t_true, uplift_scores : arrays
        Same as ``uplift_at_k``.
    n_points : int
        Number of evenly spaced evaluation points from 1 % to 100 %.

    Returns
    -------
    ks : array of float
        Fraction values (0.01 ... 1.0).
    uplift_vals : array of float
        Realised uplift at each fraction.
    """
    order = np.argsort(-uplift_scores)
    y_sorted = y_true[order]
    t_sorted = t_true[order]
    ks = np.linspace(0.01, 1.0, n_points)
    vals = []
    for frac in ks:
        n = max(1, int(len(y_sorted) * frac))
        y_sub, t_sub = y_sorted[:n], t_sorted[:n]
        nt, nc = (t_sub == 1).sum(), (t_sub == 0).sum()
        if nt == 0 or nc == 0:
            vals.append(np.nan)
        else:
            vals.append(y_sub[t_sub == 0].mean() - y_sub[t_sub == 1].mean())
    return ks, np.array(vals)


def approx_auuc(ks: np.ndarray, uplift_vals: np.ndarray) -> float:
    """Area under the uplift curve (trapezoidal rule, NaN-safe).

    Parameters
    ----------
    ks : array of float
        Fraction values returned by ``uplift_curve``.
    uplift_vals : array of float
        Corresponding uplift values.

    Returns
    -------
    float
        Approximate AUUC.  Returns ``np.nan`` if all values are NaN.
    """
    valid = ~np.isnan(uplift_vals)
    if valid.sum() < 2:
        return np.nan
    return float(np.trapz(uplift_vals[valid], ks[valid]))


def assign_segments(uplift_scores: np.ndarray) -> np.ndarray:
    """Assign four uplift segments based on predicted-uplift quartiles.

    Segments
    --------
    - **Persuadables** (top quartile): highest predicted treatment benefit.
    - **Sure Things** (Q50-Q75): moderate predicted benefit.
    - **Lost Causes** (Q25-Q50): low predicted benefit.
    - **Do-Not-Disturb** (bottom quartile): lowest / negative predicted benefit.

    Parameters
    ----------
    uplift_scores : array of float
        Predicted uplift per member.

    Returns
    -------
    array of str
        Segment labels, same length as ``uplift_scores``.
    """
    q25, q50, q75 = np.nanquantile(uplift_scores, [0.25, 0.50, 0.75])
    seg = np.empty(len(uplift_scores), dtype=object)
    seg[uplift_scores >= q75]                                   = "Persuadables"
    seg[(uplift_scores >= q50) & (uplift_scores < q75)]         = "Sure Things"
    seg[(uplift_scores >= q25) & (uplift_scores < q50)]         = "Lost Causes"
    seg[uplift_scores < q25]                                    = "Do-Not-Disturb"
    return seg


print("Metric helpers loaded: uplift_at_k, uplift_curve, approx_auuc, assign_segments")

### **5.4 Stratified K-fold cross-validation**
For each of the 6 candidate models, run `N_SPLITS`-fold CV. In every fold:

1. **Fit** the meta-learner on the training split.
2. **Predict uplift** on the validation split.
3. **Compute** AUUC, uplift@10 %, uplift@20 %, and assign four segments.
4. **Store** per-fold metrics and the cumulative uplift curve for later visualisation.

The loop collects everything in `cv_records` (one row per model × fold) and `cv_curves` / `cv_segments` for the diagnostic plots.

In [None]:
# ── 5.4  Stratified K-fold cross-validation ─────────────────────────────────


def build_model(meta_key: str, base_key: str, spw: float):
    """Instantiate a CausalML meta-learner with the requested base learner.

    Parameters
    ----------
    meta_key : str
        One of ``'S'``, ``'T'``, ``'X'``.
    base_key : str
        One of ``'LGBM'``, ``'XGB'``.
    spw : float
        ``scale_pos_weight`` for XGBoost (ignored for LGBM).

    Returns
    -------
    CausalML meta-learner instance.
    """
    learner = make_lgbm() if base_key == "LGBM" else make_xgb(spw)
    meta_map = {"S": BaseSRegressor, "T": BaseTRegressor, "X": BaseXRegressor}
    return meta_map[meta_key](learner=learner)


# ---- Storage ----------------------------------------------------------------
skf = StratifiedKFold(n_splits=N_SPLITS, shuffle=True, random_state=RANDOM_STATE)

cv_records  = []     # one dict per (model, fold)
cv_curves   = {}     # model_name -> list of (ks, uplift_vals) per fold
cv_segments = {}     # model_name -> list of segment-share Series per fold

# ---- Loop -------------------------------------------------------------------
for name, meta_key, base_key in CANDIDATE_DEFS:
    cv_curves[name]   = []
    cv_segments[name] = []

    for fold_i, (tr_idx, va_idx) in enumerate(skf.split(X, stratify_col), start=1):
        X_tr, X_va = X.iloc[tr_idx], X.iloc[va_idx]
        y_tr, y_va = y[tr_idx], y[va_idx]
        t_tr, t_va = treatment[tr_idx], treatment[va_idx]

        # Fold-specific imbalance ratio
        spw = (y_tr == 0).sum() / max((y_tr == 1).sum(), 1)

        # Build & fit
        model = build_model(meta_key, base_key, spw)
        model.fit(X_tr, t_tr, y_tr)

        # Predict uplift (positive = outreach reduces churn)
        tau = np.asarray(model.predict(X_va)).reshape(-1)

        # Metrics
        ks, uvals = uplift_curve(y_va, t_va, tau, n_points=N_CURVE_POINTS)
        auuc_val  = approx_auuc(ks, uvals)
        try:
            qini_val = float(qini_auc_score(y_va, tau, t_va)) if qini_auc_score is not None else np.nan
        except Exception:
            qini_val = np.nan
        u10 = uplift_at_k(y_va, t_va, tau, 0.10)
        u20 = uplift_at_k(y_va, t_va, tau, 0.20)

        # Segment stability
        seg       = assign_segments(tau)
        seg_share = pd.Series(seg).value_counts(normalize=True)

        # Store
        cv_records.append({
            "model": name, "fold": fold_i,
            "auuc": auuc_val, "qini": qini_val, "uplift@10%": u10, "uplift@20%": u20,
            "persuadables_pct": seg_share.get("Persuadables", 0),
            "sure_things_pct":  seg_share.get("Sure Things", 0),
            "lost_causes_pct":  seg_share.get("Lost Causes", 0),
            "dnd_pct":          seg_share.get("Do-Not-Disturb", 0),
        })
        cv_curves[name].append((ks, uvals))
        cv_segments[name].append(seg_share)

        print(f"[{name}] Fold {fold_i}: AUUC={auuc_val:+.5f}  Qini={qini_val:+.5f}  u@10%={u10:+.4f}  u@20%={u20:+.4f}")

print(f"\nCV complete — {len(cv_records)} records "
      f"({len(CANDIDATE_DEFS)} models x {N_SPLITS} folds).")

**What it means:** Each candidate model was trained and evaluated on 5 non-overlapping validation folds. The printed lines show per-fold AUUC and uplift@k so you can spot any fold that behaves very differently from the others (a sign of instability).

**What it says about further analysis:** The raw numbers are collected in `cv_records`. The next cells aggregate them into a comparison table and diagnostic plots.

### **5.5 Results summary table**
Aggregate per-fold metrics into **mean ± std** per model. The table is sorted by descending mean AUUC. Use this table together with the diagnostic plots below to justify your model choice — balancing **performance** (higher AUUC / uplift@k) with **stability** (lower std across folds) and **segment consistency**.

In [None]:
# ── 5.5  Results summary table ──────────────────────────────────────────────

cv_df = pd.DataFrame(cv_records)

summary = (
    cv_df
    .groupby("model")
    .agg(
        auuc_mean=("auuc", "mean"),
        auuc_std=("auuc", "std"),
        qini_mean=("qini", "mean"),
        qini_std=("qini", "std"),
        u10_mean=("uplift@10%", "mean"),
        u10_std=("uplift@10%", "std"),
        u20_mean=("uplift@20%", "mean"),
        u20_std=("uplift@20%", "std"),
        pers_mean=("persuadables_pct", "mean"),
        pers_std=("persuadables_pct", "std"),
    )
    .sort_values("auuc_mean", ascending=False)
)

# Pretty-print
print("=" * 90)
print("  MODEL SELECTION — CV RESULTS  (mean +/- std across folds, sorted by AUUC)")
print("=" * 90)
for model, row in summary.iterrows():
    print(f"\n  {model}")
    print(f"    AUUC:            {row.auuc_mean:+.5f} +/- {row.auuc_std:.5f}")
    print(f"    Qini:            {row.qini_mean:+.5f} +/- {row.qini_std:.5f}")
    print(f"    Uplift @10%:     {row.u10_mean:+.5f} +/- {row.u10_std:.5f}")
    print(f"    Uplift @20%:     {row.u20_mean:+.5f} +/- {row.u20_std:.5f}")
    print(f"    Persuadables %:  {row.pers_mean:.3f}  +/- {row.pers_std:.3f}")

display(summary.round(5))

### **5.6 Diagnostic visualisations**
Five charts to support model-selection decisions:

| Plot | Purpose |
|---|---|
| 5.6a AUUC and Qini bar charts | Compare mean AUUC and Qini with error bars (std) across models |
| 5.6b Uplift@k grouped bars | Compare uplift@10 % and uplift@20 % side by side |
| 5.6c Cumulative uplift curves | Show how uplift accumulates as you target more of the population |
| 5.6d Per-fold metric heatmap | Reveal fold-level variation for every model × metric |
| 5.6e Segment stability | Check whether Persuadables share is consistent across folds |

In [None]:
# ── 5.6a  AUUC and Qini comparison (mean +/- std) ───────────────────────────

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 4))
models = summary.index.tolist()
x = np.arange(len(models))
pal = sns.color_palette("Set2", len(models))

# AUUC
ax1.bar(x, summary["auuc_mean"], yerr=summary["auuc_std"], capsize=5,
        color=pal, edgecolor="black", alpha=0.85)
ax1.set_xticks(x)
ax1.set_xticklabels(models, rotation=30, ha="right", fontsize=9)
ax1.set_ylabel("AUUC (higher = better)")
ax1.set_title("AUUC by Model (mean +/- std across CV folds)", fontsize=11, fontweight="bold")
ax1.axhline(0, color="grey", linewidth=0.8, linestyle="--")
ax1.grid(axis="y", alpha=0.3)

# Qini
ax2.bar(x, summary["qini_mean"], yerr=summary["qini_std"], capsize=5,
        color=pal, edgecolor="black", alpha=0.85)
ax2.set_xticks(x)
ax2.set_xticklabels(models, rotation=30, ha="right", fontsize=9)
ax2.set_ylabel("Qini coefficient (higher = better)")
ax2.set_title("Qini by Model (mean +/- std across CV folds)", fontsize=11, fontweight="bold")
ax2.axhline(0, color="grey", linewidth=0.8, linestyle="--")
ax2.grid(axis="y", alpha=0.3)

plt.tight_layout()
plt.show()

In [None]:
# ── 5.6b  Uplift@k grouped bar chart ──────────────────────────────────────

fig, ax = plt.subplots(figsize=(10, 4))
models = summary.index.tolist()
x = np.arange(len(models))
w = 0.35
ax.bar(x - w / 2, summary["u10_mean"], w, yerr=summary["u10_std"],
       capsize=4, label="Uplift @10 %", edgecolor="black", alpha=0.85)
ax.bar(x + w / 2, summary["u20_mean"], w, yerr=summary["u20_std"],
       capsize=4, label="Uplift @20 %", edgecolor="black", alpha=0.85)
ax.set_xticks(x)
ax.set_xticklabels(models, rotation=30, ha="right", fontsize=9)
ax.set_ylabel("Churn reduction (control - treated)")
ax.set_title("Uplift @k (mean +/- std across CV folds)",
             fontsize=12, fontweight="bold")
ax.axhline(0, color="grey", linewidth=0.8, linestyle="--")
ax.legend()
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# ── 5.6c  Cumulative uplift curves (mean +/- std across folds) ─────────────

fig, ax = plt.subplots(figsize=(10, 6))
palette = sns.color_palette("tab10", len(cv_curves))

for (name, curve_list), colour in zip(cv_curves.items(), palette):
    if not curve_list:
        continue
    ks = curve_list[0][0]
    stack = np.vstack([u for _, u in curve_list])
    mean_c = np.nanmean(stack, axis=0)
    std_c  = np.nanstd(stack, axis=0)
    ax.plot(ks, mean_c, label=name, color=colour, linewidth=2)
    ax.fill_between(ks, mean_c - std_c, mean_c + std_c,
                    alpha=0.12, color=colour)

ax.axhline(0, color="grey", linewidth=0.8, linestyle="--")
ax.set_xlabel("Fraction of population targeted (ranked by predicted uplift)")
ax.set_ylabel("Cumulative uplift (control churn - treated churn)")
ax.set_title("Cumulative Uplift Curves — mean +/- 1 std across CV folds",
             fontsize=12, fontweight="bold")
ax.legend(loc="best", fontsize=8)
ax.grid(alpha=0.3)
plt.tight_layout()
plt.show()

In [None]:
# ── 5.6d  Per-fold metric heatmap ──────────────────────────────────────────

fig, axes = plt.subplots(1, 4, figsize=(24, 5))
fig.suptitle("Per-Fold Metric Values by Model",
             fontsize=13, fontweight="bold")

for ax, metric in zip(axes, ["auuc", "qini", "uplift@10%", "uplift@20%"]):
    pivot = cv_df.pivot(index="model", columns="fold", values=metric)
    sns.heatmap(pivot, annot=True, fmt=".4f", cmap="RdYlGn", center=0,
                ax=ax, cbar_kws={"shrink": 0.8})
    ax.set_title(metric.upper(), fontsize=11, fontweight="bold")
    ax.set_ylabel("")
    ax.set_xlabel("Fold")

plt.tight_layout()
plt.show()

In [None]:
# ── 5.6e  Segment stability (Persuadables share across folds) ──────────────

seg_cols = ["persuadables_pct", "sure_things_pct", "lost_causes_pct", "dnd_pct"]
seg_labels = ["Persuadables", "Sure Things", "Lost Causes", "Do-Not-Disturb"]

seg_summary = (
    cv_df
    .groupby("model")[seg_cols]
    .agg(["mean", "std"])
    .round(4)
)

print("=" * 80)
print("  SEGMENT STABILITY (share of validation fold, mean +/- std)")
print("=" * 80)
display(seg_summary)

# Bar chart: Persuadables share per model
fig, ax = plt.subplots(figsize=(10, 4))
models = summary.index.tolist()
x = np.arange(len(models))
ax.bar(x, summary["pers_mean"], yerr=summary["pers_std"],
       capsize=5, color=sns.color_palette("Set2", len(models)),
       edgecolor="black", alpha=0.85)
ax.set_xticks(x)
ax.set_xticklabels(models, rotation=30, ha="right", fontsize=9)
ax.set_ylabel("Persuadables share")
ax.set_title("Persuadables Share Stability (mean +/- std across folds)",
             fontsize=12, fontweight="bold")
ax.grid(axis="y", alpha=0.3)
plt.tight_layout()
plt.show()

**What it means:** The five diagnostic plots above show:

- **5.6a (AUUC):** Which model has the highest area under the uplift curve, and how much it varies across folds.
- **5.6b (Uplift@k):** Whether the model provides real churn reduction when targeting the top 10 % and 20 % of members ranked by predicted uplift.
- **5.6c (Cumulative curves):** How uplift accumulates as you contact a larger share of the population — look for curves that stay above zero and separate clearly from others.
- **5.6d (Heatmap):** Fold-by-fold stability — a model with one very high fold and four near-zero folds is unreliable.
- **5.6e (Segments):** Whether the Persuadables share is roughly constant across folds (~25 % by construction of quartiles, but actual treated-vs-control uplift inside that segment can vary).

**What it says about further analysis:** Choose the model that balances the highest mean AUUC / uplift@k with the lowest cross-fold variance and stable segment sizes. That model will be retrained on the full training set in Section 6 (final training), then scored on the test set.