# Synthetic Risk Score Validation Notebook

This notebook provides a **rigorous validation** of a **synthetic survival risk score**.

The goal is to evaluate how well a single feature (e.g., a Cox-based linear predictor)
predicts time-to-event outcomes.

## Assumptions
- The dataset is in a CSV file with at least the following columns:
  - `OS_YEARS`: survival time in years.
  - `OS_STATUS`: event indicator (1 = event, 0 = censored).
  - `RiskScore`: continuous synthetic risk score (higher = worse prognosis).

---
### What this notebook will compute
1. Descriptive analysis of the risk score.
2. Kaplanâ€“Meier curves by risk quartiles + log-rank test.
3. Univariate Cox proportional hazards model for the risk score.
4. Bootstrap confidence intervals for the C-index.
5. K-fold cross-validated C-index, AUC(t), and Integrated Brier Score (IBS)
   for:
   - a baseline model (no RiskScore),
   - a Cox model with RiskScore.
6. Permutation test.
7. Calibration plot at a chosen time horizon.
8. Summary + deltas (with - without RiskScore).


In [None]:
# ===============================
# 0. Imports & configuration
# ===============================
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
from lifelines import CoxPHFitter, KaplanMeierFitter
from lifelines.statistics import multivariate_logrank_test
from lifelines.utils import concordance_index

from sksurv.util import Surv
from sksurv.metrics import (
    concordance_index_censored,
    cumulative_dynamic_auc,
    integrated_brier_score,
)
from sksurv.nonparametric import kaplan_meier_estimator

from sklearn.model_selection import KFold
import plotly.graph_objects as go

plt.rcParams['figure.figsize'] = (7, 5)
plt.rcParams['axes.grid'] = True

DATA_PATH = '../../data/train_enhanced.csv'
TIME_COL = 'OS_YEARS'
EVENT_COL = 'OS_STATUS'
RISK_COL = 'RiskScore'

AUC_TIMES = np.array([1.0, 2.0, 3.0])
CALIBRATION_TIME = 2.0

RANDOM_STATE = 42
N_BOOTSTRAP = 1000
N_SPLITS_CV = 5

np.random.seed(RANDOM_STATE)

## 1. Load and inspect data

In [None]:
df = pd.read_csv(DATA_PATH)
df = df[[TIME_COL, EVENT_COL, RISK_COL]].copy().dropna()
df[EVENT_COL] = df[EVENT_COL].astype(int)
print('Data shape:', df.shape)
df.head()

In [None]:
print(df[RISK_COL].describe())
fig = go.Figure()
fig.add_trace(go.Histogram(x=df[RISK_COL], nbinsx=30, name='Risk score'))
fig.update_layout(title='Distribution of synthetic risk score',
                  xaxis_title='Risk score', yaxis_title='Count',
                  template='simple_white', bargap=0.05)
fig.show()

In [None]:
df = df.sort_values(RISK_COL).reset_index(drop=True)
df['risk_quartile'] = pd.qcut(df[RISK_COL], q=4,
                              labels=['Q1 (lowest)','Q2','Q3','Q4 (highest)'])

kmf = KaplanMeierFitter()
fig = go.Figure()

for q in df['risk_quartile'].cat.categories:
    mask = df['risk_quartile'] == q
    kmf.fit(df.loc[mask, TIME_COL], df.loc[mask, EVENT_COL], label=str(q))
    sf = kmf.survival_function_
    t = sf.index.values
    s = sf.iloc[:,0].values
    fig.add_trace(go.Scatter(x=t,y=s,mode='lines',name=str(q)))

fig.update_layout(title='KM survival by risk quartile',
                  xaxis_title='Time', yaxis_title='Survival probability',
                  template='simple_white')
fig.show()

lr = multivariate_logrank_test(df[TIME_COL], df['risk_quartile'], df[EVENT_COL])
print(lr)

In [None]:
cph = CoxPHFitter()
cph.fit(df[[TIME_COL, EVENT_COL, RISK_COL]], duration_col=TIME_COL, event_col=EVENT_COL)
cph.print_summary()
c_index_in_sample = cph.concordance_index_
print("In-sample C-index =", c_index_in_sample)

In [None]:
from lifelines.utils import concordance_index

