In [None]:
import sys
import pandas as pd
import numpy as np
import os
from functools import reduce
import datetime
import pickle
import matplotlib.pyplot as plt
from sklearn.metrics import (
    auc,
    accuracy_score,
    confusion_matrix,
    roc_auc_score,
    roc_curve,
    precision_recall_curve,
    average_precision_score,
)

sys.path.append("../..")

from drift_detector.explainer import Explainer
from drift_detector.experiments import *
from gemini.constants import *
from drift_detector.plotter import errorfill, plot_roc, plot_pr, linestyles, markers, colors, brightness, colorscale
from gemini.utils import *
from baseline_models.static.utils import run_model

# Config Parameters #

In [None]:
PATH = "/mnt/nfs/project/delirium/drift_exp/JULY-04-2022/"
DATASET = "gemini"
SAMPLES = [10, 20, 50, 100, 200, 500]
RANDOM_RUNS = 5
SIGN_LEVEL = 0.05
CALC_ACC = True
TIMESTEPS = 6
AGGREGATION_TYPE = "time_flatten"
SCALE=True

In [None]:
# ["NoRed"] --> ["MMD", "LK","Context-Aware MMD", "Spot-the-diff"]
# ["SRP", "PCA", "kPCA", "IsoMap", "BBSDs_untrained_FFNN", "BBSDs_untrained_LSTM", "BBSDs_trained_LSTM"] --> ["MMD", "LK" "Classifier" Spot-the-diff"]

# Query Data

In [None]:
admin_data, x, y = get_gemini_data(PATH)

# Input Parameters #

In [None]:
SHIFT = input("Select experiment: ")  # covid,seasonal,hosp_type
OUTCOME = input("Select outcome variable: ") # mortality

if SHIFT == "covid":
    EXPERIMENTS = ["pre-covid", "covid"]
    HOSPITAL= ["SMH","MSH","UHNTG","UHNTW","PMH","THPC","THPM"]
    MODEL_PATH="../../saved_models/covid_lstm.pt"

if SHIFT == "seasonal_summer":
    EXPERIMENTS = ["seasonal_summer_baseline","seasonal_summer"] 
    MODEL_PATH="../../saved_models/seasonal_summer_lstm.pt"

if SHIFT == "seasonal_winter":
    EXPERIMENTS = ["seasonal_winter_baseline","seasonal_winter"] 
    MODEL_PATH="../../saved_models/seasonal_winter_lstm.pt"
    
if SHIFT == "hosp_type_academic":
    EXPERIMENTS = ["hosp_type_academic_baseline","hosp_type_academic"]
    MODEL_PATH="../../saved_models/hosp_type_academic_lstm.pt"

if SHIFT == "hosp_type_community":
    EXPERIMENTS = ["hosp_type_community_baseline","hosp_type_community"]
    MODEL_PATH="../../saved_models/hosp_type_community_lstm.pt"

MODEL_PATH = os.path.join(os.getcwd(),MODEL_PATH)
HOSPITAL = ["SMH","MSH","THPC","THPM","UHNTG","UHNTW","PMH"]

# Drift Tests #

In [None]:
if AGGREGATION_TYPE == "time":
    DR_TECHNIQUES = ["NoRed","SRP", "PCA", "kPCA", "Isomap","BBSDs_untrained_FFNN","BBSDs_untrained_LSTM", "BBSDs_trained_LSTM"] 
    CONTEXT_TYPE="rnn"
    REPRESENTATION="rnn"
    DRIFT_PATH = PATH + '_'.join([AGGREGATION_TYPE,CONTEXT_TYPE,SHIFT, '_'.join(HOSPITAL),''])
else:
    DR_TECHNIQUES = ["NoRed","SRP", "PCA", "kPCA", "Isomap","BBSDs_untrained_FFNN"]
    CONTEXT_TYPE="ffnn"
    REPRESENTATION="rf"
    DRIFT_PATH = PATH + '_'.join([AGGREGATION_TYPE,CONTEXT_TYPE,REPRESENTATION, SHIFT, '_'.join(HOSPITAL),''])
    
