In [None]:
import os
import numpy as np
from sklearn.calibration import calibration_curve
import matplotlib.pyplot as plt

In [None]:
model_weights_path= '/Users/jk1/temp/opsum_prediction_output/LSTM_72h_testing/3M_mRS02/2023_01_02_1057/test_LSTM_sigmoid_all_balanced_0.2_2_True_RMSprop_3M mRS 0-2_16_3/sigmoid_all_balanced_0.2_2_True_RMSprop_3M mRS 0-2_16_3.hdf5'
features_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_features_01012023_233050.csv'
labels_path = '/Users/jk1/temp/opsum_prepro_output/gsu_prepro_01012023_233050/preprocessed_outcomes_01012023_233050.csv'

In [None]:
outcome = '3M mRS 0-2'

In [None]:
model_name = os.path.basename(model_weights_path).split('.hdf5')[0]

model_config = {
    'activation': model_name.split('_')[0],
    'batch': model_name.split('_')[1],
    'data': model_name.split('_')[2],
    'dropout': float(model_name.split('_')[3]),
    'layers': int(model_name.split('_')[4]),
    'masking': model_name.split('_')[5],
    'optimizer': model_name.split('_')[6],
    'units': int(model_name.split('_')[8]),
    'cv_fold': int(model_name.split('_')[9])
}
# define constants
seed = 42
test_size = 0.20

## Initial calibration curve


Load data

In [None]:
from sklearn.model_selection import train_test_split
from prediction.utils.utils import check_data
from prediction.outcome_prediction.data_loading.data_formatting import format_to_2d_table_with_time, \
    link_patient_id_to_outcome, features_to_numpy

# load the dataset
X, y = format_to_2d_table_with_time(feature_df_path=features_path, outcome_df_path=labels_path,
                                    outcome=outcome)

n_time_steps = X.relative_sample_date_hourly_cat.max() + 1
n_channels = X.sample_label.unique().shape[0]

# test if data is corrupted
check_data(X)

"""
    SPLITTING DATA
    Splitting is done by patient id (and not admission id) as in case of the rare multiple admissions per patient there
    would be a risk of data leakage otherwise split 'pid' in TRAIN and TEST pid = unique patient_id
    """
# Reduce every patient to a single outcome (to avoid duplicates)
all_pids_with_outcome = link_patient_id_to_outcome(y, outcome)
pid_train, pid_test, y_pid_train, y_pid_test = train_test_split(all_pids_with_outcome.patient_id.tolist(),
                                                                all_pids_with_outcome.outcome.tolist(),
                                                                stratify=all_pids_with_outcome.outcome.tolist(),
                                                                test_size=test_size,
                                                                random_state=seed)

test_X_df = X[X.patient_id.isin(pid_test)]
test_y_df = y[y.patient_id.isin(pid_test)]

test_X_np = features_to_numpy(test_X_df,
                              ['case_admission_id', 'relative_sample_date_hourly_cat', 'sample_label', 'value'])
test_y_np = np.array([test_y_df[test_y_df.case_admission_id == cid].outcome.values[0] for cid in
                      test_X_np[:, 0, 0, 0]]).astype('float32')

# Remove the case_admission_id, sample_label, and time_step_label columns from the data
test_X_np = test_X_np[:, :, :, -1].astype('float32')

Load model

In [None]:
from prediction.outcome_prediction.LSTM.LSTM import lstm_generator
from prediction.utils.scoring import precision, recall, matthews

model = lstm_generator(x_time_shape=n_time_steps, x_channels_shape=n_channels, masking=model_config['masking'], n_units=model_config['units'],
                       activation=model_config['activation'], dropout=model_config['dropout'], n_layers=model_config['layers'])

model.compile(loss='binary_crossentropy', optimizer=model_config['optimizer'],
              metrics=['accuracy', precision, recall, matthews])

model.load_weights(model_weights_path)

Make predictions

In [None]:
# calculate overall model prediction
y_pred_test = model.predict(test_X_np)

In [None]:
from prediction.outcome_prediction.LSTM.calibration.calibration_visualisation_tools import plot_calibration_curve

plot_calibration_curve(test_y_np, y_pred_test, n_bins=10)
plt.title("Calibration curve")
plt.show()