# Bank Customer Churn – EDA and Baseline Models

This notebook is part of the **Modern Bank Churn** project.

Goal of this notebook:

1. Load and clean the **Bank Customer Churn** dataset (`Churn_Modelling.csv`).
2. Perform **Exploratory Data Analysis (EDA)** to understand churn patterns.
3. Build **baseline models** (dummy, logistic regression, random forest).
4. Establish a reference performance for more advanced models (LightGBM, Optuna, SHAP)
   in later notebooks.

The target variable is **`Exited`**:

- `0` – customer stayed.
- `1` – customer left (churn).


## 1. Imports and configuration

We import:

- `pandas`, `numpy` for data handling.
- `matplotlib`, `seaborn` for visualization.
- `scikit-learn` for preprocessing, modelling, and evaluation.

We also:

- Set a random seed for reproducibility.
- Define the expected path to the dataset.


In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Dict, List

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from sklearn.compose import ColumnTransformer
from sklearn.dummy import DummyClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import (
    accuracy_score,
    classification_report,
    confusion_matrix,
    roc_auc_score,
    RocCurveDisplay,
)
from sklearn.model_selection import train_test_split
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder, StandardScaler
from sklearn.base import BaseEstimator

sns.set(style="whitegrid")
plt.rcParams["figure.figsize"] = (8, 5)

RANDOM_STATE: int = 42
np.random.seed(RANDOM_STATE)

DATA_PATH: Path = Path("data") / "Churn_Modelling.csv"

if not DATA_PATH.exists():
    raise FileNotFoundError(
        f"Data file not found at {DATA_PATH.resolve()}. "
        "Please download the Bank Customer Churn CSV and place it under the 'data/' directory."
    )


## 2. Load and inspect the data

We now:

1. Load the CSV into a DataFrame.
2. Inspect the head, info, and basic statistics.
3. Verify that the target column `Exited` is present.

The raw dataset typically contains columns like:

- `RowNumber`, `CustomerId`, `Surname` (identifiers, not useful for prediction).
- `CreditScore`, `Geography`, `Gender`, `Age`, `Tenure`, `Balance`,
  `NumOfProducts`, `HasCrCard`, `IsActiveMember`, `EstimatedSalary`.
- `Exited` – target variable (0/1).


In [None]:
def load_bank_churn_data(path: Path) -> pd.DataFrame:
    """Load the bank customer churn dataset from a CSV file.

    Args:
        path: Path to the CSV file.

    Returns:
        DataFrame containing the bank churn data.

    Raises:
        FileNotFoundError: If the file does not exist.
        ValueError: If the loaded DataFrame is empty.
    """
    if not path.exists():
        raise FileNotFoundError(f"File not found: {path!s}")

    df: pd.DataFrame = pd.read_csv(path)

    if df.empty:
        raise ValueError(f"Loaded DataFrame is empty: {path!s}")

    return df


raw_df: pd.DataFrame = load_bank_churn_data(DATA_PATH)

display(raw_df.head())
display(raw_df.info())
display(raw_df.describe(include="all").T)


### Section summary

We loaded the bank churn dataset and inspected its structure.

Next we will clean it:

- Remove purely identifier columns.
- Confirm data types.
- Look for missing values.


## 3. Data cleaning

Steps:

1. Drop identifier columns (`RowNumber`, `CustomerId`, `Surname`) which do not
   carry predictive signal.
2. Check and report missing values.
3. Ensure the target `Exited` is present and binary (0/1).

We keep the cleaning simple and explicit so that it is easy to review and audit.


