### Clinical drift detection

In [None]:
import sys
import pandas as pd
import numpy as np
import os
from functools import reduce
import datetime
import pickle
import matplotlib.pyplot as plt

sys.path.append("../..")

from gemini.constants import *
from gemini.utils import *
from drift_detector.plotter import plot_drift_samples_pval, errorfill, plot_roc, plot_pr, linestyles, markers, colors, brightness, colorscale
from drift_detector.utils import scale
from drift_detector.detector import Detector
from drift_detector.reductor import Reductor
from drift_detector.tester import TSTester, DCTester
from drift_detector.experimenter import Experimenter
from drift_detector.clinical_applicator import ClinicalShiftApplicator

## Config parameters

In [None]:
PATH = "/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/"
HOSPITALS = ["SMH","MSH","THPC","THPM","UHNTG","UHNTW","PMH"]
TIMESTEPS = 6
AGGREGATION_TYPE = "time_flatten"
ACADEMIC=["PMH", "SMH","UHNTW","UHNTG","PMH","SBK"]
COMMUNITY=["THPC","THPM"]

if AGGREGATION_TYPE == "time":
    CONTEXT_TYPE="rnn"
    REPRESENTATION="rnn"
else:
    CONTEXT_TYPE="ffnn"
    REPRESENTATION="rf"
    
OUTCOME = input("Select outcome variable: ") # mortality
SHIFT = input("Select experiment: ")  # covid,seasonal,hosp_type
MODEL_PATH="../../saved_models/"+SHIFT+"_lstm.pt"

if SHIFT == "covid":
    EXPERIMENTS = ["pre-covid", "covid"]

if SHIFT == "seasonal_summer":
    exp_params = {
        'baseline': {'source': [1,2,3,4,5,6,7,8,9,10,11,12], 'target': [6,7,8,9]},
        'experiment': {'source': [1,2,3,4,5,10,11,12], 'target':[6,7,8,9]},
        'shift_type': "month"
    }

if SHIFT == "seasonal_winter":
    exp_params = {
        'baseline': {'source': [1,2,3,4,5,6,7,8,9,10,11,12], 'target': [11,12,1,2]},
        'experiment': {'source': [3,4,5,6,7,8,9,10], 'target':[11,12,1,2]},
        'shift_type': "month"
    }
    
if SHIFT == "hosp_type_academic":
    exp_params = {
        'baseline': {'source': ACADEMIC, 'target':ACADEMIC},
        'experiment': {'source':COMMUNITY, 'target':ACADEMIC},
        'shift_type':"hospital_type"
    }

if SHIFT == "hosp_type_community":
    exp_params = {
        'baseline': {'source': COMMUNITY, 'target':COMMUNITY},
        'experiment': {'source': ACADEMIC, 'target':COMMUNITY},
        'shift_type':"hospital_type"
    }

## Query data

In [None]:
admin_data, x, y = get_gemini_data(PATH)

(X_tr, y_tr), (X_val, y_val), (X_t, y_t), feats, admin_data = import_dataset_hospital(admin_data, x, y, SHIFT, OUTCOME, HOSPITALS)
x = x.loc[~x.index.get_level_values(0).isin(X_tr.index.get_level_values(0))]

# Normalize training data
X_tr_normalized  = normalize(AGGREGATION_TYPE, admin_data, TIMESTEPS, X_tr)
# Get training labels
y_tr = get_label(admin_data, X_tr, "mortality")
# Scale training data
X_tr_scaled = scale(X_tr_normalized)
# Process training data
X_tr_final = process(AGGREGATION_TYPE, TIMESTEPS, X_tr_scaled)

## Reductor

In [None]:
 DR_TECHNIQUE = input("Select dimensionality reduction technique: ")

reductor = Reductor(
    dr_method = DR_TECHNIQUE,
    model_path = MODEL_PATH,
    var_ret = 0.8,
)
reductor.fit(X_tr_final)

## Tester

In [None]:
MD_TEST = input("Select test method: ")

tester = TSTester(
    tester_method=MD_TEST,
)

## Detector

In [None]:
detector = Detector(
    reductor = reductor,
    tester = tester,
    p_val_threshold = 0.05,
)

## ClinicalShiftApplicator

In [None]:
clinicalshiftapplicator = ClinicalShiftApplicator(
    shift_type = exp_params['shift_type']
    
) 

experimenter = Experimenter(
    detector = detector,
    clinicalshiftapplicator = clinicalshiftapplicator,
    admin_data = admin_data
)

## Experimenter

In [None]:
shift_results = {}
for si, shift in enumerate(["baseline","experiment"]):
    X_val, X_t = experimenter.apply_clinical_shift(
        x,
        source=exp_params[shift]['source'],
        target=exp_params[shift]['target']
    )
    # Normalize data
    X_val_normalized = normalize(AGGREGATION_TYPE, admin_data, TIMESTEPS, X_val)
    X_t_normalized = normalize(AGGREGATION_TYPE, admin_data, TIMESTEPS, X_t)

    # Get labels
    y_val = get_label(admin_data, X_val, "mortality")
    y_t = get_label(admin_data, X_t, "mortality")
    
    # Scale data
    X_val_scaled = scale(X_val_normalized)
    X_t_scaled = scale(X_t_normalized)

    # Process data
    X_val_final = process(AGGREGATION_TYPE, TIMESTEPS, X_val_scaled)
    X_t_final = process(AGGREGATION_TYPE, TIMESTEPS, X_t_scaled)

    results = experimenter.detect_shift_samples(
        X_val_final, 
        X_t_final,
        synthetic=False
    )
    shift_results.update({shift:results})

## Plot drift results

In [None]:
plot_drift_samples_pval(shift_results, 0.05)