# Telco Customer Churn – Survival Analysis (Time-to-Churn)

This notebook is an **extended survival analysis project** built on the IBM
**Telco Customer Churn** dataset.

Instead of only predicting **whether** a customer churns, we focus on
**when** churn happens. We treat churn as a **time-to-event** problem and use
survival analysis.

---

## Objectives

1. Reframe churn as a **time-to-event** problem using `tenure` and `Churn`.
2. Build and interpret **Kaplan–Meier survival curves**:
   - Overall survival.
   - By contract type and other segments.
   - Log-rank tests for differences between groups.
3. Fit and interpret a **Cox Proportional Hazards model**:
   - Feature engineering for survival.
   - Hazard ratios and their business meaning.
4. Check **model assumptions and diagnostics**:
   - Proportional hazards assumption.
   - Concordance index (discrimination).
5. Create **scenario analyses**:
   - Predicted survival curves for example customers.
   - Survival by risk segments.
6. (Optional) Explore an alternative **Aalen Additive** model.

This notebook is self-contained. It loads the Telco dataset directly
from a CSV file.


## 1. Imports and configuration

We use:

- `pandas`, `numpy` for data handling.
- `matplotlib`, `seaborn` for visualisation.
- `lifelines` for survival analysis:
  - `KaplanMeierFitter`
  - `CoxPHFitter`
  - `AalenAdditiveFitter` (optional)
  - log-rank tests and concordance index.

We also set a random seed and define the expected data path.


In [None]:
from __future__ import annotations

from pathlib import Path
from typing import List, Dict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns

from lifelines import KaplanMeierFitter, CoxPHFitter, AalenAdditiveFitter
from lifelines.statistics import logrank_test, multivariate_logrank_test
from lifelines.utils import concordance_index

sns.set(style="whitegrid")
plt.rcParams["figure.figsize"] = (8, 5)

RANDOM_STATE: int = 42
np.random.seed(RANDOM_STATE)

DATA_PATH: Path = Path("data") / "WA_Fn-UseC_-Telco-Customer-Churn.csv"

if not DATA_PATH.exists():
    raise FileNotFoundError(
        f"Data file not found at {DATA_PATH.resolve()}. "
        "Please download the Telco churn CSV and place it under the 'data/' directory."
    )


## 2. Data loading and basic cleaning

We reuse the Telco dataset, but now we care especially about:

- `tenure` – months the customer has stayed.
- `Churn` – `Yes` if the customer churned, `No` if still active.

We do minimal but explicit cleaning:

1. Convert `TotalCharges` to numeric.
2. Drop rows with missing `TotalCharges`.
3. Drop duplicate `customerID`s.


In [None]:
def load_telco_data(path: Path) -> pd.DataFrame:
    """Load the Telco customer churn dataset from CSV.

    Args:
        path: Path to the CSV file.

    Returns:
        DataFrame with Telco churn data.

    Raises:
        FileNotFoundError: If the file does not exist.
        ValueError: If the DataFrame is empty.
    """
    if not path.exists():
        raise FileNotFoundError(f"File not found: {path!s}")

    df: pd.DataFrame = pd.read_csv(path)
    if df.empty:
        raise ValueError(f"Loaded DataFrame is empty: {path!s}")
    return df


def clean_telco_data(raw_df: pd.DataFrame) -> pd.DataFrame:
    """Clean Telco data: fix types, handle missing, drop duplicates.

    Steps:
    - Convert `TotalCharges` to numeric (invalid entries -> NaN).
    - Drop rows with missing `TotalCharges`.
    - Drop duplicate `customerID` rows.

    Args:
        raw_df: Raw Telco DataFrame.

    Returns:
        Cleaned DataFrame.
    """
    df = raw_df.copy()

    if "TotalCharges" not in df.columns:
        raise ValueError("Expected 'TotalCharges' column not found.")

    # Convert TotalCharges to numeric
    df["TotalCharges"] = pd.to_numeric(df["TotalCharges"], errors="coerce")

    # Show missing values
    missing = df.isna().sum()
    print("Missing values per column (non-zero only):")
    display(missing[missing > 0])

    before_rows: int = df.shape[0]
    df = df.dropna(subset=["TotalCharges"])
    after_rows: int = df.shape[0]
    print(f"Dropped {before_rows - after_rows} rows with missing TotalCharges.")

    # Drop duplicate customers
    before_rows = df.shape[0]
    df = df.drop_duplicates(subset=["customerID"])
    after_rows = df.shape[0]
    print(f"Dropped {before_rows - after_rows} duplicate customerID rows.")

    return df.reset_index(drop=True)


