In [3]:
from auton_survival.datasets import load_dataset
from auton_survival.preprocessing import Preprocessor
from auton_survival.models.dcm import DeepCoxMixtures
from sklearn.preprocessing import StandardScaler
import pandas as pd

In [10]:
data = pd.read_csv('flchain.csv')
data.head()

Unnamed: 0.1,Unnamed: 0,age,sex,sample.yr,kappa,lambda,flc.grp,creatinine,mgus,futime,death
0,0,97.0,0.0,1997,5.7,4.86,10,1.7,0.0,85.0,1.0
1,1,92.0,0.0,2000,0.87,0.683,1,0.9,0.0,1281.0,1.0
2,2,94.0,0.0,1997,4.36,3.85,10,1.4,0.0,69.0,1.0
3,3,92.0,0.0,1996,2.42,2.22,9,1.0,0.0,115.0,1.0
4,4,93.0,0.0,1996,1.32,1.69,6,1.1,0.0,1039.0,1.0


In [12]:
data.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 6524 entries, 0 to 6523
Data columns (total 11 columns):
 #   Column      Non-Null Count  Dtype  
---  ------      --------------  -----  
 0   Unnamed: 0  6524 non-null   int64  
 1   age         6524 non-null   float64
 2   sex         6524 non-null   float64
 3   sample.yr   6524 non-null   int64  
 4   kappa       6524 non-null   float64
 5   lambda      6524 non-null   float64
 6   flc.grp     6524 non-null   int64  
 7   creatinine  6524 non-null   float64
 8   mgus        6524 non-null   float64
 9   futime      6524 non-null   float64
 10  death       6524 non-null   float64
dtypes: float64(8), int64(3)
memory usage: 560.8 KB


In [15]:
data.describe()

Unnamed: 0.1,Unnamed: 0,age,sex,sample.yr,kappa,lambda,flc.grp,creatinine,mgus,futime,death
count,6524.0,6524.0,6524.0,6524.0,6524.0,6524.0,6524.0,6524.0,6524.0,6524.0,6524.0
mean,3261.5,65.057787,0.449418,1996.623237,1.451986,1.728203,5.53786,1.093516,0.014715,3647.502146,0.300736
std,1883.46091,10.682585,0.497473,1.416592,0.936699,1.074378,2.884023,0.416507,0.120418,1458.287949,0.458613
min,0.0,50.0,0.0,1995.0,0.01,0.0433,1.0,0.4,0.0,0.0,0.0
25%,1630.75,56.0,0.0,1996.0,0.96,1.21,3.0,0.9,0.0,2907.5,0.0
50%,3261.5,63.5,0.0,1996.0,1.28,1.52,6.0,1.0,0.0,4303.0,0.0
75%,4892.25,73.0,1.0,1997.0,1.7,1.95,8.0,1.2,0.0,4771.0,1.0
max,6523.0,101.0,1.0,2003.0,20.5,26.6,10.0,10.8,1.0,5166.0,1.0


In [13]:
outcomes = data.copy()
outcomes['event'] =  data['death']
outcomes['time'] = data['futime']
outcomes = outcomes[['event', 'time']]
outcomes.head()

Unnamed: 0,event,time
0,1.0,85.0
1,1.0,1281.0
2,1.0,69.0
3,1.0,115.0
4,1.0,1039.0


In [18]:
cat_feats = ['sex', 'mgus', 'flc.grp']
num_feats = ['age', 'sample.yr', 'kappa', 'lambda', 'creatinine']

features = data[cat_feats + num_feats]
features.head()

Unnamed: 0,sex,mgus,flc.grp,age,sample.yr,kappa,lambda,creatinine
0,0.0,0.0,10,97.0,1997,5.7,4.86,1.7
1,0.0,0.0,1,92.0,2000,0.87,0.683,0.9
2,0.0,0.0,10,94.0,1997,4.36,3.85,1.4
3,0.0,0.0,9,92.0,1996,2.42,2.22,1.0
4,0.0,0.0,6,93.0,1996,1.32,1.69,1.1


In [19]:
features = Preprocessor().fit_transform(features, cat_feats=cat_feats, num_feats=num_feats)
features.head()

Unnamed: 0,age,sample.yr,kappa,lambda,creatinine,sex_1.0,mgus_1.0,flc.grp_2,flc.grp_3,flc.grp_4,flc.grp_5,flc.grp_6,flc.grp_7,flc.grp_8,flc.grp_9,flc.grp_10
0,2.990349,0.265985,4.535439,2.91521,1.456232,0,0,0,0,0,0,0,0,0,0,1
1,2.522262,2.383906,-0.621364,-0.97292,-0.464653,0,0,0,0,0,0,0,0,0,0,0
2,2.709497,0.265985,3.104773,1.975059,0.7359,0,0,0,0,0,0,0,0,0,0,1
3,2.522262,-0.439989,1.033511,0.457785,-0.224542,0,0,0,0,0,0,0,0,0,1,0
4,2.61588,-0.439989,-0.140916,-0.035561,0.015568,0,0,0,0,0,0,1,0,0,0,0


In [21]:
import numpy as np

