# 07 â€” Survival Analysis
**Author:** Ebenezer Adjartey

Covers: Kaplan-Meier estimator, log-rank test, Cox PH model, parametric survival models (Weibull, exponential), competing risks, time-varying covariates.

In [None]:
import os
import numpy as np
import pandas as pd
from lifelines import KaplanMeierFitter, CoxPHFitter, WeibullFitter
from lifelines.statistics import logrank_test, multivariate_logrank_test
from lifelines import WeibullAFTFitter
import matplotlib.pyplot as plt
import seaborn as sns
import warnings
warnings.filterwarnings('ignore')
np.random.seed(42)
sns.set_theme(style='whitegrid')
print('Libraries loaded. (install lifelines: pip install lifelines)')

## 1. Generate Survival Data

In [None]:
n = 300
# True Weibull(shape=1.5, scale=5) survival times
true_time = np.random.weibull(1.5, n) * 5
# Censoring time ~ Uniform(0, 10)
cens_time = np.random.uniform(0, 10, n)

observed_time  = np.minimum(true_time, cens_time)
event_occurred = (true_time <= cens_time).astype(int)

# Covariates
age     = np.random.normal(55, 12, n)
treated = np.random.binomial(1, 0.5, n)
stage   = np.random.choice(['Early','Late'], n, p=[0.6, 0.4])

# Modify survival: treatment reduces hazard
observed_time += treated * np.random.exponential(1, n)

df = pd.DataFrame({'time':observed_time, 'event':event_occurred,
                   'age':age, 'treated':treated, 'stage':stage})
print(f'N={n}, events={event_occurred.sum()}, censored={n-event_occurred.sum()}')
print(f'Censoring rate: {(1-event_occurred.mean())*100:.1f}%')
print(df.head())

## 2. Kaplan-Meier Estimator

In [None]:
kmf_all = KaplanMeierFitter()
kmf_all.fit(df['time'], event_observed=df['event'], label='Overall')

fig, axes = plt.subplots(1, 2, figsize=(13, 5))

# Overall KM curve
kmf_all.plot_survival_function(ax=axes[0], ci_show=True)
axes[0].set_title('Kaplan-Meier: Overall Survival')
axes[0].set_xlabel('Time'); axes[0].set_ylabel('Survival Probability')

# KM by treatment group
for trt, subset in df.groupby('treated'):
    kmf = KaplanMeierFitter()
    kmf.fit(subset['time'], event_observed=subset['event'], label=f'Treated={trt}')
    kmf.plot_survival_function(ax=axes[1], ci_show=True)
axes[1].set_title('KM Curves by Treatment'); axes[1].set_xlabel('Time')

print(f'Median survival time (overall): {kmf_all.median_survival_time_:.3f}')
print(f'S(5): {kmf_all.survival_function_at_times(5).values[0]:.4f}')

plt.tight_layout()
os.makedirs('07_survival_analysis', exist_ok=True)
plt.savefig('07_survival_analysis/km_curves.png', dpi=100, bbox_inches='tight')
plt.show(); print('Saved.')

## 3. Log-Rank Test

In [None]:
# Compare survival by treatment
trt0 = df[df['treated']==0]; trt1 = df[df['treated']==1]
lr = logrank_test(trt0['time'], trt1['time'],
                  event_observed_A=trt0['event'], event_observed_B=trt1['event'])
print(f'Log-rank test: statistic={lr.test_statistic:.4f}, p={lr.p_value:.4f}')
print('Verdict:', 'Survival curves differ' if lr.p_value < 0.05 else 'No significant difference')

# Multi-group log-rank (by stage)
mlr = multivariate_logrank_test(df['time'], df['stage'], event_col=df['event'])
print(f'\nMultivariate log-rank (by stage): p={mlr.p_value:.4f}')

## 4. Cox Proportional Hazards Model

In [None]:
cph = CoxPHFitter()
cph.fit(df, duration_col='time', event_col='event',
        formula='age + treated + C(stage)')
print(cph.summary)

# Hazard ratios
print('\nHazard Ratios (exp(coef)):')
print(cph.hazard_ratios_.round(4))

# Test proportional hazards assumption
print('\nProportional Hazards Test (Schoenfeld residuals):')
cph.check_assumptions(df, p_value_threshold=0.05, show_plots=False)

## 5. Parametric Survival Model (Weibull AFT)

In [None]:
waf = WeibullAFTFitter()
waf.fit(df, duration_col='time', event_col='event',
        formula='age + treated + C(stage)')
print(waf.summary)

# Predict median survival for new observation
new_obs = pd.DataFrame({'age':[55], 'treated':[1], 'stage':['Early']})
pred_median = waf.predict_median(new_obs)
print(f'\nPredicted median survival (treated, early, age=55): {pred_median.values[0]:.3f}')

## 6. Baseline Hazard & Survival Function

In [None]:
wf = WeibullFitter()
wf.fit(df['time'], event_observed=df['event'])
print(wf.summary)
print(f'Weibull rho (shape): {wf.rho_:.4f}  lambda (scale): {wf.lambda_:.4f}')

fig, ax = plt.subplots(figsize=(8, 5))
wf.plot_survival_function(ax=ax, label='Weibull fit')
kmf_all.plot_survival_function(ax=ax, label='Kaplan-Meier')
ax.set_title('Weibull vs Kaplan-Meier Survival Function')
plt.tight_layout()
plt.savefig('07_survival_analysis/weibull_vs_km.png', dpi=100, bbox_inches='tight')
plt.show(); print('Saved.')

## Key Takeaways

- **KM estimator**: non-parametric; handles censoring; plots survival probability
- **Log-rank test**: non-parametric test for equal survival curves
- **Cox PH model**: semi-parametric; hazard ratio = exp(beta)
- **HR < 1**: protective factor; **HR > 1**: risk factor
- Always check proportional hazards assumption
- **Weibull**: flexible parametric model; shape>1 = increasing hazard