def bootstrap_c_index(df,n_boot=1000):
    rng=np.random.RandomState(42)
    cvals=[]
    for _ in range(n_boot):
        idx=rng.randint(0,len(df),size=len(df))
        s=df.iloc[idx]
        c=concordance_index(s[TIME_COL],-s[RISK_COL],s[EVENT_COL])
        cvals.append(c)
    cvals=np.array(cvals)
    return cvals.mean(),np.percentile(cvals,[2.5,97.5]),cvals

mean_c,(low,high),boot=bootstrap_c_index(df,N_BOOTSTRAP)
print(mean_c,low,high)

fig=go.Figure()
fig.add_trace(go.Histogram(x=boot,nbinsx=30))
fig.update_layout(title='Bootstrap C-index',template='simple_white')
fig.show()

In [None]:
# ===== 5. Cross-validation =====
y_struct = Surv.from_dataframe(event=EVENT_COL,time=TIME_COL,data=df)
kf = KFold(n_splits=N_SPLITS_CV,shuffle=True,random_state=RANDOM_STATE)
event_field,time_field=y_struct.dtype.names

cindex_with=[]; ibs_with=[]; auc_with=[]
cindex_base=[]; ibs_base=[]; auc_base=[]

for tr,te in kf.split(df):
    df_tr=df.iloc[tr]; df_te=df.iloc[te]
    y_tr=y_struct[tr]; y_te=y_struct[te]

    # Cox with RiskScore
    m=CoxPHFitter()
    m.fit(df_tr[[TIME_COL,EVENT_COL,RISK_COL]],duration_col=TIME_COL,event_col=EVENT_COL)
    risk=m.predict_partial_hazard(df_te[[RISK_COL]]).values.ravel()
    cindex_with.append(concordance_index_censored(y_te[event_field].astype(bool),y_te[time_field],risk)[0])

    surv=m.predict_survival_function(df_te[[RISK_COL]])
    times=surv.index.values
    preds=surv.T.values

    max_t=y_tr[time_field].max()
    mask=y_te[time_field]<max_t
    if mask.sum()>0:
        y_te_ib=y_te[mask]
        preds_ib=preds[mask]
        tmin=y_te_ib[time_field].min(); tmax=y_te_ib[time_field].max()
        mt=(times>=tmin)&(times<tmax)
        if mt.sum()>=2:
            ibs_with.append(integrated_brier_score(y_tr,y_te_ib,preds_ib[:,mt],times[mt]))
        else:
            ibs_with.append(np.nan)
    else:
        ibs_with.append(np.nan)

    auc_vec=np.full(len(AUC_TIMES),np.nan)
    tt=y_te[time_field]
    ok=AUC_TIMES[(AUC_TIMES>=max(tt.min(),1e-8))&(AUC_TIMES<tt.max())]
    if len(ok)>0:
        et,av=cumulative_dynamic_auc(y_tr,y_te,risk,ok)
        for t0,a0 in zip(et,av):
            idx=np.where(np.isclose(AUC_TIMES,t0))[0]
            if len(idx)>0: auc_vec[idx[0]]=a0
    auc_with.append(auc_vec)

    # Baseline KM
    ttr,strr=kaplan_meier_estimator(y_tr[event_field].astype(bool),y_tr[time_field])
    risk0=np.zeros_like(y_te[time_field])
    cindex_base.append(concordance_index_censored(y_te[event_field].astype(bool),y_te[time_field],risk0)[0])

    max_t2=y_tr[time_field].max()
    mask2=y_te[time_field]<max_t2
    if mask2.sum()>0:
        y_te2=y_te[mask2]
        tmin2=y_te2[time_field].min(); tmax2=y_te2[time_field].max()
        mt2=(ttr>=tmin2)&(ttr<tmax2)
        if mt2.sum()>=2:
            pred=np.tile(strr[mt2],(len(y_te2),1))
            ibs_base.append(integrated_brier_score(y_tr,y_te2,pred,ttr[mt2]))
        else:
            ibs_base.append(np.nan)
    else:
        ibs_base.append(np.nan)

    auc_vec2=np.full(len(AUC_TIMES),np.nan)
    if len(ok)>0:
        et2,av2=cumulative_dynamic_auc(y_tr,y_te,risk0,ok)
        for t0,a0 in zip(et2,av2):
            idx=np.where(np.isclose(AUC_TIMES,t0))[0]
            if len(idx)>0: auc_vec2[idx[0]]=a0
    auc_base.append(auc_vec2)

