In [None]:
# import sys
# sys.path.append("../..")
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from drift_detector.plotter import colorscale
from gemini.utils import import_dataset_hospital

## Functions ##

In [None]:
def summary_stats(X, y, label):
    data = pd.concat([X, y], axis=1)
    con = data.loc[data[label] == 0].describe()
    case = data.loc[data[label] == 1].describe()
    stats = pd.concat([con, case], axis=0, keys=["controls", "cases"])
    return stats.T

###  Load data ###

In [None]:
LABEL = "mortality"

HOSPITAL_ID = dict()
features = []
plot_admin = callable()

print("Loading data...")
(X_s_tr, y_s_tr), (X_s_val, y_s_val), (X_t, y_t), orig_dims = import_dataset_hospital(
    "los", [HOSPITAL_ID["SMH"]], [HOSPITAL_ID["UHNTG"]]
)
X = pd.DataFrame(X_s_tr, columns=features)
y = pd.DataFrame(y_s_tr, columns=[LABEL])
stats_2019_tr = summary_stats(X, y, LABEL)
plot_admin(X, y, LABEL)

### Define Shift Detector and Parameters ###

In [None]:
ShiftDetector = callable()

sign_level = 0.05
dr_technique = "PCA"
md_test = "MMD"
red_model = None
datset = "test experiment"
samples = [10, 20, 50, 100, 200, 500, 1000, 2000]
sample = samples[5]
orig_dims = X_s_tr.shape[1:]
print("Building shift detector...")
sd = ShiftDetector(dr_technique, md_test, sign_level, red_model, sample, datset)

#### Drift from 2019 to 2019 ####

In [None]:
p_val, dist, red_dim, red_model, t1_acc, t2_acc = sd.detect_data_shift(
    X_s_tr, y_s_tr, X_s_tr, y_s_tr, X_s_val[:sample, :], y_s_val[:sample], orig_dims
)
print(p_val, dist)

X = pd.DataFrame(X_s_val, columns=features)
y = pd.DataFrame(y_s_val, columns=[LABEL])
stats_2019_val = summary_stats(X, y, LABEL)
plot_admin(X, y, LABEL)

#### Drift from 2019 to 2020 ####

In [None]:
p_val, dist, red_dim, red_model, t1_acc, t2_acc = sd.detect_data_shift(
    X_s_tr, y_s_tr, X_s_val, y_s_val, X_t[:sample, :], y_t[:sample], orig_dims
)
print(p_val, dist)

X = pd.DataFrame(X_t, columns=features)
y = pd.DataFrame(y_t, columns=[LABEL])
stats_2020 = summary_stats(X, y, LABEL)
plot_admin(X, y, LABEL)

In [None]:
datasets = ["2019_tr", "2019_val", "2020"]
stats = pd.concat([stats_2019_tr, stats_2019_val, stats_2020], keys=datasets, axis=1)
fig = plt.figure(figsize=(20, 8))
brightness = [1.5, 1.25, 1.0, 0.75, 0.5]
colors = [
    "#2196f3",
    "#f44336",
    "#9c27b0",
    "#64dd17",
    "#009688",
    "#ff9800",
    "#795548",
    "#607d8b",
]
n = len(datasets)
w = 0.3
x = np.arange(0, len(stats.index.values[1:]))
for i, dataset in enumerate(datasets):
    position = x + (w * (1 - n) / 2) + i * w
    plt.errorbar(
        position,
        stats[dataset]["cases"]["mean"].values[1:],
        stats[dataset]["cases"]["std"].values[1:],
        fmt="o",
        ecolor="red",
        color=colorscale(colors[i], brightness[0]),
    )
    plt.errorbar(
        position + 0.15,
        stats[dataset]["controls"]["mean"].values[1:],
        stats[dataset]["controls"]["std"].values[1:],
        fmt="o",
        ecolor="black",
        color=colorscale(colors[i], brightness[0]),
    )
    plt.xticks(position - 0.2, stats.index.values[1:])
plt.tick_params(rotation=60)
plt.show()

#### Knockout Shift ####

In [None]:
import_dataset_year = callable()
(X_s_tr, y_s_tr), (X_s_val, y_s_val), (X_t, y_t), orig_dims = import_dataset_year(
    "los", "2020"
)
shift = "ko_shift_0.5"
apply_shift = callable()
X_ko, y_ko = apply_shift(X_s_tr, y_s_tr, X_s_val, y_s_val, shift)
p_val, dist, red_dim, red_model, t1_acc, t2_acc = sd.detect_data_shift(
    X_s_tr, y_s_tr, X_s_val, y_s_val, X_ko[:sample, :], y_ko[:sample], orig_dims
)
print(p_val, dist)
X = pd.DataFrame(X_ko, columns=features)
y = pd.DataFrame(y_ko, columns=[LABEL])
stats_ko = summary_stats(X, y, LABEL)
plot_admin(X, y, LABEL)

#### Small Gaussian Noise Shift ####

