In [None]:
%reload_ext autoreload
%autoreload 2

import pickle
from pathlib import Path
from collections import defaultdict

import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import pandas as pd
from sklearn.metrics import ConfusionMatrixDisplay

from koafusion.various import calc_metrics, calc_metrics_bootstrap

In [None]:
DIR_PROJ = #ACT:SET_PATH
DIR_RESULTS = Path(DIR_PROJ, "results/")
DIR_OUT = Path(DIR_RESULTS, "temporary/")

In [None]:
DIR_META = Path("ACT:SET_PATH/meta_agg.csv")

df_meta = pd.read_csv(DIR_META, header=[0, 1])
df_meta = df_meta.loc[:, "-"]

display(df_meta.head())
display(df_meta.columns)

In [None]:
def read_cache(fn, presel=None):
    with open(fn, "rb") as f:
        c = pickle.load(f)
    if presel is not None:
        c = c[presel]
    if "predict_proba" in c:
        c["predict_proba"] = [e for e in c["predict_proba"]]
    df = pd.DataFrame.from_dict(c)
    #     {"AGE": float, "P02SEX": str, 'P01BMI': float, 'XRKL': int, 'exam_knee_id': str,
    #      'predict_proba': np.array, 'predict': np.array, 'target': np.array}
    return df

In [None]:
# Experiment list

