In [38]:
import os
import matplotlib.pyplot as plt
import pandas as pd
import numpy as np
import torchtuples as tt
import json
import torch
from sksurv.metrics import brier_score
from lifelines.utils import concordance_index
from pycox.models import CoxPH
from pycox_XL.models import DeepHitSingle
from pycox_XL.evaluation import EvalSurv

np.random.seed(1234)
torch.manual_seed(123)
torch.cuda.manual_seed(123)


In [39]:
def import_dataset_external(in_filename, norm_mode):
    df = pd.read_csv(in_filename, sep=',')
    df.loc[df["month_fp_whocvd"] == 0, "month_fp_whocvd"] = 0.001
    df.reset_index(drop=True, inplace=True)
    df_test = df[['smoke', 'sbp', 'tc', 'hdlc', 'age_enter', 'sex', 'base_dm', 'who_cvd', 'month_fp_whocvd']]

    smoke_dy = pd.get_dummies(df_test["smoke"], prefix=['smoke'], drop_first=True)
    smoke_dy.rename(columns={"['smoke']_1": "smoke_1", "['smoke']_2": "smoke_2"}, inplace=True)
    df_test = pd.concat([df_test, smoke_dy], axis=1)
    df_test = df_test.drop("smoke", axis=1)

    df_test_con= df_test[['sbp', 'tc', 'hdlc', 'age_enter']]
    mean_train=pd.Series([137.768444, 5.761835, 1.460055, 56.598483],
        index=['sbp', 'tc', 'hdlc', 'age_enter'])
    mean_test=df_test_con.mean()

    if norm_mode == 'train':
        df_test_con = (df_test_con-mean_train)
        rdf_test = pd.concat([df_test_con, df_test[['smoke_1', 'smoke_2','sex', 'base_dm', 'who_cvd', 'month_fp_whocvd']]], axis=1)

    elif norm_mode == 'test':
        df_test_con = (df_test_con-mean_test)
        rdf_test = pd.concat([df_test_con, df_test[['smoke_1', 'smoke_2','sex', 'base_dm', 'who_cvd', 'month_fp_whocvd']]], axis=1)

    order = ['sbp', 'tc', 'hdlc', 'age_enter','smoke_1', 'smoke_2','sex', 'base_dm', 'who_cvd', 'month_fp_whocvd']
    df_test = df_test[order]
    rdf_test = rdf_test[order]

    return df_test, rdf_test, mean_train, mean_test

def split_var(df, method, **kwargs):
    if method == 'discrete':
        x = df.drop(['who_cvd', 'month_fp_whocvd'], axis=1).values.astype('float32')
        get_target = lambda df: (df['month_fp_whocvd'].values, df['who_cvd'].values)
        y = labtrans.transform(*get_target(df))
    elif method == 'continuous':
        x = df.drop(['who_cvd', 'month_fp_whocvd'], axis=1).values.astype('float32')
        get_target = lambda df: (df['month_fp_whocvd'].values, df['who_cvd'].values)
        y = get_target(df)
    else:
        print("INPUT METHOD ERROR!")

    return x ,y

## 1. Import dataset and calculate 10-year risk by original models

In [40]:
df_test, rdf_test, mean_train, mean_test= import_dataset_external(in_filename='./example_data.csv', norm_mode='train')

In [42]:
rdf_test.head(10)

Unnamed: 0,sbp,tc,hdlc,age_enter,smoke_1,smoke_2,sex,base_dm,who_cvd,month_fp_whocvd
0,5.231556,-3.041835,1.159945,-7.441878,False,False,0,0,0,101.70685
1,35.231556,-0.511835,0.229945,-6.090612,False,False,0,0,0,92.41096
2,6.231556,-1.811835,-0.150055,5.371264,False,False,0,0,0,44.860274
3,-10.768444,0.348165,-0.850055,8.237246,False,False,0,1,1,51.838356
4,13.231556,-2.571835,-0.910055,-9.403548,False,False,0,0,0,89.14247
5,20.231556,-1.291835,-0.460055,-2.695677,False,False,0,1,0,116.90959
6,51.231556,-2.901835,0.289945,21.295699,False,True,1,0,0,95.2274
7,-8.768444,-2.551835,0.539945,3.384816,False,False,0,1,0,115.2274
8,-1.768444,-0.071835,0.089945,0.849977,False,True,0,0,0,85.34247
9,-17.768444,-3.061835,0.649945,-1.158373,False,True,1,0,0,86.580821


### 1.1 DeepSurv model

In [43]:
# Define the target extraction function
x_test, y_test = split_var(rdf_test, method='continuous')
get_target = lambda df: (df['month_fp_whocvd'].values, df['who_cvd'].values)
duration_test, events_test = get_target(rdf_test)

# Define model parameters
out_path = os.path.join(os.getcwd(), "Oct_hpo_best_deepsurv")
in_features = x_test.shape[1]
batch_size = 2048
epochs = 2048
dropout = 0.1
hid_layers=2
num_nodes=[7,3]
lr_rate=  0.002224842384072191
weight_decay=  0.00034917421065908957
optimizer_name = 'Adam'
momentum = 0.6451658205690902