MD_TESTS = ["Univariate","MMD", "LK", "Spot-the-diff"]

In [None]:
DRIFT_PATH

In [None]:
# Run shift experiments
if os.path.exists(os.path.join(DRIFT_PATH + "mean_dr_md_pval.pkl")):
    with open(os.path.join(DRIFT_PATH + "mean_dr_md_pval.pkl"), "rb") as f:
        mean_dr_md_pval = pickle.load(f)
    with open(os.path.join(DRIFT_PATH + "mean_dr_md_dist.pkl"), "rb") as f:
        mean_dr_md_dist = pickle.load(f)
else:
    mean_dr_md_pval = np.ones(
        (len(EXPERIMENTS), len(DR_TECHNIQUES), len(MD_TESTS), len(SAMPLES))
    ) * (-1)
    mean_dr_md_dist = np.ones(
        (len(EXPERIMENTS), len(DR_TECHNIQUES), len(MD_TESTS), len(SAMPLES))
    ) * (-1)

if os.path.exists(os.path.join(DRIFT_PATH + "std_dr_md_pval.pkl")):
    with open(os.path.join(DRIFT_PATH + "std_dr_md_pval.pkl"), "rb") as f:
        std_dr_md_pval = pickle.load(f)
    with open(os.path.join(DRIFT_PATH + "std_dr_md_dist.pkl"), "rb") as f:
        std_dr_md_dist = pickle.load(f)
        
else:
    std_dr_md_pval = np.ones(
        (len(EXPERIMENTS), len(DR_TECHNIQUES), len(MD_TESTS), len(SAMPLES))
    ) * (-1)
    std_dr_md_dist = np.ones(
        (len(EXPERIMENTS), len(DR_TECHNIQUES), len(MD_TESTS), len(SAMPLES))
    ) * (-1)


for si, SHIFT in enumerate(EXPERIMENTS):
    for di, DR_TECHNIQUE in enumerate(DR_TECHNIQUES):
        for mi, MD_TEST in enumerate(MD_TESTS):
            if np.any(mean_dr_md_pval[si, di, mi, :] == -1):
                print(
                    "{} | {} | {} | {}".format(
                        SHIFT, HOSPITAL, DR_TECHNIQUE, MD_TEST
                    )
                )
                
                if AGGREGATION_TYPE == "time_flatten" and MD_TEST == "Classifier":
                    REPRESENTATION="rf"
                elif AGGREGATION_TYPE == "time_flatten" and MD_TEST == "LK":
                    REPRESENTATION="ffnn"
                    
                    
                if True:
                    mean_p_vals, std_p_vals, mean_dist, std_dist = run_shift_experiment(
                            shift=SHIFT,
                            admin_data=admin_data,
                            x=x, 
                            y=y,
                            outcome=OUTCOME,
                            hospital=HOSPITAL,
                            path=PATH,
                            aggregation_type=AGGREGATION_TYPE,
                            scale=SCALE,
                            dr_technique=DR_TECHNIQUE,
                            model_path=MODEL_PATH,
                            md_test=MD_TEST,
                            context_type=CONTEXT_TYPE,
                            representation=REPRESENTATION,
                            samples=SAMPLES,
                            dataset=DATASET,
                            sign_level=SIGN_LEVEL,
                            timesteps=TIMESTEPS,
                            random_runs=RANDOM_RUNS,
                            calc_acc=CALC_ACC
                    )

                    mean_dr_md_pval[si, di, mi, :] = mean_p_vals
                    std_dr_md_pval[si, di, mi, :] = std_p_vals
                    mean_dr_md_dist[si, di, mi, :] = mean_dist
                    std_dr_md_dist[si, di, mi, :] = std_dist


In [None]:
DIM_RED = input("Select Pre-Processing: ")
MD_TEST = input("Select Two-Sample Testing: ")
PLOT_METRIC = input("Plot: ")

