In [None]:
import os, sys
currentdir = os.path.dirname(os.path.realpath("__file__"))
parentdir = os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath("__file__"))))
sys.path.append(parentdir)

from joblib import Parallel, delayed
import numpy as np
import pickle 
import pandas as pd
import torch
import matplotlib.pyplot as plt
import torchmetrics
import seaborn as sns
import glob
from tqdm import tqdm
from copy import deepcopy
from src.transformer.metrics import AUL, CorrectedMCC, CorrectedBAcc, CorrectedF1
from src.analysis.metrics.cartesian import cartesian_jit

from torchmetrics import MatthewsCorrCoef, F1Score, Accuracy
from sklearn.utils import resample
from scipy.stats import median_abs_deviation as mad
import time
import cmcrameri.cm as cmc
import matplotlib.colors as clr
from matplotlib.ticker import MultipleLocator, LinearLocator, AutoMinorLocator
plt.rcParams["grid.linestyle"] =  ":"
plt.rcParams["axes.edgecolor"] = "gray"
plt.rcParams["axes.linewidth"] = 0.7

plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = "Arial"
#sns.set_context("notebook", font_scale=1.2)
import scienceplots
plt.style.use(["nature"])

In [None]:
SMALL_SIZE = 6
MEDIUM_SIZE = 6.4
BIGGER_SIZE = 7
plt.rcParams["text.usetex"] = True
plt.rcParams["pdf.fonttype"] = 42

