In [None]:
import datetime
import os
import random
from datetime import date

import matplotlib.pyplot as plt
import numpy as np
from baseline_models.temporal.pytorch.utils import get_device, load_ckp
from drift_detector.detector import Detector
from drift_detector.reductor import Reductor
from drift_detector.rolling_window import RollingWindow
from drift_detector.tester import TSTester
from drift_detector.utils import get_serving_data, get_temporal_model
from gemini.query import get_gemini_data
from gemini.utils import get_label, import_dataset_hospital, normalize, process, scale
from matplotlib.colors import ListedColormap

## Config parameters

In [None]:
PATH = "/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/"
TIMESTEPS = 6
AGGREGATION_TYPE = "time"
HOSPITALS = ["SMH", "MSH", "THPC", "THPM", "UHNTG", "UHNTW", "PMH", "SBK"]
ACADEMIC = ["MSH", "PMH", "SMH", "UHNTW", "UHNTG", "PMH", "SBK"]
COMMUNITY = ["THPC", "THPM"]
OUTCOME = "mortality"
THRESHOLD = 0.05
NUM_TIMESTEPS = 6
STAT_WINDOW = 30
LOOKUP_WINDOW = 0
STRIDE = 1

SHIFT = input("Select experiment: ")  # hospital_type
MODEL_PATH = os.path.join(PATH, "saved_models", SHIFT + "_lstm.pt")

if SHIFT == "hosp_type_academic":
    exp_params = {
        "source": ACADEMIC,
        "target": COMMUNITY,
        "shift_type": "hospital_type",
    }

if SHIFT == "hosp_type_community":
    exp_params = {
        "source": COMMUNITY,
        "target": ACADEMIC,
        "shift_type": "hospital_type",
    }

## Get data

In [None]:
admin_data, x, y = get_gemini_data(PATH)

## Get prediction model

In [None]:
output_dim = 1
batch_size = 64
input_dim = 108
timesteps = 6
hidden_dim = 64
layer_dim = 2
dropout = 0.2
last_timestep_only = False

device = get_device()

model_params = {
    "device": device,
    "input_dim": input_dim,
    "hidden_dim": hidden_dim,
    "layer_dim": layer_dim,
    "output_dim": output_dim,
    "dropout_prob": dropout,
    "last_timestep_only": last_timestep_only,
}

model = get_temporal_model("lstm", model_params).to(device)
model, optimizer, n_epochs = load_ckp(MODEL_PATH, model)

## Rolling window

In [None]:
DR_TECHNIQUE = "BBSDs_trained_LSTM"
MD_TEST = "mmd"
SAMPLE = 1000
CONTEXT_TYPE = "lstm"
PROJ_TYPE = "lstm"
START_DATE = date(2019, 1, 1)
END_DATE = date(2020, 8, 1)

## Hospital type experiment over time

In [None]:
# Set constant reference distribution
random.seed(1)
print("Query data %s ..." % SHIFT)

(X_tr, y_tr), (X_val, y_val), (X_t, y_t), feats, admin_data = import_dataset_hospital(
    admin_data,
    x,
    y,
    SHIFT,
    OUTCOME,
    HOSPITALS,
)

# Normalize data
X_tr_normalized = normalize(admin_data, X_tr, AGGREGATION_TYPE, TIMESTEPS)
X_val_normalized = normalize(admin_data, X_val, AGGREGATION_TYPE, TIMESTEPS)
X_t_normalized = normalize(admin_data, X_t, AGGREGATION_TYPE, TIMESTEPS)

# Get labels
y_tr = get_label(admin_data, X_tr, OUTCOME)
y_val = get_label(admin_data, X_val, OUTCOME)
y_t = get_label(admin_data, X_t, OUTCOME)

# Scale data
X_tr_scaled = scale(X_tr_normalized)
X_val_scaled = scale(X_val_normalized)
X_t_scaled = scale(X_t_normalized)

# Process data
X_tr_final = process(X_tr_scaled, AGGREGATION_TYPE, TIMESTEPS)
X_val_final = process(X_val_scaled, AGGREGATION_TYPE, TIMESTEPS)
X_t_final = process(X_t_scaled, AGGREGATION_TYPE, TIMESTEPS)

train_ids = list(X_tr_normalized.index.get_level_values(0).unique())
val_ids = list(X_val_normalized.index.get_level_values(0).unique())
exclude_ids = train_ids + val_ids

print("Get target data streams...")
data_streams = get_serving_data(
    x,
    y,
    admin_data,
    START_DATE,
    END_DATE,
    stride=1,
    window=1,
    ids_to_exclude=exclude_ids,
    encounter_id="encounter_id",
    admit_timestamp="admit_timestamp",
)

print("Get Shift Reductor...")
reductor = Reductor(
    dr_method=DR_TECHNIQUE,
    model_path=MODEL_PATH,
    n_features=len(feats),
    var_ret=0.8,
)