PATHS_EXPERIMS = {
    # (input, target, model_descr)
    ("DESS_sag", "tiulpin2019_prog", "2D+FC_p+"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
    ("DESS_sag", "tiulpin2019_prog", "2D+FC_p-"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
    
    ("DESS_sag", "tiulpin2019_prog", "2D+LSTM_p+"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
    ("DESS_sag", "tiulpin2019_prog", "2D+LSTM_p-"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
    
    ("DESS_sag", "tiulpin2019_prog", "2D+TRF_p-(sag)"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
    ("DESS_sag", "tiulpin2019_prog", "2D+TRF_p+(sag)"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
    ("DESS_cor", "tiulpin2019_prog", "2D+TRF_p+(cor)"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
    ("DESS_ax", "tiulpin2019_prog", "2D+TRF_p+(ax)"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),

    ("DESS_cor", "tiulpin2019_prog", "2D+TRF_p+(cor)_adj"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
    ("DESS_ax", "tiulpin2019_prog", "2D+TRF_p+(ax)_adj"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),

    ("DESS_multi", "tiulpin2019_prog", "multi_2D+TRF_s-_p+"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
    ("DESS_multi", "tiulpin2019_prog", "multi_2D+TRF_s+_p+"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
    ("DESS_multi", "tiulpin2019_prog", "multi_2D+TRF_s-_p-"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
    ("DESS_multi", "tiulpin2019_prog", "multi_2D+TRF_s+_p-"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),

    ("DESS", "tiulpin2019_prog", "(2+1)D"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
    ("DESS", "tiulpin2019_prog", "3D_ResNetXt50_l1s=2"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
    ("DESS", "tiulpin2019_prog", "3D_ResNetXt50_l1s=1"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
    ("DESS", "tiulpin2019_prog", "3D_ShuffleNet"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
    
    ("XR", "tiulpin2019_prog", "XR"): \
         Path(DIR_RESULTS, "EXPERIMENT_ID", "logs_eval", "eval_raw_ens.pkl"),
}

In [None]:
# Read from cached

expers_data = {}

for k, p in PATHS_EXPERIMS.items():
    if "age,sex,BMI" in k[0] or "clin" in k[0]:
        # Take only LR model
        t = read_cache(p, presel="LR")
        expers_data[k] = t
    else:
        expers_data[k] = read_cache(p)

In [None]:
# Calculate metrics

expers_mx_all = {}

for code_exp, v in expers_data.items():
    print(code_exp)

    k = code_exp
    target = code_exp[1]
    expers_mx_all[k] = dict()
    
    prog_target = np.asarray(list(map(np.asarray, v["target"].tolist()))).ravel()
    prog_pred_proba = np.asarray(list(map(np.asarray, v["predict_proba"].tolist())))
    
    # 1 shot metrics
    t = calc_metrics(prog_target=prog_target,
                     prog_pred_proba=prog_pred_proba,
                     with_curves=True)
    expers_mx_all[k].update(t)

    # Bootstrapped metrics
    t = calc_metrics_bootstrap(prog_target=prog_target,
                               prog_pred_proba=prog_pred_proba)
    expers_mx_all[k].update(t)

t = pd.DataFrame.from_dict(expers_mx_all, orient="index")
display(t)
df_mx_all = t

In [None]:
model_to_label = {
    "2D+FC_p+": "2D + FC",
    "2D+LSTM_p+": "2D + LSTM",
    "2D+TRF_p+(sag)": r"2D $\mathit{(sag)}$ + TRF",
    "multi_2D+TRF_s+_p+": r"2D$^{sh}$ + TRF",
    "3D_ShuffleNet": "3D ShuffleNet",
    "XR": "XR",
}

n = 10
colors = plt.cm.Set1(np.linspace(0, 1, n))

model_to_color = {
#     "2D+FC_p+": colors[0],
#     "2D+LSTM_p+": colors[1],
#     "2D+TRF_p+(sag)": colors[2],
#     "multi_2D+TRF_s+_p+": colors[3],
#     "3D_ShuffleNet": colors[4],
    
#     "2D+TRF_p+(sag)": colors[0],
#     "multi_2D+TRF_s+_p+": colors[1],
#     "3D_ShuffleNet": colors[2],
    
    "2D+TRF_p+(sag)": colors[0],
    "3D_ShuffleNet": colors[1],
    "XR": colors[4],
}

data_code_to_vars = {
    "DESS": "DESS",
    "DESS_sag": "DESS_sag",
    "DESS_cor": "DESS_cor",
    "DESS_ax": "DESS_ax",
    "DESS_multi": "DESS_multi",
    "XR": "XR",
}

In [None]:
# Figures. ROC and Precision-Recall curves

T_EXPERS_SEL = [
    ("DESS_sag", "tiulpin2019_prog", "2D+FC_p+"),
    ("DESS_sag", "tiulpin2019_prog", "2D+FC_p-"),
    
    ("DESS_sag", "tiulpin2019_prog", "2D+LSTM_p+"),
    ("DESS_sag", "tiulpin2019_prog", "2D+LSTM_p-"),
    
    ("DESS_sag", "tiulpin2019_prog", "2D+TRF_p-(sag)"),
    ("DESS_sag", "tiulpin2019_prog", "2D+TRF_p+(sag)"),
    ("DESS_cor", "tiulpin2019_prog", "2D+TRF_p+(cor)"),
    ("DESS_ax", "tiulpin2019_prog", "2D+TRF_p+(ax)"),
    ("DESS_cor", "tiulpin2019_prog", "2D+TRF_p+(cor)_adj"),
    ("DESS_ax", "tiulpin2019_prog", "2D+TRF_p+(ax)_adj"),

    ("DESS_multi", "tiulpin2019_prog", "multi_2D+TRF_s-_p+"),
    ("DESS_multi", "tiulpin2019_prog", "multi_2D+TRF_s+_p+"),
    ("DESS_multi", "tiulpin2019_prog", "multi_2D+TRF_s-_p-"),
    ("DESS_multi", "tiulpin2019_prog", "multi_2D+TRF_s+_p-"),

    ("DESS", "tiulpin2019_prog", "(2+1)D"),
    ("DESS", "tiulpin2019_prog", "3D_ResNetXt50_l1s=2"),
    ("DESS", "tiulpin2019_prog", "3D_ResNetXt50_l1s=1"),
    ("DESS", "tiulpin2019_prog", "3D_ShuffleNet"),
]

matplotlib.rcParams.update({'font.size': 8})

summary = defaultdict(list)

target = = "tiulpin2019_prog"

fig, axes = plt.subplots(ncols=2, figsize=(4.2, 1.9))
lw = 1.2

t_metrics = {k: v for k, v in expers_mx_all.items()}
# Select by filter list
t_metrics = {k: v for k, v in t_metrics.items() if k in T_EXPERS_SEL}

for idx, k in enumerate(t_metrics):
    v = t_metrics[k]

    t_color = model_to_color[k[2]]
    t_label = model_to_label[k[2]]
    axes[0].plot(*v["roc_curve"], color=t_color, lw=lw, label=t_label)

    t_label = model_to_label[k[2]]
    axes[1].plot(*v["pr_curve"][::-1], color=t_color, lw=lw, label=t_label)

    if idx == 0:
        axes[0].plot([0, 1], [0, 1], color='lightgray', lw=lw, linestyle='--')
        t = v["prevalence"]
        axes[1].plot([0, 1], [t, t], color='lightgray', lw=lw, linestyle='--')

    summary["data"].append(k[0])
    summary["model"].append(k[2])
    summary["roc_auc"].append(
        f"{v['roc_auc'][0]:0.2f} ({v['roc_auc'][1]:0.2f})")
    summary["avg_precision"].append(
        f"{v['avg_precision'][0]:0.2f} ({v['avg_precision'][1]:0.2f})")
    summary["prevalence"].append(v["prevalence"])
    print(summary)

axes[0].set_xlim([-0.01, 1.01]); axes[0].set_ylim([-0.01, 1.01])
axes[1].set_xlim([-0.01, 1.01]); axes[1].set_ylim([-0.01, 1.01])

axes[0].set_xlabel('False Positive Rate')
axes[0].set_ylabel('True Positive Rate')
axes[0].set_title('ROC curve')
axes[0].legend(loc="lower right")

axes[1].set_xlabel('Recall')
axes[1].set_ylabel('Precision')
axes[1].set_title('Precision-Recall curve')
axes[1].legend(loc="lower right")

path_out = Path(DIR_OUT, "roc_pr.png")

plt.tight_layout()
plt.savefig(path_out, dpi=300)
# plt.show()
plt.close()

In [None]:
# Table. Metrics summary

t = pd.DataFrame.from_dict(summary)

for m in ("prevalence", "avg_precision", "roc_auc"):
    e = pd.pivot(
        t, index=["data", "model"],
        columns=["target",],
        values=[m,],
    )
    display(e)

In [None]:
# Plot confusion matrix

matplotlib.rcParams.update({'font.size': 8})

summary = defaultdict(list)

target = "tiulpin2019_prog"

t_metrics = {k: v for k, v in expers_mx_all.items()}
# Select by filter list
t_metrics = {k: v for k, v in t_metrics.items() if k in T_EXPERS_SEL}

for idx, (k, v) in enumerate(t_metrics.items()):
    print(f"{idx}: {repr(k)}")
    fig, axes = plt.subplots(figsize=(2.4, 2.4))

    print(v["cm"])

    cm_disp = ConfusionMatrixDisplay(
        v["cm_norm"], display_labels=["no\nprog.", "slow\n(72-96m)", "fast\n(<72m)"])

    cm_disp.plot(include_values=True,
                 cmap="OrRd",
                 xticks_rotation="horizontal",
                 values_format=".2f",
                 ax=axes,
                 colorbar=False)

    summary["target"].append(target)
    summary["data"].append(data_code_to_vars[k[0]])
    summary["model"].append(k[2])
    summary["cm"].append(v["cm"])
    summary["cm_norm"].append(v["cm_norm"])

    plt.tight_layout()

    path_out = Path(DIR_OUT, f"cm_norm__{idx}.pdf")
    plt.savefig(path_out, dpi=300)
    plt.close()