
# Retention and Survival Analysis for A/B Testing

This notebook is a **playbook for experiments on user retention and time-to-event metrics**.

Instead of only looking at one-shot outcomes (like `converted`), we consider **when** users churn
or perform an action. We cover:

1. Simulating a **churn / retention** experiment with censoring.  
2. **D+1, D+7, D+30** retention as Bernoulli metrics.  
3. **Kaplan–Meier curves** to compare survival / retention over time.  
4. The **log-rank test** for equality of survival curves.  
5. A brief **Cox proportional hazards** model for covariate adjustment (if `lifelines` is available).

All code is typed, documented, and meant to be adapted to real experiment data.


## 0) Setup

In [None]:

from __future__ import annotations

from dataclasses import dataclass
from typing import Tuple, Dict, Any

import math
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

plt.rcParams["figure.figsize"] = (7, 4.5)
plt.rcParams["axes.grid"] = True

# Optional: Cox model via lifelines (if installed)
try:
    from lifelines import CoxPHFitter  # type: ignore
except Exception as e:  # pragma: no cover
    CoxPHFitter = None
    print("lifelines not available; Cox model section will be skipped:", e)



## 1) Simulated retention / churn experiment

We simulate an experiment with two arms:

- `group ∈ {control, treatment}`.  
- Users have a continuous **time-to-churn** (in days) drawn from an exponential distribution.  
- We apply **administrative censoring** at a maximum follow-up time `T_max` (e.g. 60 days).  
- We also generate a pre-period covariate `pre_activity` which affects the churn rate.

Notation:

- \(T_i\): true time to churn for user \(i\).  
- \(C_i\): censoring time (here, fixed at `T_max`).  
- Observed time: \(Y_i = \min(T_i, C_i)\).  
- Event indicator: \(\delta_i = 1\{T_i \le C_i\}\).

In practice, \(T_i\) would be the time from experiment start to churn, and \(C_i\)
could be the experiment end or last observation time.


In [None]:

def simulate_retention_experiment(
    n: int = 20_000,
    baseline_hazard_control: float = 1.0 / 30.0,  # ~30-day average lifetime
    hazard_ratio_treatment: float = 0.8,          # treatment reduces hazard (better retention)
    T_max: float = 60.0,
    seed: int | None = 123,
) -> pd.DataFrame:
    """Simulate a retention / churn experiment with exponential hazards and censoring.

    Parameters
    ----------
    n : int
        Number of users.
    baseline_hazard_control : float
        Baseline hazard rate (lambda) in the control arm.
    hazard_ratio_treatment : float
        Hazard ratio for treatment vs control (< 1 => better retention).
    T_max : float
        Administrative censoring time (maximum follow-up in days).
    seed : int | None
        Random seed.

    Returns
    -------
    DataFrame
        Columns:
        - user_id : int
        - group : 'control' or 'treatment'
        - pre_activity : float (user-level covariate)
        - time : float (observed time, in days)
        - event : int (1 if churn observed, 0 if censored)
    """
    rng = np.random.default_rng(seed)

    user_id = np.arange(n)
    group_flag = rng.binomial(1, 0.5, size=n)
    group = np.where(group_flag == 0, "control", "treatment")

    # Pre-period activity: higher values => more engaged, lower hazard
    pre_activity = rng.normal(loc=0.0, scale=1.0, size=n)

    # Individual hazards: baseline modified by treatment and pre_activity
    # Control hazard:
    lambda_control = baseline_hazard_control * np.exp(-0.4 * pre_activity)
    # Treatment hazard: apply HR globally but keep covariate effect
    lambda_treat = hazard_ratio_treatment * baseline_hazard_control * np.exp(-0.4 * pre_activity)

    hazard = np.where(group_flag == 0, lambda_control, lambda_treat)

    # Exponential time-to-event: T = -log(U)/lambda
    U = rng.uniform(size=n)
    true_time = -np.log(U) / hazard

    # Administrative censoring at T_max
    observed_time = np.minimum(true_time, T_max)
    event = (true_time <= T_max).astype(int)

    df = pd.DataFrame(
        {
            "user_id": user_id,
            "group": group,
            "pre_activity": pre_activity.astype(float),
            "time": observed_time.astype(float),
            "event": event.astype(int),
        }
    )
    return df