In [None]:
shift = "small_gn_shift_0.1"
X_sgn, y_sgn = apply_shift(X_s_tr, y_s_tr, X_s_val, y_s_val, shift)
p_val, dist, red_dim, red_model, t1_acc, t2_acc = sd.detect_data_shift(
    X_s_tr, y_s_tr, X_s_val, y_s_val, X_sgn[:sample, :], y_sgn[:sample], orig_dims
)
print(p_val, dist)
X = pd.DataFrame(X_sgn, columns=features)
y = pd.DataFrame(y_sgn, columns=[LABEL])
stats_sgn = summary_stats(X, y, LABEL)
plot_admin(X, y, LABEL)

#### Large Gaussian Noise Shift ####

In [None]:
shift = "large_gn_shift_1.0"
X_lgn, y_lgn = apply_shift(X_s_tr, y_s_tr, X_s_val, y_s_val, shift)
p_val, dist, red_dim, red_model, t1_acc, t2_acc = sd.detect_data_shift(
    X_s_tr, y_s_tr, X_s_val, y_s_val, X_lgn[:sample, :], y_lgn[:sample], orig_dims
)
print(p_val, dist)
X = pd.DataFrame(X_lgn, columns=features)
y = pd.DataFrame(y_lgn, columns=[LABEL])
stats_lgn = summary_stats(X, y, LABEL)
plot_admin(X, y, LABEL)

#### Multiway Feature Association Shift ####

In [None]:
(X_s_tr, y_s_tr), (X_s_val, y_s_val), (X_t, y_t), orig_dims = import_dataset_year(
    "los", "2020"
)
shift = "mfa_shift_0.5"
X_mfa, y_mfa = apply_shift(X_s_tr, y_s_tr, X_s_val, y_s_val, shift)
p_val, dist, red_dim, red_model, t1_acc, t2_acc = sd.detect_data_shift(
    X_s_tr, y_s_tr, X_s_val, y_s_val, X_mfa[:sample, :], y_mfa[:sample], orig_dims
)
print(p_val, dist)
X = pd.DataFrame(X_mfa, columns=features)
y = pd.DataFrame(y_mfa, columns=[LABEL])
stats_mfa = summary_stats(X, y, LABEL)
plot_admin(X, y, LABEL)

#### Changepoint Shift ####

In [None]:
(X_s_tr, y_s_tr), (X_s_val, y_s_val), (X_t, y_t), orig_dims = import_dataset_year(
    "los", "2020"
)
shift = "cp_shift_0.75"
X_cp, y_cp = apply_shift(X_s_tr, y_s_tr, X_s_val, y_s_val, shift)
p_val, dist, red_dim, red_model, t1_acc, t2_acc = sd.detect_data_shift(
    X_s_tr, y_s_tr, X_s_val, y_s_val, X_cp[:sample, :], y_cp[:sample], orig_dims
)
print(p_val, dist)
X = pd.DataFrame(X_mfa, columns=features)
y = pd.DataFrame(y_mfa, columns=[LABEL])
stats_cp = summary_stats(X, y, LABEL)
plot_admin(X, y, LABEL)

### Binary ###

In [None]:
shift = "large_bn_shift_1.0"
X_bn, y_bn = apply_shift(X_s_tr, y_s_tr, X_s_val, y_s_val, shift)
p_val, dist, red_dim, red_model, t1_acc, t2_acc = sd.detect_data_shift(
    X_s_tr, y_s_tr, X_s_val, y_s_val, X_bn[:sample, :], y_bn[:sample], orig_dims
)
print(p_val, dist)
X = pd.DataFrame(X_bn, columns=features)
y = pd.DataFrame(y_bn, columns=[LABEL])
stats_bn = summary_stats(X, y, LABEL)
plot_admin(X, y, LABEL)

In [None]:
datasets = ["2019_val", "ko", "bn"]
stats = pd.concat([stats_2019_val, stats_ko, stats_bn], keys=datasets, axis=1)
fig = plt.figure(figsize=(20, 8))
brightness = [1.5, 1.25, 1.0, 0.75, 0.5]
colors = [
    "#2196f3",
    "#f44336",
    "#9c27b0",
    "#64dd17",
    "#009688",
    "#ff9800",
    "#795548",
    "#607d8b",
]
n = len(datasets)
w = 0.3
x = np.arange(0, len(stats.index.values[1:]))
for i, dataset in enumerate(datasets):
    position = x + (w * (1 - n) / 2) + i * w
    plt.errorbar(
        position,
        stats[dataset]["cases"]["mean"].values[1:],
        stats[dataset]["cases"]["std"].values[1:],
        fmt="o",
        ecolor="red",
        color=colorscale(colors[i], brightness[0]),
    )
    plt.errorbar(
        position + 0.15,
        stats[dataset]["controls"]["mean"].values[1:],
        stats[dataset]["controls"]["std"].values[1:],
        fmt="o",
        ecolor="black",
        color=colorscale(colors[i], brightness[0]),
    )
    plt.xticks(position - 0.2, stats.index.values[1:])
plt.tick_params(rotation=60)
plt.show()