In [1]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.preprocessing import StandardScaler
from sklearn.model_selection import train_test_split
from sklearn_pandas import DataFrameMapper
from sklearn.metrics import mean_squared_error, roc_auc_score
from scipy.integrate import trapz
import torch
import torchtuples as tt
from pycox.models import CoxCC
from pycox.evaluation import EvalSurv
from survival_evaluation import d_calibration, l1
import statistics
from pycox.models import LogisticHazard, PMF, DeepHitSingle, CoxPH, MTLR
from sksurv.metrics import cumulative_dynamic_auc

In [2]:
path = './datacsv/battery_signature_features.csv'
D = pd.read_csv(path)

x_cols = D.columns[-15:].tolist() 
battery_battery_name = ['sheet_cycle']
event_col = ['event']  
time_col = ['time']   

D = D[x_cols + event_col + time_col]

d = D.copy()
cols_standardize = d.columns.values.tolist()
n_exp = 5

if len(cols_standardize) > 4: 
    cols_standardize.pop(1)
    cols_standardize.pop(4)

CI = []
IBS = []
L1_hinge = []
AUC_scores = []

In [3]:
def train_val_test_stratified_split(df, stratify_colname='y', frac_train=0.6, frac_val=0.15, frac_test=0.25, random_state=None):
    if frac_train + frac_val + frac_test != 1.0:
        raise ValueError('fractions %f, %f, %f do not add up to 1.0' % (frac_train, frac_val, frac_test))
    if stratify_colname not in df.columns:
        raise ValueError('%s is not a column in the dataframe' % (stratify_colname))
    X = df # Contains all columns.
    y = df[[stratify_colname]] # Dataframe of just the column on which to stratify.
    df_train, df_temp, y_train, y_temp = train_test_split(X, y, stratify=y, test_size=(1.0 - frac_train), random_state=random_state)
    relative_frac_test = frac_test / (frac_val + frac_test)
    if relative_frac_test == 1.0:
        df_val, df_test, y_val, y_test = [], df_temp, [], y_temp
    else:
        df_val, df_test, y_val, y_test = train_test_split(df_temp, y_temp, stratify=y_temp, test_size=relative_frac_test, random_state=random_state)
    assert len(df) == len(df_train) + len(df_val) + len(df_test)
    return df_train, df_val, df_test

In [4]:
CI = []
IBS = []
L1_hinge = []
L1_margin = []
for i in range(n_exp):
    df_train, df_val, df_test = train_val_test_stratified_split(d, 'event', frac_train=0.8, frac_val=0.05, frac_test=0.15, random_state=10)
    standardize = [([col], StandardScaler()) for col in cols_standardize]
    leave = []
    x_mapper = DataFrameMapper(standardize + leave)
    x_train = x_mapper.fit_transform(df_train).astype('float32')
    x_val = x_mapper.transform(df_val).astype('float32')
    x_test = x_mapper.transform(df_test).astype('float32')

    get_target = lambda df: (df['time'].values, df['event'].values)
    num_durations = 10
    labtrans = MTLR.label_transform(num_durations)
    y_train = labtrans.fit_transform(*get_target(df_train))
    y_val = labtrans.transform(*get_target(df_val))

    #train = (x_train, y_train)
    val = (x_val, y_val)
    durations_test, events_test = get_target(df_test)

    in_features = x_train.shape[1]
    out_features = labtrans.out_features
    num_nodes = [32, 32]
    batch_norm = True
    dropout = 0.1

    net = tt.practical.MLPVanilla(in_features, num_nodes, out_features, batch_norm, dropout)


    model = MTLR(net,tt.optim.Adam, duration_index=labtrans.cuts)

    batch_size = 256
    lr_finder = model.lr_finder(x_train, y_train, batch_size, tolerance=6)
    #_ = lr_finder.plot()
    lr_finder.get_best_lr()
    model.optimizer.set_lr(0.01)

    epochs = 300
    callbacks = [tt.cb.EarlyStopping()]
    model.fit(x_train, y_train, batch_size, epochs, callbacks, val_data=val, verbose=0)

    surv = model.interpolate(10).predict_surv_df(x_test)
    print(surv)

    surv_df = pd.DataFrame(surv)

    surv_df.index.name = 'time'
    surv_df.columns.name = 'survival_function'

    surv_df.to_csv("./datacsv/test-data/discharge_MTLR.csv", index=True)

    #surv.iloc[:, :5].plot(drawstyle='steps-post')
    #plt.ylabel('S(t | x)')
    #_ = plt.xlabel('Time')

    surv = model.interpolate(10).predict_surv_df(x_test)
    survival_predictions = pd.Series(trapz(surv.values.T, surv.index), index=df_test.index)
    l1_hinge = l1(df_test.time, df_test.event, survival_predictions, l1_type = 'hinge')
    l1_margin = l1(df_test.time, df_test.event, survival_predictions, df_train.time, df_train.event, l1_type = 'margin')
    #surv.iloc[:, :5].plot(drawstyle='steps-post')
    #plt.ylabel('S(t | x)')
    #_ = plt.xlabel('Time')

    ev = EvalSurv(surv, durations_test, events_test, censor_surv='km')
    c_index = ev.concordance_td('antolini')
    time_grid = np.linspace(durations_test.min(), durations_test.max(), 100)
    #ev.brier_score(time_grid).plot()
    brier = ev.integrated_brier_score(time_grid)
    #plt.ylabel('Brier score')
    #_ = plt.xlabel('Time')

    quantiles = np.sort(df_test['time'].unique())

    labels_train = np.array([(e, t) for e, t in zip(df_train['event'], df_train['time'])], dtype=[('event', 'bool'), ('time', 'float')])
    labels_test = np.array([(e, t) for e, t in zip(df_test['event'], df_test['time'])], dtype=[('event', 'bool'), ('time', 'float')])

    time_grid_train = np.unique(df_train['time'])
    
    auc_scores = []
    for eval_time in quantiles: 
        try:
            interp_time_index = np.argmin(np.abs(eval_time - surv.index.values))
            surv_values_at_eval_time = surv.iloc[interp_time_index].values
            estimated_risks = 1 - surv_values_at_eval_time


            if np.min(estimated_risks) == np.max(estimated_risks):
                continue  

            auc = cumulative_dynamic_auc(labels_train, labels_test, estimated_risks, times=[eval_time])[0][0]
            if not np.isnan(auc) and not np.isinf(auc):
                auc_scores.append(auc)
        except Exception as e:
            print(f"AUC calculation failed: {e}, eval_time={eval_time}")


    AUC_scores.append(np.mean(auc_scores) if auc_scores else 0.5) 

    CI.append(c_index)
    IBS.append(brier)
    L1_hinge.append(l1_hinge)
    L1_margin.append(l1_margin)