print("Get Shift Tester...")
tester = TSTester(
    tester_method=MD_TEST,
)

print("Get Shift Detector...")
detector = Detector(
    reductor=reductor,
    tester=tester,
    p_val_threshold=0.05,
)

detector.fit(X_val_final)

print("Get Rolling Window...")

rolling_window = RollingWindow(shift_detector=detector, optimizer=optimizer)

drift_metrics = rolling_window.drift(
    data_streams,
    SAMPLE,
    STAT_WINDOW,
    LOOKUP_WINDOW,
    STRIDE,
)

performance_metrics = rolling_window.performance(
    data_streams,
    STAT_WINDOW,
    LOOKUP_WINDOW,
    STRIDE,
)

results = {
    "timestamps": [
        (
            datetime.datetime.strptime(date, "%Y-%m-%d")
            + datetime.timedelta(days=LOOKUP_WINDOW + STAT_WINDOW)
        ).strftime("%Y-%m-%d")
        for date in data_streams["timestamps"]
    ][:-STAT_WINDOW],
}
results.update(drift_metrics)
results.update(performance_metrics)
results.to_pickle(
    os.path.join(PATH, SHIFT + "_" + DR_TECHNIQUE + "_" + MD_TEST + "_results.pkl"),
)

In [None]:
threshold = 0.05
sig_drift = np.array(results["shift_detected"])[np.newaxis]

fig, (ax1, ax2, ax3, ax4, ax5, ax6) = plt.subplots(6, 1, figsize=(18, 12))
cmap = ListedColormap(["lightgrey", "red"])
ax1.plot(
    results["timestamps"],
    results["p_val"],
    ".-",
    color="red",
    linewidth=0.5,
    markersize=2,
)
ax1.set_xlim(results["timestamps"][0], results["timestamps"][-1])
ax1.axhline(y=threshold, color="dimgrey", linestyle="--")
ax1.set_ylabel("P-Values", fontsize=16)
ax1.set_xticklabels([])
ax1.pcolorfast(ax1.get_xlim(), ax1.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)

ax2.plot(
    results["timestamps"],
    results["distance"],
    ".-",
    color="red",
    linewidth=0.5,
    markersize=2,
)
ax2.set_xlim(results["timestamps"][0], results["timestamps"][-1])
ax2.set_ylabel("Distance", fontsize=16)
ax2.axhline(y=np.mean(results["distance"]), color="dimgrey", linestyle="--")
ax2.set_xticklabels([])
ax2.pcolorfast(ax2.get_xlim(), ax2.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)

ax3.plot(
    results["timestamps"],
    results["auroc"],
    ".-",
    color="blue",
    linewidth=0.5,
    markersize=2,
)
ax3.set_xlim(results["timestamps"][0], results["timestamps"][-1])
ax3.set_ylabel("AUROC", fontsize=16)
ax3.axhline(y=np.mean(results["auroc"]), color="dimgrey", linestyle="--")
ax3.set_xticklabels([])
ax3.pcolorfast(ax3.get_xlim(), ax3.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)

ax4.plot(
    results["timestamps"],
    results["auprc"],
    ".-",
    color="blue",
    linewidth=0.5,
    markersize=2,
)
ax4.set_xlim(results["timestamps"][0], results["timestamps"][-1])
ax4.set_ylabel("AUPRC", fontsize=16)
ax4.axhline(y=np.mean(results["auprc"]), color="dimgrey", linestyle="--")
ax4.set_xticklabels([])
ax4.pcolorfast(ax4.get_xlim(), ax4.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)

ax5.plot(
    results["timestamps"],
    results["prec1"],
    ".-",
    color="blue",
    linewidth=0.5,
    markersize=2,
)
ax5.set_xlim(results["timestamps"][0], results["timestamps"][-1])
ax5.set_ylabel("PPV", fontsize=16)
ax5.axhline(y=np.mean(results["prec1"]), color="dimgrey", linestyle="--")
ax5.set_xticklabels([])
ax5.pcolorfast(ax5.get_xlim(), ax5.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)

ax6.plot(
    results["timestamps"],
    results["rec1"],
    ".-",
    color="blue",
    linewidth=0.5,
    markersize=2,
)
ax6.set_xlim(results["timestamps"][0], results["timestamps"][-1])
ax6.set_ylabel("Sensitivity", fontsize=16)
ax6.set_xlabel("time (s)", fontsize=16)
ax6.axhline(y=np.mean(results["rec1"]), color="dimgrey", linestyle="--")
ax6.tick_params(axis="x", labelrotation=45)
ax6.pcolorfast(ax6.get_xlim(), ax6.get_ylim(), sig_drift, cmap=cmap, alpha=0.4)

for index, label in enumerate(ax6.xaxis.get_ticklabels()):
    if index % 28 != 0:
        label.set_visible(False)

plt.show()