In [1]:
import h5py as h5
import matplotlib.pyplot as plt
import os
import pandas as pd
import numpy as np
from tqdm import tqdm
import pickle
import sys
sys.path.append(os.path.expanduser("~/TMPredictor/survival_tm/auton-survival"))
from auton_survival.preprocessing import Scaler
import optuna
from sklearn.model_selection import ParameterGrid
#sys.path.append('/projects/EKOLEMEN/survival_tm/train_models/auton-survival')
sys.path.append(os.path.expanduser("~/TMPredictor/survival_tm/auton-survival"))
from sklearn.model_selection import ParameterGrid
from auton_survival.estimators import SurvivalModel
from auton_survival.metrics import survival_regression_metric
from auton_survival.models.dsm import DeepSurvivalMachines
from sksurv.metrics import concordance_index_ipcw, brier_score, cumulative_dynamic_auc

  from .autonotebook import tqdm as notebook_tqdm


In [55]:
with open('data/x_train_future_normed.pkl', 'rb') as f:
    x_train_df = pickle.load(f)
with open('data/x_valid_future_normed.pkl', 'rb') as f:
    x_valid_df = pickle.load(f)
with open('data/outcomes_train_df.pkl', 'rb') as f:
    outcomes_train_df = pickle.load(f)
with open('data/outcomes_valid_df.pkl', 'rb') as f:
    outcomes_valid_df = pickle.load(f)

In [58]:
param_grid = {'k' : [3],
              'iters': [10],
              'distribution' : ['LogNormal'],
              'learning_rate' : [ 1e-3 ],
              'batch_size' : [10000],
              'layers' : [[100, 60, 175, 225, 120]]
             }

params = ParameterGrid(param_grid)
models=[]
for param in params:
    model = SurvivalModel(model='dsm', 
                      iters=param['iters'], 
                      k=param['k'], 
                      layers=param['layers'], 
                      distribution=param['distribution'],
                      learning_rate=param['learning_rate'], 
                      batch_size=param['batch_size']
                    )
    '''model = SurvivalModel(model='dsm', 
                          iters=param['iters'], 
                          k=param['k'], 
                          layers=param['layers'], 
                          distribution=param['distribution'],
                          learning_rate=param['learning_rate'], 
                          batch_size=param['batch_size']
                        )'''
    _, train_loss, val_loss, param = model.fit(x_train_df, outcomes_train_df)

    models.append([model, train_loss, val_loss, param])


 13%|█▎        | 1252/10000 [02:43<19:04,  7.65it/s]  
 60%|██████    | 6/10 [06:37<04:15, 63.98s/it]

In [32]:
with open('new_hyperparams.pkl', 'rb') as file:
    loaded_models = pickle.load(file)

In [39]:
arr = []
for model in loaded_models:
    nan_indices = np.argwhere(np.isnan(model[1]))
    first_nan_index = nan_indices[0,0] if nan_indices.size else None
    arr.append(first_nan_index)
print(arr)

[None, None]


In [None]:
prediction_times = [20, 50, 100, 200]
model = loaded_models[1][0]
out_survival = model.predict_survival(x_train_df, prediction_times)
peaks = find_peaks_in_data(np.array(outcomes_train_df['time']))
for i in range(0, 10):
    peak_number = 500 + i
    start_index = peaks[peak_number]
    end_index = peaks[peak_number + 1]
    times = np.arange(0, (end_index - start_index)*20, 20)
    plt.plot(times, out_survival[start_index:end_index,2], label='Survival in 100ms')
    plt.plot(times, out_survival[start_index:end_index,3], label='Survival in 200ms')

    if (np.array(outcomes_train_df['event'])[start_index]==1):
        plt.title('YES TM')
    else:
        plt.title('NO TM')
    plt.xlabel('Time / ms')
    plt.ylabel('Survival probability')
    plt.ylim(0.5, 1)
    plt.legend()
    plt.show()

In [43]:
# make WTC curve

# one issue with this metric is if we never predict TMs then we're always perfect. 
def fpr_auc(model, normed_x, normed_t, normed_e, prediction_times, threshold=0.7):
    out_survival = model.predict_survival(normed_x, prediction_times)
    fprs = []
    for i, time in enumerate(prediction_times):
        survival_prediction = out_survival[:,i]
        survival_prediction = (survival_prediction < threshold).astype(int)
        false_positives = np.logical_and(survival_prediction == 1, normed_e == 0)
        true_negatives = np.logical_and(survival_prediction == 0, normed_e == 0)
        fpr = false_positives.sum() / (false_positives.sum() + true_negatives.sum())
        fprs.append(fpr)

    auc = np.trapz(fprs, prediction_times)
    return auc, fprs, prediction_times

def find_peaks_in_data(data):
    peaks = []
    for i in range(1, len(data) - 1):
        if data[i-1] < data[i] > data[i+1]:
            peaks.append(i)
    return peaks

def fnr_auc(model, normed_x, normed_t, normed_e, prediction_times, threshold=0.7):
    out_survival = model.predict_survival(normed_x, prediction_times)
    fnrs = []
    shot_indices = find_peaks_in_data(normed_t)
    for i, time in enumerate(prediction_times):
        tm_prediction_per_shot = []
        # 1 means correct TM prediction, 0 means unpredicted TM, -1 means no TM in shot
        # a TM is predicted when the survival prediction is 0 at any point in the shot. Check if better results when TM is consecutive 0s
        for j, shot_index in enumerate(shot_indices):
            survival_prediction = out_survival[:,i]
            survival_prediction = (survival_prediction < threshold).astype(int)
            tm = (0 in survival_prediction[shot_indices[j]:shot_indices[j+1]])
            if normed_e[shot_index] == 1 and tm:
                tm_prediction_per_shot.append(1)
            elif normed_e[shot_index] == 1 and not tm:
                tm_prediction_per_shot.append(0)
            else:
                tm_prediction_per_shot.append(-1)
        fnr = tm_prediction_per_shot.count(0) / (tm_prediction_per_shot.count(1) + tm_prediction_per_shot.count(0))
        fnrs.append(fnr)
    auc = np.trapz(fnrs, prediction_times)
    return auc, fnrs, prediction_times

In [86]:
threshold = 0.7
fpr_auc, fprs, prediction_times = fpr_auc(models[0][3], x_test_df, t_test, e_test, prediction_times, threshold=threshold)
fnr_auc, fnrs, prediction_times = fnr_auc(models[0][3], x_test_df, t_test, e_test, prediction_times, threshold=threshold)

  return survival_predictions.sort_index(axis=0).interpolate().interpolate(method='bfill').T[times].values
  return survival_predictions.sort_index(axis=0).interpolate().interpolate(method='bfill').T[times].values


IndexError: list index out of range