# Bank Customer Churn – Lifetime Value and Retention Prioritisation

This notebook is a **new churn project** focused on combining:

- **Churn modelling** (probability a customer exits), and
- **Customer Lifetime Value (CLV)** approximation,

to build **priority segments for retention**.

We use the common **Bank Customer Churn** dataset (`Churn_Modelling.csv`) and:

1. Build a supervised churn model (classification) to estimate `P(Exited = 1)`.
2. Approximate customer value using simple proxies from the dataset.
3. Combine churn risk and value into **actionable segments**.
4. Discuss which customers should be prioritised for retention campaigns.

The goal is to move from *"Who will churn?"* to *"Which churners matter most?"*.


## 1. Imports and configuration

We use:

- `pandas`, `numpy` for data handling.
- `matplotlib`, `seaborn` for visualisation.
- `scikit-learn` for modelling and evaluation.

We assume the dataset is available at:

```text
data/Churn_Modelling.csv
```


In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.compose import ColumnTransformer
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix,
    roc_auc_score,
    RocCurveDisplay,
)
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.base import BaseEstimator

sns.set(style="whitegrid")
plt.rcParams["figure.figsize"] = (8, 5)

RANDOM_STATE: int = 42
np.random.seed(RANDOM_STATE)

DATA_PATH: Path = Path("data") / "Churn_Modelling.csv"

if not DATA_PATH.exists():
    raise FileNotFoundError(
        f"Data file not found at {DATA_PATH.resolve()}. "
        "Please download the Bank Customer Churn CSV and place it under the 'data/' directory."
    )


## 2. Load and inspect the data

We load the bank churn dataset and take a first look.

Typical columns include:

- IDs: `RowNumber`, `CustomerId`, `Surname` (not predictive).
- Features: `CreditScore`, `Geography`, `Gender`, `Age`, `Tenure`, `Balance`,
  `NumOfProducts`, `HasCrCard`, `IsActiveMember`, `EstimatedSalary`.
- Target: `Exited` (1 = churned, 0 = stayed).


In [None]:
def load_bank_churn_data(path: Path) -> pd.DataFrame:
    """Load the bank customer churn dataset from a CSV file.

    Args:
        path: Path to the CSV file.

    Returns:
        DataFrame containing the bank churn data.

    Raises:
        FileNotFoundError: If the file does not exist.
        ValueError: If the loaded DataFrame is empty.
    """
    if not path.exists():
        raise FileNotFoundError(f"File not found: {path!s}")

    df: pd.DataFrame = pd.read_csv(path)
    if df.empty:
        raise ValueError(f"Loaded DataFrame is empty: {path!s}")

    return df


raw_df: pd.DataFrame = load_bank_churn_data(DATA_PATH)

display(raw_df.head())
print("\nDataframe shape:", raw_df.shape)
print("Columns:", list(raw_df.columns))


## 3. Basic cleaning

We:

1. Drop identifier columns that do not carry predictive signal.
2. Confirm that `Exited` is present and binary.
3. Check missing values.


In [None]:
def clean_bank_churn_data(raw_df: pd.DataFrame) -> pd.DataFrame:
    """Clean the bank customer churn dataset.

    - Drop identifier columns.
    - Check for missing values.
    - Ensure `Exited` exists.

    Args:
        raw_df: Raw bank churn DataFrame.

    Returns:
        Cleaned DataFrame.
    """
    df = raw_df.copy()

    id_cols: List[str] = ["RowNumber", "CustomerId", "Surname"]
    drop_cols: List[str] = [c for c in id_cols if c in df.columns]
    if drop_cols:
        df = df.drop(columns=drop_cols)
        print(f"Dropped identifier columns: {drop_cols}")

    # Check missing values
    missing = df.isna().sum()
    print("Missing values per column (non-zero only):")
    display(missing[missing > 0])

    if "Exited" not in df.columns:
        raise ValueError("Target column 'Exited' not found in DataFrame.")

    # Ensure target is integer 0/1
    df["Exited"] = df["Exited"].astype(int)

    return df


df: pd.DataFrame = clean_bank_churn_data(raw_df)

display(df.head())
print("\nClass distribution (Exited):")
print(df["Exited"].value_counts(normalize=True).rename("proportion"))


### Section summary

We now have a clean dataset with predictive features and a binary
`Exited` column. The churn rate is typically around 20% in this dataset.