plt.rc('font', size=SMALL_SIZE)          # controls default text sizes
plt.rc('axes', titlesize=SMALL_SIZE)     # fontsize of the axes title
plt.rc('axes', labelsize=MEDIUM_SIZE)    # fontsize of the x and y labels
plt.rc('xtick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('ytick', labelsize=SMALL_SIZE)    # fontsize of the tick labels
plt.rc('legend', fontsize=SMALL_SIZE)    # legend fontsize
plt.rc('figure', titlesize=BIGGER_SIZE)  # fontsize of the figure title

plt.rcParams["font.family"] = "sans-serif"
plt.rcParams["font.sans-serif"] = "Arial"


cmap = cmc.batlowS
cmap

In [None]:
v = "4.02"
#v= "6.3"
save_path = r"../analysis/plots/%s/" %v

In [None]:
def load_stats(path):
    data = {}
    with open(path + "prb.npy", "rb") as f:
        data["prb"] = np.load(f)
    with open(path + "trg.npy", "rb") as f:
        data["trg"] = np.load(f)
    with open(path + "id.npy", "rb") as f:
        data["id"] = np.load(f)
    return data
def metrics_parallel(metrics, preds, targs):
    return metrics(preds, targs).numpy()
def aul(prb_p, prb_a):
    score = 0
    for p in prb_p:
        score += (p > prb_a).sum()
    score += 0.5 * (p == prb_a).sum()
    n_pos = float(prb_p.shape[0])
    n = float(prb_a.shape[0])
    return score/(n_pos*n)

def return_aul(preds, targs):
    preds_p = preds[targs==1]
    return aul(preds_p, preds).item()
def return_mcc(preds, targs):
    metric = MatthewsCorrCoef(num_classes=2)
    return metric(preds, targs).numpy().item()

def return_acc(preds, targs):
    metric = Accuracy(num_classes=2, average="macro", multiclass=True)
    return metric(preds, targs).numpy().item()

def return_f1(preds, targs):
    metric = F1Score(num_classes=2, average="macro", multiclass=True)
    return metric(preds, targs).numpy().item()

def bootstrap_mcc(preds, targs, seed: int = 2021, n_bootstraps: int = 1000, ci: float = 0.05, alpha: float = 0.025, beta=1.0, only_scores = False):
    ids = np.arange(0, targs.shape[0], 1)
    
    idx = list()
    for n in range(n_bootstraps):
        i  = resample(ids, stratify=targs.numpy(), random_state=n)
        if len(np.unique(targs[i])) < 2:
                continue
        idx.append(i)

    executor = Parallel(n_jobs=7)
    tasks = (delayed(metrics_parallel)(CorrectedMCC(alpha = alpha, beta= beta, threshold = 0.5, average="micro"), preds[i], targs[i]) for i in idx)
    scores = np.array(executor(tasks))
    #### on full dataset
    if only_scores:
        return scores
    metric = CorrectedMCC(alpha = alpha, beta= beta, threshold = 0.5, average="micro")
    
    return {"mean": metric(preds, targs).numpy().item(), "lower": np.quantile(scores, ci /2), "upper": np.quantile(scores, 1-ci/2)}

def bootstrap_acc(preds, targs, seed: int = 2021, n_bootstraps: int = 1000, ci: float = 0.05, alpha: float = 0.025, beta=1.0, only_scores = False):
    ids = np.arange(0, targs.shape[0], 1)
    
    idx = list()
    for n in range(n_bootstraps):
        i  = resample(ids, stratify=targs.numpy(), random_state=n)
        if len(np.unique(targs[i])) < 2:
                continue
        idx.append(i)

    executor = Parallel(n_jobs=7)
    tasks = (delayed(metrics_parallel)(CorrectedBAcc(alpha = alpha, beta= beta, threshold = 0.5, average="micro"), preds[i], targs[i]) for i in idx)
    scores = np.array(executor(tasks))
    #### on full dataset
    if only_scores:
        return scores
    metric = CorrectedBAcc(alpha = alpha, beta= beta, threshold = 0.5, average="micro")
    
    return {"mean": metric(preds, targs).numpy().item(), "lower": np.quantile(scores, ci/2), "upper": np.quantile(scores, 1-ci/2)}


def bootstrap_f1(preds, targs, seed: int = 2021, n_bootstraps: int = 1000, ci: float = 0.05, alpha: float = 0.025, beta=1.0, only_scores = False):
    ids = np.arange(0, targs.shape[0], 1)
    
    idx = list()
    for n in range(n_bootstraps):
        i  = resample(ids, stratify=targs.numpy(), random_state=n)
        if len(np.unique(targs[i])) < 2:
                continue
        idx.append(i)

    executor = Parallel(n_jobs=7)
    tasks = (delayed(metrics_parallel)(CorrectedF1(alpha = alpha, beta= beta, threshold = 0.5, average="micro"), preds[i], targs[i]) for i in idx)
    scores = np.array(executor(tasks))
    #### on full dataset
    if only_scores:
        return scores
    metric = CorrectedF1(alpha = alpha, beta = beta, threshold = 0.5, average="micro")
    
    return {"mean": metric(preds, targs).numpy().item(), "lower": np.quantile(scores, ci/2), "upper": np.quantile(scores, 1-ci/2)}

def bootstrap_aul(preds, targs, seed: int = 2021, ci: float = 0.05, n_bootstraps: int = 1000, only_scores = False):
    ids = np.arange(0, targs.shape[0], 1)
    
    idx = list()
    for n in range(n_bootstraps):
        i = resample(ids, stratify=targs, random_state=n)
        if len(np.unique(targs[i])) < 2:
                continue
        idx.append(i)

    executor = Parallel(n_jobs=7)
    tasks = (delayed(return_aul)( preds[i], targs[i]) for i in idx)
    scores = np.array(executor(tasks))
    #### on full dataset
    if only_scores:
        return scores
    
    return {"mean": return_aul(preds, targs), "lower": np.quantile(scores, ci /2), "upper": np.quantile(scores, 1-ci/2)}


def return_stats(path: str):
    x  = load_stats(path)
    start = time.time()
    x["aul"] = return_aul(preds = x["prb"],  
              targs = x["trg"]) 
             # n_bootstraps=5000)
    print("AUL is done: %.2f s" %(time.time()-start))
    print(x["aul"])
    start = time.time()
    x["mcc"] = bootstrap_mcc(preds = torch.from_numpy(x["prb"]),  
              targs = torch.from_numpy(x["trg"]).long(), 
              n_bootstraps=5000,
              alpha = 0.025,
              beta=1.0)
    print("MCC is done: %.2f s" %(time.time()-start))
    print(x["mcc"])
    start = time.time()
    x["acc"] = bootstrap_acc(preds = torch.from_numpy(x["prb"]),  
              targs = torch.from_numpy(x["trg"]).long(), 
              n_bootstraps=5000,
              alpha = 0.025,
              beta=1.0)
    print("ACC is done: %.2f s" %(time.time()-start))
    print(x["acc"])

    start = time.time()
    
    x["f1"] = bootstrap_f1(preds = torch.from_numpy(x["prb"]),  
              targs = torch.from_numpy(x["trg"]).long(), 
              n_bootstraps=5000,
              alpha = 0.025,
              beta=1.0)
    print("F1 is done: %.2f s" %(time.time()-start))
    print(x["f1"])
    return x

def contains_in_sequence(sample, min_, max_):
    """Checks if sequence contains tokens in range [min_, max_]"""
    return np.where((sample >= min_) & (sample <=max_))[0].shape[0] > 0

In [None]:
data["rnn"]["prb"].shape

In [None]:
data = {}

In [None]:
%%time
data["l2v"] = return_stats(r"...\\predictions\\v15\\cls\\eos_l2v\\%s\\"%v)

In [None]:
data["rnn"] = return_stats(r"...\\predictions\\v15\\cls\\eos_rnn\\1.0\\")

In [None]:
data["nn"] = return_stats(r"...\\predictions\\v15\\tcls\\eos_tab\\1.0\\")

In [None]:
data["logistic"] = return_stats(r"...\\predictions\\v15\\tcls\\eos_tab\\3.1\\")

In [None]:
data["table"] = return_stats(r"...\\predictions\\v15\\tcls\\eos_tab\\3.2\\")

In [None]:
simple_baselines = True
if simple_baselines:
    np.random.seed(0)
    data["rnd"] = {"prb": np.random.uniform(size=data["rnn"]["trg"].shape[0]),
               "trg": data["rnn"]["trg"]}
    data["rnd"]["aul"] = return_aul(preds = data["rnd"]["prb"],  
                               targs = data["rnd"]["trg"])
    data["rnd"]["mcc"] = bootstrap_mcc(preds = torch.from_numpy(data["rnd"]["prb"]),  
              targs = torch.from_numpy(data["rnd"]["trg"]).long(), 
              n_bootstraps=1000,
              alpha = 0.025,
              beta=1.0)
    data["rnd"]["acc"] = bootstrap_acc(preds = torch.from_numpy(data["rnd"]["prb"]),  
              targs = torch.from_numpy(data["rnd"]["trg"]).long(), 
              n_bootstraps=1000,
              alpha = 0.025,
              beta=1.0)
    data["rnd"]["f1"] = bootstrap_f1(preds = torch.from_numpy(data["rnd"]["prb"]),  
              targs = torch.from_numpy(data["rnd"]["trg"]).long(), 
              n_bootstraps=1000,
              alpha = 0.025,
              beta=1.0)


In [None]:
np.random.seed(0)
data["mjr"] = {"prb": np.random.uniform(high=0.5, size=data["rnn"]["trg"].shape[0]),
               "trg": data["rnn"]["trg"]}
data["mjr"]["aul"] = return_aul(preds = data["mjr"]["prb"],  
                               targs = data["mjr"]["trg"])
data["mjr"]["mcc"] = {"mean": return_mcc(preds = torch.from_numpy(data["mjr"]["prb"]),  
                               targs = torch.from_numpy(data["mjr"]["trg"]).long()), "lower": np.nan, "upper": np.nan}
data["mjr"]["acc"] = {"mean": return_acc(preds = torch.from_numpy(data["mjr"]["prb"]),  
                               targs = torch.from_numpy(data["mjr"]["trg"]).long()), "lower": np.nan, "upper": np.nan}
data["mjr"]["f1"] = {"mean": return_f1(preds = torch.from_numpy(data["mjr"]["prb"]),  
                               targs = torch.from_numpy(data["mjr"]["trg"]).long()), "lower": np.nan, "upper": np.nan}

In [None]:
try:
    raise Error() 
    with open(save_path + "metric.pkl", "wb") as f:
        pickle.dump(data,f)
except:
    with open(save_path + "metric.pkl", "rb") as f:
        data = pickle.load(f)

In [None]:
print("AUL")
for key in data.keys():
    print("\t%s: %.3f [%.3f, %.3f]" %(key, data[key]["aul"]["mean"], 
                                           data[key]["aul"]["lower"], 
                                           data[key]["aul"]["upper"]))
print("MCC")
for key in data.keys():
    print("\t%s: %.3f [%.3f, %.3f]" %(key, data[key]["mcc"]["mean"], 
                                           data[key]["mcc"]["lower"], 
                                           data[key]["mcc"]["upper"]))
print("ACC")
for key in data.keys():
    print("\t%s: %.3f [%.3f, %.3f]" %(key, data[key]["acc"]["mean"], 
                                           data[key]["acc"]["lower"], 
                                           data[key]["acc"]["upper"]))
print("F1")
for key in data.keys():
    print("\t%s: %.3f [%.3f, %.3f]" %(key, data[key]["f1"]["mean"], 
                                           data[key]["f1"]["lower"], 
                                           data[key]["f1"]["upper"]))

In [None]:
mean = list()
quantiles = list()
for key in data.keys():
    if key == "mjr":
        mean.append(0)
        quantiles.append([0, 0])
    else:
        mean.append(data[key]["mcc"]["mean"])
        quantiles.append([np.abs(mean[-1] - data[key]["mcc"]["lower"]),  np.abs(mean[-1] - data[key]["mcc"]["upper"])])

fig, ax = plt.subplots(figsize=(5,7))
plt.bar(x = data.keys(), height = mean, width=0.5,
        yerr=np.array(quantiles).T, capsize=4, 
        edgecolor="none", facecolor="silver", ecolor="black")

# plt.errorbar(x = list(data.keys()), y = mean, yerr = np.array(quantiles).T, fmt="o",
#                      capsize=5, ecolor="dimgray", ms=3.5,
#                      elinewidth=2, mfc="black", mec="black")
ax.set_ylabel("MCC Score")
ax.set_xlabel("Model")
#ax.set_title("Mortality Prediction: Corrected Matthews Correlation Coefficient with 95%-CI")
ax.set_ylim( -0.1 , 0.5)
ax.axhline(0.0, color="gray", linewidth=0.5, linestyle= ":")
#ax.tick_params(which="both", width=2, length =2)
ax.tick_params(axis= "y", which="major", width=1, length = 6, direction="out", color="gray")
ax.tick_params(axis= "y", which="minor", width=1, length =3, direction="out", color="gray")
ax.tick_params(axis= "x", which="both", width=0, length =0)

ax.yaxis.set_major_locator(MultipleLocator(0.1))
ax.yaxis.set_minor_locator(AutoMinorLocator(5))
sns.despine()
plt.savefig(save_path + "mcc_mortality.svg", format="svg")
plt.show()

In [None]:
mean = list()
quantiles = list()
for key in data.keys():
    mean.append(data[key]["acc"]["mean"])
    quantiles.append([np.abs(mean[-1] - data[key]["acc"]["lower"]),  np.abs(mean[-1] - data[key]["acc"]["upper"])])

fig, ax = plt.subplots(figsize=(5,7))
plt.bar(x = data.keys(), height = mean, width=0.5,
        yerr=np.array(quantiles).T, capsize=4, 
        edgecolor="none", facecolor="silver", ecolor="black")
ax.set_ylabel("Accuracy")
ax.set_xlabel("Model")
ax.set_title("Mortality Prediction: Corrected Accuracy with 95%-CI")
ax.tick_params(axis= "y", which="major", width=1, length = 6, direction="out", color="gray")
ax.tick_params(axis= "y", which="minor", width=1, length =3, direction="out", color="gray")
ax.tick_params(axis= "x", which="both", width=0, length =0)

ax.yaxis.set_major_locator(MultipleLocator(0.1))
ax.yaxis.set_minor_locator(AutoMinorLocator(5))
sns.despine()
plt.savefig(save_path + "acc_mortality.svg", format="svg")
plt.show()

In [None]:
mean = list()
quantiles = list()
for key in data.keys():
    mean.append(data[key]["f1"]["mean"])
    quantiles.append([np.abs(mean[-1] - data[key]["f1"]["lower"]),  np.abs(mean[-1] - data[key]["f1"]["upper"])])

fig, ax = plt.subplots(figsize=(5,7))
plt.bar(x = data.keys(), height = mean, width=0.5,
        yerr=np.array(quantiles).T, capsize=4, 
        edgecolor="none", facecolor="silver", ecolor="black")
ax.set_ylabel("F1-Score")
ax.set_xlabel("Model")
ax.set_title("Mortality Prediction: Corrected F1-Score with 95%-CI")
ax.tick_params(axis= "y", which="major", width=1, length = 6, direction="out", color="gray")
ax.tick_params(axis= "y", which="minor", width=1, length =3, direction="out", color="gray")
ax.tick_params(axis= "x", which="both", width=0, length =0)

ax.yaxis.set_major_locator(MultipleLocator(0.1))
ax.yaxis.set_minor_locator(AutoMinorLocator(5))
sns.despine()
plt.savefig(save_path + "f1_mortality.svg", format="svg")
plt.show()

In [None]:
mean = list()
quantiles = list()
for key in data.keys():
    mean.append(data[key]["aul"]["mean"])
    quantiles.append([np.abs(mean[-1] - data[key]["aul"]["lower"]),  np.abs(mean[-1] - data[key]["aul"]["upper"])])

fig, ax = plt.subplots(figsize=(5,7))
plt.bar(x = data.keys(), height = mean,  width=0.5,
        yerr=np.array(quantiles).T, capsize=4, 
        edgecolor="none", facecolor="silver", ecolor="black")
ax.set_ylabel("AUL Score")
ax.set_xlabel("Model")
ax.set_title("Mortality Prediction: Area Under the Lift with 95%-CI")
ax.tick_params(axis= "y", which="major", width=1, length = 6, direction="out", color="gray")
ax.tick_params(axis= "y", which="minor", width=1, length =3, direction="out", color="gray")
ax.tick_params(axis= "x", which="both", width=0, length =0)

ax.yaxis.set_major_locator(MultipleLocator(0.1))
ax.yaxis.set_minor_locator(AutoMinorLocator(5))
sns.despine()

plt.savefig(save_path + "aul_mortality.svg", format="svg")
plt.show()

# 2. Breakdown by Unitary Groups

In [None]:
years = [1951, 1956, 1961, 1966, 1971, 1976, 1982]
period = [(pd.to_datetime("01/01/%s" %a), pd.to_datetime("31/12/%s" %b)) for a, b in zip(years[:-1], years[1:])]
period.reverse()
def discrete_age(x):
    for i, p in enumerate(period):
        if (x >= p[0]) & (x <= p[1]):
            return i

In [None]:
with open(r"...\\predictions\\v15\\cls\\eos_l2v\\%s\\" %v + "seqlen.pkl", "rb") as f:
    counts = pickle.load(f)
    counts = pd.DataFrame.from_dict(counts, orient="index")
    counts.index = counts.index.rename("PERSON_ID")
    counts = counts.rename(columns={0: "SEQLEN"})
    counts["SEQLEN"] = counts["SEQLEN"].apply(lambda x: np.minimum(x, 2560))


In [None]:
with open(r"...\\predictions\\v15\\cls\\eos_l2v\\%s\\" %v + "has_health.pkl", "rb") as f:
    has_health = pickle.load(f)
    has_health = pd.DataFrame.from_dict(has_health, orient="index")
    has_health.index = has_health.index.rename("PERSON_ID")
    has_health = has_health.rename(columns={0: "HAS_HEALTH"})


In [None]:
ppl = pd.read_csv(".../processed/populations/survival/population/result.csv").set_index("PERSON_ID")
ppl["EVENT_FINAL_DATE"] = pd.to_datetime(ppl["EVENT_FINAL_DATE"], format="%Y-%m-%d")
result = pd.DataFrame({"PERSON_ID": data["l2v"]["id"].astype(int),
                        "PRED": data["l2v"]["prb"]}).set_index("PERSON_ID")
result = result.join(ppl, how="left").dropna()
result = result.join(counts, how="left").dropna()
result = result.join(has_health, how="left").dropna()
result["HAS_HEALTH_RAW"] = result["HAS_HEALTH"].values
hh_q = np.quantile(result["HAS_HEALTH"].values, [0, 0.25, 0.5, 0.75, 0.9,  1.])
hh_q[1] = 1
result["HAS_HEALTH"] = pd.cut(result["HAS_HEALTH"], bins = hh_q,
                               include_lowest=True, labels=False)
ec_q = np.quantile(result["SEQLEN"].values, [0, 0.33, 0.66, 1.])
result["SEQLEN"] = pd.cut(result["SEQLEN"], bins = ec_q,
                               include_lowest=True, labels=False)
result["BIRTHDAY"] = pd.to_datetime(result["BIRTHDAY"], format="%Y-%m-%d")
result["AGE_GROUP"] = result["BIRTHDAY"].apply(lambda x: discrete_age(x))
result["UNLABELED"] = result.apply(lambda x: (x["TARGET"] == 0) & (x["EVENT_FINAL_DATE"] < pd.to_datetime("2020-12-31", format="%Y-%m-%d")), axis = 1)
result["YEAR"] = result["EVENT_FINAL_DATE"].dt.year

In [None]:
df_ = result.groupby(["AGE_GROUP"])["HAS_HEALTH_RAW"].agg(["median", mad])
labels_ = ["(34,39]", "(39, 44]", "(44, 49]", "(49, 54]", "(54, 59]",  "(59, 64]"]
figsize =(max_width_in_inches, max_width_in_inches/2)
fig = plt.figure(figsize= figsize)
plt.bar(labels_, df_["median"], yerr= df_["median_abs_deviation"], capsize=5, width=0.5, label = "median with +/- one Median Abs. Dev.")
plt.xlabel("Age Group")
plt.ylabel("Median Number of Health Records")
plt.title("Number of Health Records per Age Group")
plt.tick_params(which="both", right=False, top=False)
plt.tick_params(which="minor", bottom=False)

plt.legend()
sns.despine()
plt.tight_layout()
plt.savefig(save_path + "health_per_age.pdf", format="pdf")

In [None]:
def bootstrap_simple(x, n_samples = 1000):
    N = x.shape[0]
    out = list()
    for i in range(n_samples):
        _id = np.random.choice(N, size=N, replace=True)
        tp = x[_id].sum()
        out.append(tp/N)
    return out, N, 
        
        
py_label = [2016, 2017, 2018, 2019]
py_recall = list()
py_n_tp = list()
py_n_total = list()
for i in py_label:
    r = result[(result["YEAR"]== (i)) & (result["TARGET"] == 1)]
    y_hat = np.array((r["PRED"].values > 0.5)).astype(int)
    _recall, _n = bootstrap_simple(y_hat)
    py_recall.append((np.mean(_recall), np.std(_recall)))
    py_n_tp.append(y_hat.sum())
    py_n_total.append(_n)

In [None]:
max_width_in_mm = 180
max_width_in_inches = max_width_in_mm / 25.4
fig, ax = plt.subplots(1,2, figsize=(max_width_in_inches, max_width_in_inches/2))
mu = [i for i, _ in py_recall]
std = [i for _, i in py_recall]
ax[0].bar(x=py_label, height=mu, xerr=std, width=0.5, label="recall value")
ax[0].set_xlabel("Year of death")
ax[0].set_ylabel("Fraction of correctly identified deceased people")
ax[0].set_title("Recall for deceased people")
ax[0].set_ylim([0, 1.])
ax[0].tick_params(axis= "x", which="major", width= 1, length = 6, direction="out", color="gray")
ax[0].tick_params(axis= "y", which="minor", width= 1, length = 3, direction="out", color="gray")
ax[0].yaxis.set_major_locator(MultipleLocator(0.1))
ax[0].yaxis.set_minor_locator(AutoMinorLocator(5))
ax[0].xaxis.set_major_locator(MultipleLocator(1))
ax[0].legend() 
ax[0].tick_params(which="both", top=False, right=False)
ax[0].tick_params(which="minor", bottom=False)


ax[1].bar(x=[i-0.17 for i in py_label], height=py_n_total, width=0.3, label="all predicted deceased")
ax[1].bar(x=[i+ 0.17 for i in py_label], height=py_n_tp, width=0.3, label = "correctly predicted deceased")
ax[1].legend()
ax[1].set_xlabel("Year of death")
ax[1].set_ylabel("Fraction of correctly identified deceased people")
ax[1].set_title("Predictions for deceased people")
# ax[1].set_ylim([0, 1.])
ax[1].tick_params(axis= "x", which="major", width= 1, length = 6, direction="out", color="gray")
ax[1].tick_params(axis= "y", which="minor", width= 1, length = 3, direction="out", color="gray")
ax[1].xaxis.set_major_locator(MultipleLocator(1))
ax[1].yaxis.set_minor_locator(AutoMinorLocator(5))
ax[1].tick_params(which="both", top=False, right=False)
ax[1].tick_params(which="minor", bottom=False)
ax[1].legend()
sns.despine()
plt.tight_layout()
plt.savefig(save_path + "yearly_mortality_performance.pdf", format="pdf")
plt.show()

In [None]:
dfm = result.copy()[["TARGET", "PRED", "RES_ORIGIN", "GENDER", "SEQLEN", "HAS_HEALTH", "AGE_GROUP", "UNLABELED"]]
dfm["TARGET"] = dfm["TARGET"].astype(bool)
dfm["PRED"] = (dfm["PRED"] >=0.5).astype(bool)
dfm = dfm.rename(columns={"TARGET": "true_label (deceased)", 
                    "RES_ORIGIN": "res_status",
                    "PRED": "l2v_pred", 
                    "GENDER": "sex",
                    "SEQLEN": "num_events",
                    "HAS_HEALTH": "num_health_events",
                    "AGE_GROUP": "age",
                    "UNLABELED": "unlabeled" })
dfm["age"] = dfm["age"].replace({0:'(34,39]', 1:'(39, 44]', 2:'(44, 49]', 3:'(49, 54]', 4:'(54, 59]', 5:'(59, 64]'})
dfm["num_events"] = dfm["num_events"].replace({0:'[48, 921)', 1:'[921, 1176)', 2:'[1176, 2560]'})
dfm["num_health_events"] = dfm["num_health_events"].replace({0:'[0]', 1:'[1,5)', 2:'[5,11)', 3: '[11,19)', 4:'>18'})
np.random.seed(42)
dfm = dfm.sample(frac=1).reset_index(drop=True)
dfm.to_csv(save_path + "l2v_mortality_predictions_with_groups.csv")
dfm.head()

In [None]:
groups = {"male": result["GENDER"] == "M",
          "female": result["GENDER"] == "F",
          "(34,39]": result["AGE_GROUP"] == 0,
          "(39, 44]": result["AGE_GROUP"] == 1,
          "(44, 49]": result["AGE_GROUP"] == 2,
          "(49, 54]": result["AGE_GROUP"] == 3,
          "(54, 59]": result["AGE_GROUP"] == 4,
          "(59, 64]": result["AGE_GROUP"] == 5,
          "DK": result["RES_ORIGIN"] == "DK",
          "NON-DK": result["RES_ORIGIN"] == "NON_DK",
          "[48, 921)": result["SEQLEN"] == 0,
          "[921, 1176)": result["SEQLEN"] == 1,
          "[1176, 2560]": result["SEQLEN"] == 2,
          "[0]": result["HAS_HEALTH"] == 0,
          "[1,5)": result["HAS_HEALTH"] == 1,
          "[5,11)": result["HAS_HEALTH"] == 2,
          "[11,19)": result["HAS_HEALTH"] == 3,
          ">=19": result["HAS_HEALTH"] == 4
         }

In [None]:
groups.keys()

In [None]:
condition = result["HAS_HEALTH"] == 4
preds = result["PRED"].values
targs = result["TARGET"].values
alpha = result["UNLABELED"].mean()
metric = CorrectedMCC(alpha = 0.025, beta= 1., threshold = 0.5, average="micro")
metric(torch.Tensor(preds),torch.IntTensor(targs))

In [None]:
metric = CorrectedMCC(alpha = 0.025, beta= 1., threshold = 0.5, average="micro")
metric(torch.Tensor(preds),torch.IntTensor(targs))

In [None]:
unitary_mcc = {}
unitary_prb = {}
unitary_cnt = {}

for k, condition in groups.items():
    preds = result[condition]["PRED"].values
    targs = result[condition]["TARGET"].values
    unitary_mcc[k] = bootstrap_mcc(torch.from_numpy(preds).float(), torch.from_numpy(targs).long())
    unitary_prb[k] = {"median": np.median(preds), "mad": mad(preds)}
    unitary_cnt[k] = {"total": preds.shape[0], "positive": np.sum(targs), "unlabeled": result[condition]["UNLABELED"].values.sum()}
    print("Done:", k)

In [None]:
def return_stats(labels, prb_dict, count_dict,  mcc_dict, title:str, figsize=(10,10)):
    mu, var, err = [], [], []
    padding = 0.5

    for label in labels:
        mu.append(prb_dict[label]["median"])
        var.append(prb_dict[label]["mad"])
        err.append([mu[-1] - var[-1], mu[-1] + var[-1]])
    
    fig, ax = plt.subplots(2,2, figsize=figsize)

    #### MEDIAN PRB
    ax[0,0].errorbar(y = labels, x = mu, xerr = var, fmt="o",
                     capsize=5, ecolor="dimgray", ms=3.5,
                     elinewidth=2, mfc="black", mec="black", label=r"median with +/- one Median Abs. Dev.")
    
    ax[0,0].set_xlabel("Median Probability of Death")
    ax[0,0].set_ylabel("%s" %title)
    ax[0,0].set_title("Median Probability of Death per %s" %title)

    ax[0,0].tick_params(axis= "x", which="major", width= 1, length = 6, direction="out", color="gray")
    ax[0,0].tick_params(axis= "x", which="minor", width= 1, length = 3, direction="out", color="gray")
    ax[0,0].tick_params(axis= "y", which="both", width=0, length =0)

    ax[0,0].xaxis.set_major_locator(MultipleLocator(0.1))
    ax[0,0].xaxis.set_minor_locator(AutoMinorLocator(5))
    ax[0,0].set_ylim(-padding, len(labels) - 1 + padding)
    ax[0,0].set_xlim([0, 1.])
    ax[0,0].tick_params(which="both", top=False, right=False)
    ax[0,0].legend()
    #### POSITIVE COUNT
    count = list()
    for label in labels:
        count.append(count_dict[label]["positive"])

    ax[0,1].barh(y = labels, height = 0.5,  
        width=count,  edgecolor="none", facecolor="silver", label="Number of Deceased")
    ax[0,1].set_xlabel("Number of Deceased")
    ax[0,1].set_ylabel("%s" %title)
    ax[0,1].set_title("Number of Deceased People per %s" %title)

    ax[0,1].tick_params(axis= "x", which="major", width= 1, length = 6, direction="out", color="gray")
    ax[0,1].tick_params(axis= "x", which="minor", width= 1, length = 3, direction="out", color="gray")
    ax[0,1].tick_params(axis= "y", which="both", width=0, length =0)

    ax[0,1].xaxis.set_major_locator(MultipleLocator(500))
    ax[0,1].xaxis.set_minor_locator(AutoMinorLocator(5))
    ax[0,1].legend()
    ax[0,1].tick_params(which="both", top=False, right=False)


    ### POPULATION COUNTS
    count = list()
    for label in labels:
        count.append(count_dict[label]["total"])
    ax[1,0].barh(y = labels, height = 0.5, label = "alive",  width=count, facecolor="silver", edgecolor="none")
    count = list()
    for label in labels:
        count.append(count_dict[label]["positive"] + count_dict[label]["unlabeled"])
    
    ax[1,0].barh(y = labels, height = 0.5, label= "unlabeled",  width=count, facecolor="orange", edgecolor="none")
    count = list()
    for label in labels:
        count.append(count_dict[label]["positive"])
    ax[1,0].barh(y = labels, height = 0.5, label= "deceased",  width=count, facecolor="dimgray", edgecolor="none")


    ax[1,0].set_xlabel('Number of People')
    ax[1,0].set_ylabel("%s" %title)
    ax[1,0].set_title("Number of People (stacked) per %s" %title)
    ax[1,0].tick_params(axis= "x", which="major", width=1, length = 6, direction="out", color="gray")
    ax[1,0].tick_params(axis= "x", which="minor", width=1, length =3, direction="out", color="gray")
    ax[1,0].tick_params(axis= "y", which="both", width=0, length =0)
    xlim = np.max([ax[1,0].get_xlim()[1], 10100])
    ax[1,0].set_xlim([0, xlim])
    if xlim < 35000:
        ax[1,0].xaxis.set_major_locator(MultipleLocator(5000))
    else:
        ax[1,0].xaxis.set_major_locator(MultipleLocator(15000))

    ax[1,0].xaxis.set_minor_locator(AutoMinorLocator(5))
    ax[1,0].legend()
    ax[1,0].tick_params(which="both", top=False, right=False)

    mu, var, err = [], [], []

    for label in labels:
        mu.append(mcc_dict[label]["mean"])
        var.append(mcc_dict[label]["upper"])
        err.append([np.abs(mu[-1] - mcc_dict[label]["lower"]), np.abs(mu[-1] - mcc_dict[label]["upper"])])
        

    ax[1,1].errorbar(y = labels,  
        xerr= np.array(err).T, x= mu, fmt="o",
                     capsize=5, ecolor="dimgray", ms=3.5,
                     elinewidth=2, mfc="black", mec="black", label="mean with 95\%-CI (bootstrap)")
    ax[1,1].set_xlabel("MCC Score")
    ax[1,1].set_ylabel("%s" %title)
    ax[1,1].set_title("Corrected MCC per %s" %title)
    xlim = np.max([np.max(var) + 0.05, 0.8])
    ax[1,1].set_xlim([-0.1, xlim])
    ax[1,1].set_ylim([-padding, len(labels)-1+padding])
    ax[1,1].axvline(0.0, color="gray", linewidth=0.5, linestyle= ":", label="center line")
    ax[1,1].axvline(0.413, color="blue", linewidth=1, linestyle= ":", label="global mean")


    ax[1,1].tick_params(axis= "x", which="major", width=1, length = 6, direction="out", color="gray")
    ax[1,1].tick_params(axis= "x", which="minor", width=1, length =3, direction="out", color="gray")
    ax[1,1].tick_params(axis= "y", which="both", width=0, length =0)

    ax[1,1].xaxis.set_major_locator(MultipleLocator(0.1))
    ax[1,1].xaxis.set_minor_locator(AutoMinorLocator(5))
    ax[1,1].legend()
    ax[1,1].tick_params(which="both", top=False, right=False)

    return fig

In [None]:
max_width_in_mm = 180
max_width_in_inches = max_width_in_mm / 25.4

In [None]:
labels = ['[0]', '[1,5)', '[5,11)', '[11,19)', '>=19']
group_title = "Health Group"
figsize=(max_width_in_inches, max_width_in_inches / 2)
return_stats(labels, unitary_prb, unitary_cnt, unitary_mcc, group_title, figsize=figsize)
sns.despine()
plt.tight_layout()
plt.savefig(save_path + "mortality_health_events.pdf", format="pdf")
plt.show()

In [None]:
labels = [ '[48, 921)', '[921, 1176)', '[1176, 2560]']
group_title = "Sequence Len Group"
figsize=(max_width_in_inches, max_width_in_inches / 2.5)
return_stats(labels, unitary_prb, unitary_cnt, unitary_mcc, group_title, figsize=figsize)
sns.despine()
plt.tight_layout()
plt.savefig(save_path + "mortality_seqlen.pdf", format="pdf")
plt.show()

In [None]:
labels = ['(34,39]', '(39, 44]', '(44, 49]', '(49, 54]', '(54, 59]', '(59, 64]']
group_title = "Age Group"
figsize=(max_width_in_inches, max_width_in_inches / 1.3)
return_stats(labels, unitary_prb, unitary_cnt, unitary_mcc, group_title, figsize=figsize)
plt.tight_layout()
plt.savefig(save_path + "mortality_age.pdf", format="pdf")
plt.show()

In [None]:
labels = ["male", "female", "DK", "NON-DK"]
group_title = "Unitary (Subset) Groups"
figsize=(max_width_in_inches, max_width_in_inches / 2)
return_stats(labels, unitary_prb, unitary_cnt, unitary_mcc, group_title, figsize=figsize)
sns.despine()
plt.tight_layout()
plt.savefig(save_path + "mortality_unitary.pdf", format="pdf")
plt.show()

In [None]:
labels = ['male', 'female', '(34,39]', '(39, 44]', '(44, 49]', '(49, 54]', '(54, 59]', '(59, 64]', 'DK', 'NON-DK']
group_title = "Unitary (All) Groups"
figsize=(max_width_in_inches, max_width_in_inches)
return_stats(labels, unitary_prb, unitary_cnt, unitary_mcc, group_title, figsize=figsize)
sns.despine()
plt.tight_layout()
plt.savefig(save_path + "mortality_unitary_full.pdf", format="pdf")
plt.show()

In [None]:
intersect_groups = dict()

for l1 in ['(34,39]', '(39, 44]', '(44, 49]', '(49, 54]', '(54, 59]', '(59, 64]']:
    for l2 in ["male", "female"]:
        intersect_groups[" ".join([l1,l2])] = (groups[l1]) & (groups[l2])

intersec_mcc = {}
intersec_prb = {}
intersec_cnt = {}

for k, condition in intersect_groups.items():
    preds = result[condition]["PRED"].values
    targs = result[condition]["TARGET"].values
    intersec_mcc[k] = bootstrap_mcc(torch.from_numpy(preds).float(), torch.from_numpy(targs).long())
    intersec_prb[k] = {"median": np.median(preds), "mad": mad(preds)}
    intersec_cnt[k] = {"total": preds.shape[0], "positive": np.sum(targs), "unlabeled": result[condition]["UNLABELED"].values.sum()}
    print("Done:", k)


In [None]:
labels = list(intersect_groups.keys())
group_title = "Intersectional (All) Groups"
figsize=(max_width_in_inches, max_width_in_inches)
return_stats(labels, intersec_prb, intersec_cnt, intersec_mcc, group_title, figsize=figsize)
sns.despine()
plt.tight_layout()
plt.savefig(save_path + "mortality_intersect_age_sex.pdf", format="pdf")
plt.show()

In [None]:
intersect_groups = dict()

for l1 in ['[0]', '[1,5)', '[5,11)', '[11,19)', '>=19']:
    for l2 in ["male", "female"]:
        intersect_groups[" ".join([l1,l2])] = (groups[l1]) & (groups[l2])

intersec_mcc = {}
intersec_prb = {}
intersec_cnt = {}

for k, condition in intersect_groups.items():
    preds = result[condition]["PRED"].values
    targs = result[condition]["TARGET"].values
    intersec_mcc[k] = bootstrap_mcc(torch.from_numpy(preds).float(), torch.from_numpy(targs).long())
    intersec_prb[k] = {"median": np.median(preds), "mad": mad(preds)}
    intersec_cnt[k] = {"total": preds.shape[0], "positive": np.sum(targs), "unlabeled": result[condition]["UNLABELED"].values.sum()}
    print("Done:", k)

In [None]:
labels = list(intersect_groups.keys())
group_title = "Intersectional (Subset) Groups"
figsize=(max_width_in_inches, max_width_in_inches)
return_stats(labels, intersec_prb, intersec_cnt, intersec_mcc, group_title, figsize=figsize)
sns.despine()
plt.tight_layout()
plt.savefig(save_path + "mortality_intersect_sex_health.pdf", format="pdf")
plt.show()

In [None]:
intersect_groups = dict()

for l1 in ['[48, 921)', '[921, 1176)', '[1176, 2560]']:
    for l2 in ["male", "female"]:
        intersect_groups[" ".join([l1,l2])] = (groups[l1]) & (groups[l2])

intersec_mcc = {}
intersec_prb = {}
intersec_cnt = {}

for k, condition in intersect_groups.items():
    preds = result[condition]["PRED"].values
    targs = result[condition]["TARGET"].values
    intersec_mcc[k] = bootstrap_mcc(torch.from_numpy(preds).float(), torch.from_numpy(targs).long())
    intersec_prb[k] = {"median": np.median(preds), "mad": mad(preds)}
    intersec_cnt[k] = {"total": preds.shape[0], "positive": np.sum(targs), "unlabeled": result[condition]["UNLABELED"].values.sum()}
    print("Done:", k)
    

In [None]:
labels = list(intersect_groups.keys())
group_title = "Intersectional (Subset) Groups"
return_stats(labels, intersec_prb, intersec_cnt, intersec_mcc, group_title, figsize=(11,7))
sns.despine()
plt.tight_layout()
plt.savefig(save_path + "mortality_intersect_sex_length.svg", format="svg")
plt.show()

## Age vs Number of Health Records

In [None]:
### AGE vs Health Records
intersect_groups = dict()

for l1 in ['(34,39]', '(39, 44]', '(44, 49]', '(49, 54]', '(54, 59]', '(59, 64]']:
    for l2 in  ['[0]', '[1,5)', '[5,11)', '[11,19)', '>=19']:
        intersect_groups[" ".join([l1,l2])] = (groups[l1]) & (groups[l2])

intersec_mcc = {}
intersec_prb = {}
intersec_cnt = {}

for k, condition in intersect_groups.items():
    preds = result[condition]["PRED"].values
    targs = result[condition]["TARGET"].values
    intersec_mcc[k] = bootstrap_mcc(torch.from_numpy(preds).float(), torch.from_numpy(targs).long())
    intersec_prb[k] = {"median": np.median(preds), "mad": mad(preds)}
    intersec_cnt[k] = {"total": preds.shape[0], "positive": np.sum(targs), "unlabeled": result[condition]["UNLABELED"].values.sum()}
    print("Done:", k)

In [None]:
labels = list(intersect_groups.keys())
group_title = "Intersectional (Age vs Health Events) Groups"
return_stats(labels, intersec_prb, intersec_cnt, intersec_mcc, group_title, figsize=(11,7))
sns.despine()
plt.tight_layout()
plt.savefig(save_path + "mortality_intersect_age_health.svg", format="svg")
plt.show()