In [None]:
def clean_bank_churn_data(raw_df: pd.DataFrame) -> pd.DataFrame:
    """Clean the bank customer churn dataset.

    - Drop identifier columns.
    - Check for missing values.
    - Ensure `Exited` is present.

    Args:
        raw_df: Raw bank churn DataFrame.

    Returns:
        Cleaned DataFrame.
    """
    df = raw_df.copy()

    # Drop known identifier columns if present
    id_cols: List[str] = ["RowNumber", "CustomerId", "Surname"]
    drop_cols: List[str] = [c for c in id_cols if c in df.columns]

    if drop_cols:
        df = df.drop(columns=drop_cols)
        print(f"Dropped identifier columns: {drop_cols}")
    else:
        print("No identifier columns to drop.")

    # Check missing values
    missing = df.isna().sum()
    print("Missing values per column (non-zero only):")
    display(missing[missing > 0])

    # Basic target checks
    if "Exited" not in df.columns:
        raise ValueError("Target column 'Exited' not found in DataFrame.")

    unique_exited = df["Exited"].unique()
    print(f"Unique values in 'Exited': {unique_exited}")

    return df


df: pd.DataFrame = clean_bank_churn_data(raw_df)
display(df.head())


### Section summary

We:

- Dropped identifier columns.
- Confirmed that `Exited` exists.
- Checked missing values (usually none in this dataset).

Now we explore churn patterns through EDA.


## 4. Exploratory Data Analysis (EDA)

We start with:

1. **Churn rate and class balance** (distribution of `Exited`).
2. **Numerical features vs churn** (e.g. `Age`, `Balance`, `CreditScore`).
3. **Categorical features vs churn** (e.g. `Geography`, `Gender`, `IsActiveMember`).

The goal is to build intuition about what drives churn.


In [None]:
def plot_churn_distribution(df: pd.DataFrame, target_col: str = "Exited") -> None:
    """Plot the distribution of the churn target variable.

    Args:
        df: Clean bank churn DataFrame.
        target_col: Name of the churn column.

    Raises:
        KeyError: If the target column is not in the DataFrame.
    """
    if target_col not in df.columns:
        raise KeyError(f"Column {target_col!r} not found in DataFrame.")

    counts = df[target_col].value_counts().sort_index()
    churn_rate = (counts.get(1, 0) / counts.sum()) * 100.0
    print(f"Churn rate (Exited=1): {churn_rate:.2f}%")

    ax = sns.countplot(data=df, x=target_col)
    ax.set_title("Churn distribution (Exited)")
    ax.bar_label(ax.containers[0])
    plt.show()


plot_churn_distribution(df)


### 4.1 Numerical features vs churn

We look at distributions of key numerical features split by churn status:

- `Age`
- `Balance`
- `CreditScore`

We use kernel density plots to compare the shapes.


In [None]:
numeric_to_inspect: List[str] = ["Age", "Balance", "CreditScore"]

for col in numeric_to_inspect:
    if col not in df.columns:
        raise KeyError(f"Expected numeric column {col!r} not found in DataFrame.")

fig, axes = plt.subplots(1, len(numeric_to_inspect), figsize=(18, 4))

for ax, col in zip(axes, numeric_to_inspect):
    sns.kdeplot(
        data=df,
        x=col,
        hue="Exited",
        common_norm=False,
        fill=True,
        alpha=0.5,
        ax=ax,
    )
    ax.set_title(f"{col} distribution by churn")

plt.tight_layout()
plt.show()


### 4.2 Categorical features vs churn

We examine churn rates for selected categorical features:

- `Geography`
- `Gender`
- `IsActiveMember`

We compute churn rate per category and visualise as bar charts.


In [None]:
def churn_rate_by_category(
    df: pd.DataFrame,
    category_col: str,
    target_col: str = "Exited",
) -> pd.DataFrame:
    """Compute churn rate for each category in a given column.

    Args:
        df: Clean bank churn DataFrame.
        category_col: Name of the categorical column.
        target_col: Target column, expected values 0/1.

    Returns:
        DataFrame with counts and churn rate per category.
    """
    for col in (category_col, target_col):
        if col not in df.columns:
            raise KeyError(f"Column {col!r} not found in DataFrame.")

    grouped = (
        df.groupby(category_col)[target_col]
        .value_counts()
        .unstack(fill_value=0)
        .rename(columns={0: "Stayed", 1: "Exited"})
    )
    grouped["Total"] = grouped["Stayed"] + grouped["Exited"]
    grouped["ChurnRate"] = grouped["Exited"] / grouped["Total"]
    return grouped.sort_values("ChurnRate", ascending=False)


