### Compare performance of local vs global models

In [None]:
# sys.path.append("../..")
import datetime
import os
import random

# import sys
from datetime import date

from baseline_models.temporal.pytorch.utils import get_device, load_ckp

# from drift_detector.rolling_window import RollingWindow
from drift_detector.utils import get_serving_data, get_temporal_model
from gemini.query import get_gemini_data
from gemini.utils import get_label, import_dataset_hospital, normalize, process, scale

In [None]:
PATH = "/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/"
TIMESTEPS = 6
AGGREGATION_TYPE = "time"
HOSPITALS = ["SMH", "MSH", "THPC", "THPM", "UHNTG", "UHNTW", "PMH", "SBK"]
OUTCOME = "mortality"
THRESHOLD = 0.05
NUM_TIMESTEPS = 6
STAT_WINDOW = 30
LOOKUP_WINDOW = 0
STRIDE = 1

SHIFT = input("Select experiment: ")  # hospital_type
MODEL_PATH = os.path.join(PATH, "saved_models", SHIFT + "_lstm.pt")

if SHIFT == "simulated_deployment":
    exp_params = {
        "source": [datetime.date(2015, 1, 1), datetime.date(2019, 1, 1)],
        "target": [datetime.date(2019, 1, 1), datetime.date(2020, 8, 1)],
        "shift_type": "source_target",
    }

if SHIFT == "covid":
    exp_params = {
        "source": [datetime.date(2019, 1, 1), datetime.date(2020, 2, 1)],
        "target": [datetime.date(2020, 3, 1), datetime.date(2020, 8, 1)],
        "shift_type": "time",
    }

if SHIFT == "seasonal_summer":
    exp_params = {
        "source": [1, 2, 3, 4, 5, 10, 11, 12],
        "target": [6, 7, 8, 9],
        "shift_type": "month",
    }

if SHIFT == "seasonal_winter":
    exp_params = {
        "source": [3, 4, 5, 6, 7, 8, 9, 10],
        "target": [11, 12, 1, 2],
        "shift_type": "month",
    }

## Get data

In [None]:
admin_data, x, y = get_gemini_data(PATH)

In [None]:
random.seed(1)

(X_tr, y_tr), (X_val, y_val), (X_t, y_t), feats, admin_data = import_dataset_hospital(
    admin_data, x, y, SHIFT, OUTCOME, HOSPITALS,
)

# Normalize data
X_tr_normalized = normalize(admin_data, X_tr, AGGREGATION_TYPE, TIMESTEPS)
X_val_normalized = normalize(admin_data, X_val, AGGREGATION_TYPE, TIMESTEPS)
X_t_normalized = normalize(admin_data, X_t, AGGREGATION_TYPE, TIMESTEPS)

if AGGREGATION_TYPE != "time":
    y_tr = get_label(admin_data, X_tr, OUTCOME)
    y_val = get_label(admin_data, X_val, OUTCOME)
    y_t = get_label(admin_data, X_t, OUTCOME)

# Scale data
X_tr_scaled = scale(X_tr_normalized)
X_val_scaled = scale(X_val_normalized)
X_t_scaled = scale(X_t_normalized)

# Process data
X_tr_final = process(X_tr_scaled, AGGREGATION_TYPE, TIMESTEPS)
X_val_final = process(X_val_scaled, AGGREGATION_TYPE, TIMESTEPS)
X_t_final = process(X_t_scaled, AGGREGATION_TYPE, TIMESTEPS)

## Create data streams

In [None]:
START_DATE = date(2019, 1, 1)
END_DATE = date(2020, 8, 1)

In [None]:
print("Get target data streams...")
data_streams = get_serving_data(
    x,
    y,
    admin_data,
    START_DATE,
    END_DATE,
    stride=1,
    window=1,
    encounter_id="encounter_id",
    admit_timestamp="admit_timestamp",
)

In [None]:
output_dim = 1
input_dim = 108
hidden_dim = 64
layer_dim = 2
dropout = 0.2
last_timestep_only = False
device = get_device()

model_params = {
    "device": device,
    "input_dim": input_dim,
    "hidden_dim": hidden_dim,
    "layer_dim": layer_dim,
    "output_dim": output_dim,
    "dropout_prob": dropout,
    "last_timestep_only": last_timestep_only,
}

model = get_temporal_model("lstm", model_params).to(device)
model, optimizer, n_epochs = load_ckp(MODEL_PATH, model)