Next we perform a short EDA focused on **value-related** features.


## 4. Short EDA: churn and value proxies

We are particularly interested in:

- `EstimatedSalary` – rough proxy for potential revenue.
- `NumOfProducts` – engagement with the bank.
- `Balance` – deposit volume.

We quickly explore these features by churn status.


In [None]:
value_cols: List[str] = ["EstimatedSalary", "NumOfProducts", "Balance"]

for col in value_cols:
    if col not in df.columns:
        print(f"Skipping {col!r} – not found in DataFrame.")
        continue

    fig, ax = plt.subplots(1, 2, figsize=(12, 4))

    # Distribution by churn status
    sns.kdeplot(data=df, x=col, hue="Exited", common_norm=False, fill=True, alpha=0.4, ax=ax[0])
    ax[0].set_title(f"{col} distribution by Exited")

    # Boxplot
    sns.boxplot(data=df, x="Exited", y=col, ax=ax[1])
    ax[1].set_title(f"{col} by Exited (boxplot)")

    plt.tight_layout()
    plt.show()


These plots suggest how churn relates to potential value. Our next step is to
build a **churn model** that outputs `P(Exited = 1)` for each customer.


## 5. Train–test split and preprocessing

We separate:

- Features `X` (all columns except `Exited`).
- Target `y` (`Exited`).

Then we define a preprocessing pipeline:

- Numeric features → `StandardScaler`.
- Categorical features (`Geography`, `Gender`) → `OneHotEncoder`.

We use `ColumnTransformer` inside a `Pipeline` to keep things clean and
reproducible.


In [None]:
TARGET_COL: str = "Exited"

if TARGET_COL not in df.columns:
    raise KeyError(f"Target column {TARGET_COL!r} not found in DataFrame.")

X: pd.DataFrame = df.drop(columns=[TARGET_COL])
y: pd.Series = df[TARGET_COL]

categorical_cols: List[str] = [c for c in ["Geography", "Gender"] if c in X.columns]
numeric_cols: List[str] = [c for c in X.columns if c not in categorical_cols]

print("Categorical columns:", categorical_cols)
print("Numeric columns:", numeric_cols)

X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=0.2,
    stratify=y,
    random_state=RANDOM_STATE,
)

print("Train shape:", X_train.shape, "Test shape:", X_test.shape)

numeric_transformer = Pipeline(
    steps=[("scaler", StandardScaler())]
)

categorical_transformer = Pipeline(
    steps=[("encoder", OneHotEncoder(handle_unknown="ignore"))]
)

preprocessor = ColumnTransformer(
    transformers=[
        ("num", numeric_transformer, numeric_cols),
        ("cat", categorical_transformer, categorical_cols),
    ]
)


## 6. Churn model: Random Forest

We train a **Random Forest classifier** as our churn model.

It is not the only choice (we could use LightGBM, XGBoost, etc.), but it is:

- Strong for tabular data.
- Robust and easy to use.

We define a small evaluation helper and inspect accuracy and ROC-AUC.


In [None]:
def evaluate_classifier(
    name: str,
    model: BaseEstimator,
    X_train: pd.DataFrame,
    X_test: pd.DataFrame,
    y_train: pd.Series,
    y_test: pd.Series,
) -> Dict[str, float]:
    """Fit and evaluate a classifier.

    Args:
        name: Model name.
        model: scikit-learn estimator or pipeline.
        X_train: Training features.
        X_test: Test features.
        y_train: Training labels.
        y_test: Test labels.

    Returns:
        Dictionary with key metrics on the test set.
    """
    print(f"\n===== {name} =====")

    model.fit(X_train, y_train)

    y_pred_test = model.predict(X_test)
    y_proba_test = model.predict_proba(X_test)[:, 1]

    acc = accuracy_score(y_test, y_pred_test)
    roc_auc = roc_auc_score(y_test, y_proba_test)

    print(f"Test accuracy: {acc:.3f}")
    print(f"Test ROC-AUC:  {roc_auc:.3f}")

    print("\nClassification report (test):")
    print(classification_report(y_test, y_pred_test, target_names=["Stayed", "Exited"]))

    cm = confusion_matrix(y_test, y_pred_test)
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=["Pred stayed", "Pred exited"],
        yticklabels=["True stayed", "True exited"],
    )
    plt.title(f"Confusion matrix - {name}")
    plt.ylabel("True label")
    plt.xlabel("Predicted label")
    plt.show()

    RocCurveDisplay.from_predictions(y_test, y_proba_test)
    plt.title(f"ROC curve - {name}")
    plt.show()

    return {"model": name, "test_accuracy": acc, "test_roc_auc": roc_auc}