def safe_stat(data):
    return round(statistics.mean(data), 3), round(statistics.stdev(data), 3) if len(data) > 1 else (0.0, 0.0)
auc_mean, auc_std = safe_stat(AUC_scores) 

print('CI:', round(statistics.mean(CI), 3), round(statistics.stdev(CI), 3))
print('IBS:', round(statistics.mean(IBS), 3), round(statistics.stdev(IBS), 3))
print('L1_hinge:', round(statistics.mean(L1_hinge), 3), round(statistics.stdev(L1_hinge), 3))
print('L1_margin:', round(statistics.mean(L1_margin), 3), round(statistics.stdev(L1_margin), 3))
print(f'AUC: {auc_mean} ± {auc_std}')

print('d_calibration_p_value:', round(d_calibration(df_test.event, surv.iloc[6])['p_value'], 3))
print('d_calibration_bin_proportions:')
for i in d_calibration(df_test.event, surv.iloc[6])['bin_proportions']:
    print(i)
print('d_calibration_censored_contributions:')
for i in d_calibration(df_test.event, surv.iloc[6])['censored_contributions']:
    print(i)
print('d_calibration_uncensored_contributions:')
for i in d_calibration(df_test.event, surv.iloc[6])['uncensored_contributions']:
    print(i)

                 0         1         2         3         4         5   \
0.0000     0.999973  0.999963  0.999989  0.999998  0.999996  0.999983   
41.0026    0.999972  0.999957  0.999987  0.999998  0.999996  0.999980   
82.0052    0.999970  0.999952  0.999986  0.999998  0.999995  0.999977   
123.0078   0.999968  0.999946  0.999985  0.999998  0.999994  0.999974   
164.0104   0.999967  0.999941  0.999983  0.999998  0.999994  0.999971   
...             ...       ...       ...       ...       ...       ...   
3526.2236  0.000822  0.384707  0.000006  0.400387  0.000002  0.397833   
3567.2262  0.000619  0.289633  0.000006  0.300522  0.000002  0.299150   
3608.2288  0.000417  0.194559  0.000006  0.200656  0.000002  0.200466   
3649.2314  0.000215  0.099485  0.000006  0.100791  0.000002  0.101783   
3690.2340  0.000013  0.004411  0.000006  0.000926  0.000002  0.003100   

                 6         7         8         9   ...        86        87  \
0.0000     0.999966  0.999838  0.999954  0.99

  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]


