In [None]:
import os

# import sys
# sys.path.append("..")
import pickle

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from baseline_models.static.utils import run_model

# from gemini.constants import *
from gemini.utils import import_dataset_hospital, run_shift_experiment
from sklearn.metrics import auc, average_precision_score, roc_curve

## Parameters ##

In [None]:
PATH = "/mnt/nfs/project/delirium/drift_exp"
DATASET = "gemini"
SAMPLES = [10, 20, 50, 100, 200, 500, 1000]
RANDOM_RUNS = 5
SIGN_LEVEL = 0.05
CALC_ACC = True
SAMPLES = [10, 20, 50, 100, 200, 500, 1000]
DR_TECHNIQUES = ["NoRed", "SRP", "PCA", "kPCA", "Isomap", "BBSDs_FFNN"]
MD_TESTS = ["LSDD", "MMD", "LK", "Classifier"]

SHIFTS = ["pre-covid", "covid", "summer", "winter", "seasonal"]
OUTCOMES = ["length_of_stay_in_er", "mortality_in_hospital"]
HOSPITALS = ["SMH", "MSH", "THPC", "THPM", "UHNTG", "UHNTW"]
MODELS = ["lr", "rf", "xgb", "mlp", "gp"]
NA_CUTOFF = 0.6

## Model Fitting ##

In [None]:
# Run model fitting
if os.path.exists(PATH + "/shift_auc.pkl"):
    with open(PATH + "/shift_auc.pkl", "rb") as f:
        shift_auc = pickle.load(f)
else:
    shift_auc = np.ones(
        (len(SHIFTS), len(OUTCOMES), len(HOSPITALS), len(MODELS), 2)
    ) * (-1)

if os.path.exists(PATH + "/shift_pr.pkl"):
    with open(PATH + "/shift_pr.pkl", "rb") as f:
        shift_pr = pickle.load(f)
else:
    shift_pr = np.ones((len(SHIFTS), len(OUTCOMES), len(HOSPITALS), len(MODELS), 2)) * (
        -1
    )

In [None]:
for si, SHIFT in enumerate(SHIFTS):
    for oi, OUTCOME in enumerate(OUTCOMES):
        for hi, HOSPITAL in enumerate(HOSPITALS):
            for mi, MODEL in enumerate(MODELS):
                if np.any(shift_auc[si, oi, hi, mi, :] == -1) or (
                    SHIFT in ["summer", "winter", "seasonal"]
                ):
                    print("{} | {} | {} | {}".format(SHIFT, OUTCOME, HOSPITAL, MODEL))
                    (
                        (X_train, y_train),
                        (X_val, y_val),
                        (X_test, y_test),
                        feats,
                        orig_dims,
                    ) = import_dataset_hospital(
                        SHIFT, OUTCOME, [HOSPITAL], NA_CUTOFF, shuffle=True
                    )

                    model_path = (
                        PATH + "_".join([SHIFT, OUTCOME, HOSPITAL, MODEL]) + ".pkl"
                    )
                    if os.path.exists(model_path):
                        optimised_model = pickle.load(open(model_path, "rb"))
                    else:
                        optimised_model = run_model(
                            MODEL, X_train, y_train, X_val, y_val
                        )
                        pickle.dump(
                            optimised_model,
                            open(
                                model_path,
                                "wb",
                            ),
                        )

                    # calc metrics for validation set
                    y_val_pred_prob = optimised_model.predict_proba(X_val)[:, 1]
                    val_fpr, val_tpr, val_thresholds = roc_curve(
                        y_val, y_val_pred_prob, pos_label=1
                    )

                    val_roc_auc = auc(val_fpr, val_tpr)

                    val_avg_pr = average_precision_score(y_val, y_val_pred_prob)
                    # val_recall = recall_score(y_val,
                    #              y_val_pred_prob, average='weighted')

                    # calc metrics for test set
                    y_test_pred_prob = optimised_model.predict_proba(X_test)[:, 1]
                    test_fpr, test_tpr, test_thresholds = roc_curve(
                        y_test, y_test_pred_prob, pos_label=1
                    )
                    test_roc_auc = auc(test_fpr, test_tpr)
                    test_avg_pr = average_precision_score(y_test, y_test_pred_prob)
                    # test_recall = recall_score(y_test,
                    #               y_test_pred_prob, average='weighted')

                    shift_auc[si, oi, hi, mi, :] = [val_roc_auc, test_roc_auc]
                    shift_pr[si, oi, hi, mi, :] = [val_avg_pr, test_avg_pr]

In [None]:
if not os.path.exists(PATH + "/shift_auc.pkl"):
    with open(PATH + "/shift_auc.pkl", "wb") as f:
        pickle.dump(shift_auc, f)
if not os.path.exists(PATH + "/shift_pr.pkl"):
    with open(PATH + "/shift_pr.pkl", "wb") as f:
        pickle.dump(shift_pr, f)

## ROC AUC ##

In [None]:
auc_file = PATH + "/driftexp_auc.csv"
if os.path.exists(auc_file):
    all_auc = pd.read_csv(auc_file, sep="\t", header=[0, 1, 2], index_col=[0, 1])
else:
    all_auc = np.rollaxis(shift_auc, 4, 1)
    cols = pd.MultiIndex.from_product(
        [
            OUTCOMES,
            HOSPITALS,
            MODELS,
        ]
    )
    index = pd.MultiIndex.from_product([SHIFTS, ["VAL_ROC_AUC", "TEST_ROC_AUC"]])
    all_auc = all_auc.reshape(
        len(SHIFTS) * 2, len(OUTCOMES) * len(HOSPITALS) * len(MODELS)
    )
    all_auc = pd.DataFrame(all_auc, columns=cols, index=index)
    all_auc.to_csv(auc_file, sep="\t")
