In [None]:
import os
import torch as ch
import pandas as pd
import xgboost as xgb
import numpy as np
from tqdm import tqdm
from prediction.utils.utils import aggregate_features_over_time
from prediction.short_term_outcome_prediction.timeseries_decomposition import prepare_aggregate_dataset
from prediction.utils.scoring import precision, recall, specificity
from sklearn.metrics import accuracy_score, roc_auc_score, matthews_corrcoef

In [None]:
data_path = '/Users/jk1/temp/opsum_end/preprocessing/gsu_Extraction_20220815_prepro_08062024_083500/early_neurological_deterioration_train_data_splits/train_data_splits_early_neurological_deterioration_ts0.8_rs42_ns5.pth'
config_path = '/Users/jk1/Downloads/checkpoints_short_opsum_xgb_20240925_161559/xgb_best_model.csv'
model_path = '/Users/jk1/Downloads/checkpoints_short_opsum_xgb_20240925_161559/xgb_20240925_161559_cv_4.model'

In [None]:
best_config = pd.read_csv(config_path)

In [None]:
best_config['CV'][0]

In [None]:
splits = ch.load(os.path.join(data_path))

In [None]:
preprocessed_split = prepare_aggregate_dataset(splits[best_config['CV'][0]], rescale=True, target_time_to_outcome=6,
                                              mask_after_first_positive=True)

In [None]:
full_X_train, full_X_val, y_train, y_val = preprocessed_split

In [None]:
trained_xgb = xgb.XGBClassifier(learning_rate=best_config['learning_rate'][0], max_depth=best_config['max_depth'][0], n_estimators=best_config['n_estimators'][0],
                                        reg_lambda=best_config['reg_lambda'][0], alpha=best_config['alpha'][0])


In [None]:
trained_xgb.load_model(model_path)

### Evaluation as in gridsearch

In [None]:
model_y_val = trained_xgb.predict_proba(full_X_val)[:, 1].astype('float32')
model_y_pred_val = np.where(model_y_val > 0.5, 1, 0).astype('float32')
model_acc_val = accuracy_score(y_val, model_y_pred_val)
model_precision_val = precision(y_val, model_y_pred_val.astype(float)).numpy()
model_sn_val = recall(y_val, model_y_pred_val).numpy()
model_auc_val = roc_auc_score(y_val, model_y_val)
model_mcc_val = matthews_corrcoef(y_val, model_y_pred_val)
model_sp_val = specificity(y_val, model_y_pred_val).numpy()

model_y_train = trained_xgb.predict_proba(full_X_train)[:, 1].astype('float32')
model_y_pred_train = np.where(model_y_train > 0.5, 1, 0).astype('float32')
model_acc_train = accuracy_score(y_train, model_y_pred_train)
model_precision_train = precision(y_train, model_y_pred_train.astype(float)).numpy()
model_sn_train = recall(y_train, model_y_pred_train).numpy()
model_auc_train = roc_auc_score(y_train, model_y_train)
model_mcc_train = matthews_corrcoef(y_train, model_y_pred_train)
model_sp_train = specificity(y_train, model_y_pred_train).numpy()

In [None]:
# print the results in a table
results = pd.DataFrame({'train': [model_acc_train, model_precision_train, model_sn_train, model_sp_train, model_auc_train, model_mcc_train],
                        'val': [model_acc_val, model_precision_val, model_sn_val, model_sp_val, model_auc_val, model_mcc_val]},
                       index=['accuracy', 'precision', 'sensitivity', 'specificity', 'auc', 'mcc'])

In [None]:
results

### More realistic evaluation

In [None]:
raw_X_train, raw_X_val, raw_y_train, raw_y_val = splits[best_config['CV'][0]]

In [None]:
raw_X_train.shape

In [None]:
n_time_steps = 72

In [None]:
eval_n_time_steps_before_event = 6

In [None]:
pred_over_ts = []
for ts in tqdm(range(n_time_steps)):
    modified_time_steps = ts + 1

    X_val_with_first_n_ts = raw_X_val[:, 0:modified_time_steps, :]
    x_data = X_val_with_first_n_ts[:, :, :, -1].astype('float32')
    # aggregate features
    x_data, _ = aggregate_features_over_time(x_data, np.array([None]), moving_average=False)
    
    y_pred = trained_xgb.predict_proba(x_data)[:, 1].astype('float32')
    
    # reshape into (n_patients, n_time_steps)
    n_patients = X_val_with_first_n_ts.shape[0]
    y_pred = y_pred.reshape(n_patients, -1)
    # only keep last timestep prediction
    y_pred = y_pred[:, -1]
    
    pred_over_ts.append(np.squeeze(y_pred))