fig = plt.figure(figsize=(8, 6))
for si, shift in enumerate(EXPERIMENTS):
    for di, dr_technique in enumerate(DR_TECHNIQUES):
        for mi, md_test in enumerate(MD_TESTS):
            if (
                    dr_technique == DIM_RED
                    and md_test == MD_TEST and 
                    PLOT_METRIC == "Distance"
            ):
                errorfill(
                        np.array(SAMPLES),
                        mean_dr_md_dist[si, di, mi, :],
                        std_dr_md_dist[si, di, mi, :],
                        fmt=linestyles[si] + markers[si],
                        color=colorscale(colors[si], brightness[si]),
                        label="%s" % "_".join([shift, dr_technique, md_test]),
                )
            elif (
                    dr_technique == DIM_RED
                    and md_test == MD_TEST and 
                    PLOT_METRIC == "P-Value"
            ):
                errorfill(
                        np.array(SAMPLES),
                        mean_dr_md_pval[si, di, mi, :],
                        std_dr_md_pval[si, di, mi, :],
                        fmt=linestyles[si] + markers[si],
                        color=colorscale(colors[si], brightness[si]),
                        label="%s" % "_".join([shift, dr_technique, md_test]),
                )
plt.xlabel("Number of samples from test data")
plt.ylabel(PLOT_METRIC)
plt.axhline(y=SIGN_LEVEL, color="k")
plt.legend()
plt.show()

# Run shift experiments
if not os.path.exists(os.path.join(DRIFT_PATH + "mean_dr_md_pval.pkl")):
    with open(os.path.join(DRIFT_PATH + "mean_dr_md_pval.pkl"), "wb") as f:
        pickle.dump(mean_dr_md_pval, f)
    with open(os.path.join(DRIFT_PATH + "mean_dr_md_dist.pkl"), "wb") as f:
        pickle.dump(mean_dr_md_dist, f)
if not os.path.exists(os.path.join(DRIFT_PATH + "std_dr_md_pval.pkl")):
    with open(os.path.join(DRIFT_PATH + "std_dr_md_pval.pkl"), "wb") as f:
        pickle.dump(std_dr_md_pval, f)
    with open(os.path.join(DRIFT_PATH + "std_dr_md_dist.pkl"), "wb") as f:
        pickle.dump(std_dr_md_dist, f)

# Build Model #

In [None]:
SHIFT="hosp_type_community"
scale = True

(X_tr, y_tr), (X_val, y_val), (X_t, y_t), feats, admin_data = import_dataset_hospital(admin_data, x, y, SHIFT, OUTCOME, HOSPITAL, 1, shuffle=True)

aggregation_type = "time_flatten"
numerical_cols = get_numerical_cols(PATH)

# Normalize data
(X_tr_normalized, y_tr),(X_val_normalized, y_val), (X_t_normalized, y_t) = normalize_data(aggregation_type, admin_data, TIMESTEPS, X_tr, y_tr, X_val, y_val, X_t, y_t)

# Scale data
if scale:
    X_tr_normalized, X_val_normalized, X_t_normalized = scale_data(numerical_cols, X_tr_normalized, X_val_normalized, X_t_normalized)
# Process data
X_tr_final, X_val_final, X_t_final = process_data(aggregation_type, TIMESTEPS, X_tr_normalized, X_val_normalized, X_t_normalized)

MODEL_NAME = input("Select Model: ")
MODEL_PATH = PATH + "_".join([SHIFT, OUTCOME, aggregation_type, '_'.join(HOSPITAL), MODEL_NAME]) + ".pkl"
if os.path.exists(MODEL_PATH):
    optimised_model = pickle.load(open(MODEL_PATH, "rb"))
else:
    optimised_model = run_model(MODEL_NAME, X_tr_final, y_tr, X_val_final, y_val)
    pickle.dump(optimised_model, open(MODEL_PATH, 'wb'))

### Performance on Source Data ###

In [None]:
y_pred_prob = optimised_model.predict_proba(X_val_final)[:, 1]