df = simulate_retention_experiment()
df.head()


In [None]:

df.groupby("group")[["time", "event"]].agg(["mean", "std", "sum", "count"])



## 2) D+1, D+7, D+30 retention as Bernoulli metrics

Many product teams track retention as **binary indicators**:

- D+1 retention: user is still active at day 1.  
- D+7 retention: user is active at day 7.  
- D+30 retention: user is active at day 30.

Under our churn model, a user is “active at day D” if they **have not churned before day D**.

Given (time, event):

- If `time > D`, user is retained at D (regardless of event status at later times).  
- If `time ≤ D` and `event = 1`, user churned before or at D (not retained).  
- If `time ≤ D` and `event = 0`, they are censored before D (we do not know their status at D);
  here we treat them as **not retained** for simplicity, but in practice you may exclude them
  or handle them carefully.


In [None]:

def add_binary_retention(
    df_in: pd.DataFrame,
    days: Tuple[int, int, int] = (1, 7, 30),
) -> pd.DataFrame:
    """Add D+1, D+7, D+30 retention indicators to a survival dataset.

    For simplicity, censored before day D are treated as not retained at D.
    Adapt this logic if you prefer to drop such users.

    Parameters
    ----------
    df_in : DataFrame
        Input with columns time (float) and event (0/1).
    days : tuple of int
        Days at which to compute retention indicators.

    Returns
    -------
    DataFrame
        Copy of df_in with new binary columns: retain_D1, retain_D7, retain_D30.
    """
    df = df_in.copy()
    for d in days:
        col = f"retain_D{d}"
        df[col] = (df["time"] > float(d)).astype(int)
    return df


df = add_binary_retention(df)
df.head()


In [None]:

ret_cols = ["retain_D1", "retain_D7", "retain_D30"]
retention_summary = (
    df.groupby("group")[ret_cols]
      .agg(["mean", "count"])
)
retention_summary



These D+1 / D+7 / D+30 retention rates can be analyzed with the usual **two-sample
proportion tests** or Bayesian Beta–Binomial models, just like any other Bernoulli metric.

However, this discards the **time-to-event structure** and censoring. For a more complete
view of retention dynamics, we now use **Kaplan–Meier curves** and **log-rank tests**.



## 3) Kaplan–Meier survival curves

The **Kaplan–Meier (KM) estimator** is a non-parametric estimate of the survival function:

\[
\hat S(t) = \prod_{t_j \le t} \left(1 - \frac{d_j}{n_j}\right),
\]

where:

- \(t_j\) are distinct event times,  
- \(d_j\) is the number of events at \(t_j\),  
- \(n_j\) is the number at risk just before \(t_j\).

We implement a simple KM estimator and compare curves between control and treatment.


In [None]:

@dataclass(frozen=True)
class KMCurve:
    """Kaplan–Meier survival curve.

    Attributes
    ----------
    time : np.ndarray
        Sorted unique event times where survival changes.
    survival : np.ndarray
        Estimated survival probabilities at those times.
    n_at_risk : np.ndarray
        Number of individuals at risk just before each time.
    n_events : np.ndarray
        Number of events at each time.
    """
    time: np.ndarray
    survival: np.ndarray
    n_at_risk: np.ndarray
    n_events: np.ndarray


def kaplan_meier(
    time: np.ndarray,
    event: np.ndarray,
) -> KMCurve:
    """Compute the Kaplan–Meier survival curve for a single group.

    Parameters
    ----------
    time : np.ndarray
        Observed times (event or censoring).
    event : np.ndarray
        Event indicators (1 = event, 0 = censored).

    Returns
    -------
    KMCurve
        Stepwise survival curve.
    """
    time = np.asarray(time, dtype=float)
    event = np.asarray(event, dtype=int)
    if time.shape != event.shape:
        raise ValueError("time and event must have the same shape.")
    if time.size == 0:
        raise ValueError("Empty input.")

    # Sort by time
    order = np.argsort(time)
    t_sorted = time[order]
    e_sorted = event[order]

    unique_times = np.unique(t_sorted[e_sorted == 1])  # event times only
    n_times = unique_times.size

    if n_times == 0:
        # No events: survival is 1 for all times
        return KMCurve(
            time=np.array([], dtype=float),
            survival=np.array([], dtype=float),
            n_at_risk=np.array([], dtype=float),
            n_events=np.array([], dtype=float),
        )

    n = time.size
    survival = []
    n_at_risk = []
    n_events = []

    S = 1.0
    idx = 0  # index over sorted data

    for t_j in unique_times:
        # number at risk just before t_j
        at_risk = np.sum(t_sorted >= t_j)
        # events at t_j
        d_j = np.sum((t_sorted == t_j) & (e_sorted == 1))

        if at_risk <= 0:
            continue

        S *= (1.0 - d_j / at_risk)

        survival.append(S)
        n_at_risk.append(float(at_risk))
        n_events.append(float(d_j))

    return KMCurve(
        time=unique_times,
        survival=np.asarray(survival, dtype=float),
        n_at_risk=np.asarray(n_at_risk, dtype=float),
        n_events=np.asarray(n_events, dtype=float),
    )