cat_cols_to_inspect: List[str] = ["Geography", "Gender", "IsActiveMember"]

for col in cat_cols_to_inspect:
    if col not in df.columns:
        print(f"Skipping {col!r} (not found).")
        continue

    print(f"\n=== Churn rate by {col} ===")
    summary_df = churn_rate_by_category(df, col)
    display(summary_df)

    ax = sns.barplot(
        data=summary_df.reset_index(),
        x=col,
        y="ChurnRate",
        order=summary_df.index,
    )
    ax.set_title(f"Churn rate by {col}")
    ax.set_ylabel("Churn rate")
    ax.set_xlabel(col)
    ax.set_ylim(0, 1)
    ax.bar_label(ax.containers[0], fmt="%.2f")
    plt.xticks(rotation=30)
    plt.tight_layout()
    plt.show()


### Section summary

From EDA we usually observe patterns such as:

- Certain **geographies** having higher churn.
- Differences in churn by **activity status** (`IsActiveMember`).
- Older customers or customers with certain balance ranges being more likely to churn.

These qualitative insights will be complemented by quantitative models next.


## 5. Train–test split and preprocessing

We now prepare data for modelling:

1. Separate **features (X)** and **target (y)**.
2. Perform a train–test split with stratification on `Exited`.
3. Define a preprocessing pipeline:
   - Standard scaling for numeric features.
   - One-hot encoding for categorical features (`Geography`, `Gender`).

We use `ColumnTransformer` and `Pipeline` from scikit-learn to keep the process clean.


In [None]:
TARGET_COL: str = "Exited"

if TARGET_COL not in df.columns:
    raise KeyError(f"Target column {TARGET_COL!r} not found in DataFrame.")

X: pd.DataFrame = df.drop(columns=[TARGET_COL])
y: pd.Series = df[TARGET_COL].astype(int)

categorical_cols: List[str] = [c for c in ["Geography", "Gender"] if c in X.columns]
numeric_cols: List[str] = [c for c in X.columns if c not in categorical_cols]

print("Categorical columns:", categorical_cols)
print("Numeric columns:", numeric_cols)

X_train, X_test, y_train, y_test = train_test_split(
    X,
    y,
    test_size=0.2,
    stratify=y,
    random_state=RANDOM_STATE,
)

print("Train shape:", X_train.shape, "Test shape:", X_test.shape)

numeric_transformer = Pipeline(
    steps=[("scaler", StandardScaler())]
)
categorical_transformer = Pipeline(
    steps=[("encoder", OneHotEncoder(handle_unknown="ignore"))]
)

preprocessor = ColumnTransformer(
    transformers=[
        ("num", numeric_transformer, numeric_cols),
        ("cat", categorical_transformer, categorical_cols),
    ]
)


## 6. Baseline and classical models

We evaluate three models:

1. **DummyClassifier** (most frequent class) – baseline.
2. **Logistic Regression** – linear, interpretable model.
3. **Random Forest** – non-linear ensemble for tabular data.

We define a helper function `evaluate_classifier` to:

- Fit the model.
- Compute accuracy and ROC-AUC.
- Print a classification report.
- Show a confusion matrix and ROC curve.