raw_df: pd.DataFrame = load_telco_data(DATA_PATH)
telco_df: pd.DataFrame = clean_telco_data(raw_df)

display(telco_df.head())
display(telco_df[["tenure", "Churn"]].describe(include="all"))


### Section summary

We now have a clean Telco dataset with valid `TotalCharges` and unique
`customerID`s. The columns `tenure` and `Churn` are ready to be used as the
core survival variables.

Next we construct the explicit **survival dataset**.


## 3. Constructing the survival dataset

For survival analysis, we need at least two columns:

- **Duration**: how long each subject is observed.
  - Here: `tenure` (months since joining).
- **Event indicator**: whether the event of interest happened.
  - Here: `churn_event` = 1 if the customer churned (`Churn == "Yes"`),
    0 otherwise (`No` → censored observation).

We will:

1. Map `Churn` from `Yes` / `No` to `1` / `0`.
2. Keep additional features for later modelling.
3. Summarise the censoring structure.


In [None]:
# Sanity check of essential columns
required_cols: List[str] = ["customerID", "tenure", "Churn"]
missing_required = [c for c in required_cols if c not in telco_df.columns]
if missing_required:
    raise KeyError(f"Missing required columns in Telco data: {missing_required}")

surv_df = telco_df.copy()

# Event indicator: 1 = churned, 0 = censored (still active)
surv_df["churn_event"] = surv_df["Churn"].map({"No": 0, "Yes": 1}).astype(int)

# Basic summary of events vs censored
event_counts = surv_df["churn_event"].value_counts().rename(index={0: "censored", 1: "event"})
print("Churn event vs censored counts:")
print(event_counts)

print("\nTenure distribution by churn_event:")
display(surv_df.groupby("churn_event")["tenure"].describe())


### Quick visual: tenure vs churn

Before jumping into Kaplan–Meier, we quickly visualise the tenure distribution
by churn status.


In [None]:
sns.kdeplot(data=surv_df, x="tenure", hue="churn_event", common_norm=False, fill=True, alpha=0.4)
plt.title("Tenure distribution by churn event (0=censored, 1=churn)")
plt.xlabel("Tenure (months)")
plt.show()


We see that churned customers typically have **shorter tenure**. Survival
analysis will quantify this pattern over time.


## 4. Global Kaplan–Meier survival curve

We now estimate the **overall survival function** using the
`KaplanMeierFitter`.

- The curve starts at 1 (100% of customers at time 0).
- It steps down at churn times.
- At each tenure `t`, the curve value is:
  > Estimated probability that a randomly chosen customer is still
  > active (has not churned) by time `t`.

We also compute the **median survival time**.


In [None]:
kmf = KaplanMeierFitter()

T = surv_df["tenure"]  # durations
E = surv_df["churn_event"]  # events

kmf.fit(durations=T, event_observed=E, label="All customers")

kmf.plot_survival_function()
plt.title("Kaplan–Meier survival curve – All customers")
plt.xlabel("Tenure (months)")
plt.ylabel("Survival probability (still a customer)")
plt.show()

print("Median survival time (months):", kmf.median_survival_time_)


**Interpretation:**

- The curve shows how quickly the customer base decays over time.
- The median survival time (if defined) tells us roughly when **50% of
  customers have churned**.

Next, we compare survival curves across important **segments**.


## 5. Survival by contract type (+ log-rank test)

Contract type is a key driver of churn. We compare survival curves for
levels of `Contract`:

- `Month-to-month`
- `One year`
- `Two year`

We then use a **log-rank test** to statistically assess whether survival
curves differ.


In [None]:
if "Contract" not in surv_df.columns:
    raise KeyError("Expected 'Contract' column not found in data.")

kmf_contract = KaplanMeierFitter()
contract_types = surv_df["Contract"].unique()

