In [1]:
import os
import random
import sys
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import StandardScaler
from matplotlib.colors import ListedColormap
from datetime import date, timedelta

sys.path.append("../..")

from utils.utils import *
from drift_detector.rolling_window import *
from baseline_models.temporal.pytorch.optimizer import Optimizer
from baseline_models.temporal.pytorch.utils import *
from drift_detection.baseline_models.temporal.pytorch.optimizer import Optimizer
from drift_detection.baseline_models.temporal.pytorch.utils import *

2022-08-23 19:28:01,300 [1;37mINFO[0m cyclops.orm     - Database setup, ready to run queries!


In [2]:
PATH = "/mnt/nfs/project/delirium/drift_exp/JULY-04-2022"
SHIFT="covid"
threshold=0.05
num_timesteps = 6
stat_window=30
lookup_window=0
stride=1
run=1
hospital = ["SBK", "UHNTG", "THPC", "THPM", "UHNTW", "SMH","MSH","PMH"]
outcome="mortality"
aggregation_type="time_flatten"
experiments = ["hosp_type_community_baseline","hosp_type_community"]

In [3]:
admin_data, x, y = get_gemini_data(PATH)

numerical_cols = get_numerical_cols(PATH)
for col in numerical_cols:
    scaler = StandardScaler().fit(x[col].values.reshape(-1, 1))
    x[col] = pd.Series(
        np.squeeze(scaler.transform(x[col].values.reshape(-1, 1))),
        index=x[col].index,
    )
X = reshape_inputs(x, num_timesteps)

2022-08-23 19:28:01,755 [1;37mINFO[0m cyclops.utils.file - Loading dataframe to /mnt/nfs/project/delirium/drift_exp/JULY-04-2022/aggregated_events.parquet


Load data from aggregated events...


2022-08-23 19:28:02,211 [1;37mINFO[0m cyclops.utils.file - Loading dataframe to /mnt/nfs/project/delirium/drift_exp/JULY-04-2022/aggmeta_start_ts.parquet
2022-08-23 19:28:02,572 [1;37mINFO[0m cyclops.feature_handler - Loading features from file...
2022-08-23 19:28:02,574 [1;37mINFO[0m cyclops.feature_handler - Found file to load for static features...
2022-08-23 19:28:02,576 [1;37mINFO[0m cyclops.feature_handler - Successfully loaded static features from file...
2022-08-23 19:28:02,603 [1;37mINFO[0m cyclops.feature_handler - Found file to load for temporal features...


Load data from feature handler...


2022-08-23 19:28:05,731 [1;37mINFO[0m cyclops.feature_handler - Successfully loaded temporal features from file...


Load data from admin data...


2022-08-23 19:28:10,250 [1;37mINFO[0m cyclops.utils.file - Loading dataframe to /mnt/nfs/project/delirium/drift_exp/JULY-04-2022/aggmeta_end_ts.parquet
2022-08-23 19:28:25,466 [1;37mINFO[0m cyclops.feature_handler - Loading features from file...
2022-08-23 19:28:25,468 [1;37mINFO[0m cyclops.feature_handler - Found file to load for static features...
2022-08-23 19:28:25,471 [1;37mINFO[0m cyclops.feature_handler - Successfully loaded static features from file...
2022-08-23 19:28:25,500 [1;37mINFO[0m cyclops.feature_handler - Found file to load for temporal features...
2022-08-23 19:28:28,340 [1;37mINFO[0m cyclops.feature_handler - Successfully loaded temporal features from file...


## Get prediction model

In [4]:
output_dim = 1
batch_size = 64
input_dim = 108
timesteps = 6
hidden_dim = 64
layer_dim = 2
dropout = 0.2
n_epochs = 256
learning_rate = 2e-3
weight_decay = 1e-6
last_timestep_only = False

device = get_device()

model_params = {
    "device": device,
    "input_dim": input_dim,
    "hidden_dim": hidden_dim,
    "layer_dim": layer_dim,
    "output_dim": output_dim,
    "dropout_prob": dropout,
    "last_timestep_only": last_timestep_only,
}

model = get_temporal_model("lstm", model_params).to(device)
model.load_state_dict(torch.load(model_path))
loss_fn = nn.BCEWithLogitsLoss(reduction="none")
optimizer = optim.Adagrad(
    model.parameters(), lr=learning_rate, weight_decay=weight_decay
)
lr_scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=128, gamma=0.5)
activation = nn.Sigmoid()
opt = Optimizer(
    model=model,
    loss_fn=loss_fn,
    optimizer=optimizer,
    activation=activation,
    lr_scheduler=lr_scheduler,
)

## Rolling window

In [5]:
dr_technique="BBSDs_untrained_FFNN"
md_test="MMD"
sign_level=0.05
sample=1000
dataset="gemini"
context_type="lstm"
representation="rf"
threshold=0.05
scale=True
start_date = date(2018, 1, 1)
end_date = date(2020, 8, 1)

In [None]:
experiment_results = []

for shift in experiments:
    
    ## Set constant reference distribution
    random.seed(1)
    print("Query data %s ..." % SHIFT)
    (x_train, y_train), (x_val, y_val), (x_test, y_test), feats, admin_data = import_dataset_hospital(admin_data, x, y, SHIFT, outcome, hospital, run, shuffle=True)

    print("Get source data...")
    # Normalize data
    (X_tr_normalized, y_tr),(X_val_normalized, y_val), (X_t_normalized, y_t) = normalize_data(aggregation_type, admin_data, num_timesteps, x_train, y_train, x_val, y_val, x_test, y_test)
    # Scale data
    if scale:
        X_tr_normalized, X_val_normalized, X_t_normalized = scale_data(numerical_cols, X_tr_normalized, X_val_normalized, X_t_normalized)
    # Process data
    X_tr_final, X_val_final, X_t_final = process_data(aggregation_type, num_timesteps, X_tr_normalized, X_val_normalized, X_t_normalized)
    
    val_ids=list(X_val_normalized.index.get_level_values(0).unique())
    
    print("Get target data streams...")
    x_test_stream, y_test_stream, measure_dates_test = get_streams(x, y, admin_data, start_date, end_date, stride=1, window=1, ids_to_exclude=val_ids)
    
    print("Get Shift Reductor...")
    # Get shift reductor
    shift_reductor = ShiftReductor(
        X_tr_final, y_tr, dr_technique, dataset, var_ret=0.8, model_path=model_path,
    )
    print("Get Shift Detector...")
    # Get shift detector
    shift_detector = ShiftDetector(
        dr_technique, md_test, sign_level, shift_reductor, sample, dataset, feats, model_path, context_type, representation,
    )

    print("Rolling Window...")
    dist_vals_test, pvals_test = rolling_window_drift(X_tr_final, x_test_stream, shift_detector, sample, stat_window, lookup_window, stride, num_timesteps, threshold, X_val_final)
    
    performance_metrics = rolling_window_performance(x_test_stream, y_test_stream, opt, sample, stat_window, lookup_window, stride, num_timesteps, threshold, X_val_final)
    
    measure_dates_test = [(datetime.datetime.strptime(date,"%Y-%m-%d")+datetime.timedelta(days=lookup_window+stat_window*2)).strftime("%Y-%m-%d") for date in measure_dates_test]
    end = len(p_vals_test)
    results = pd.DataFrame(
    {'dates': measure_dates_test[1:end],
     'pval': pvals_test[1:end],
     'dist': dist_vals_test[1:end],
     'detection': np.where(p_vals_test[1:end]<threshold,1,0)
    })
    results = pd.concat([results,performance_metrics],axis=1)
    results.to_pickle(os.path.join(PATH,shift+"_"+dr_technique+"_"+md_test+"_results.pkl")) 
    experiment_results.append(results)


In [None]:
fig, (ax1, ax2, ax3, ax4) = plt.subplots(4, 1, figsize=(22,12))
cmap = ListedColormap(['lightgrey','red'])
ax1.plot(experiment_results[0]['dates'], experiment_results[0]['pval'], '.-', color="blue", linewidth=0.5, markersize=2)
ax1.plot(experiment_results[1]['dates'], experiment_results[1]['pval'], '.-', color="red", linewidth=0.5, markersize=2)
ax1.set_xlim(experiment_results[0]['dates'][0],experiment_results[0]['dates'][len(experiment_results[0]['dates'])-1])
ax1.axhline(y=0.05, color='dimgrey', linestyle='--')
ax1.set_ylabel('P-Values',fontsize=16)
ax1.set_xticklabels([])
ax1.pcolorfast(ax1.get_xlim(), ax1.get_ylim(),experiment_results[1]['detection'].values[np.newaxis], cmap = cmap, alpha = 0.4)

ax2.plot(experiment_results[0]['dates'], experiment_results[0]['dist'], '.-',color="blue", linewidth=0.5, markersize=2)
ax2.plot(experiment_results[1]['dates'], experiment_results[1]['dist'], '.-',color="red", linewidth=0.5, markersize=2)
ax2.set_xlim(experiment_results[0]['dates'][0],experiment_results[0]['dates'][len(experiment_results[0]['dates'])-1])
ax2.set_ylabel('Distance',fontsize=16)
ax2.axhline(y=np.mean(experiment_results[0]['dist_val_val']), color='dimgrey', linestyle='--')
ax2.set_xticklabels([])
ax2.pcolorfast(ax2.get_xlim(), ax2.get_ylim(),experiment_results[1]['detection'].values[np.newaxis], cmap = cmap, alpha = 0.4)

ax3.plot(experiment_results[0]['dates'], experiment_results[0]['auroc'], '.-',color="blue", linewidth=0.5, markersize=2)
ax3.plot(experiment_results[1]['dates'], experiment_results[1]['auroc'], '.-',color="red", linewidth=0.5, markersize=2)
ax3.set_xlim(measure_dates_test_covid[1],measure_dates_test_covid[end])
ax3.set_ylabel('AUROC', fontsize=16)
ax3.axhline(y=np.mean(auroc_test_covid), color='dimgrey', linestyle='--')
ax3.set_xticklabels([])
ax3.pcolorfast(ax3.get_xlim(), ax3.get_ylim(),experiment_results[1]['detection'].values[np.newaxis], cmap = cmap, alpha = 0.4)

ax4.plot(experiment_results[0]['dates'], experiment_results[0]['auprc'], '.-',color="blue", linewidth=0.5, markersize=2)
ax4.plot(experiment_results[1]['dates'], experiment_results[1]['auprc'], '.-',color="red", linewidth=0.5, markersize=2)
ax4.set_xlim(measure_dates_test_covid[1],measure_dates_test_covid[end])
ax4.set_ylabel('AUPRC',fontsize=16)
ax4.set_xlabel('time (s)', fontsize=16)
ax4.axhline(y=np.mean(auprc_test_covid), color='dimgrey', linestyle='--')
ax4.tick_params(axis='x', labelrotation=45)
ax4.pcolorfast(ax4.get_xlim(), ax4.get_ylim(),experiment_results[1]['detection'].values[np.newaxis], cmap = cmap, alpha = 0.4)



for index, label in enumerate(ax4.xaxis.get_ticklabels()):
    if index % 28 != 0:
        label.set_visible(False)

plt.show()