cindex_with=np.array(cindex_with); ibs_with=np.array(ibs_with); auc_with=np.array(auc_with)
cindex_base=np.array(cindex_base); ibs_base=np.array(ibs_base); auc_base=np.array(auc_base)

print('WITH RiskScore C-index mean=',np.nanmean(cindex_with))
print('BASELINE C-index mean=',np.nanmean(cindex_base))

In [None]:
# ===== 6. Permutation test =====
def permutation_test(df,n_perm=500):
    rng=np.random.RandomState(42); vals=[]
    for _ in range(n_perm):
        s=df.copy()
        s[RISK_COL]=rng.permutation(s[RISK_COL].values)
        m=CoxPHFitter()
        m.fit(s[[TIME_COL,EVENT_COL,RISK_COL]],duration_col=TIME_COL,event_col=EVENT_COL)
        vals.append(m.concordance_index_)
    return np.array(vals)

perm=permutation_test(df,500)
true_cindex=c_index_in_sample
print(true_cindex,perm.mean())

fig=go.Figure()
fig.add_trace(go.Histogram(x=perm,nbinsx=30,opacity=0.75))
fig.add_trace(go.Scatter(x=[true_cindex,true_cindex],y=[0,len(perm)/3],mode='lines',name='True'))
fig.update_layout(template='simple_white')
fig.show()

In [None]:
# ===== 7. Calibration =====
m=CoxPHFitter()
m.fit(df[[TIME_COL,EVENT_COL,RISK_COL]],duration_col=TIME_COL,event_col=EVENT_COL)
surv=m.predict_survival_function(df[[RISK_COL]])
times=surv.index.values

idx=np.searchsorted(times,CALIBRATION_TIME,side='right')-1
t_eff=times[idx]
pred=surv.iloc[idx].values

n_bins=10
q=np.quantile(pred,np.linspace(0,1,n_bins+1))
bid=np.digitize(pred,q[1:-1],right=True)

bx=[]; by=[]
for b in range(n_bins):
    mask=bid==b
    if mask.sum()<10: continue
    bx.append(pred[mask].mean())

    t,s=kaplan_meier_estimator(df.loc[mask,EVENT_COL].values.astype(bool),
                               df.loc[mask,TIME_COL].values)
    if (t<=t_eff).any(): by.append(s[t<=t_eff][-1])
    else: by.append(1.0)

fig=go.Figure()
fig.add_trace(go.Scatter(x=bx,y=by,mode='markers+lines',name='Observed'))
fig.add_trace(go.Scatter(x=[0,1],y=[0,1],mode='lines',name='Perfect'))
fig.update_layout(title=f'Calibration at t={t_eff:.2f}',template='simple_white')
fig.show()

In [None]:
# ===== 8. Summary + deltas =====
def summarize(results,label):
    c=np.array(results['c_index']); i=np.array(results['ibs']); a=results['auc_per_time']
    times=np.array(results['times_auc'],dtype=float)
    out={'label':label,'c_index_mean':float(np.nanmean(c)),'ibs_mean':float(np.nanmean(i))}
    if len(a)>0:
        aa=np.array(a); 
        for k,t in enumerate(times):
            out[f'auc_{t:.2f}_mean']=float(np.nanmean(aa[:,k]))
    return pd.Series(out)

results_with={'c_index':cindex_with,'ibs':ibs_with,'auc_per_time':auc_with,'times_auc':AUC_TIMES}
results_no={'c_index':cindex_base,'ibs':ibs_base,'auc_per_time':auc_base,'times_auc':AUC_TIMES}

s_no=summarize(results_no,'no_risk')
s_with=summarize(results_with,'with_risk')

print(pd.concat([s_no,s_with],axis=1))

delta=pd.Series({'delta_c_index':s_with['c_index_mean']-s_no['c_index_mean'],
                 'delta_ibs':s_with['ibs_mean']-s_no['ibs_mean']})
print(delta)