plt.figure(figsize=(8, 5))
for c in contract_types:
    mask = surv_df["Contract"] == c
    T_c = surv_df.loc[mask, "tenure"]
    E_c = surv_df.loc[mask, "churn_event"]

    kmf_contract.fit(T_c, event_observed=E_c, label=str(c))
    kmf_contract.plot_survival_function()

plt.title("Survival curves by contract type")
plt.xlabel("Tenure (months)")
plt.ylabel("Survival probability")
plt.legend(title="Contract")
plt.show()

# Global log-rank / multivariate test across all contract groups
multi_lr = multivariate_logrank_test(
    event_durations=surv_df["tenure"],
    groups=surv_df["Contract"],
    event_observed=surv_df["churn_event"],
)
print("Multivariate log-rank test across Contract groups:")
print(multi_lr.summary)


Contract type usually shows **strongly different survival patterns**:

- `Month-to-month` → steeper decline, shorter survival.
- `One year`, `Two year` → flatter curves, longer retention.

The log-rank test p-value indicates whether these differences are
statistically significant (typically they are).


## 6. Survival by other segments

We briefly look at survival by:

- `InternetService` (e.g. DSL vs Fiber optic vs No).
- `PaymentMethod` (e.g. electronic check vs automatic methods).

This is mainly descriptive, to spot interesting patterns.


In [None]:
def plot_km_by_group(
    df: pd.DataFrame,
    group_col: str,
    duration_col: str = "tenure",
    event_col: str = "churn_event",
    max_groups: int = 5,
) -> None:
    """Plot Kaplan–Meier survival curves by values of a grouping column.

    Args:
        df: Survival DataFrame.
        group_col: Column defining groups.
        duration_col: Duration column name.
        event_col: Event indicator column name.
        max_groups: Maximum number of distinct groups to plot.
    """
    if group_col not in df.columns:
        raise KeyError(f"Column {group_col!r} not found in DataFrame.")

    kmf_local = KaplanMeierFitter()

    groups = df[group_col].value_counts().index[:max_groups]
    plt.figure(figsize=(8, 5))
    for g in groups:
        mask = df[group_col] == g
        T_g = df.loc[mask, duration_col]
        E_g = df.loc[mask, event_col]
        kmf_local.fit(T_g, event_observed=E_g, label=str(g))
        kmf_local.plot_survival_function()

    plt.title(f"Survival by {group_col}")
    plt.xlabel("Tenure (months)")
    plt.ylabel("Survival probability")
    plt.legend(title=group_col)
    plt.show()


for col in ["InternetService", "PaymentMethod"]:
    if col in surv_df.columns:
        plot_km_by_group(surv_df, col, max_groups=4)


These segment-based curves can highlight **high-risk configurations**, e.g.
particular combinations of service and payment method with lower survival.

Next we move to a **multivariate model**: the Cox proportional hazards model.


## 7. Cox Proportional Hazards model

Kaplan–Meier curves describe survival but do not handle multiple features
simultaneously. For that, we use the **Cox Proportional Hazards (PH) model**.

Cox model:

- Models the **hazard** (instantaneous risk of churn) as a function of features.
- Outputs **log hazard ratios** and **hazard ratios** (`exp(coef)`).
- Interpretable: each coefficient is the multiplicative effect on the hazard.

### 7.1 Feature engineering for Cox

`lifelines.CoxPHFitter` expects a purely numeric DataFrame.
We will:

1. Choose a set of features that are both meaningful and not too many.
2. Convert Yes/No columns to 0/1.
3. One-hot encode categorical variables with `get_dummies`.


In [None]:
# Feature set for Cox model (you can adjust this list)
cox_features_raw: List[str] = [
    "tenure",          # duration (also used as covariate; we could omit if we want)
    "churn_event",     # event indicator
    "SeniorCitizen",
    "Partner",
    "Dependents",
    "PaperlessBilling",
    "Contract",
    "PaymentMethod",
    "MonthlyCharges",
    "TotalCharges",
]

missing_cox_feats = [c for c in cox_features_raw if c not in surv_df.columns]
if missing_cox_feats:
    raise KeyError(f"Missing expected columns for Cox model: {missing_cox_feats}")

cox_df = surv_df[cox_features_raw].copy()