AUC calculation failed: all times must be within follow-up time of test data: [2211.953; 3489.094[, eval_time=3489.094
                 0         1         2         3         4         5   \
0.0000     0.999995  0.999957  0.999978  0.999998  0.999974  0.999983   
41.0026    0.999995  0.999954  0.999976  0.999998  0.999972  0.999981   
82.0052    0.999994  0.999951  0.999974  0.999998  0.999971  0.999980   
123.0078   0.999994  0.999948  0.999972  0.999998  0.999969  0.999979   
164.0104   0.999994  0.999945  0.999971  0.999998  0.999967  0.999977   
...             ...       ...       ...       ...       ...       ...   
3526.2236  0.001611  0.354162  0.000303  0.383504  0.000275  0.397646   
3567.2262  0.001211  0.265972  0.000302  0.287777  0.000275  0.298893   
3608.2288  0.000812  0.177782  0.000301  0.192050  0.000274  0.200141   
3649.2314  0.000412  0.089593  0.000301  0.096323  0.000274  0.101388   
3690.2340  0.000012  0.001403  0.000300  0.000596  0.000273  0.002636   

    

  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]


AUC calculation failed: all times must be within follow-up time of test data: [2211.953; 3489.094[, eval_time=3489.094
                 0         1         2         3         4         5   \
0.0000     0.999990  0.999932  0.999979  1.000000  0.999977  0.999973   
41.0026    0.999989  0.999927  0.999977  0.999999  0.999974  0.999971   
82.0052    0.999988  0.999923  0.999975  0.999999  0.999972  0.999969   
123.0078   0.999988  0.999918  0.999973  0.999999  0.999970  0.999968   
164.0104   0.999987  0.999913  0.999971  0.999999  0.999967  0.999966   
...             ...       ...       ...       ...       ...       ...   
3526.2236  0.000549  0.383673  0.000070  0.400554  0.000067  0.398943   
3567.2262  0.000412  0.288128  0.000069  0.300660  0.000067  0.299585   
3608.2288  0.000275  0.192583  0.000069  0.200765  0.000067  0.200226   
3649.2314  0.000138  0.097038  0.000069  0.100870  0.000067  0.100868   
3690.2340  0.000001  0.001492  0.000068  0.000975  0.000067  0.001510   

    

  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]


AUC calculation failed: all times must be within follow-up time of test data: [2211.953; 3489.094[, eval_time=3489.094
                 0         1         2         3         4         5   \
0.0000     0.999981  0.999990  0.999993  0.999972  0.999980  0.999991   
41.0026    0.999979  0.999989  0.999992  0.999968  0.999977  0.999989   
82.0052    0.999977  0.999988  0.999991  0.999965  0.999974  0.999988   
123.0078   0.999975  0.999986  0.999990  0.999962  0.999971  0.999987   
164.0104   0.999973  0.999985  0.999989  0.999959  0.999968  0.999986   
...             ...       ...       ...       ...       ...       ...   
3526.2236  0.001311  0.388396  0.000120  0.379285  0.000278  0.394837   
3567.2262  0.000985  0.291891  0.000117  0.284845  0.000273  0.296898   
3608.2288  0.000658  0.195386  0.000114  0.190405  0.000269  0.198959   
3649.2314  0.000332  0.098881  0.000111  0.095964  0.000265  0.101020   
3690.2340  0.000005  0.002375  0.000108  0.001524  0.000261  0.003081   

    

  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]


AUC calculation failed: all times must be within follow-up time of test data: [2211.953; 3489.094[, eval_time=3489.094
                 0         1         2         3         4         5   \
0.0000     0.999998  0.999997  0.999993  0.999999  0.999987  0.999997   
41.0026    0.999997  0.999997  0.999992  0.999999  0.999985  0.999997   
82.0052    0.999997  0.999996  0.999991  0.999999  0.999984  0.999997   
123.0078   0.999997  0.999996  0.999991  0.999999  0.999982  0.999996   
164.0104   0.999997  0.999996  0.999990  0.999998  0.999980  0.999996   
...             ...       ...       ...       ...       ...       ...   
3526.2236  0.001309  0.398266  0.000140  0.400212  0.000330  0.399950   
3567.2262  0.000983  0.299129  0.000140  0.300260  0.000329  0.300340   
3608.2288  0.000657  0.199992  0.000139  0.200307  0.000328  0.200730   
3649.2314  0.000331  0.100855  0.000138  0.100355  0.000327  0.101120   
3690.2340  0.000005  0.001718  0.000138  0.000402  0.000327  0.001510   

    

  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]
  true_pos = cumsum_tp / cumsum_tp[-1]


AUC calculation failed: all times must be within follow-up time of test data: [2211.953; 3489.094[, eval_time=3489.094
CI: 0.955 0.013
IBS: 0.048 0.0
L1_hinge: 70.608 2.08
L1_margin: 88.067 2.593
AUC: 0.982 ± 0.009
d_calibration_p_value: 0.0
d_calibration_bin_proportions:
0.01979169249534607
0.01979169249534607
0.01979169249534607
0.01979169249534607
0.01979169249534607
0.01979169249534607
0.01979169249534607
0.01979169249534607
0.01979169249534607
0.8218747776001692
d_calibration_censored_contributions:
0.01979169249534607
0.01979169249534607
0.01979169249534607
0.01979169249534607
0.01979169249534607
0.01979169249534607
0.01979169249534607
0.01979169249534607
0.01979169249534607
0.019791444266835843
d_calibration_uncensored_contributions:
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.0
0.8020833333333334
