In [70]:
from dsm import datasets, DeepRecurrentSurvivalMachines

In [71]:
import numpy as np

In [72]:
x, t, e = datasets.load_dataset('PBC', sequential=True)

In [74]:
x, t, e = np.array(x), np.array(t), np.array(e)

In [75]:
times = np.quantile(unrollt(t), [0.25, .5, 0.75])

In [76]:
times

array([2.1054649 , 4.57507392, 7.15967583])

In [77]:
folds = np.array((list(range(4))*1000)[:len(x)])

In [78]:
def unrollx(data):
     return np.vstack([dat for dat in data])

In [83]:
def unrollt(data):
     return np.concatenate([dat for dat in data])

In [84]:
unrollt(t)


array([ 1.0951703 ,  0.56948856, 14.15233819, ...,  2.92136677,
        1.86726536,  1.04588764])

In [100]:
from lifelines import CoxPHFitter

In [101]:
import pandas as pd

In [102]:
def convert_to_data_frame(x, t, e):

  df = pd.DataFrame(data=x, columns=['X' + str(i) for i in range(x.shape[1])])
  df['T'] = pd.DataFrame(data=t.reshape(-1, 1), columns=['T'])
  df['E'] = pd.DataFrame(data=e.reshape(-1, 1), columns=['E'])

  return df

In [117]:
from sksurv.metrics import concordance_index_ipcw, brier_score

In [126]:
cis = []
brs = []

for fold in set(folds):  
    
    x_tr, t_tr, e_tr = x[folds!=fold], t[folds!=fold], e[folds!=fold]
    x_te, t_te, e_te = x[folds==fold], t[folds==fold], e[folds==fold]
    
    x_tr = unrollx(x_tr)
    t_tr = unrollt(t_tr)
    e_tr = unrollt(e_tr)

    x_te = unrollx(x_te)
    t_te = unrollt(t_te)
    e_te = unrollt(e_te)
    
    df_tr = convert_to_data_frame(x_tr, t_tr, e_tr)

    model = CoxPHFitter(penalizer=1e-3).fit(df_tr, duration_col='T', event_col='E')
    
    preds = model.predict_survival_function(x_te, times).T.values
    
    et_tr = np.array([(e_tr[i], t_tr[i]) for i in range(len(e_tr))],
                 dtype=[('e', bool), ('t', int)])
    et_te = np.array([(e_te[i], t_te[i]) for i in range(len(e_te))],
                 dtype=[('e', bool), ('t', int)])
    
    print (preds.shape)
    
    cis_ = []
    for i in range(len(times)):
        cis_.append(concordance_index_ipcw(et_tr, et_te, 1-preds[:,i], times[i])[0])
    cis.append(cis_)
    
    brs.append(brier_score(et_tr, et_te, preds, times )[1])


(478, 3)
(505, 3)
(457, 3)
(505, 3)


In [127]:
cis

[[0.8622349652479977, 0.8496178660113483, 0.8359660027234909],
 [0.9012252214624602, 0.8471243525992452, 0.7944140040816455],
 [0.8467573397188136, 0.8212682235979585, 0.7283492496826471],
 [0.896278572134772, 0.8628305222741215, 0.7505466410210001]]

In [114]:
model.predict_survival_function(x_te, times).T.values

array([[0.60925486, 0.23441728, 0.04069175],
       [0.55294913, 0.17648165, 0.0217467 ],
       [0.74530162, 0.42290859, 0.14965458],
       ...,
       [0.9551488 , 0.87429316, 0.7434154 ],
       [0.74180704, 0.41712966, 0.14517828],
       [0.3992932 , 0.06804004, 0.00265338]])

In [91]:
[risk.y[3] for risk in out_risk]

[0.3126132634036184,
 0.46647488732207254,
 0.5624185126209059,
 0.4059503413491215,
 0.4415347567963127,
 0.2670742533214408,
 0.06359725383083349,
 0.9349522484067815,
 0.9486941636172256,
 0.9522407447171799,
 0.9103188179868816,
 0.8824124615716761,
 0.7801756137064368,
 0.752123603797489,
 0.5288001066166764,
 0.7803207243582894,
 0.46932603202272666,
 0.9289505993097185,
 0.9355269013394406,
 0.9575783226069599,
 0.9381858563916292,
 0.9595487184435972,
 0.9310981298806621,
 0.9326698824044939,
 0.9011583356828196,
 0.9229446679118422,
 0.8835080122599095,
 0.900997315998892,
 0.8925830090019774,
 0.8831209397700821,
 0.8005351813156036,
 0.7216661520824456,
 0.7599841537178933,
 0.00015928860304218303,
 0.7890039749760369,
 0.9350205039206175,
 0.9072527188714308,
 0.9755350533464282,
 0.9857568595759039,
 0.9678514283741088,
 0.9640084861264626,
 0.7937728360920354,
 0.7358192205521833,
 0.8280615261984754,
 0.6155861711521265,
 0.765769361328416,
 0.06586886769950233,
 0.16902