# Map Yes/No style columns to 0/1
bool_like_cols: List[str] = ["Partner", "Dependents", "PaperlessBilling"]
for col in bool_like_cols:
    if col in cox_df.columns:
        cox_df[col] = cox_df[col].map({"No": 0, "Yes": 1}).astype(int)

# One-hot encode categorical variables (dropping baseline)
cox_df_encoded = pd.get_dummies(
    cox_df,
    columns=["Contract", "PaymentMethod"],
    drop_first=True,
)

print("Cox design matrix shape:", cox_df_encoded.shape)
cox_df_encoded.head()


### 7.2 Train / test split for Cox model

Although survival models are often fit on the full dataset, it is useful
for evaluation to keep a **train / test split**.

We will:

- Split rows into train and test sets.
- Fit Cox on train.
- Evaluate concordance index (c-index) on both.


In [None]:
from sklearn.model_selection import train_test_split

# We keep the encoded dataset but separate duration/event from covariates

train_df, test_df = train_test_split(
    cox_df_encoded,
    test_size=0.3,
    random_state=RANDOM_STATE,
    stratify=cox_df_encoded["churn_event"],
)

print("Train shape:", train_df.shape, "Test shape:", test_df.shape)


### 7.3 Fit Cox model and inspect summary

We now fit `CoxPHFitter` using:

- `duration_col = 'tenure'`
- `event_col = 'churn_event'`

We then print the summary with hazard ratios.


In [None]:
cph = CoxPHFitter()

cph.fit(train_df, duration_col="tenure", event_col="churn_event")

# Text summary
cph.print_summary()  # includes exp(coef), p-values, CIs


In [None]:
# Extract and visualise hazard ratios
summary_df = cph.summary.copy()
summary_df["hazard_ratio"] = summary_df["exp(coef)"]

summary_sorted = summary_df.sort_values("hazard_ratio", ascending=False)

display(summary_sorted[["hazard_ratio", "p", "exp(coef) lower 95%", "exp(coef) upper 95%"]])

plt.figure(figsize=(8, 6))
plt.barh(summary_sorted.index, summary_sorted["hazard_ratio"])
plt.axvline(1.0, linestyle="--", color="black")
plt.xlabel("Hazard ratio (exp(coef))")
plt.title("Cox model hazard ratios")
plt.gca().invert_yaxis()
plt.tight_layout()
plt.show()


**Interpreting hazard ratios:**

- Hazard ratio **> 1** → higher values of the feature are associated with
  **higher churn risk** (customers churn earlier).
- Hazard ratio **< 1** → higher values of the feature are associated with
  **lower churn risk** (customers stay longer).

Examples (exact values depend on your run):

- `Contract_Two year` with hazard ratio << 1 means two-year contracts
  **strongly reduce** churn risk vs the baseline contract.
- A slight hazard ratio > 1 for `MonthlyCharges` suggests higher fees
  modestly increase churn risk, controlling for other variables.


## 8. Model performance: concordance index

The **concordance index (c-index)** measures how well the model ranks
customers by risk:

- `1.0` – perfect ranking.
- `0.5` – random ranking.

We compute c-index on train and test sets using the Cox risk scores.


In [None]:
# Predicted partial hazards (risk scores)
train_risk_scores = cph.predict_partial_hazard(train_df).values.ravel()
test_risk_scores = cph.predict_partial_hazard(test_df).values.ravel()

c_index_train = concordance_index(
    train_df["tenure"], -train_risk_scores, train_df["churn_event"]
)
# Note: we negate scores because higher risk -> shorter survival.

c_index_test = concordance_index(
    test_df["tenure"], -test_risk_scores, test_df["churn_event"]
)

print(f"C-index (train): {c_index_train:.3f}")
print(f"C-index (test):  {c_index_test:.3f}")


A decent c-index (typically 0.7–0.8) indicates the model is reasonably good
at ranking customers by churn risk over time.


## 9. Proportional hazards assumption checks

The Cox model assumes that **hazard ratios are constant over time**
(proportional hazards). We can roughly check this with
`cph.check_assumptions`.

This function prints diagnostics and can generate residual plots
(Schoenfeld residuals) to see if the effect of a covariate changes over time.

> Note: this is more of a *diagnostic guide* than a hard rule; use judgement.


In [None]:
# This will print diagnostics and, if `show_plots=True`, open plots.
# Run it interactively in your environment to inspect results.

