# Time-series modelling

# Imports

In [None]:
import random

import numpy as np
import plotly.graph_objects as go
import torch
import torch.nn as nn
import torch.optim as optim
from drift_detection.baseline_models.temporal.pytorch.optimizer import Optimizer
from drift_detection.baseline_models.temporal.pytorch.utils import (
    get_data,
    get_device,
    get_temporal_model,
)
from sklearn import metrics

from cyclops.models.util import metrics_binary
from cyclops.process.column_names import EVENT_NAME
from cyclops.utils.file import load_pickle
from use_cases.common.util import get_use_case_params

# Choose dataset and use-case

In [None]:
DATASET = "mimiciv"
USE_CASE = "mortality_decompensation"

use_case_params = get_use_case_params(DATASET, USE_CASE)
input(f"WARNING: LOADING CONSTANTS FROM {use_case_params}")

# Configuration

In [None]:
# Whether to use the combined data (tabular + temporal)
# or simply the temporal data
use_comb = True

batch_size = 64
output_dim = 1
hidden_dim = 64
layer_dim = 2
dropout = 0.2
n_epochs = 256
learning_rate = 2e-3
weight_decay = 1e-6
last_timestep_only = False

# Data

In [None]:
def prep(vec):
    arr = np.squeeze(vec.data, 0)
    arr = np.moveaxis(arr, 2, 0)
    arr = np.nan_to_num(arr)
    return arr

In [None]:
if use_comb:
    X_train_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_train_X")
    y_train_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_train_y")
    X_val_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_val_X")
    y_val_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_val_y")
    X_test_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_test_X")
    y_test_vec = load_pickle(use_case_params.TAB_VEC_COMB + "comb_test_y")
else:
    X_train_vec = load_pickle(use_case_params.TAB_VEC_COMB + "temp_train_X")
    y_train_vec = load_pickle(use_case_params.TAB_VEC_COMB + "temp_train_y")
    X_val_vec = load_pickle(use_case_params.TAB_VEC_COMB + "temp_val_X")
    y_val_vec = load_pickle(use_case_params.TAB_VEC_COMB + "temp_val_y")
    X_test_vec = load_pickle(use_case_params.TAB_VEC_COMB + "temp_test_X")
    y_test_vec = load_pickle(use_case_params.TAB_VEC_COMB + "temp_test_y")

X_train = prep(X_train_vec.data)
y_train = prep(y_train_vec.data)
X_val = prep(X_val_vec.data)
y_val = prep(y_val_vec.data)
X_test = prep(X_test_vec.data)
y_test = prep(y_test_vec.data)

In [None]:
X_train_vec.get_index(EVENT_NAME)

In [None]:
X_train.shape, y_train.shape

In [None]:
(y_train == 1).sum() / y_train.size

In [None]:
np.unique(y_train, return_counts=True)

In [None]:
X_val.shape, y_val.shape

In [None]:
(y_val == 1).sum() / y_val.size

In [None]:
np.unique(y_val, return_counts=True)

In [None]:
X_test.shape, y_test.shape

In [None]:
(y_test == 1).sum() / y_test.size

In [None]:
np.unique(y_test, return_counts=True)

In [None]:
assert np.isnan(X_train).sum() == 0
assert np.isnan(y_train).sum() == 0
assert np.isnan(X_val).sum() == 0
assert np.isnan(y_val).sum() == 0
assert np.isnan(X_test).sum() == 0
assert np.isnan(y_test).sum() == 0

In [None]:
train_dataset = get_data(X_train, y_train)
train_loader = train_dataset.to_loader(batch_size, shuffle=True)

val_dataset = get_data(X_val, y_val)
val_loader = val_dataset.to_loader(batch_size)

test_dataset = get_data(X_test, y_test)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1, shuffle=False)

In [None]:
n_features = X_train.shape[2]
timesteps = X_train.shape[1]

# Model

In [None]:
device = get_device()
device

In [None]:
model_params = {
    "device": device,
    "input_dim": n_features,
    "hidden_dim": hidden_dim,
    "layer_dim": layer_dim,
    "output_dim": output_dim,
    "dropout_prob": dropout,
    "last_timestep_only": last_timestep_only,
}

In [None]:
model = get_temporal_model("lstm", model_params).to(device)
model

# Training and validation

In [None]:
reweight_positive = (y_train == 0).sum() / (y_train == 1).sum()
reweight_positive