fpr, tpr, thresholds = roc_curve(y_val, y_pred_prob, pos_label=1)
roc_auc = auc(fpr, tpr)
precision, recall, thresholds = precision_recall_curve(y_val, y_pred_prob)
auc_pr = auc(recall, precision)
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(14, 6))
plot_roc(ax[0], fpr, tpr, roc_auc)
plot_pr(ax[1], recall, precision, auc_pr)

### Performance on Target Data ###

In [None]:
y_pred_prob = optimised_model.predict_proba(X_t_final)[:, 1]
fpr, tpr, thresholds = roc_curve(y_t, y_pred_prob, pos_label=1)
roc_auc = auc(fpr, tpr)
precision, recall, thresholds = precision_recall_curve(y_t, y_pred_prob)
auc_pr = auc(recall, precision)
fig, ax = plt.subplots(nrows=1, ncols=2, figsize=(14, 6))
plot_roc(ax[0], fpr, tpr, roc_auc)
plot_pr(ax[1], recall, precision, auc_pr)

In [None]:
val_auroc = [] 
val_auprc = []
test_auroc = []
test_auprc = []
RANDOM_RUNS=10
for i in range(0, RANDOM_RUNS):
    random.seed(i)
    print(i)
    (X_tr, y_tr), (X_val, y_val), (X_t, y_t), feats, admin_data = import_dataset_hospital(admin_data, x, y, SHIFT, OUTCOME, HOSPITAL, i, shuffle=True)

    aggregation_type = "time_flatten"
    numerical_cols = get_numerical_cols(PATH)

    # Normalize data
    (X_tr_normalized, y_tr),(X_val_normalized, y_val), (X_t_normalized, y_t) = normalize_data(aggregation_type, admin_data, TIMESTEPS, X_tr, y_tr, X_val, y_val, X_t, y_t)

    # Scale data
    if scale:
        X_tr_normalized, X_val_normalized, X_t_normalized = scale_data(numerical_cols, X_tr_normalized, X_val_normalized, X_t_normalized)
    # Process data
    X_tr_final, X_val_final, X_t_final = process_data(aggregation_type, TIMESTEPS, X_tr_normalized, X_val_normalized, X_t_normalized)
    
    y_pred_prob = optimised_model.predict_proba(X_val_final)[:, 1]
    fpr, tpr, thresholds = roc_curve(y_val, y_pred_prob, pos_label=1)
    roc_auc = auc(fpr, tpr)
    val_auroc.append(roc_auc)
    precision, recall, thresholds = precision_recall_curve(y_val, y_pred_prob)
    auc_pr = auc(recall, precision)
    val_auprc.append(auc_pr)
    
    y_pred_prob = optimised_model.predict_proba(X_t_final)[:, 1]
    fpr, tpr, thresholds = roc_curve(y_t, y_pred_prob, pos_label=1)
    roc_auc = auc(fpr, tpr)
    test_auroc.append(roc_auc)
    precision, recall, thresholds = precision_recall_curve(y_t, y_pred_prob)
    auc_pr = auc(recall, precision)
    test_auprc.append(auc_pr)

In [None]:
import scipy.stats as st

print(np.mean(val_auroc),st.t.interval(0.95, len(val_auroc)-1, loc=np.mean(val_auroc), scale=st.sem(val_auroc)))
print(np.mean(val_auprc),st.t.interval(0.95, len(val_auprc)-1, loc=np.mean(val_auprc), scale=st.sem(val_auprc)))
print(np.mean(test_auroc),st.t.interval(0.95, len(test_auroc)-1, loc=np.mean(test_auroc), scale=st.sem(test_auroc)))
print(np.mean(test_auprc),st.t.interval(0.95, len(test_auprc)-1, loc=np.mean(test_auprc), scale=st.sem(test_auprc)))

## Performance by Sample Size ##

In [None]:
def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = np.random.permutation(len(a))
    return a[p], b[p]