rf_clf = Pipeline(
    steps=[
        ("preprocess", preprocessor),
        (
            "clf",
            RandomForestClassifier(
                n_estimators=300,
                max_depth=None,
                min_samples_split=4,
                min_samples_leaf=2,
                random_state=RANDOM_STATE,
                n_jobs=-1,
            ),
        ),
    ]
)

metrics_rf = evaluate_classifier("Random Forest", rf_clf, X_train, X_test, y_train, y_test)
metrics_rf


This Random Forest is our **churn probability engine**. Next, we will use it to
compute predicted churn probabilities for **all customers**.


## 7. Predict churn probability for all customers

We refit the model on the **full dataset** (all rows) to obtain steady-state
churn probabilities for each customer.

These probabilities will feed our CLV approximation.


In [None]:
# Fit on full data for scoring purposes
rf_clf_full = Pipeline(
    steps=[
        ("preprocess", preprocessor),
        (
            "clf",
            RandomForestClassifier(
                n_estimators=300,
                max_depth=None,
                min_samples_split=4,
                min_samples_leaf=2,
                random_state=RANDOM_STATE,
                n_jobs=-1,
            ),
        ),
    ]
)

rf_clf_full.fit(X, y)

churn_proba_all = rf_clf_full.predict_proba(X)[:, 1]

scored_df = df.copy()
scored_df["churn_proba"] = churn_proba_all

display(scored_df[["Exited", "churn_proba"]].head())

print("Churn probability summary:")
print(scored_df["churn_proba"].describe())


The column `churn_proba` is our estimate of **short-term exit probability**
(e.g., within the next period). We now combine this with a **value proxy**.


## 8. Approximate customer value (simple CLV proxy)

We do **not** have actual revenue or margin per customer, so we create a
reasonable proxy using:

- `EstimatedSalary` – proxy for potential revenue.
- `NumOfProducts` – number of products held.

This is intentionally simple. In a real bank you would use:

- Net interest margin.
- Fees.
- Product-specific contribution.
- Time-varying behaviour.

Here we construct:

```text
annual_margin ≈ base_margin
                + w_products * NumOfProducts
                + w_salary * (EstimatedSalary_normalised)
```

Then:

```text
CLV_proxy ≈ annual_margin * expected_lifetime_years
```

We approximate expected lifetime using a geometric survival idea:

```text
expected_lifetime_years ≈ 1 / max(churn_proba, epsilon)
```

This is a simplification but enough for segmentation.


In [None]:
def compute_annual_margin(row: pd.Series) -> float:
    """Compute a simple annual margin proxy for a customer.

    Uses NumOfProducts and EstimatedSalary (normalised).

    Args:
        row: Row from the scored DataFrame.

    Returns:
        Approximate annual margin value.
    """
    num_products: float = float(row.get("NumOfProducts", 0.0))
    est_salary: float = float(row.get("EstimatedSalary", 0.0))

    # Normalise salary roughly to [0, 1] scale using a simple heuristic
    salary_norm: float = est_salary / 200_000.0  # assumes salaries around 0–200k

    base_margin: float = 100.0  # baseline yearly margin
    w_products: float = 80.0    # extra margin per product
    w_salary: float = 200.0     # extra margin scaled by salary_norm

    annual_margin: float = base_margin + w_products * num_products + w_salary * salary_norm
    return float(annual_margin)


scored_df["annual_margin"] = scored_df.apply(compute_annual_margin, axis=1)

print("Annual margin summary:")
print(scored_df["annual_margin"].describe())

sns.histplot(scored_df["annual_margin"], bins=30, kde=True)
plt.title("Approximate annual margin distribution")
plt.xlabel("Annual margin (proxy)")
plt.show()