In [None]:

# Compute KM curves by group
km_curves: Dict[str, KMCurve] = {}
for g in ["control", "treatment"]:
    sub = df[df["group"] == g]
    km_curves[g] = kaplan_meier(
        time=sub["time"].to_numpy(),
        event=sub["event"].to_numpy(),
    )

plt.figure()
for g, curve in km_curves.items():
    if curve.time.size == 0:
        continue
    # Stepwise plot: survival holds until next event time
    plt.step(curve.time, curve.survival, where="post", label=g)
plt.xlabel("time (days)")
plt.ylabel("survival / retention S(t)")
plt.title("Kaplan–Meier retention curves")
plt.legend()
plt.tight_layout()
plt.show()



The curves show **retention over time** for each arm:

- Higher curves mean better retention (lower hazard / churn).  
- Visual separation suggests a treatment effect on retention, but we need a formal test.



## 4) Log-rank test for equality of survival curves

The **log-rank test** compares survival between two groups under the null hypothesis:

\[
H_0: S_1(t) = S_2(t) \quad \text{for all } t.
\]

At each event time (across both groups), we compute:

- \(n_j\): total at risk just before time \(t_j\).  
- \(d_j\): total events at \(t_j\).  
- \(n_{1j}, d_{1j}\): at risk and events in group 1.  

The log-rank statistic is based on:

\[
O_1 - E_1 = \sum_j (d_{1j} - E_{1j}), \quad
E_{1j} = d_j \frac{n_{1j}}{n_j},
\]

and its variance \(V\). Under \(H_0\),

\[
Z = \frac{O_1 - E_1}{\sqrt{V}} \approx \mathcal{N}(0,1),
\]

so \(Z^2\) is approximately \(\chi^2_1\). We return a two-sided p-value based on \(Z\).


In [None]:

def logrank_test_two_groups(
    time: np.ndarray,
    event: np.ndarray,
    group: np.ndarray,
    group_labels: Tuple[str, str] = ("control", "treatment"),
) -> Tuple[float, float]:
    """Perform a log-rank test for equality of survival in two groups.

    Parameters
    ----------
    time : np.ndarray
        Observed times (event or censoring).
    event : np.ndarray
        Event indicators (1 = event, 0 = censored).
    group : np.ndarray
        Group labels (must contain exactly the two group_labels).
    group_labels : tuple of str
        Names of the two groups, e.g. ("control", "treatment").

    Returns
    -------
    z_stat : float
        Log-rank z-statistic (signed).
    p_value : float
        Two-sided p-value based on N(0,1).
    """
    time = np.asarray(time, dtype=float)
    event = np.asarray(event, dtype=int)
    group = np.asarray(group)

    g1, g2 = group_labels

    # Sort by time
    order = np.argsort(time)
    t_sorted = time[order]
    e_sorted = event[order]
    g_sorted = group[order]

    # Unique event times across both groups
    unique_times = np.unique(t_sorted[e_sorted == 1])

    O1_minus_E1 = 0.0
    V1 = 0.0

    n_total = time.size

    for t_j in unique_times:
        # At risk just before t_j
        at_risk = t_sorted >= t_j
        n_j = float(np.sum(at_risk))
        if n_j <= 1.0:
            continue

        # Events at t_j
        at_time = (t_sorted == t_j)
        d_j = float(np.sum(at_time & (e_sorted == 1)))
        if d_j == 0.0:
            continue

        # Group 1 at risk and events
        risk_g1 = at_risk & (g_sorted == g1)
        n_1j = float(np.sum(risk_g1))
        d_1j = float(np.sum(at_time & (e_sorted == 1) & (g_sorted == g1)))

        if n_1j == 0.0 or n_1j == n_j:
            continue

        # Expected events in group 1 under H0
        E_1j = d_j * (n_1j / n_j)

        # Variance contribution
        V_1j = (
            (n_1j * (n_j - n_1j) * d_j * (n_j - d_j))
            / (n_j ** 2 * (n_j - 1.0))
        )

        O1_minus_E1 += (d_1j - E_1j)
        V1 += V_1j

    if V1 <= 0.0:
        raise ValueError("Log-rank variance is zero; check data.")

    z = O1_minus_E1 / math.sqrt(V1)

    # two-sided p-value via normal approximation
    cdf = 0.5 * (1.0 + math.erf(z / math.sqrt(2.0)))
    p = 2.0 * min(cdf, 1.0 - cdf)
    return float(z), float(p)