cph.check_assumptions(train_df, show_plots=False)


If the diagnostics highlight strong violations for a variable, you can:

- Interact with that variable (e.g. create interactions).
- Stratify the model by that variable.
- Use time-varying effects or an alternative survival model.


## 10. Predicted survival curves for example customers

To make the model tangible, we construct a few **synthetic customers** and
plot their predicted survival curves.

Example profiles:

- **Customer A** – Month-to-month, electronic check, higher monthly charges.
- **Customer B** – Two-year contract, automatic bank transfer, lower charges.

We use the fitted Cox model to predict their survival functions over time.


In [None]:
# Build base profile using medians for numeric and common values for binary
base_profile: Dict[str, float] = {
    "SeniorCitizen": float(surv_df["SeniorCitizen"].median()),
    "Partner": 0.0,
    "Dependents": 0.0,
    "PaperlessBilling": 1.0,
    "MonthlyCharges": float(surv_df["MonthlyCharges"].median()),
    "TotalCharges": float(surv_df["TotalCharges"].median()),
}

# Columns used in the Cox design matrix (excluding duration/event)
cox_feature_cols: List[str] = [col for col in cox_df_encoded.columns if col not in ["tenure", "churn_event"]]


def make_cox_design_row(
    contract: str,
    payment_method: str,
    monthly_charges: float,
) -> pd.DataFrame:
    """Create a single-row Cox design matrix for a hypothetical customer.

    Args:
        contract: Contract type as in original data.
        payment_method: PaymentMethod as in original data.
        monthly_charges: MonthlyCharges value to set.

    Returns:
        1-row DataFrame with columns matching cox_feature_cols.
    """
    # Start with zeros for design matrix
    data: Dict[str, float] = {col: 0.0 for col in cox_feature_cols}

    # Insert base numeric/binary features
    tmp = base_profile.copy()
    tmp["MonthlyCharges"] = monthly_charges

    for col in data.keys():
        if col in tmp:
            data[col] = float(tmp[col])

    # Handle one-hot encoded Contract_* columns
    for col in cox_feature_cols:
        if col.startswith("Contract_"):
            data[col] = 0.0
    target_contract_col = f"Contract_{contract}"
    if target_contract_col in data:
        data[target_contract_col] = 1.0

    # Handle one-hot encoded PaymentMethod_* columns
    for col in cox_feature_cols:
        if col.startswith("PaymentMethod_"):
            data[col] = 0.0
    target_pm_col = f"PaymentMethod_{payment_method}"
    if target_pm_col in data:
        data[target_pm_col] = 1.0

    return pd.DataFrame([data])


# Define two example customers
example_A = make_cox_design_row(
    contract="Month-to-month",
    payment_method="Electronic check",
    monthly_charges=float(surv_df["MonthlyCharges"].quantile(0.75)),
)

example_B = make_cox_design_row(
    contract="Two year",
    payment_method="Bank transfer (automatic)",
    monthly_charges=float(surv_df["MonthlyCharges"].quantile(0.25)),
)

# Time grid for prediction
max_tenure = float(surv_df["tenure"].max())
timeline = np.linspace(1, max_tenure, 60)

surv_A = cph.predict_survival_function(example_A, times=timeline)
surv_B = cph.predict_survival_function(example_B, times=timeline)

plt.figure(figsize=(8, 5))
plt.plot(timeline, surv_A.T, label="A: M2M, high charges")
plt.plot(timeline, surv_B.T, label="B: 2-year, low charges")
plt.xlabel("Tenure (months)")
plt.ylabel("Predicted survival probability")
plt.title("Predicted survival curves for example customers")
plt.legend()
plt.show()


In practice, these type of curves help answer questions like:

- *“If I move this customer from month-to-month to a two-year contract,
   how much longer do I expect them to stay?”*
- *“How does lowering monthly charges for high-risk customers change their
   expected survival?”*


## 11. Survival by model-based risk segments

We can combine survival analysis with **risk scores** from the Cox model.

Steps:

1. Use the Cox model to compute risk scores (partial hazards) for all customers.
2. Split customers into **risk quantiles** (e.g. low / medium / high risk).
3. Plot Kaplan–Meier curves by risk segment.