surv_net = tt.practical.MLPVanilla(in_features, num_nodes, out_features=1, batch_norm=True, dropout=dropout, output_bias=False)
surv_net.state_dict()
surv_net = torch.load(out_path + '/deepsurv_inc_net.pt')

if optimizer_name == 'Adam':
    surv_model = CoxPH(surv_net, tt.optim.Adam(lr=lr_rate,  weight_decay=weight_decay))
else:
    surv_model = CoxPH(surv_net, tt.optim.SGD(lr=lr_rate,  weight_decay=weight_decay, momentum=momentum))

surv_model.optimizer.load_state_dict(torch.load(out_path + '/deepsurv_inc_optimizer.pt'))
surv_model.optimizer.state_dict()
surv_s0_120=0.95648398
pred= surv_model.predict(x_test)
surv_risk= 1 - surv_s0_120** np.exp(pred[:, 0])

### 1.2 DeepHit

In [44]:
labtrans = DeepHitSingle.label_transform(np.array([0.001, 12, 24, 36, 48, 60, 72, 84, 96, 108, 114, 120, 126, 132, 144, 156, 178]))
x_test, y_test = split_var(rdf_test, method='discrete')
out_path = os.path.join(os.getcwd(), "Oct_hpo_best_deephit")

batch_size = 1024
epochs = 2048
dropout =0.5
hid_layers=3
num_nodes=[14,14,7]

lr_rate = 0.0003645214326065641
weight_decay=0.0029598782837134477
optimizer_name = 'Adam'
momentum =0
alpha_beta = 1.0
prob = 0.2

hit_net =tt.practical.MLPVanilla(in_features, num_nodes, out_features=labtrans.out_features, batch_norm=True,
                              dropout=dropout, output_bias=False)
hit_net = torch.load(out_path + '/deephit_single_inc_net.pt')

if optimizer_name == 'Adam':
    hit_model = DeepHitSingle(hit_net, tt.optim.Adam(lr=lr_rate, weight_decay=weight_decay), alpha=alpha_beta*prob, beta=alpha_beta*(1-prob),sigma=10,
                          duration_index=labtrans.cuts)
else:
    hit_model = DeepHitSingle(hit_net, tt.optim.SGD(lr=lr_rate, weight_decay=weight_decay, momentum=momentum), alpha=alpha_beta*prob, beta=alpha_beta*(1-prob),
                          sigma=10, duration_index=labtrans.cuts)

hit_model.optimizer.load_state_dict(torch.load(out_path + '/deephit_single_inc_optimizer.pt'))

surv =hit_model.predict_surv_df(x_test)
ev = EvalSurv(surv, duration_test, events_test, censor_surv='km')
hit_risk  = np.clip(1 - ev.surv_at_times(120).squeeze(), 1e-10, 1 - 1e-10)


### 1.3 Cox models 

In [45]:
interaction_cols = ['sbp', 'tc', 'hdlc', 'age_enter','base_dm','smoke_1','smoke_2']
# Simplified Cox Model Calculations

# Create interaction columns
for col in interaction_cols:
    rdf_test[f'age_{col}'] = rdf_test['age_enter'] * rdf_test[col]

# Split variables
x_test, y_test = split_var(rdf_test, method='continuous')

# Read coefficients from the Excel file
coef_df_men = pd.read_excel('coefficients.xlsx', sheet_name='coef_cox_men', index_col=0)
coef_df_women = pd.read_excel('coefficients.xlsx', sheet_name='coef_cox_women', index_col=0)

# Extract coefficients for men and women
coef_cox_men = coef_df_men.squeeze()
coef_cox_women = coef_df_women.squeeze()

# Baseline survival probabilities at 120 months
s0_cox_men = 0.9707463
s0_cox_women = 0.9847183

# Calculate linear predictors
lp_men = rdf_test[coef_cox_men.index] * coef_cox_men
lp_women = rdf_test[coef_cox_women.index] * coef_cox_women

rdf_test['lp_men'] = lp_men.sum(axis=1)
rdf_test['lp_women'] = lp_women.sum(axis=1)

# Calculate Cox risk
rdf_test['cox_risk'] = np.where(
    rdf_test['sex'] == 1,
    1 - s0_cox_men ** np.exp(rdf_test['lp_men']),
    1 - s0_cox_women ** np.exp(rdf_test['lp_women'])
)

## 2. Model recalibration

In [46]:
results_test = pd.DataFrame({"lp_men": rdf_test["lp_men"],
                            "lp_women": rdf_test["lp_women"],
                            "cox_risk": rdf_test["cox_risk"],
                            "surv_risk": surv_risk,
                            "hit_risk": hit_risk,
                             "sex": df_test["sex"],
                             "age": df_test["age_enter"],
                             "sbp": df_test["sbp"],
                             "tc": df_test["tc"],
                             "hdlc": df_test["hdlc"],
                             "base_dm": df_test["base_dm"],
                             "smoke_1": df_test["smoke_1"],
                             "smoke_2": df_test["smoke_2"],
                             "time": duration_test,
                             "outcome": events_test,
                             "index":df_test.index})