In [None]:
def compute_expected_lifetime_years(churn_p: float, max_years: float = 10.0) -> float:
    """Approximate expected remaining lifetime (years) from churn probability.

    We assume a simple geometric-like model with constant churn probability
    per year:

    E[L] ≈ 1 / p, capped at `max_years`.

    Args:
        churn_p: Estimated churn probability (0–1).
        max_years: Upper cap on expected lifetime.

    Returns:
        Expected lifetime in years.
    """
    eps: float = 1e-4
    p = max(float(churn_p), eps)
    expected_lifetime: float = 1.0 / p
    return float(min(expected_lifetime, max_years))


scored_df["expected_lifetime_years"] = scored_df["churn_proba"].apply(compute_expected_lifetime_years)

print("Expected lifetime (years) summary:")
print(scored_df["expected_lifetime_years"].describe())

sns.histplot(scored_df["expected_lifetime_years"], bins=30, kde=True)
plt.title("Approximate expected lifetime (years)")
plt.xlabel("Years")
plt.show()


In [None]:
# CLV proxy: annual margin * expected lifetime
scored_df["clv_proxy"] = scored_df["annual_margin"] * scored_df["expected_lifetime_years"]

print("CLV proxy summary:")
print(scored_df["clv_proxy"].describe())

sns.histplot(scored_df["clv_proxy"], bins=30, kde=True)
plt.title("CLV proxy distribution")
plt.xlabel("CLV proxy (arbitrary units)")
plt.show()


### Section summary

We now have, for every customer:

- `churn_proba` – model-based churn risk.
- `annual_margin` – rough yearly value proxy.
- `expected_lifetime_years` – rough time horizon.
- `clv_proxy` – product of the two (value × time).

Next we build **priority segments** based on churn risk and CLV.


## 9. Risk–value segmentation

A common prioritisation scheme is to divide customers into **four quadrants**:

- **High CLV, high churn risk** → top retention priority.
- **High CLV, low churn risk** → protect and grow.
- **Low CLV, high churn risk** → selective retention.
- **Low CLV, low churn risk** → monitor, limited investment.

We implement this by splitting:

- Churn probability at its median.
- CLV proxy at its median.


In [None]:
# Compute medians as cutpoints
churn_median: float = float(scored_df["churn_proba"].median())
clv_median: float = float(scored_df["clv_proxy"].median())

print(f"Median churn probability: {churn_median:.3f}")
print(f"Median CLV proxy:        {clv_median:.1f}")


def assign_risk_value_segment(row: pd.Series) -> str:
    """Assign a segment label based on churn probability and CLV proxy.

    Segments:
    - 'high_value_high_risk'
    - 'high_value_low_risk'
    - 'low_value_high_risk'
    - 'low_value_low_risk'

    Args:
        row: Row with 'churn_proba' and 'clv_proxy'.

    Returns:
        Segment label string.
    """
    churn_p: float = float(row["churn_proba"])
    clv: float = float(row["clv_proxy"])

    value_label = "high_value" if clv >= clv_median else "low_value"
    risk_label = "high_risk" if churn_p >= churn_median else "low_risk"

    return f"{value_label}_{risk_label}"


scored_df["segment"] = scored_df.apply(assign_risk_value_segment, axis=1)

segment_counts = scored_df["segment"].value_counts().rename("n_customers")
segment_mean_churn = scored_df.groupby("segment")["churn_proba"].mean().rename("avg_churn_proba")
segment_mean_clv = scored_df.groupby("segment")["clv_proxy"].mean().rename("avg_clv_proxy")

segment_summary = pd.concat([segment_counts, segment_mean_churn, segment_mean_clv], axis=1)
segment_summary = segment_summary.sort_index()

display(segment_summary)


In [None]:
# Visualise segments on risk vs value plane
plt.figure(figsize=(7, 6))

sns.scatterplot(
    data=scored_df.sample(min(2000, len(scored_df)), random_state=RANDOM_STATE),
    x="churn_proba",
    y="clv_proxy",
    hue="segment",
    alpha=0.6,
)

plt.axvline(churn_median, linestyle="--", color="black")
plt.axhline(clv_median, linestyle="--", color="black")
plt.xlabel("Churn probability")
plt.ylabel("CLV proxy")
plt.title("Risk–value segmentation")
plt.legend(bbox_to_anchor=(1.05, 1), loc="upper left")
plt.tight_layout()
plt.show()


The scatter plot shows customers coloured by segment in the
**churn probability vs CLV proxy** plane.

- Top-right quadrant → **high_value_high_risk**.
- Top-left quadrant → **high_value_low_risk**.
- Bottom-right quadrant → **low_value_high_risk**.
- Bottom-left quadrant → **low_value_low_risk**.