RANDOM_RUNS=100
samp_metrics = np.ones((len(SAMPLES), RANDOM_RUNS, 2, 2)) * (-1)
for si, sample in enumerate(SAMPLES):
    for i in range(0, RANDOM_RUNS - 1):
        i = int(i)
        np.random.seed(i)
        X_val_shuffled, y_val_shuffled = unison_shuffled_copies(X_val_final, y_val)
        X_test_shuffled, y_test_shuffled = unison_shuffled_copies(X_t_final, y_t)

        y_val_pred_prob = optimised_model.predict_proba(X_val_shuffled[:sample])[:, 1]
        val_fpr, val_tpr, val_thresholds = roc_curve(
            y_val_shuffled[:sample], y_val_pred_prob[:sample], pos_label=1
        )
        val_roc_auc = auc(val_fpr, val_tpr)
        val_avg_pr = average_precision_score(y_val_shuffled[:sample], y_val_pred_prob[:sample])

        y_test_pred_prob = optimised_model.predict_proba(X_test_shuffled[:sample])[:, 1]
        test_fpr, test_tpr, test_thresholds = roc_curve(
            y_test_shuffled[:sample], y_test_pred_prob[:sample], pos_label=1
        )
        test_roc_auc = auc(test_fpr, test_tpr)
        test_avg_pr = average_precision_score(y_test_shuffled[:sample], y_test_pred_prob[:sample])

        samp_metrics[si, i, 0, :] = [val_roc_auc, val_avg_pr]
        samp_metrics[si, i, 1, :] = [test_roc_auc, test_avg_pr]

    mean_samp_metrics = np.mean(samp_metrics, axis=1)
    std_samp_metrics = np.std(samp_metrics, axis=1)

In [None]:
fig = plt.figure(figsize=(8, 6))
for si, shift in enumerate(["baseline", SHIFT]):
    for mi, metric in enumerate(["AuROC", "Avg Pr"]):
        if metric == "AuROC":
            errorfill(
                np.array(SAMPLES[1:]),
                mean_samp_metrics[1:, si, mi],
                std_samp_metrics[1:, si, mi],
                fmt=linestyles[mi] + markers[mi],
                color=colorscale(colors[si], brightness[si]),
                label="%s" % "_".join([shift]),
            )
plt.xlabel("Number of samples from test data")
plt.ylabel("AuROC")
plt.legend()
plt.show()

## Explain Difference in Model Predictions ## 

In [None]:
import itertools 

timesteps = ['T1_', 'T2_','T3_','T4_', 'T5_','T6_']

flattened_feats = []
for ts in timesteps:
    flattened_feats.append(ts+feats)
flattened_feats = list(itertools.chain.from_iterable(flattened_feats))

In [None]:
explainer = ShiftExplainer(optimised_model)
explainer.get_explainer()

X_val_df = pd.DataFrame(X_val_final, columns=flattened_feats)
val_shap_values = explainer.get_shap_values(X_val_df)
X_test_df = pd.DataFrame(X_t_final, columns=flattened_feats)
test_shap_values = explainer.get_shap_values(X_test_df)

diff = np.mean(np.abs(test_shap_values.values), axis=0) - np.mean(
    np.abs(val_shap_values.values), axis=0
)
diff_sorted, feats_sorted = zip(*sorted(zip(diff, flattened_feats), reverse=True))
diff_sorted, feats_sorted = zip(
    *(((x, y) for x, y in zip(diff_sorted, feats_sorted) if (x > 0.01 or x < -0.01)))
)
diff_sorted = list(diff_sorted)
feats_sorted=list(map(lambda x: x.replace('T1_', '').replace('T2_', '').replace('T3_', '').replace('T4_', '').replace('T5_', '').replace('T6_', ''),feats_sorted))
del diff_sorted[5]
del feats_sorted[5]
fig, ax = plt.subplots(figsize=(12, 12))
y_pos = np.arange(len(diff_sorted))
ax.barh(y_pos, diff_sorted, align="center")
ax.set_yticks(y_pos, labels=feats_sorted)
ax.invert_yaxis()  # labels read top-to-bottom
ax.set_xlabel("Mean Difference in Shap Value")
ax.set_title("Features")
plt.show()