z_lr, p_lr = logrank_test_two_groups(
    time=df["time"].to_numpy(),
    event=df["event"].to_numpy(),
    group=df["group"].to_numpy(),
    group_labels=("control", "treatment"),
)
z_lr, p_lr



The log-rank test gives a **global comparison** of survival curves between control and treatment.

- Large |z| and small p-value ⇒ evidence that retention differs between the arms.  
- It is the standard test used in clinical trials and many production retention experiments.



## 5) Cox proportional hazards model (optional)

The **Cox proportional hazards model** is a semi-parametric regression model for survival data:

\[
h(t \mid X) = h_0(t) \exp(\beta^\top X),
\]

where:

- \(h_0(t)\) is an unspecified baseline hazard,  
- \(X\) are covariates (e.g. treatment arm, pre-activity),  
- \(\beta\) are log hazard ratios.

This model allows us to:

- Adjust for covariates like `pre_activity`, country, device, etc.  
- Estimate a **hazard ratio for treatment** while controlling for these factors.

Below we fit a Cox model if `lifelines` is available in the environment.


In [None]:

if CoxPHFitter is None:
    print("lifelines not available; skipping Cox model fit.")
else:
    df_cox = df.copy()
    df_cox["treat_flag"] = (df_cox["group"] == "treatment").astype(int)

    # Keep only the columns we need for CoxPHFitter
    df_cox_input = df_cox[["time", "event", "treat_flag", "pre_activity"]]

    cph = CoxPHFitter()
    cph.fit(df_cox_input, duration_col="time", event_col="event")
    cph.print_summary()



Interpretation of the Cox model:

- The coefficient for `treat_flag` corresponds to a **log hazard ratio** between treatment and control,  
  *conditional* on `pre_activity`.  
- `exp(coef)` is the estimated hazard ratio:
  - < 1 ⇒ treatment reduces churn (improves retention).  
  - > 1 ⇒ treatment increases churn (worse retention).

This is a natural extension of log-rank: when you include only treatment as a covariate,
the Cox score test is closely related to the log-rank test; adding extra covariates improves
efficiency and controls for imbalances.



## 6) Practical notes for real retention experiments

When applying this to real data:

1. **Data preparation**
   - Define a clear **start date** for each user (e.g. experiment exposure).  
   - Compute time to churn or last activity, and an event indicator (churn vs censored).  
   - Consider multiple churn definitions (e.g. 14 days inactive, 30 days inactive).

2. **KM & log-rank first**
   - Plot Kaplan–Meier curves by arm, with confidence bands if possible.  
   - Run a log-rank test as the primary comparison of retention.

3. **Cox model for adjustment**
   - Include treatment and key covariates (`pre_activity`, geography, device).  
   - Focus on the hazard ratio for treatment with its confidence interval.

4. **Connect to D+1 / D+7 / D+30 metrics**
   - Use the same survival data to compute binary retention at selected horizons.  
   - This makes it easy to communicate results to non-technical stakeholders
     while still using proper survival methods under the hood.

5. **Guardrails and decisions**
   - Combine retention with other metrics (revenue, engagement, complaints) as guardrails.  
   - Use the same multi-metric decision patterns you already use for conversion and revenue.

This notebook gives you the core building blocks for **“do users stick around?”** experiments
with solid survival analysis foundations.