## 10. Profiling key segments

We now inspect the **high_value_high_risk** segment more closely:

- Size and share of total CLV proxy.
- Average age, products, balance, activity.

This is typically the group you most want to retain.


In [None]:
# Focus on high_value_high_risk segment
hv_hr_mask = scored_df["segment"] == "high_value_high_risk"

hv_hr_df = scored_df.loc[hv_hr_mask].copy()

print("High-value, high-risk customers:")
print("N:", hv_hr_df.shape[0])

share_of_customers = hv_hr_df.shape[0] / scored_df.shape[0]
share_of_clv = hv_hr_df["clv_proxy"].sum() / scored_df["clv_proxy"].sum()

print(f"Share of customers: {share_of_customers:.2%}")
print(f"Share of total CLV proxy: {share_of_clv:.2%}")

profile_cols: List[str] = ["Age", "NumOfProducts", "Balance", "EstimatedSalary", "IsActiveMember"]

summary_profile = hv_hr_df[profile_cols].describe().T if all(c in hv_hr_df.columns for c in profile_cols) else hv_hr_df.describe().T

display(summary_profile)


In [None]:
# Compare activity and geography across segments
if "IsActiveMember" in scored_df.columns:
    sns.barplot(
        data=scored_df,
        x="segment",
        y="IsActiveMember",
        estimator=np.mean,
    )
    plt.xticks(rotation=30)
    plt.ylabel("Mean IsActiveMember")
    plt.title("Average activity by segment")
    plt.tight_layout()
    plt.show()

if "Geography" in scored_df.columns:
    geo_segment = (
        scored_df.groupby(["segment", "Geography"])["clv_proxy"]
        .mean()
        .reset_index()
    )
    sns.catplot(
        data=geo_segment,
        x="segment",
        y="clv_proxy",
        hue="Geography",
        kind="bar",
        height=5,
        aspect=1.6,
    )
    plt.xticks(rotation=30)
    plt.ylabel("Average CLV proxy")
    plt.title("Average CLV proxy by segment and geography")
    plt.tight_layout()
    plt.show()


These profiles help answer questions such as:

- Which **geographies** concentrate high-value at-risk customers?
- Are high-value, high-risk customers less active (`IsActiveMember`)?
- How many products do they hold on average?

From here, the business can design **targeted strategies**.


## 11. Business interpretation and strategies

Given this segmentation, a bank might:

### High-value, high-risk

- **Goal:** prevent churn for these customers.
- Possible actions:
  - Personalised outreach (relationship managers, calls).
  - Tailored offers: fee waivers, better terms, bundled products.
  - Service recovery for customers with low activity or recent issues.

### High-value, low-risk

- **Goal:** maintain satisfaction and grow value.
- Possible actions:
  - Cross-sell and up-sell (new products that fit profile).
  - Loyalty programmes and VIP treatment.

### Low-value, high-risk

- **Goal:** selective retention based on campaign cost.
- Possible actions:
  - Lower-cost digital campaigns.
  - A/B tests to see if certain offers have acceptable ROI.

### Low-value, low-risk

- **Goal:** efficient maintenance.
- Possible actions:
  - Minimal proactive contact.
  - Automated nudges and monitoring.

The exact actions depend on:

- Operational constraints (call centre capacity, budget).
- Regulatory context.
- Product portfolio and pricing.


## 12. Limitations and extensions

This notebook deliberately uses **simple approximations**:

- Churn probability is modelled as a one-period risk.
- Expected lifetime uses a geometric approximation `1 / p`.
- CLV proxy uses `NumOfProducts` and `EstimatedSalary` with arbitrary weights.

In a more advanced project you could:

1. Use **survival analysis** or hazard models for richer lifetime estimates.
2. Replace CLV proxy with a proper **LTV model** using:
   - Product-level margins.
   - Time-varying behaviour.
   - Discounting of future cash flows.
3. Optimise **thresholds and segment definitions** based on:
   - Campaign costs.
   - Expected uplift.
   - Capacity constraints.
4. Run **what-if simulations**:
   - "What if we reduce churn probability by X% for this segment?"
   - "What is the incremental CLV gained?"

Even with simple assumptions, the combination of **churn risk** and **value**
provides a more realistic view of **which customers to focus on**.