If the model is useful, high-risk segments should have lower survival.


In [None]:
# Compute risk scores for the full encoded dataset
full_risk_scores = cph.predict_partial_hazard(cox_df_encoded).values.ravel()

risk_df = cox_df_encoded[["tenure", "churn_event"]].copy()
risk_df["risk_score"] = full_risk_scores

# Define quantile-based risk segments
risk_df["risk_segment"] = pd.qcut(
    risk_df["risk_score"],
    q=3,
    labels=["low_risk", "medium_risk", "high_risk"],
)

risk_df["risk_segment"].value_counts()


In [None]:
# Plot KM curves by risk segment
kmf_risk = KaplanMeierFitter()

plt.figure(figsize=(8, 5))
for seg in ["low_risk", "medium_risk", "high_risk"]:
    mask = risk_df["risk_segment"] == seg
    T_seg = risk_df.loc[mask, "tenure"]
    E_seg = risk_df.loc[mask, "churn_event"]

    kmf_risk.fit(T_seg, event_observed=E_seg, label=seg)
    kmf_risk.plot_survival_function()

plt.title("Survival curves by model-based risk segment")
plt.xlabel("Tenure (months)")
plt.ylabel("Survival probability")
plt.legend(title="Risk segment")
plt.show()


If the model is doing a good job, you should see:

- **High-risk** segment with the **lowest survival**.
- **Low-risk** segment with the **highest survival**.

This directly links the model output (risk score) to time-based retention.


## 12. Optional: Aalen Additive model

The **Aalen Additive model** is an alternative to Cox that models the hazard
as a **sum** of time-varying coefficients times the covariates.

It is more flexible in some situations, and can show how the effect of
features **changes over time**.

We fit a small Aalen model on a simplified feature set just to illustrate.


In [None]:
# Simplified design for Aalen (to keep things readable)
aalen_features_raw: List[str] = [
    "tenure",
    "churn_event",
    "SeniorCitizen",
    "MonthlyCharges",
    "TotalCharges",
]

missing_aalen = [c for c in aalen_features_raw if c not in surv_df.columns]
if not missing_aalen:
    aalen_df = surv_df[aalen_features_raw].copy()

    aalen_fitter = AalenAdditiveFitter(fit_intercept=True)
    aalen_fitter.fit(
        aalen_df,
        duration_col="tenure",
        event_col="churn_event",
    )

    aalen_fitter.plot()
    plt.title("Aalen Additive model – cumulative coefficients over time")
    plt.show()
else:
    print("Skipping Aalen model; missing columns:", missing_aalen)


The Aalen coefficient plots can show, for example, whether the effect of
`SeniorCitizen` or `MonthlyCharges` on churn risk strengthens or weakens as
tenure increases.

For production applications, you would typically choose **one core model**
(Cox or a more advanced alternative) and spend more time validating it.


## 13. Summary and next steps

In this expanded survival analysis notebook we:

1. Reframed Telco churn as a **time-to-event** problem.
2. Built **Kaplan–Meier survival curves**:
   - Overall.
   - By contract and other segments.
   - Assessed group differences via log-rank tests.
3. Trained a **Cox Proportional Hazards model**:
   - Engineered suitable covariates.
   - Interpreted hazard ratios in business language.
   - Evaluated model discrimination via **concordance index**.
   - Checked proportional hazards assumptions.
4. Created **scenario-based survival curves** for hypothetical customers.
5. Combined model risk scores with survival curves to build **risk-based
   survival segments**.
6. Briefly explored an **Aalen Additive model** as an alternative.

### How this complements classification-based churn models

- Classification:
  - "Will this customer churn in the next X months?"
- Survival:
  - "How long until this customer is likely to churn?"
  - "How does churn risk evolve over time?"

Together, they support:

- **Retention timing** – when to intervene for each segment.
- **Contract / pricing design** – how different contracts change expected
  lifetime.
- **Scenario analysis** – compare policies via predicted survival.

### Possible extensions

- Include more features (services, add-ons, internet type) into the Cox model.
- Use **time-varying covariates** (e.g. evolving usage or billing behaviour).
- Explore parametric survival models (Weibull, log-logistic, etc.).
- Integrate survival predictions into a **lifetime value (LTV)** model by
  combining expected survival with monthly revenue.
