### Clinical drift detection

In [None]:
import sys
import pandas as pd
import numpy as np
import os
from functools import reduce
import datetime
import pickle
import matplotlib.pyplot as plt

sys.path.append("../..")

from gemini.constants import *
from gemini.utils import *
from drift_detector.plotter import plot_drift_samples_pval, errorfill, plot_roc, plot_pr, linestyles, markers, colors, brightness, colorscale
from drift_detector.utils import scale
from drift_detector.detector import Detector
from drift_detector.reductor import Reductor
from drift_detector.tester import TSTester, DCTester
from drift_detector.experimenter import Experimenter
from drift_detector.synthetic_applicator import ClinicalShiftApplicator, apply_predefined_shift

## Config Parameters

In [None]:
PATH = "/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/"
TIMESTEPS = 6
AGGREGATION_TYPE = "time_flatten"

## Query Data

In [None]:
admin_data, x, y = get_gemini_data(PATH)

## Input Parameters

In [None]:
SHIFT = input("Select experiment: ")  # covid,seasonal,hosp_type
OUTCOME = input("Select outcome variable: ") # mortality

if SHIFT == "covid":
    EXPERIMENTS = ["pre-covid", "covid"]
    HOSPITAL= ["SMH","MSH","UHNTG","UHNTW","PMH","THPC","THPM"]
    MODEL_PATH="../../saved_models/covid_lstm.pt"

if SHIFT == "seasonal_summer":
    EXPERIMENTS = ["seasonal_summer_baseline","seasonal_summer"] 
    MODEL_PATH="../../saved_models/seasonal_summer_lstm.pt"

if SHIFT == "seasonal_winter":
    EXPERIMENTS = ["seasonal_winter_baseline","seasonal_winter"] 
    MODEL_PATH="../../saved_models/seasonal_winter_lstm.pt"
    
if SHIFT == "hosp_type_academic":
    EXPERIMENTS = ["hosp_type_academic_baseline","hosp_type_academic"]
    MODEL_PATH="../../saved_models/hosp_type_academic_lstm.pt"

if SHIFT == "hosp_type_community":
    EXPERIMENTS = ["hosp_type_community_baseline","hosp_type_community"]
    MODEL_PATH="../../saved_models/hosp_type_community_lstm.pt"

MODEL_PATH = os.path.join(os.getcwd(),MODEL_PATH)
HOSPITAL = ["SMH","MSH","THPC","THPM","UHNTG","UHNTW","PMH"]

## Drift Tests

In [None]:
if AGGREGATION_TYPE == "time":
    DR_TECHNIQUES = ["NoRed","SRP", "PCA", "kPCA", "Isomap","BBSDs_untrained_FFNN","BBSDs_untrained_LSTM", "BBSDs_trained_LSTM"] 
    CONTEXT_TYPE="rnn"
    REPRESENTATION="rnn"
    DRIFT_PATH = PATH + '_'.join([AGGREGATION_TYPE,CONTEXT_TYPE,SHIFT, '_'.join(HOSPITAL),''])
else:
    DR_TECHNIQUES = ["NoRed","SRP", "PCA", "kPCA", "Isomap","BBSDs_untrained_FFNN"]
    CONTEXT_TYPE="ffnn"
    REPRESENTATION="rf"
    DRIFT_PATH = PATH + '_'.join([AGGREGATION_TYPE,CONTEXT_TYPE,REPRESENTATION, SHIFT, '_'.join(HOSPITAL),''])
    
MD_TESTS = ["Univariate","MMD", "LK", "Spot-the-diff"]

## Reductor

In [None]:
DR_TECHNIQUE = input("Select dimensionality reduction technique: ")

reductor = Reductor(
    dr_method = DR_TECHNIQUE,
    model_path = MODEL_PATH,
    var_ret = 0.8,
)

## Tester

In [None]:
MD_TEST = input("Select test method: ")

tester = TSTester(
    tester_method=MD_TEST,
)

## Detector

In [None]:
detector = Detector(
    reductor = reductor,
    tester = tester,
    p_val_threshold = 0.05,
)

## ClinicalShiftApplicator

In [None]:
shiftapplicator = ClinicalShiftApplicator(
    shift_type = "time"
    
) 

## Build Model 

In [None]:
experimenter = Experimenter(
    detector = detector,
    shiftapplicator = shiftapplicator,
)

In [None]:
shift_results = {}
for si, shift in enumerate(shifts):
    results = experimenter.detect_shift_samples(
        X_val_final, 
        X_t_final_shifted,
    )
    shift_results.update({shift:results})

## Plot