horizons = [0.25, 0.5, 0.75]
times = np.quantile(outcomes.time[outcomes.event==1], horizons).tolist()
times

[902.25, 2084.0, 3245.0]

In [23]:
x, t, e = features.values, outcomes.time.values, outcomes.event.values

n = len(x)

tr_size = int(n * 0.70)
vl_size = int(n * 0.10)
te_size = int(n * 0.20)

x_train, x_test, x_val = x[:tr_size], x[-te_size:], x[tr_size:tr_size+vl_size]
t_train, t_test, t_val = t[:tr_size], t[-te_size:], t[tr_size:tr_size+vl_size]
e_train, e_test, e_val = e[:tr_size], e[-te_size:], e[tr_size:tr_size+vl_size]

t = outcomes["time"]
e = outcomes["event"]
quantiles = [0.25, 0.5, 0.75]
quantiles = np.quantile(t[e == 1], quantiles)
quantiles

array([ 902.25, 2084.  , 3245.  ])

In [24]:
# Initialize the Deep Cox Mixtures model
model = DeepCoxMixtures(
    k=6,                   # number of latent clusters
    layers=[100],          # hidden layer dimensions
    gamma=1.0,             # regularization term
    smoothing_factor=1e-4, # baseline smoothing
    use_activation=False,  # linear transformation
    random_seed=42
)

In [25]:
model.fit(x_train, t_train, e_train, iters=100, val_data=(x_val, t_val, e_val))

  probs = gates+np.log(event_probs)
  probs = gates+np.log(event_probs)
  return spl(ts)**risks
  s0ts = (-risks)*(spl(ts)**(risks-1))
 40%|████      | 40/100 [00:28<00:42,  1.42it/s]


<auton_survival.models.dcm.DeepCoxMixtures at 0x1c998384670>

In [33]:
out_survival = model.predict_survival(x_test, times)
out_survival

array([[0.97681016, 0.9482042 , 0.91350657],
       [0.9721724 , 0.9368561 , 0.8925423 ],
       [0.9728453 , 0.938365  , 0.8953506 ],
       ...,
       [0.9841021 , 0.96589816, 0.9462062 ],
       [0.976333  , 0.9468227 , 0.9109912 ],
       [0.98237723, 0.9616073 , 0.9382817 ]], dtype=float32)

In [36]:
out_risk = 1 - out_survival
out_risk

array([[0.02318984, 0.05179578, 0.08649343],
       [0.02782762, 0.06314391, 0.1074577 ],
       [0.02715468, 0.06163502, 0.10464942],
       ...,
       [0.01589793, 0.03410184, 0.05379379],
       [0.02366698, 0.0531773 , 0.08900881],
       [0.01762277, 0.03839272, 0.06171829]], dtype=float32)

In [35]:
latent_z = model.predict_latent_z(x_test)
latent_z

array([[0.03062083, 0.85112166, 0.02779981, 0.03823591, 0.02758756,
        0.02463439],
       [0.04905907, 0.7854898 , 0.03905129, 0.05333647, 0.03902204,
        0.03404137],
       [0.0473644 , 0.79333496, 0.03719844, 0.05361987, 0.03794697,
        0.0305353 ],
       ...,
       [0.01002624, 0.95190287, 0.00868921, 0.01303166, 0.00878667,
        0.00756337],
       [0.03833099, 0.84153056, 0.02949509, 0.04103643, 0.02785478,
        0.02175217],
       [0.01731054, 0.92655474, 0.01357036, 0.02018164, 0.01223912,
        0.01014367]], dtype=float32)

In [28]:
from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc

In [37]:
cis = []
brs = []

et_train = np.array([(e_train[i], t_train[i]) for i in range(len(e_train))],
                 dtype = [('e', bool), ('t', float)])
et_test = np.array([(e_test[i], t_test[i]) for i in range(len(e_test))],
                 dtype = [('e', bool), ('t', float)])
et_val = np.array([(e_val[i], t_val[i]) for i in range(len(e_val))],
                 dtype = [('e', bool), ('t', float)])

for i, _ in enumerate(times):
    cis.append(concordance_index_ipcw(et_train, et_test, out_risk[:, i], times[i])[0])
brs.append(brier_score(et_train, et_test, out_survival, times)[1])
roc_auc = []
for i, _ in enumerate(times):
    roc_auc.append(cumulative_dynamic_auc(et_train, et_test, out_risk[:, i], times[i])[0])
for horizon in enumerate(horizons):
    print(f"For {horizon[1]} quantile,")
    print("TD Concordance Index:", cis[horizon[0]])
    print("Brier Score:", brs[0][horizon[0]])
    print("ROC AUC ", roc_auc[horizon[0]][0], "\n")

For 0.25 quantile,
TD Concordance Index: 0.5791622903677868
Brier Score: 0.021187988636670076
ROC AUC  0.5824401917051168 

For 0.5 quantile,
TD Concordance Index: 0.617581564692493
Brier Score: 0.0344572222517535
ROC AUC  0.6220365808391632 

For 0.75 quantile,
TD Concordance Index: 0.5859263544448308
Brier Score: 0.050610393501503724
ROC AUC  0.5918498225074984 