In [None]:
loss_fn = nn.BCEWithLogitsLoss(reduction="none")
optimizer = optim.Adagrad(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=128, gamma=0.5)
activation = nn.Sigmoid()
opt = Optimizer(
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    activation=activation,
    lr_scheduler=lr_scheduler,
    reweight_positive="mini-batch",
)

## Training

In [None]:
opt.train(
    train_loader,
    val_loader,
    n_epochs=n_epochs,
)
opt.plot_losses()

## Evaluation

In [None]:
y_test_labels, y_pred_values, y_pred_labels = opt.evaluate(
    test_loader, batch_size=1, n_features=n_features, timesteps=timesteps
)

y_pred_values = y_pred_values[y_test_labels != -1]
y_pred_labels = y_pred_labels[y_test_labels != -1]
y_test_labels = y_test_labels[y_test_labels != -1]

confusion_matrix = metrics.confusion_matrix(y_test_labels, y_pred_labels)
print(confusion_matrix)

pred_metrics = metrics_binary(y_test_labels, y_pred_values, y_pred_labels)
prec = (pred_metrics["prec0"] + pred_metrics["prec1"]) / 2
rec = (pred_metrics["rec0"] + pred_metrics["rec1"]) / 2
print(f"Precision: {prec}")
print(f"Recall: {rec}")

## Plot confusion matrix

In [None]:
def plot_confusion_matrix(confusion_matrix, class_names):
    confusion_matrix = (
        confusion_matrix.astype("float") / confusion_matrix.sum(axis=1)[:, np.newaxis]
    )

    layout = {
        "title": "Confusion Matrix",
        "xaxis": {"title": "Predicted value"},
        "yaxis": {"title": "Real value"},
    }

    fig = go.Figure(
        data=go.Heatmap(
            z=confusion_matrix,
            x=class_names,
            y=class_names,
            hoverongaps=False,
            colorscale="Greens",
        ),
        layout=layout,
    )
    fig.update_layout(height=512, width=1024)
    fig.show()


plot_confusion_matrix(
    confusion_matrix, ["low risk of mortality", "high risk of mortality"]
)

In [None]:
y_test_labels, y_pred_values, y_pred_labels = opt.evaluate(
    test_loader, batch_size=1, n_features=n_features, timesteps=timesteps, flatten=False
)

## Visualize model outputs and labels

In [None]:
def plot_risk_mortality(predictions, labels=None):
    prediction_hours = list(range(24, 168, 24))
    is_mortality = labels == 1
    after_discharge = labels == -1
    label_h = -0.2
    fig = go.Figure(
        data=[
            go.Scatter(
                mode="markers",
                x=prediction_hours,
                y=[label_h for x in prediction_hours],
                line=dict(color="Black"),
                name="low risk of mortality label",
                marker=dict(color="Green", size=20, line=dict(color="Black", width=2)),
            ),
            go.Scatter(
                mode="markers",
                x=[prediction_hours[i] for i, v in enumerate(is_mortality) if v],
                y=[label_h for _, v in enumerate(is_mortality) if v],
                line=dict(color="Red"),
                name="high risk of mortality label",
                marker=dict(color="Red", size=20, line=dict(color="Black", width=2)),
            ),
            go.Scatter(
                mode="markers",
                x=[prediction_hours[i] for i, v in enumerate(after_discharge) if v],
                y=[label_h for _, v in enumerate(after_discharge) if v],
                line=dict(color="Grey"),
                name="post discharge label",
                marker=dict(color="Grey", size=20, line=dict(color="Black", width=2)),
            ),
            go.Bar(
                x=prediction_hours,
                y=predictions,
                marker_color="Red",
                name="model confidence",
            ),
        ]
    )
    fig.update_yaxes(range=[label_h, 1])
    fig.update_xaxes(tickvals=prediction_hours)
    fig.update_xaxes(showline=True, linewidth=2, linecolor="black")

    fig.add_hline(y=0.5)

    fig.update_layout(
        title="Model output visualization",
        autosize=False,
        xaxis_title="No. of hours after admission",
        yaxis_title="Model confidence",
    )

    return fig


mortality_cases = list(range(y_test_labels.shape[1]))
sample_idx = random.choice(mortality_cases)
fig = plot_risk_mortality(
    y_pred_values[:, sample_idx].squeeze(), y_test_labels[:, sample_idx].squeeze()
)
fig.show()