## Train temporal models for mortality risk prediction

## Imports

In [None]:
import os
import random

# import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import plotly.graph_objects as go
import seaborn as sns

# import sys
import torch
import torch.nn as nn
import torch.optim as optim
from drift_detection.baseline_models.temporal.pytorch.optimizer import Optimizer
from drift_detection.baseline_models.temporal.pytorch.utils import (
    get_data,
    get_device,
    get_temporal_model,
    print_metrics_binary,
)
from drift_detection.gemini.utils import prep
from sklearn import metrics

# from cyclops.processors.column_names import EVENT_NAME
from cyclops.utils.file import load_pickle
from use_cases.common.util import get_use_case_params

## Load train/val/test inputs and labels

In [None]:
DATASET = "gemini"
USE_CASE = "mortality"

use_case_params = get_use_case_params(DATASET, USE_CASE)
input(f"WARNING: LOADING CONSTANTS FROM {use_case_params}")

In [None]:
X_train_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_train_X")
y_train_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_train_y")
X_val_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_val_X")
y_val_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_val_y")
X_test_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_test_X")
y_test_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_test_y")

X_train = prep(X_train_vec.data)
y_train = prep(y_train_vec.data)
X_val = prep(X_val_vec.data)
y_val = prep(y_val_vec.data)
X_test = prep(X_test_vec.data)
y_test = prep(y_test_vec.data)

In [None]:
unique, train_counts = np.unique(y_train, return_counts=True)
unique, val_counts = np.unique(y_val, return_counts=True)
unique, test_counts = np.unique(y_test, return_counts=True)
print(
    pd.DataFrame(
        {"Train": train_counts, "Val": val_counts, "Test": test_counts}, index=unique
    )
)

In [None]:
batch_size = 64
train_dataset = get_data(X_train, y_train)
train_loader = train_dataset.to_loader(batch_size, shuffle=True)

val_dataset = get_data(X_val, y_val)
val_loader = val_dataset.to_loader(batch_size)