all_auc.head(n=10)

## Avg Precision ## 

In [None]:
pr_file = PATH + "/driftexp_pr.csv"
if os.path.exists(pr_file):
    pr = pd.read_csv(pr_file, sep="\t", header=[0, 1, 2], index_col=[0, 1])
else:
    pr = np.rollaxis(shift_pr, 4, 1)
    cols = pd.MultiIndex.from_product([OUTCOMES, HOSPITALS, MODELS])
    index = pd.MultiIndex.from_product([SHIFTS, ["VAL_AVG_PR", "TEST_AVG_PR"]])
    pr = pr.reshape(len(SHIFTS) * 2, len(OUTCOMES) * len(HOSPITALS) * len(MODELS))
    pr = pd.DataFrame(pr, columns=cols, index=index)
    pr.to_csv(pr_file, sep="\t")
pr.head(n=10)

## Drift Detection ##

In [None]:
# Run shift experiments
if os.path.exists(PATH + "/mean_dr_md.pkl"):
    with open(PATH + "/mean_dr_md.pkl", "rb") as f:
        mean_dr_md = pickle.load(f)
else:
    mean_dr_md = np.ones(
        (len(SHIFTS), len(HOSPITALS), len(DR_TECHNIQUES), len(MD_TESTS), len(SAMPLES))
    ) * (-1)

if os.path.exists(PATH + "/std_dr_md.pkl"):
    with open(PATH + "/std_dr_md.pkl", "rb") as f:
        std_dr_md = pickle.load(f)
else:
    std_dr_md = np.ones(
        (len(SHIFTS), len(HOSPITALS), len(DR_TECHNIQUES), len(MD_TESTS), len(SAMPLES))
    ) * (-1)

for si, SHIFT in enumerate(SHIFTS):
    for hi, HOSPITAL in enumerate(HOSPITALS):
        for di, DR_TECHNIQUE in enumerate(DR_TECHNIQUES):
            for mi, MD_TEST in enumerate(MD_TESTS):
                if np.any(mean_dr_md[si, hi, di, mi, :] == -1):
                    print(
                        "{} | {} | {} | {}".format(
                            SHIFT, HOSPITAL, DR_TECHNIQUE, MD_TEST
                        )
                    )
                    try:
                        mean_p_vals, std_p_vals = run_shift_experiment(
                            SHIFT,
                            OUTCOME,
                            HOSPITAL,
                            PATH,
                            DR_TECHNIQUE,
                            MD_TEST,
                            SAMPLES,
                            DATASET,
                            SIGN_LEVEL,
                            NA_CUTOFF,
                            RANDOM_RUNS,
                            calc_acc=True,
                        )
                        mean_dr_md[si, hi, di, mi, :] = mean_p_vals
                        std_dr_md[si, hi, di, mi, :] = std_p_vals
                    except ValueError:
                        print("Value Error")
                        pass

In [None]:
means_file = PATH + "/driftexp_means.csv"
if os.path.exists(means_file):
    means = pd.read_csv(means_file, sep="\t", header=[0, 1], index_col=[0, 1, 2])
else:
    means = np.moveaxis(mean_dr_md, 4, 2)
    cols = pd.MultiIndex.from_product([DR_TECHNIQUES, MD_TESTS])
    index = pd.MultiIndex.from_product([SHIFTS, HOSPITALS, SAMPLES])
    means = means.reshape(
        len(SHIFTS) * len(HOSPITALS) * len(SAMPLES), len(DR_TECHNIQUES) * len(MD_TESTS)
    )
    means = pd.DataFrame(means, columns=cols, index=index)
    means.index.names = ["Dataset", "Hospital", "Samples"]
    means.to_csv(PATH + "/driftexp_means.csv", sep="\t")

stds_file = PATH + "/driftexp_stds.csv"
if os.path.exists(stds_file):
    means = pd.read_csv(stds_file, sep="\t", header=[0, 1], index_col=[0, 1, 2])
else:
    stds = np.moveaxis(std_dr_md, 4, 2)
    cols = pd.MultiIndex.from_product([DR_TECHNIQUES, MD_TESTS])
    index = pd.MultiIndex.from_product([SHIFTS, HOSPITALS, SAMPLES])
    stds = stds.reshape(
        len(SHIFTS) * len(HOSPITALS) * len(SAMPLES), len(DR_TECHNIQUES) * len(MD_TESTS)
    )
    stds = pd.DataFrame(stds, columns=cols, index=index)
    stds.to_csv(PATH + "/driftexp_stds.csv", sep="\t")

pd.set_option("display.max_rows", 500)
means.head(n=16)

## COVID 

In [None]:
plt.figure(figsize=(12, 18))
idx = pd.IndexSlice
sns.set(font_scale=0.6)
s = sns.heatmap(means.loc[pd.IndexSlice[["pre-covid", "covid"]], :, :])
s.set_xlabel("Drift Detection", fontsize=10)
s.set_ylabel("Dataset", fontsize=10)
plt.show()

## Seasonal

In [None]:
plt.figure(figsize=(12, 18))
idx = pd.IndexSlice
sns.set(font_scale=0.6)
s = sns.heatmap(means.loc[pd.IndexSlice[["summer", "winter", "seasonal"]], :, :])
s.set_xlabel("Drift Detection", fontsize=10)
s.set_ylabel("Dataset", fontsize=10)
plt.show()