In [None]:
import os
import sys
from pathlib import Path

from PIL import Image, ImageOps
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
import torch
from torchmetrics.functional import accuracy

import pandas as pd
import seaborn as sns

if not os.path.abspath("../") in sys.path:
    sys.path.insert(0, os.path.abspath("../"))
from src.eval.visualizations_base import create_runs_from_folder, extract_results
from src.eval.plot_functions import plot_roc_curves, compute_metrics, build_attention_img, overlay_attention
from src.datamodules.datasets.histo import plot_coords

%load_ext autoreload
%autoreload 2

In [None]:
# Set matplotlib rcParams
rcParams["font.size"] = 14
rcParams["axes.labelsize"] = 14
rcParams["axes.titlesize"] = 14
rcParams["xtick.labelsize"] = 12
rcParams["ytick.labelsize"] = 12
rcParams["legend.fontsize"] = 12
rcParams["legend.title_fontsize"] = 12

## Load Experiments

In [None]:
RUN_DIR = Path(os.environ["EXPERIMENT_LOCATION"])
CONFIG_PATH = "logs/config.yaml"
SAVE_PATH = Path.cwd().parent / "report"

In [None]:
runs = [
    *create_runs_from_folder("slide-calib_512/mco-clam", run_dir=RUN_DIR, config_path=CONFIG_PATH, name="1 CLAM ImageNet"),
    *create_runs_from_folder("slide-calib_512/mco-clam-ciga", run_dir=RUN_DIR, config_path=CONFIG_PATH, name="2 CLAM Ciga"),
    *create_runs_from_folder("slide-calib_512/mco-transformer", run_dir=RUN_DIR, config_path=CONFIG_PATH, name="3 Transformer ImageNet"),
    *create_runs_from_folder("slide-calib_512/mco-transformer-ciga", run_dir=RUN_DIR, config_path=CONFIG_PATH, name="4 Transformer Ciga"),
    *create_runs_from_folder("slide-calib_512/mco-gnn", run_dir=RUN_DIR, config_path=CONFIG_PATH, name="5 GNN ImageNet"),
    *create_runs_from_folder("slide-calib_512/mco-gnn-ciga", run_dir=RUN_DIR, config_path=CONFIG_PATH, name="6 GNN Ciga"),
]
print("Total number of runs: ", len(runs))

## Evaluate Performance

In [None]:
grid = compute_metrics(
    runs,
    split=["test_id", "test_ood"],
    metrics=["AUROC", "ECE"],
    plot_hue="method",
    custom_legend=True,
    # figname="../report/figures/metrics-512.pdf",
    # tabname="../report/tables/metrics-512-ciga.tex"
)

## Evaluate Calibration

In [None]:
# Compare ECE implementations
from torchmetrics.classification import CalibrationError as TM_ECE
from src.metrics.calibration_error import CalibrationError as Custom_ECE


def plot_runs(indices: list[int]):
    fig, axs = plt.subplots(1, len(indices), figsize=(12, 3.2))
    for i in range(len(indices)):
        custom_ece = Custom_ECE(n_bins=10, norm="l1")
        run = runs[indices[i]]
        preds = run.test_id_preds
        Y_prob, labels = preds["softmax"], preds["label"]
        custom_ece.update(Y_prob[:, 1], labels)
        # print(custom_ece.compute().item())
        if i == 0:
            axs[i] = custom_ece.plot_reliability_diagram(custom_ax=axs[i], title=run.name, show_legend=True)
        else:
            axs[i] = custom_ece.plot_reliability_diagram(custom_ax=axs[i], title=run.name, show_legend=False)


plot_runs([0, 11, 20])  # One regular run for each model architecture
# plt.savefig("../report/figures/reliability_imagenet.pdf", bbox_inches="tight")
# plot_runs([5, 15, 25])  # The correspondig temperture-scaled run for each model architecture
# plt.savefig("../report/figures/reliability_ts.pdf", bbox_inches="tight")
plt.show()