In [None]:
def evaluate_classifier(
    name: str,
    model: BaseEstimator,
    X_train: pd.DataFrame,
    X_test: pd.DataFrame,
    y_train: pd.Series,
    y_test: pd.Series,
) -> Dict[str, float]:
    """Fit a classifier and evaluate it on train and test data.

    Args:
        name: Name of the model (for printing).
        model: Unfitted scikit-learn estimator or pipeline.
        X_train: Training features.
        X_test: Test features.
        y_train: Training labels (0/1).
        y_test: Test labels (0/1).

    Returns:
        Dictionary with key metrics on the test set.
    """
    print(f"\n===== {name} =====")
    model.fit(X_train, y_train)

    y_pred_train = model.predict(X_train)
    y_pred_test = model.predict(X_test)

    if hasattr(model, "predict_proba"):
        y_proba_test = model.predict_proba(X_test)[:, 1]
        roc_auc = roc_auc_score(y_test, y_proba_test)
    else:
        y_proba_test = None
        roc_auc = np.nan

    acc_train = accuracy_score(y_train, y_pred_train)
    acc_test = accuracy_score(y_test, y_pred_test)

    print(f"Train accuracy: {acc_train:.3f}")
    print(f"Test accuracy:  {acc_test:.3f}")
    if not np.isnan(roc_auc):
        print(f"Test ROC-AUC:  {roc_auc:.3f}")

    print("\nClassification report (test):")
    print(classification_report(y_test, y_pred_test, target_names=["Stayed", "Exited"]))

    cm = confusion_matrix(y_test, y_pred_test)
    sns.heatmap(
        cm,
        annot=True,
        fmt="d",
        cmap="Blues",
        xticklabels=["Pred stayed", "Pred exited"],
        yticklabels=["True stayed", "True exited"],
    )
    plt.title(f"Confusion matrix - {name}")
    plt.ylabel("True label")
    plt.xlabel("Predicted label")
    plt.show()

    if y_proba_test is not None:
        RocCurveDisplay.from_predictions(y_test, y_proba_test)
        plt.title(f"ROC curve - {name}")
        plt.show()

    return {
        "model": name,
        "train_accuracy": acc_train,
        "test_accuracy": acc_test,
        "roc_auc": float(roc_auc) if not np.isnan(roc_auc) else np.nan,
    }


In [None]:
# 6.1 Dummy baseline
dummy_clf = Pipeline(
    steps=[
        ("preprocess", preprocessor),
        ("clf", DummyClassifier(strategy="most_frequent", random_state=RANDOM_STATE)),
    ]
)

dummy_metrics = evaluate_classifier(
    "Dummy (Most Frequent)", dummy_clf, X_train, X_test, y_train, y_test
)
dummy_metrics


In [None]:
# 6.2 Logistic Regression
log_reg_clf = Pipeline(
    steps=[
        ("preprocess", preprocessor),
        (
            "clf",
            LogisticRegression(
                max_iter=1000,
                random_state=RANDOM_STATE,
                n_jobs=-1,
            ),
        ),
    ]
)

log_reg_metrics = evaluate_classifier(
    "Logistic Regression", log_reg_clf, X_train, X_test, y_train, y_test
)
log_reg_metrics


In [None]:
# 6.3 Random Forest
rf_clf = Pipeline(
    steps=[
        ("preprocess", preprocessor),
        (
            "clf",
            RandomForestClassifier(
                n_estimators=200,
                max_depth=None,
                min_samples_split=4,
                min_samples_leaf=2,
                random_state=RANDOM_STATE,
                n_jobs=-1,
            ),
        ),
    ]
)

rf_metrics = evaluate_classifier(
    "Random Forest", rf_clf, X_train, X_test, y_train, y_test
)
rf_metrics


## 7. Model comparison and conclusions

We summarise the performance of all three models to establish a reference.


In [None]:
results_df = pd.DataFrame([dummy_metrics, log_reg_metrics, rf_metrics])
display(results_df)


### Section summary

In this notebook we:

- Explored the bank churn dataset.
- Built three baseline models:
  - A dummy classifier (most frequent class).
  - Logistic Regression.
  - Random Forest.
- Compared their performance using accuracy and ROC-AUC.

These baselines give us a **benchmark**. In the next notebook we will use
**LightGBM + Optuna + SHAP** to build a stronger and more explainable churn model.