In [None]:
pred_over_ts_np = np.squeeze(pred_over_ts).T

In [None]:
pred_over_ts_np.shape

In [None]:
# construct y 
y_val_list = []
for cid in raw_X_val[:, 0, 0, 0]:
    if cid not in raw_y_val.case_admission_id.values:
        cid_y = np.zeros(n_time_steps)
    else:
        cid_event_ts = raw_y_val[raw_y_val.case_admission_id == cid].relative_sample_date_hourly_cat.values
        if cid_event_ts < (eval_n_time_steps_before_event + 1):
            # if the event occurs before a detection window, ignore the patient
            cid_y = np.array([])
        else:
            # let y be 0s until 6 hours before the event then stop the series
            cid_y = np.zeros(int(cid_event_ts) - eval_n_time_steps_before_event - 1)
            cid_y = np.append(cid_y, 1)
    
    y_val_list.append(cid_y)

In [None]:
# compute roc scores for each time step
from sklearn.metrics import roc_auc_score

roc_scores = []
for ts in range(n_time_steps):
    pts_idx = [i for i, y in enumerate(y_val_list) if len(y) > ts]
    y_true = np.array([y[ts] for y in y_val_list if len(y) > ts])
    y_pred = pred_over_ts_np[pts_idx, ts]
    if len(np.unique(y_true)) == 1:
        roc_scores.append(np.nan)
    else:
        roc_scores.append(roc_auc_score(y_true, y_pred))



In [None]:
np.nanmedian(roc_scores)


In [None]:
# plot roc scores over time
import matplotlib.pyplot as plt
import seaborn as sns

sns.scatterplot(x=range(1, n_time_steps + 1), y=roc_scores)
plt.title('ROC AUC over time')
plt.show()

In [None]:
from sklearn.metrics import average_precision_score

# compute auprc scores for each time step
auprc_scores = []
for ts in range(n_time_steps):
    pts_idx = [i for i, y in enumerate(y_val_list) if len(y) > ts]
    y_true = np.array([y[ts] for y in y_val_list if len(y) > ts])
    y_pred = pred_over_ts_np[pts_idx, ts]
    if len(np.unique(y_true)) == 1:
        auprc_scores.append(np.nan)
    else:
        # auprc_scores.append(binary_auprc(y_true, y_pred))
        auprc_scores.append(average_precision_score(y_true, y_pred))
np.nanmedian(auprc_scores)

In [None]:
# compute MCC scores for each time step
# matthews_corrcoef(y_val, model_y_pred_val)
mcc_scores = []
for ts in range(n_time_steps):
    pts_idx = [i for i, y in enumerate(y_val_list) if len(y) > ts]
    y_true = np.array([y[ts] for y in y_val_list if len(y) > ts])
    y_pred = pred_over_ts_np[pts_idx, ts]
    if len(np.unique(y_true)) == 1:
        mcc_scores.append(np.nan)
    else:
        mcc_scores.append(matthews_corrcoef(y_true, np.where(y_pred > 0.5, 1, 0)))

In [None]:
np.nanmedian(mcc_scores)

Plot the prediction over time for each patient

In [None]:
# for every patient plot the prediction over time
import matplotlib.pyplot as plt
import seaborn as sns

for i in range(pred_over_ts_np.shape[0]):
    cid = raw_X_val[i, 0, 0, 0]
    ax = sns.scatterplot(x=range(1, pred_over_ts_np.shape[1] + 1), y=pred_over_ts_np[i, :], hue=pred_over_ts_np[i, :] > 0.5)
    ax.set_ylim(0, 1)
    ax.set_title(f'Prediction over time for patient {cid}')
    
    if cid in raw_y_val.case_admission_id.values:
        ax.axvline(x=raw_y_val[raw_y_val.case_admission_id == cid].relative_sample_date_hourly_cat.values, color='red')
    
    plt.show()