test_dataset = get_data(X_test, y_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

## Model and training configuration

In [None]:
output_dim = 1
batch_size = 64
input_dim = X_train.shape[2]
timesteps = X_train.shape[1]
hidden_dim = 64
layer_dim = 2
dropout = 0.2
n_epochs = 256
learning_rate = 2e-3
weight_decay = 1e-6
last_timestep_only = False

device = get_device()

X_train_inputs = X_train
X_val_inputs = X_val
X_test_inputs = X_test

train_dataset = get_data(X_train_inputs, y_train)
train_loader = train_dataset.to_loader(batch_size, shuffle=True)

val_dataset = get_data(X_val_inputs, y_val)
val_loader = val_dataset.to_loader(batch_size)

model_params = {
    "device": device,
    "input_dim": input_dim,
    "hidden_dim": hidden_dim,
    "layer_dim": layer_dim,
    "output_dim": output_dim,
    "dropout_prob": dropout,
    "last_timestep_only": last_timestep_only,
}
split_type = None
model = get_temporal_model("lstm", model_params).to(device)
os.chdir(os.path.join(os.getcwd(), "../../saved_models"))
model_path = os.path.join(os.getcwd(), split_type + "_lstm.pt")
model.load_state_dict(torch.load(model_path))

## Training and validation

In [None]:
loss_fn = nn.BCEWithLogitsLoss(reduction="none")
optimizer = optim.Adagrad(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=128, gamma=0.5)
activation = nn.Sigmoid()
opt = Optimizer(
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    activation=activation,
    lr_scheduler=lr_scheduler,
)
opt.train(
    train_loader,
    val_loader,
    batch_size=batch_size,
    n_epochs=n_epochs,
    n_features=input_dim,
    timesteps=timesteps,
)
opt.plot_losses()

## Validation metrics

In [None]:
val_evaluate_loader = torch.utils.data.DataLoader(
    val_dataset, batch_size=1, shuffle=False
)
y_val_labels, y_val_pred_values, y_val_pred_labels = opt.evaluate(
    val_evaluate_loader, batch_size=1, n_features=input_dim, timesteps=timesteps
)

y_val_pred_values = y_val_pred_values[y_val_labels != -1]
y_val_pred_labels = y_val_pred_labels[y_val_labels != -1]
y_val_labels = y_val_labels[y_val_labels != -1]

confusion_matrix = metrics.confusion_matrix(y_val_labels, y_val_pred_labels)
print(confusion_matrix)

pred_metrics = print_metrics_binary(y_val_labels, y_val_pred_values, y_val_pred_labels)
prec = (pred_metrics["prec0"] + pred_metrics["prec1"]) / 2
rec = (pred_metrics["rec0"] + pred_metrics["rec1"]) / 2
print(f"Precision: {prec}")
print(f"Recall: {rec}")


def plot_pretty_confusion_matrix(confusion_matrix):
    sns.set(style="white")
    fig, ax = plt.subplots(figsize=(9, 6))
    sns.heatmap(
        np.eye(2),
        annot=confusion_matrix,
        fmt="g",
        annot_kws={"size": 50},
        cmap=sns.color_palette(["tomato", "palegreen"], as_cmap=True),
        cbar=False,
        yticklabels=["False", "True"],
        xticklabels=["False", "True"],
        ax=ax,
    )
    ax.xaxis.tick_top()
    ax.xaxis.set_label_position("top")
    ax.tick_params(labelsize=20, length=0)

    ax.set_title("Confusion Matrix for Test Set", size=24, pad=20)
    ax.set_xlabel("Predicted Values", size=20)
    ax.set_ylabel("Actual Values", size=20)

    additional_texts = [
        "(True Negative)",
        "(False Negative)",
        "(False Positive)",
        "(True Positive)",
    ]
    for text_elt, additional_text in zip(ax.texts, additional_texts):
        ax.text(
            *text_elt.get_position(),
            "\n" + additional_text,
            color=text_elt.get_color(),
            ha="center",
            va="top",
            size=24,
        )
    plt.tight_layout()
    plt.show()


plot_pretty_confusion_matrix(confusion_matrix)

## Testing metrics

In [None]:
test_dataset = get_data(X_test_inputs, y_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)
y_test_labels, y_pred_values, y_pred_labels = opt.evaluate(
    test_loader, batch_size=1, n_features=input_dim, timesteps=timesteps
)

y_pred_values = y_pred_values[y_test_labels != -1]
y_pred_labels = y_pred_labels[y_test_labels != -1]
y_test_labels = y_test_labels[y_test_labels != -1]

confusion_matrix = metrics.confusion_matrix(y_test_labels, y_pred_labels)
print(confusion_matrix)

pred_metrics = print_metrics_binary(y_test_labels, y_pred_values, y_pred_labels)
prec = (pred_metrics["prec0"] + pred_metrics["prec1"]) / 2
rec = (pred_metrics["rec0"] + pred_metrics["rec1"]) / 2
print(f"Precision: {prec}")
print(f"Recall: {rec}")

## Plot confusion matrix

In [None]:
def plot_confusion_matrix(confusion_matrix, class_names):
    confusion_matrix = (
        confusion_matrix.astype("float") / confusion_matrix.sum(axis=1)[:, np.newaxis]
    )

    layout = {
        "title": "Confusion Matrix",
        "xaxis": {"title": "Predicted value"},
        "yaxis": {"title": "Real value"},
    }

    fig = go.Figure(
        data=go.Heatmap(
            z=confusion_matrix,
            x=class_names,
            y=class_names,
            hoverongaps=False,
            colorscale="Greens",
        ),
        layout=layout,
    )
    fig.update_layout(height=512, width=1024)
    fig.show()


plot_confusion_matrix(
    confusion_matrix, ["low risk of mortality", "high risk of mortality"]
)

## Compute AUROC across timesteps

In [None]:
y_test_labels, y_pred_values, y_pred_labels = opt.evaluate(
    test_loader, batch_size=1, n_features=input_dim, timesteps=timesteps, flatten=False
)

num_timesteps = y_pred_labels.shape[1]
auroc_timesteps = []
for i in range(num_timesteps):
    labels = y_test_labels[:, i]
    pred_vals = y_pred_values[:, i]
    preds = y_pred_labels[:, i]
    pred_vals = pred_vals[labels != -1]
    preds = preds[labels != -1]
    labels = labels[labels != -1]
    pred_metrics = print_metrics_binary(labels, pred_vals, preds, verbose=False)
    auroc_timesteps.append(pred_metrics["auroc"])


prediction_hours = list(range(24, 168, 24))
fig = go.Figure(
    data=[go.Bar(x=prediction_hours, y=auroc_timesteps, name="model confidence")]
)

fig.update_xaxes(tickvals=prediction_hours)
fig.update_yaxes(range=[min(auroc_timesteps) - 0.05, max(auroc_timesteps) + 0.05])

fig.update_layout(
    title="AUROC split by no. of hours after admission",
    autosize=False,
    xaxis_title="No. of hours after admission",
)
fig.show()

## WIP: Compute accuracy across lead times

In [None]:
# BASE_DATA_PATH = "/mnt/nfs/project/delirium/drift_exp/risk_of_mortality"

# combined_events = load_dataframe(os.path.join(BASE_DATA_PATH, "combined_events"))
# timestep_end_timestamps = load_dataframe(os.path.join(BASE_DATA_PATH,
# "aggmeta_end_ts"))

# mortality_events = combined_events.loc[combined_events["event_name"] == "death"]

# y_test_labels, y_pred_values, y_pred_labels = opt.evaluate(
#     test_loader, batch_size=1, n_features=input_dim, timesteps=timesteps,
# flatten=False
# )
# train_val_test_ids = load_dataframe(os.path.join(BASE_DATA_PATH,
# "train_val_test_ids"))
# test_ids = train_val_test_ids["test"].dropna()

# num_timesteps = y_pred_labels.shape[1]
# acc_timesteps = []
# for timestep in range(num_timesteps):
#     labels = y_test_labels[:, timestep]
#     pred_vals = y_pred_values[:, timestep]
#     preds = y_pred_labels[:, timestep]

#     is_correct_timestep = []
#     for enc_id in test_ids:
#         timestep_end_timestamp = timestep_end_timestamps.loc[enc_id, timestep]
#         mortality_timestamp = mortality_events.loc[mortality_events["encounter_id"]
#         == enc_id]["discharge_timestamp"]
#         lead_time = mortality_timestamp - timestep_end_timestamp
#         print(timestep_end_timestamp, mortality_timestamp)
#         if (lead_time > pd.to_timedelta(0, unit="h")).all():
#             label_ = labels[test_ids.index(enc_id)]
#             pred_ = preds[test_ids.index(enc_id)]

#             if label_ == 1:
#                 if label_ == pred_:
#                     is_correct_timestep.append(1)
#                 else:
#                     is_correct_timestep.append(0)

#     acc_timesteps.append(sum(is_correct_timestep) / len(is_correct_timestep))

## Visualize model outputs and labels

In [None]:
def plot_risk_mortality(predictions, labels=None):
    prediction_hours = list(range(24, 168, 24))
    is_mortality = labels == 1
    after_discharge = labels == -1
    label_h = -0.2
    fig = go.Figure(
        data=[
            go.Scatter(
                mode="markers",
                x=prediction_hours,
                y=[label_h for x in prediction_hours],
                line=dict(color="Black"),
                name="low risk of mortality label",
                marker=dict(color="Green", size=20, line=dict(color="Black", width=2)),
            ),
            go.Scatter(
                mode="markers",
                x=[prediction_hours[i] for i, v in enumerate(is_mortality) if v],
                y=[label_h for _, v in enumerate(is_mortality) if v],
                line=dict(color="Red"),
                name="high risk of mortality label",
                marker=dict(color="Red", size=20, line=dict(color="Black", width=2)),
            ),
            go.Scatter(
                mode="markers",
                x=[prediction_hours[i] for i, v in enumerate(after_discharge) if v],
                y=[label_h for _, v in enumerate(after_discharge) if v],
                line=dict(color="Grey"),
                name="post discharge label",
                marker=dict(color="Grey", size=20, line=dict(color="Black", width=2)),
            ),
            go.Bar(
                x=prediction_hours,
                y=predictions,
                marker_color="Red",
                name="model confidence",
            ),
        ]
    )
    fig.update_yaxes(range=[label_h, 1])
    fig.update_xaxes(tickvals=prediction_hours)
    fig.update_xaxes(showline=True, linewidth=2, linecolor="black")

    fig.add_hline(y=0.5)

    fig.update_layout(
        title="Model output visualization",
        autosize=False,
        xaxis_title="No. of hours after admission",
        yaxis_title="Model confidence",
    )

    return fig


mortality_cases = [idx for idx, v in enumerate(y_test_labels)]
sample_idx = random.choice(mortality_cases)
fig = plot_risk_mortality(
    y_pred_values[sample_idx].squeeze(), y_test_labels[sample_idx]
)
fig.show()

## Journal of some experiments


<table>
    <thead>
        <tr>
            <th>Split</th>
            <th>Model</th>
            <th>AUROC</th>
        </tr>
    </thead>
    <tbody>
        <tr>
            <td rowspan=4>Random</td>
        </tr>
        <tr>
            <td>LSTM</td>
            <td><b>0.8005</b></td>
        </tr>
          <tr style="border-bottom:1px solid black">
            <td colspan="100%"></td>
          </tr>
          <tr> ... </tr>
    </tbody>
</table>