### 2.1 calculate rescaling factors

In [47]:
# We assume that the model is designed for individuals aged 40-74 without prior CVD, with an average 10-year CVD risk of 4% for the derived population.

# Now, the model is applied to the target population. 
# According to local annual disease surveillance reports, about 8 out of every 1,000 people aged 40 and above 
# are diagnosed with CVD each year.
# Therefore, P=0.8%, T=1,
# 𝑟=−𝑙𝑛⁡(1−0.8%)/1=0.008,
# 𝜃=1−𝑒𝑥𝑝⁡(−0.008∗10)=7.7%

# calculate the rescaling factors
index = np.log(-np.log(1 - 0.04)) - np.log(-np.log(1 - 0.077))
print(index)  # Output should be approximately -0.67

# Then risk could be rescaled as follows:
# risk_recal = 1 - np.exp(-np.exp(np.log(-np.log(1 - risk_ori)) - (-0.67)))
# print(risk_recal)  # Output should be approximately 0.18 or 18%

# calculate the recalirbated risks
results_test['cox_risk_recal'] = 1 - np.exp(-np.exp(np.log(-np.log(1 - results_test['cox_risk'])) - (index)))
results_test['surv_risk_recal'] = 1 - np.exp(-np.exp(np.log(-np.log(1 - results_test['surv_risk'])) - (index)))
results_test['hit_risk_recal'] = 1 - np.exp(-np.exp(np.log(-np.log(1 - results_test['hit_risk'])) - (index)))


-0.6743799332420188


In [48]:
results_test.head(10)

Unnamed: 0,lp_men,lp_women,cox_risk,surv_risk,hit_risk,sex,age,sbp,tc,hdlc,base_dm,smoke_1,smoke_2,time,outcome,index,cox_risk_recal,surv_risk_recal,hit_risk_recal
0,-1.954437,-1.760865,0.002644,0.004268,0.007554,0,49.156605,143,2.72,2.62,0,False,False,101.70685,0,0,0.005182,0.008359,0.014773
1,-0.14361,-0.021335,0.014962,0.013747,0.012127,0,50.507871,173,5.25,1.69,0,False,False,92.41096,0,1,0.029155,0.026805,0.023665
2,0.327414,0.59358,0.027495,0.028957,0.02189,0,61.969747,144,3.95,1.31,0,False,False,44.860274,0,2,0.053254,0.056044,0.042512
3,1.571877,1.661113,0.077882,0.10541,0.103759,0,64.835729,127,6.11,0.61,1,False,False,51.838356,1,3,0.14713,0.196387,0.193473
4,-0.4793,-0.061498,0.014377,0.013631,0.01071,0,47.194935,151,3.19,0.55,0,False,False,89.14247,0,4,0.028024,0.026579,0.020913
5,1.027182,1.122829,0.046229,0.041152,0.040644,0,53.902806,158,4.47,1.0,1,False,False,116.90959,0,5,0.088718,0.079172,0.078214
6,2.342578,3.648041,0.265831,0.190275,0.149587,1,77.894182,189,2.86,1.75,0,False,True,95.2274,0,6,0.454766,0.339179,0.272427
7,0.291005,0.480742,0.024598,0.045267,0.019716,0,59.983299,129,3.21,2.0,1,False,False,115.2274,0,7,0.047709,0.086914,0.038331
8,0.714729,0.960322,0.039434,0.031787,0.042923,0,57.44846,136,5.69,1.55,0,False,True,85.34247,0,8,0.075931,0.061437,0.082508
9,-0.493822,-0.064127,0.017956,0.042008,0.019225,1,55.44011,120,2.7,2.11,0,False,True,86.580821,0,9,0.03494,0.080785,0.037386


In [None]:
from lifelines import KaplanMeierFitter

# Divide the predicted risks into 5 groups
results_test['risk_group'] = pd.qcut(results_test['cox_risk_recal'], 5, labels=False)

# Calculate observed risk for each group using Kaplan-Meier estimates
kmf = KaplanMeierFitter()
observed_risks = []

for group in range(5):
    group_data = results_test[results_test['risk_group'] == group]
    kmf.fit(group_data['time'], event_observed=group_data['outcome'])
    observed_risks.append(1 - kmf.survival_function_at_times(120).values[0])

# Calculate mean predicted risk for each group
mean_predicted_risks = results_test.groupby('risk_group')['cox_risk'].mean()

# Calibration plot
plt.figure(figsize=(10, 6))
plt.plot([0, 0.6], [0, 0.6], 'k--', label='Ideal')
plt.scatter(mean_predicted_risks, observed_risks, alpha=0.5, label='Cox model')
plt.xlabel('Predicted risk')
plt.ylabel('Observed risk')
plt.title('Calibration plot')
plt.legend()
plt.show()