# Create figures for the paper

- test.py needs to be run before.
- DATASET_LOCATION needs to be set before.
- Model weights need to be available.

In [None]:
%load_ext autoreload
%autoreload 2
import os

import torch
import torchmetrics
from torchmetrics import ConfusionMatrix

from tqdm import tqdm
from matplotlib import pyplot as plt
import pandas as pd

import matplotlib.patches as mpatches

from torchmetrics import Dice

from pathlib import Path
from src.gleason_data import GleasonX
from src.augmentations import basic_transforms_val_test_scaling512, normalize_only_transform
from src.model_utils import SoftDICEMetric
from src.jdt_losses import SoftCorrectDICEMetric

from src.tree_loss import generate_label_hierarchy
from  src.gleason_utils import classes_ll1_shortform as class_names

from src.jdt_losses import SoftDICECorrectAccuSemiMetric
from src.jdt_losses import JDTLoss

In [None]:
# VISUALIATIONS

from itertools import zip_longest
import math
from textwrap import wrap
from typing import Literal

import seaborn as sns

import src.augmentations as augmentations
import numpy as np
from src.gleason_utils import create_composite_plot

import ipywidgets as wid

from matplotlib.patches import Rectangle
from matplotlib.colors import LinearSegmentedColormap
from matplotlib.colors import ListedColormap
from matplotlib import colormaps as cm

In [None]:
base_path = Path(os.environ["DATASET_LOCATION"]) / "GleasonXAI"
fig_dir = Path('./figures')
os.makedirs(fig_dir, exist_ok=True)

In [None]:
def create_ensemble_plot(dataset, idx, ensemble_predictions, individual_predictions, show_ensemble_preds = 0):
    
    img, masks, background_mask = dataset.__getitem__(idx, False)
    org_img = augmentations.basic_transforms_val_test_colorpreserving(image=np.array(dataset.get_raw_image(idx)))["image"]
    np_seg = np.array(ensemble_predictions[idx].argmax(dim=0)).astype(np.uint8)

    masks = {"segmentation": np_seg} | {f"Annotator {i}": mask.astype(np.uint8) for i,
                                        mask in enumerate(masks)}

    sub_ensemble_preds = []
    rnd_ensemble_subset = np.random.permutation(show_ensemble_preds)

    for i in rnd_ensemble_subset:
        np_seg = np.array(individual_predictions[i, idx].argmax(dim=0)).astype(np.uint8)
        sub_ensemble_preds.append(np_seg)
    

    
    if len(sub_ensemble_preds) > 0:
        masks = {f"ensemble_pred_{i}":sub_ensemble_pred for i, sub_ensemble_pred in enumerate(sub_ensemble_preds)} | masks


    f = create_composite_plot(dataset, org_img, masks, background_mask.astype(np.uint8), label_level=1, only_show_existing_annotation=True)


def composite_prediction_plot(predictions, dataset, indices, mask_background=False, full_legend=True):

    num_plots = len(indices)
    rows = int(np.ceil(num_plots / 3))  # Assuming 3 columns per row, adjust as needed

    # f, axs = plt.subplots(rows, 3, figsize=(12, 4 * rows))

    # for idx, ax in zip(indices, axs.flatten()):
    for idx in indices:

        img, masks, background_mask = dataset.__getitem__(idx, False)
        org_img = augmentations.basic_transforms_val_test_colorpreserving(image=np.array(dataset.get_raw_image(idx)))["image"]

        background = background_mask if mask_background else None

        #out = generate_model_output(model, img, device=device, label_remapping=LABEL_REMAPPING, inferer=SLIDING_WINDOW_INFERER)
        out = predictions[idx]
        out = torch.nn.functional.softmax(out, 0)
        np_seg = np.array(out.argmax(dim=0)).astype(np.uint8)
        # out = generate_model_output(model, img, masks, dataset.num_classes)
        # np_seg = np.array(out.squeeze(0).argmax(dim=0)).astype(np.uint8)
        f = create_composite_plot(dataset, org_img,  {"segmentation": np_seg} | {f"Annotator {i}": mask for i,
                                    mask in enumerate(masks)}, background, only_show_existing_annotation=not full_legend)

        return f


def create_single_class_acti_maps(predictions, dataset, idx, plot_mode: Literal["heatmap", "contourf", "contour", "thresholded"] = "contourf", thresholds: list[float] = None, strip_background=False, plot=True):

    img, masks, background_mask = dataset.__getitem__(idx, False)
    # org_img = augmentations.basic_transforms_val_test_colorpreserving(image=np.array(dataset.get_raw_image(idx)))["image"]

    out = predictions[idx] 
    out = torch.nn.functional.softmax(out, 0)

    np_seg = np.array(out.argmax(dim=0)).astype(np.uint8)

    org_img = np.array(dataset.get_raw_image(idx).resize(np_seg.shape))

    colormap = ListedColormap(dataset.colormap.colors)
    num_class_to_vis = dataset.num_classes

    if strip_background:

        for mask in masks:
            mask += 1
            mask[background_mask] = 0

        np_seg += 1
        np_seg[background_mask] = 0

        out[:, torch.tensor(background_mask).bool()] = 0.0

        colormap = ListedColormap(np.concatenate([np.array([[0., 0., 0., 1.]]), dataset.colormap.colors]))
        num_class_to_vis = dataset.num_classes + 1

    f, axes = plt.subplots(2, 3+math.ceil(dataset.num_classes/2), sharex=False, sharey=False, constrained_layout=False, figsize=(12, 4))
    f.tight_layout()

    axes[0, 0].imshow(org_img)
    axes[0, 0].set_title("Image", size=7)
    axes[0, 0].set_axis_off()

    # axes[1, 0].imshow(img)
    axes[1, 0].imshow(np_seg.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1, interpolation_stage="rgba")
    axes[1, 0].set_title("Segmentation", size=7)
    axes[1, 0].set_axis_off()

    for sub_ax, mask in zip_longest(list(axes[:, 1:3].flatten()), masks):
        sub_ax.set_axis_off()

        if mask is not None:
            sub_ax.imshow(mask.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1,  interpolation_stage="rgba")
            sub_ax.set_title("Annotation", size=7)

    for i in range(dataset.num_classes):
        active_axis = axes[:, 3:].flatten()[i]

        class_out = out[i, :].detach().numpy()

        if strip_background:
            class_out[background_mask] = 0

        temp_colormap = ListedColormap(np.concatenate([np.array([[0., 0., 0., 1.]]), dataset.colormap.colors[i].reshape(1, -1)]))

        match plot_mode:
            case "heatmap": active_axis.matshow(class_out, cmap=cm["Grays"].reversed(), vmin=out[out != 0.0].min(), vmax=out.max())
            case "multilabel": active_axis.matshow(class_out >= 0.32, cmap=temp_colormap)
            case "contour": active_axis.contour(np.flipud(class_out), cmap=cm["Grays"].reversed(), vmin=0.0, vmax=out.max())
            case "contourf": active_axis.contourf(np.flipud(class_out), cmap=cm["Grays"].reversed(), vmin=0.0, vmax=out.max())
            case "thresholded": active_axis.matshow(class_out > thresholds[i], cmap=cm["Grays"].reversed())

        if i == 0:
            title = "Benign"
        else:
            exp = dataset.explanations[i-1]
            classes_named = ["benign tissue", "3 - individual glands", "3 - compressed glands", "4 - poorly formed glands",
                             "4 - cribriform glands", "4 - glomeruloid glands", "5 - group of tumor cells", "5 - single cells", "5 - cords", "5 - comedenocrosis"]
            exp_short = classes_named[i]
            title = "\n".join(wrap(str(dataset.exp_grade_mapping[exp]) + ": " + exp_short, width=20))

        active_axis.set_title(title, size=7)
        active_axis.set_axis_off()
    
    plt.subplots_adjust(wspace=0.1, hspace=0.3)

    if plot:
        plt.show()
    else:
        return f, axes


def create_simple_seg_anno_plot(predictions, dataset, idx, plot_mode: Literal["heatmap", "contourf", "contour", "thresholded"] = "contourf", thresholds: list[float] = None, strip_background=False, plot=True):

    img, masks, background_mask = dataset.__getitem__(idx, False)

    out = predictions[idx]

    out = torch.nn.functional.softmax(out, 0)

    np_seg = np.array(out.argmax(dim=0)).astype(np.uint8)

    org_img = np.array(dataset.get_raw_image(idx).resize(np_seg.shape))

    colormap = ListedColormap(dataset.colormap.colors)
    num_class_to_vis = dataset.num_classes

    if strip_background:

        for mask in masks:
            mask += 1
            mask[background_mask] = 0

        np_seg += 1
        np_seg[background_mask] = 0

        out[:, torch.tensor(background_mask).bool()] = 0.0

        colormap = ListedColormap(np.concatenate([np.array([[0., 0., 0., 1.]]), dataset.colormap.colors]))
        num_class_to_vis = dataset.num_classes + 1

    f, axes = plt.subplots(1, 5, sharex=False, sharey=False, constrained_layout=False, figsize=(12, 4))
    f.tight_layout()
    axes[0].imshow(org_img)
    axes[0].set_title("Image", size=7)
    axes[0].set_axis_off()

    axes[1].imshow(np_seg.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1, interpolation_stage="rgba")
    axes[1].set_title("Segmentation", size=7)
    axes[1].set_axis_off()

    for sub_ax, mask in zip_longest(axes[2:], masks):
        
        if sub_ax is None:
            continue
        sub_ax.set_axis_off()

        if mask is not None:
            sub_ax.imshow(mask.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1,  interpolation_stage="rgba")
            sub_ax.set_title("Annotation", size=7)
    plt.tight_layout()


def create_multi_seg_anno_plot(predictions, dataset, idcs, strip_background=False, legend=False, class_names=None):

    if class_names is None:
        class_names = dataset.classes_named
    num_plots = len(idcs)
    
    f, top_axes = plt.subplots(num_plots, 5, sharex=False, sharey=False, constrained_layout=False, figsize=(10, 2*num_plots))
    f.tight_layout()

    encountered_classes = set()


    if strip_background:
        colormap = ListedColormap(np.concatenate([np.array([[0., 0., 0., 1.]]), dataset.colormap.colors]))
        num_class_to_vis = dataset.num_classes + 1
    else:
        colormap = ListedColormap(dataset.colormap.colors)
        num_class_to_vis = dataset.num_classes
    
    for i,idx in enumerate(idcs):

        axes = top_axes[i]
        
        _, masks, background_mask = dataset.__getitem__(idx, False)

        out = predictions[idx]  
        
        out = torch.nn.functional.softmax(out, 0)

        np_seg = np.array(out.argmax(dim=0)).astype(np.uint8)

        org_img = np.array(dataset.get_raw_image(idx).resize(np_seg.shape))



        if strip_background:

            for mask in masks:
                mask += 1
                mask[background_mask] = 0

            np_seg += 1
            np_seg[background_mask] = 0

            out[:, torch.tensor(background_mask).bool()] = 0.0

        encountered_classes |= set(np.unique(np_seg))
        
        for mask in masks:
            encountered_classes |= set(np.unique(mask))

        # f, axes = plt.subplots(2, 3+math.ceil(dataset.num_classes/2), sharex=False, sharey=False, constrained_layout=False, figsize=(12, 4))
        axes[0].imshow(org_img)
        #axes[0].set_title(f"Image {i}", size=7)
        axes[0].set_axis_off()
        #axes[0].text(0,0,idx)

        # axes[1, 0].imshow(img)
        axes[1].imshow(np_seg.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1, interpolation_stage="rgba")
        #axes[1].set_title("Segmentation", size=7)
        axes[1].set_axis_off()

        for sub_ax, mask in zip_longest(axes[2:], masks):

            if sub_ax is None:
                continue
            sub_ax.set_axis_off()

            if mask is not None:
                sub_ax.imshow(mask.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1,  interpolation_stage="rgba")
                #sub_ax.set_title("Annotation", size=7)

    if legend:
        if strip_background:
            legend_handels = [mpatches.Patch(color=np.array([0., 0., 0., 1.]), label=f"Background")]
            legend_handels += [mpatches.Patch(color=colormap(dataset.classes_number_mapping[cls]+1), label=cls_renamed if len(cls_renamed) < 60 else cls_renamed[:60]+"...")
                                for cls, cls_renamed in zip(dataset.classes_named, class_names) if dataset.classes_number_mapping[cls]+1 in encountered_classes]
        else:
            legend_handels = [mpatches.Patch(color=colormap(dataset.classes_number_mapping[cls]), label=cls_renamed[:40] if len(cls_renamed) < 60 else cls_renamed[:60]+"...")
                              for cls, cls_renamed in zip(dataset.classes_named, class_names) if dataset.classes_number_mapping[cls] in encountered_classes]


        f.legend(handles=legend_handels, loc="center left", fontsize=12, bbox_to_anchor=[1.0,0.5])
        # plt.title(test_df["TMA"][0])
    print([dataset.classes_named[e-1] for e in encountered_classes if e != 0])
    plt.tight_layout()
    plt.subplots_adjust(wspace=0.03, hspace=0.05)

In [None]:
available_models = ["ll1/CE", "ll1/Tree", "ll1/SDB", "ll0/CE", "ll0/SDB", "ens/ll1/Tree", "ens/ll1/SDB",
                    "final/ll0/CE",
                    "final/ll1/CE",
                    "final/ll0/SDB",
                    "final/ll1/SDB",
                    "final/ll1/OHCE",
                    "final/ll1/SDBML",
                    "final/ll0/DICE",
                    "final/ll1/DICE",
                    "final2/ll1/SDB",
                    "final2/ll1/CE",
                    ]

In [None]:

def get_model_settings(selected_model, remap_ll0):
    label_level = 1 if "ll1" in selected_model else 0
    eval_on = "test"

    num_classes = 10 if (label_level == 1 and not remap_ll0) else 4

    if remap_ll0:
        assert "ll1" in selected_model
        label_level = 0

    data_opts = "final" if "final" in selected_model else "final" if "final2" in selected_model else "org"


    if selected_model == "ens/ll1/SDB":
        model_paths = [Path(f"GleasonBackgroundMasking/label_level1/HoleMask/SoftDICEBalancedNoZoomCont-{i}/version_0/") for i in [1,2,3,4]]
        model_paths +=  [Path(f"GleasonBackgroundMasking/label_level1/HoleMask/SoftDICEBalancedNoZoom-{i}/version_0/") for i in [1, 2,3]]

    elif selected_model == "ens/ll1/Tree":
        model_paths = [Path(f"GleasonBackgroundMasking/label_level1/HoleMask/FinalTreeLossNoZoomCont-{i}/version_0/") for i in [1, 2, 3, 4]]
        model_paths +=  [Path(f"GleasonBackgroundMasking/label_level1/HoleMask/NoZoomFinalTreeLoss-{i}/version_0/") for i in [1, 2,3]]

    elif selected_model == "ll1/CE":
        model_paths = [Path(f"GleasonBackgroundMasking/label_level1/HoleMask/NoZoomFinalCE-{i}/version_0/") for i in [1, 2, 3]]

    elif selected_model == "ll1/Tree":
        model_paths = [Path(f"GleasonBackgroundMasking/label_level1/HoleMask/NoZoomFinalTreeLoss-{i}/version_0/") for i in [1, 2, 3]]

    elif selected_model == "ll1/SDB":
        model_paths = [Path(f"GleasonBackgroundMasking/label_level1/HoleMask/SoftDICEBalancedNoZoom-{i}/version_0/") for i in [1, 2, 3]]

    elif selected_model == "ll0/CE":
        model_paths = [Path(f"GleasonBackgroundMasking/label_level0/HoleMask/NoZoomFinalCE-{i}/version_0/") for i in [1, 2, 3]]

    elif selected_model == "ll0/SDB":
        model_paths = [Path(f"GleasonBackgroundMasking/label_level0/HoleMask/SoftDICEBalancedNoZoom-{i}/version_0/") for i in [1, 2, 3]]

    elif selected_model == "final/ll0/CE":
        model_paths = [Path(f"GleasonFinal/label_level0/CE-{i}/version_0/") for i in [1,2,3]]

    elif selected_model == "final/ll1/CE":
        model_paths = [Path(f"GleasonFinal/label_level1/CE-{i}/version_0/") for i in [1,2,3]]

    elif selected_model == "final/ll0/SDB":
        model_paths = [Path(f"GleasonFinal/label_level0/SoftDiceBalanced-{i}/version_0/") for i in [1, 2, 3]]

    elif selected_model == "final/ll1/SDB":
        model_paths = [Path(f"GleasonFinal/label_level1/SoftDiceBalanced-{i}/version_0/") for i in [1, 2, 3]]

    elif selected_model == "final/ll0/OHCE":
        model_paths = [Path(f"GleasonFinal/label_level0/OH_CE-{i}/version_0/") for i in [1, 2, 3]]

    elif selected_model == "final/ll1/OHCE":
        model_paths = [Path(f"GleasonFinal/label_level1/OH_CE-{i}/version_0/") for i in [1, 2, 3]]

    elif selected_model == "final/ll1/SDBML":
        model_paths = [Path(f"GleasonFinal/label_level1/SoftDiceBalancedMultiLevel-{i}/version_0/") for i in [1, 2, 3]]

    elif selected_model == "final/ll0/DICE":
        model_paths = [Path(f"GleasonFinal/label_level0/DICE-{i}/version_0/") for i in [1, 2, 3]]

    elif selected_model == "final/ll1/DICE":
        model_paths = [Path(f"GleasonFinal/label_level1/DICE-{i}/version_0/") for i in [1, 2, 3]]
    
    elif selected_model == "final2/ll1/SDB":
        model_paths = [Path(f"GleasonFinal2/label_level1/SoftDiceBalanced-{i}/version_0/") for i in [1, 2, 3]]
    
    elif selected_model == "final2/ll1/CE":
        model_paths = [Path(f"GleasonFinal2/label_level1/CE-{i}/version_0/") for i in [1, 2, 3]]


    else:
        raise RuntimeError()

    preds_paths = []
    for path in model_paths:
        assert (base_path/path).exists(), f"Could not find {str(base_path/path)}"
        assert (base_path/path/"preds"/f"pred_{eval_on}.pt").exists(), base_path/path/"preds"/f"pred_test.pt"
        preds_paths.append(base_path/path/"preds"/f"pred_{eval_on}.pt")
        
    data_options = {"org": {"scaling": "1024", "transforms": basic_transforms_val_test_scaling512, "label_level": label_level, "create_seg_masks": True, "tissue_mask_kwargs": {"open": False, "close": False, "flood": False}},
                    "final": {"scaling": "MicronsCalibrated", "transforms": normalize_only_transform, "label_level": label_level, "create_seg_masks": True, "tissue_mask_kwargs": {"open": False, "close": False, "flood": False}, "drawing_order": "grade_frame_order", "explanation_file": "final_filtered_explanations_df.csv", "data_split": (0.7, 0.15, 0.15)},
                    }

    # data = GleasonX(base_path, split="test", scaling="1024", transforms=basic_transforms_val_test_scaling512, label_level=label_level, create_seg_masks=True, tissue_mask_kwargs={"open": False, "close":False, "flood":False})
    data = GleasonX(base_path, split="test", **data_options[data_opts])

    labels = []
    bgs = []
    for i in tqdm(range(len(data))):
        _, label, background = data[i]
        labels.append(label)
        bgs.append(background)


    def remapping_function(out):
    
        out_remappings = generate_label_hierarchy(out, data.exp_numbered_lvl_remapping, start_level=1)

        return out_remappings[0]

    return model_paths, preds_paths,  label_level, num_classes, data, labels, bgs, remapping_function

In [None]:
from torchmetrics import Accuracy
from src.model_utils import L1CalibrationMetric

def get_all_metrics_from_list(preds, labels, bgs):

    num_classes = preds[0].shape[0]
    
    d_mac = Dice(num_classes=num_classes, average="macro")
    d_mic = Dice(num_classes=num_classes, average="micro")
    b_acc = Accuracy(task="multiclass", num_classes=num_classes, average="macro")
    acc = Accuracy(task="multiclass", num_classes=num_classes, average="micro")
    
    emd = []
    L1 = L1CalibrationMetric()
    mDICED = SoftDICECorrectAccuSemiMetric()

    for i in tqdm(range(len(preds))):
        pred = preds[i]
        bg = bgs[i]
        label = labels[i]
        fg = ~bg

        emd.append(((pred[:, fg] - label[:, fg]).abs().sum(dim=1)/2)/fg.sum())

        label_max = torch.max(label, dim=0)[0]
        duplicated_max = torch.sum(label == label_max.unsqueeze(0), dim=0) > 1

        unique_max = ~duplicated_max

        unique_max_fg = torch.logical_and(fg, unique_max)

        unique_max_fg_pred_maj = pred[:, unique_max_fg].unsqueeze(0).argmax(dim=1)
        unique_max_fg_label_maj = label[:, unique_max_fg].unsqueeze(0).argmax(dim=1)

        d_mac.update(unique_max_fg_pred_maj, unique_max_fg_label_maj)
        d_mic.update(unique_max_fg_pred_maj, unique_max_fg_label_maj)
        b_acc.update(unique_max_fg_pred_maj, unique_max_fg_label_maj)
        acc.update(unique_max_fg_pred_maj, unique_max_fg_label_maj)

        L1.update(pred[:, fg].unsqueeze(0), label[:, fg].unsqueeze(0))
        mDICED.update(pred[:,fg].unsqueeze(0), label[:,fg].unsqueeze(0))
        #corrected_masked_mIoUD.update(preds_ensemble[i].unsqueeze(0), labels[i].unsqueeze(0), keep_mask=~rel_bg)


    emd = torch.stack(emd).mean(dim=0)
    emd = emd.sum()

    d_mac = d_mac.compute()
    d_mic = d_mic.compute()
    mDICED = mDICED.compute()
    L1 = L1.compute()
    acc = acc.compute()
    b_acc = b_acc.compute()

    return {"mDICED":mDICED.item(), "L1":emd.item(), "L1Compare":L1.item(),  "DICEmacro_unique_max":d_mac.item(), "DICE_unique_max":d_mic.item(), "Acc":acc.item(), "Bacc":b_acc.item()}

In [None]:
def compute_model_metrics(preds_path, labels, bgs, remap=False):

    mets = []

    for p_path in preds_path:
        print(f"Loading {p_path.parents[1]}")
        preds = torch.load(base_path/p_path)
        print("Softmax")
        preds = [torch.nn.functional.softmax(img_pred.float().squeeze(0), dim=0) for img_pred in preds]

        if remap:
            print("Remapping to LL0")
            preds = [remap(pred.unsqueeze(0)).squeeze(0) for pred in preds]
            
        print("Computing metrics")
        mets.append(get_all_metrics_from_list(preds, labels, bgs))


    print("Results")
    met_df = pd.DataFrame(mets).aggregate(["mean", "std"])
    print(met_df)

    return met_df

def get_save_path(model_name, remapped):
    
    save_name = model_name.replace("/", "_")
    remap_str = "_remaped" if remapped else ""
    save_path = Path(f"./results/metrics_final/{save_name}{remap_str}.csv")
    return save_path

In [None]:
models_to_test = [s for s in available_models if "final2/ll1/SDB" in s]#[s for s in available_models if "final2/" in s]

for selected_model in models_to_test:
    
    print(f"selected_model: {selected_model}")
    assert selected_model in available_models
    remap_ll0 = False
    print("No Remap")

    save_path = get_save_path(selected_model, remap_ll0)

    if not save_path.exists():
        print("Loading model_paths and data")
        model_paths, preds_paths, label_level, num_classes, data, labels, bgs, remapping_function = get_model_settings(selected_model, remap_ll0)
        df = compute_model_metrics(preds_paths, labels, bgs, remap=False)
        save_path.parent.mkdir(parents=True, exist_ok=True)
        df.to_csv(save_path)
    else:
            print("Skipping due to existing save_file")
    
    if "ll1" in selected_model:
        remap_ll0 = True
        print("Remap to LL1")
        
        save_path = get_save_path(selected_model, remap_ll0)
        if not save_path.exists():
            print("Loading model_paths and data")
            model_paths, preds_paths, label_level, num_classes, data, labels, bgs, remapping_function = get_model_settings(selected_model, remap_ll0)
            df = compute_model_metrics(preds_paths, labels, bgs, remap=remapping_function)
            save_path.parent.mkdir(parents=True, exist_ok=True)
            df.to_csv(save_path)
        else:
            print("Skipping due to existing save_file")

In [None]:
a = get_model_settings("final2/ll1/SDB", False)
_, p_paths, _, _, data, labels, bgs, remapping_function  = a

In [None]:
preds = torch.load(base_path/p_paths[0])

preds2 = torch.load(base_path/p_paths[1])

for i,pred2 in enumerate(preds2):
    preds[i] += pred2

preds2 = None
preds2 = torch.load(base_path/p_paths[2])

for i, pred2 in enumerate(preds2):
    preds[i] += pred2

preds2 = None

print("Softmax")
preds = [torch.nn.functional.softmax(img_pred.float().squeeze(0), dim=0) for img_pred in preds]

if False:
    print("Remapping to LL0")
    preds = [remapping_function(pred.unsqueeze(0)).squeeze(0) for pred in preds]
preds_ensemble=preds

In [None]:
# Compute outputs
STRIP_BACKGROUND = True


num_classes = 10

pix_freq = torch.zeros(num_classes, dtype=torch.int)
max_freq = torch.zeros(num_classes, dtype=torch.int)
pred_freq = torch.zeros(num_classes, dtype=torch.int)
ml_pred_freq = torch.zeros(num_classes, dtype=torch.int)

one_annotator_pred_freq = torch.zeros(num_classes, dtype=torch.int)

conf_matrix = ConfusionMatrix(task="multiclass", num_classes=num_classes)

for p_path in p_paths:

    preds = torch.load(base_path/p_path)
    preds = [torch.nn.functional.softmax(img_pred.float().squeeze(0), dim=0) for img_pred in preds]

        
    for i in tqdm(range(len(data))):

        out = preds[i]

        background_mask = bgs[i]
        mask = labels[i]
        if STRIP_BACKGROUND:
            foreground_mask = ~background_mask.bool()
        else:
            foreground_mask = torch.ones_like(background_mask).bool()

        label_max = torch.max(mask, dim=0)[0]
        duplicated_max = torch.sum(mask == label_max.unsqueeze(0), dim=0) > 1

        unique_max = ~duplicated_max

        unique_max_fg = torch.logical_and(foreground_mask, unique_max)

        mask_unique = mask[:, unique_max_fg]
        out_unique = out[:, unique_max_fg]

        mask = mask[:, foreground_mask].flatten(start_dim=1)
        out = out[:, foreground_mask].flatten(start_dim=1)

        pix_freq += torch.sum((mask > 0), dim=(1))
        max_freq += torch.bincount(torch.argmax(mask_unique, dim=0).reshape(-1), minlength=num_classes)
        pred_freq += torch.bincount(torch.argmax(out,  dim=0).reshape(-1), minlength=num_classes)
        one_annotator_pred_freq += torch.sum(out >= 0.33, dim=(1))  # torch.bincount(torch.argmax(out,  dim=0).reshape(-1), minlength=num_classes)

        conf_matrix(out_unique.argmax(dim=0), mask_unique.argmax(dim=0))


In [None]:
# Compute outputs
STRIP_BACKGROUND = True

num_classes = 10

pix_freq = torch.zeros(num_classes, dtype=torch.int)
max_freq = torch.zeros(num_classes, dtype=torch.int)
pred_freq = torch.zeros(num_classes, dtype=torch.int)
ml_pred_freq = torch.zeros(num_classes, dtype=torch.int)

one_annotator_pred_freq = torch.zeros(num_classes, dtype=torch.int)

conf_matrices = [ConfusionMatrix(task="multiclass", num_classes=num_classes) for _ in range(len(p_paths))]

for i, p_path in enumerate(p_paths):

    count_unique_max = 0.0
    count_foreground = 0.0
    count_pixels = 0.0

    conf_matrix = conf_matrices[i]

    preds = torch.load(base_path/p_path)
    preds = [torch.nn.functional.softmax(img_pred.float().squeeze(0), dim=0) for img_pred in preds]

    for i in tqdm(range(len(data))):

        out = preds[i]

        background_mask = bgs[i]
        mask = labels[i]
        if STRIP_BACKGROUND:
            foreground_mask = ~background_mask.bool()
        else:
            foreground_mask = torch.ones_like(background_mask).bool()

        label_max = torch.max(mask, dim=0)[0]
        duplicated_max = torch.sum(mask == label_max.unsqueeze(0), dim=0) > 1

        unique_max = ~duplicated_max

        unique_max_fg = torch.logical_and(foreground_mask, unique_max)

        count_unique_max += torch.logical_and(foreground_mask, unique_max).sum()
        count_foreground += foreground_mask.sum()
        count_pixels += foreground_mask.numel()



        mask_unique = mask[:, unique_max_fg]
        out_unique = out[:, unique_max_fg]

        mask = mask[:, foreground_mask].flatten(start_dim=1)
        out = out[:, foreground_mask].flatten(start_dim=1)

        pix_freq += torch.sum((mask > 0), dim=(1))
        max_freq += torch.bincount(torch.argmax(mask_unique, dim=0).reshape(-1), minlength=num_classes)
        pred_freq += torch.bincount(torch.argmax(out,  dim=0).reshape(-1), minlength=num_classes)
        one_annotator_pred_freq += torch.sum(out >= 0.33, dim=(1))  # torch.bincount(torch.argmax(out,  dim=0).reshape(-1), minlength=num_classes)

        conf_matrix(out_unique.argmax(dim=0), mask_unique.argmax(dim=0))

In [None]:
# Evaluate the number of foreground and unique_max pixels.

# data = GleasonX(base_path, split="test", scaling="1024", transforms=basic_transforms_val_test_scaling512, label_level=label_level, create_seg_masks=True, tissue_mask_kwargs={"open": False, "close":False, "flood":False})
data_all = GleasonX(base_path, split="all",  **{"scaling": "MicronsCalibrated", "transforms": normalize_only_transform, "label_level": 1, "create_seg_masks": True, "tissue_mask_kwargs": {
                "open": False, "close": False, "flood": False}, "drawing_order": "grade_frame_order", "explanation_file": "final_filtered_explanations_df.csv", "data_split": (0.7, 0.15, 0.15)}
                )

count_unique_max = 0.0
count_foreground = 0.0
count_pixels = 0.0
count_agg_annotators = torch.zeros(4)

label_counts = torch.zeros(data_all.num_classes, 4)
STRIP_BACKGROUND = True

num_classes = 10

for i in tqdm(range(len(data_all))):

    _, mask, background_mask = data_all[i]
    
    if STRIP_BACKGROUND:
        foreground_mask = ~background_mask.bool()
    else:
        foreground_mask = torch.ones_like(background_mask).bool()

    mask = (mask *3).int()
    mask_fg = mask[:, foreground_mask].flatten(start_dim=1)

    count_agg_annotators += torch.bincount(mask_fg.max(dim=0)[0], minlength=4)
    for c in range(data_all.num_classes):
        occ_count = torch.bincount(mask_fg[c].flatten(), minlength=4)
        label_counts[c,:] += occ_count

    
    label_max = torch.max(mask, dim=0)[0]
    duplicated_max = torch.sum(mask == label_max.unsqueeze(0), dim=0) > 1

    unique_max = ~duplicated_max

    unique_max_fg = torch.logical_and(foreground_mask, unique_max)

    count_unique_max += torch.logical_and(foreground_mask, unique_max).sum()
    count_foreground += foreground_mask.sum()
    count_pixels += foreground_mask.numel()

count_unique_max/count_foreground

ll0: [0.0000, 0.0246, 0.3623, 0.6130]
ll1: [0.0000, 0.1359, 0.4107, 0.4535]
ll2: [0.0000, 0.3224, 0.3780, 0.2996]

In [None]:
pixel_agreement = label_counts[:,1:]/label_counts[:,1:].sum(dim=1).reshape(-1,1)
percentages = label_counts[:, 1:].sum(dim=1)/label_counts[:, :].sum(dim=1)

for i,c in enumerate(class_names):
    s = f"{c}: {label_counts[i, 1:].sum()/label_counts[i,:].sum():0.2f}, {list(pixel_agreement[i].numpy()*100)}"
    print(s)


In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 8), gridspec_kw={'width_ratios': [1, 1], 'hspace':-4.0}, sharey=True)

fontsize_axis = 22
fontsize=20
fontsize_label = 20
fontsize_bar = 20


cmap = ListedColormap(cm["Blues"](np.linspace(0, 1, 1000))[:-250])

# First subplot: pixel agreement as a heatmap on the left
im = ax1.matshow(pixel_agreement*100, cmap=cmap, aspect='auto')

# Set xtick labels for the heatmap as 1, 2, and 3
ax1.set_xticks(np.arange(pixel_agreement.shape[1]))
ax1.set_xticklabels([1, 2, 3], fontsize=fontsize)
ax1.set_xlabel("Number of agreeing annotators", fontsize=fontsize_axis)
ax1.set_ylabel("Explanation", fontsize=fontsize_axis)
ax1.set_yticks(np.arange(len(class_names)))
ax1.set_yticklabels(class_names, fontsize=fontsize)
# ax1.set_title("Number of annotators per explanation annotated pixels")
# Loop over data dimensions and create text annotations for the heatmap
for i in range(pixel_agreement.shape[0]):
    for j in range(pixel_agreement.shape[1]):
        ax1.text(j, i, f"{pixel_agreement[i, j]*100:.2f}%",
                 ha="center", va="center", color="black", fontsize=fontsize_label)

ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax1.spines['left'].set_visible(False)
ax1.spines['bottom'].set_visible(False)

ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)
#ax2.spines['left'].set_visible(False)
#ax2.spines['bottom'].set_visible(False)


# Add color bar for pixel agreement
#cbar = fig.colorbar(im, ax=ax1)

# Second subplot: percentages as a horizontal bar chart on the right
ax2.barh(class_names, percentages, color='lightskyblue')

for i, p in enumerate(percentages):
        if i in [2, 5, 7]:
            ax2.text(
                percentages[i].item()/2 + 0.02,# max(percentages[i].item()/2, 0.02) + 0.02,
                i,
                f"{percentages[i].item()*100:.2f}",
                ha='left',
                va='center',
                fontsize=fontsize_bar
            )
        elif i == 9:
             ax2.text(
                percentages[i].item()/2 + 0.03,# max(percentages[i].item()/2, 0.02) + 0.02,
                i,
                f"{percentages[i].item()*100:.2f}",
                ha='left',
                va='center',
                fontsize=fontsize_bar
            )
        else:
            ax2.text(
                max(percentages[i].item()/2, 0.02),
                i,
                f"{percentages[i].item()*100:.2f}",
                ha='center',
                va='center',
                fontsize=fontsize_bar
            )

ax2.set_xlabel("% of foreground pixels", fontsize=fontsize_axis)
ax2.set_xticks([0.1,0.2,0.3,0.4,0.5], ["10%", "20%", "30%", "40%", "50%"], fontsize=fontsize_label)
# Invert y-axis to match the order of class names
ax2.invert_yaxis()

# Hide y-labels for the barplot
ax2.get_yaxis().set_visible(False)
# ax2.set_title("Percentage of foreground pixels annotated for an explanation")


# Adjust layout to ensure everything fits
ax1.invert_yaxis()
ax1.xaxis.set_ticks_position('bottom')  # Ensure x-ticks are at the bottom
ax1.text(-1.0, -1, "a)", size=24)
ax2.text(-0.1, -1, "b)", size=24)

# Display the plot
plt.savefig("figures/pixelagreement.svg", dpi=500)

In [None]:
label_counts.sum(dim=1)/count_foreground

In [None]:
ccmm = torch.stack([c.compute() for c in conf_matrices], dim=0).float()
ccmm = ccmm/ccmm.sum(dim=2).reshape(3,10,1)
print(ccmm.float().mean(dim=0).diag())
print(ccmm.float().std(dim=0).diag())

tensor([6.8426e-01, 7.3834e-01, 7.5832e-05, 4.4400e-01, 7.2964e-01, 0.0000e+00,
        6.7376e-01, 0.0000e+00, 7.2274e-01, 1.0153e-03])
tensor([0.0169, 0.0266, 0.0001, 0.0118, 0.0505, 0.0000, 0.0270, 0.0000, 0.0410,
        0.0018])

In [None]:
ccmm.sum(dim=2).shape

In [None]:
# Confusion Matrix

torch.set_printoptions(sci_mode=False)
classes_named = data.classes_named
classes_named = ["benign tissue", "3 - individual glands", "3 - compressed glands", "4 - poorly formed glands",
                 "4 - cribriform glands", "4 - glomeruloid glands", "5 - group of tumor cells", "5 - single cells", "5 - cords", "5 - comedenocrosis"]
confm = ccmm.float().mean(dim=0)

confm_normed = confm / confm.sum(dim=1).reshape(-1, 1)
confm_normed = (confm_normed * 1000).round().long()

confm_expanded = confm_normed
original_cmap = cm.get_cmap("Blues")
im = plt.matshow(confm_expanded, cmap=LinearSegmentedColormap.from_list("Blues_80", original_cmap(np.linspace(0, 0.7, 256))))

for (i, j), val in np.ndenumerate(confm_expanded):
    if val/10 < 0.01:
        s = "0"
    else:
        s = f'{val/10:.1f}'
    plt.text(j, i, s, ha='center', va='center', color='black', size=11)

# plt.colorbar(im)
plt.xticks(range(num_classes), list(map(lambda x: x, classes_named)),  # + ["Proportion"],
           rotation=45, rotation_mode="anchor", ha="right", va="center_baseline", fontsize=16)

plt.yticks(range(num_classes), list(map(lambda x: x, classes_named))  # + ["Proportion"]
           , rotation=45, fontsize=16)
plt.tick_params(axis="x", bottom=True, top=False, labelbottom=True, labeltop=False)

plt.ylabel("Annotations", fontsize=22)
plt.xlabel("Predictions", fontsize=22)


# Adding grey unfilled squares on the diagonal
diagonal_boxes = [
    (0, 0, 1, 1),
    (1, 1, 2, 2),
    (3, 3, 3, 3),
    (6, 6, 4, 4),  

]

for x, y, width, height in diagonal_boxes:
    rect = Rectangle((y - 0.5, x - 0.5), width, height, fill=False, edgecolor='grey', linewidth=2)
    plt.gca().add_patch(rect)

plt.savefig(fig_dir / "confmatrix.svg", dpi=1000)

In [None]:
# Class annotation and prediction frequency

a = torch.sum(torch.stack([pred.sum(dim=(1, 2)) for pred in preds]),dim=0)
a /= a.sum()

b = torch.sum(torch.stack([label.sum(dim=(1, 2)) for label in labels]), dim=0)

b /= b.sum()


bar_width = 0.15
plt.bar(np.arange(num_classes)-1.7*bar_width, b, label="soft label probability mass", width=bar_width)
plt.bar(np.arange(num_classes)-0.7*bar_width, a, label="predicted probability mass", width=bar_width, color='lightslategray')

plt.bar(np.arange(num_classes)+0.7*bar_width, max_freq/torch.sum(max_freq), label="annotator majority vote", width=bar_width, color='skyblue')
# plt.bar(np.arange(num_classes)+1.5*bar_width, one_annotator_pred_freq/torch.sum(max_freq), label="one_annotator_pred_freq", width=bar_width)
plt.bar(np.arange(num_classes)+1.7*bar_width, pred_freq/torch.sum(pred_freq), label="prediction argmax", width=bar_width, color='lightgray')

# plt.yscale("log")
_ = plt.legend(fontsize=14)
_ = plt.xticks(np.arange(num_classes), list(map(lambda x: x[:24], classes_named)), rotation=45, ha="right", rotation_mode="anchor", fontsize=18)
_ = plt.yticks(fontsize=18)
_ = plt.xlabel("Explanation", fontsize=20)
_ = plt.ylabel("Proportion", fontsize=20)

plt.grid(axis='y', linestyle='--', linewidth=0.5, color='lightgray')

plt.savefig(fig_dir / "proportions.svg", dpi=1000)

# Agreement and Class Distribution

In [None]:
import os
from pathlib import Path

In [None]:
# %%
in_path = Path("./output")
out_path = Path("./output")

In [None]:
creation_in_path = Path(os.environ["DATASET_LOCATION"]) / "GleasonXAI"
dataset = GleasonX(creation_in_path, split='all', scaling='MicronsCalibrated', label_level=1, create_seg_masks=True, drawing_order='grade_frame_order', explanation_file='final_filtered_explanations_df.csv', data_split=[0.7, 0.15, 0.15], tissue_mask_kwargs={'open': False, 'close':False, 'flood':False})

### Fleiss Kappa Boxplot

In [None]:
import random
from calculate_fleiss_kappa import calculate_kappa_per_group_and_label

In [None]:
%%capture
use_sub_expl = True
prefix = "pattern" if use_sub_expl == None else "sub-expl" if use_sub_expl == True else "expl"

if not (in_path / prefix / f"{prefix}_kappas.csv").exists():
    random.seed(42)
    calculate_kappa_per_group_and_label(dataset.df, in_path, use_sub_explanations=use_sub_expl)

In [None]:
kappas_df = pd.read_csv(in_path / prefix /  f"{prefix}_kappas.csv")
renamed_labels = np.load(in_path / prefix / f"{prefix}_kappas_y-lables.npy")

figsize = (6, 2.6) if use_sub_expl == None else (5,16) if use_sub_expl == True else (10, 9)
plt.figure(figsize=(figsize))
ax_bp = sns.boxplot(data=kappas_df.T[1:], color='lightskyblue', showfliers=False, width=0.5, native_scale=False, orient='h') #data=kappas_df.T, palette=cmap
plt.xlim(-0.3, 1.1)
plt.grid(axis='x', linestyle='--', linewidth=0.5, color='lightgray')
ax_bp.set_yticklabels(renamed_labels)
plt.xticks(rotation=45, ha='right', fontsize=18)
plt.yticks(fontsize=22)
sns.stripplot(data=kappas_df.T[1:], color='lightskyblue', jitter=False, linewidth=1, orient='h')
sns.boxplot(data=kappas_df.T[1:], showfliers=False, medianprops={'visible': False}, showbox=False, showcaps=False, width=0.5, native_scale=False, orient='h', showmeans=True, meanprops={"marker":"D","markerfacecolor":"white", "markeredgecolor":"gray"}, zorder=10) #data=kappas_df.T, palette=cmap


plt.tight_layout()
plt.savefig(fig_dir / f"{prefix}_boxplot_kappas.svg", dpi=1000)

# plt.clf()
# plt.cla()
# plt.close()
plt.show()

### Rater Agrement

In [None]:
from calculate_dataset_characteristics import agreement_occurrence_per_class, plot_three_rater_agreement_occurrence

In [None]:
use_sub_expl = None
prefix = 'expl' if use_sub_expl == False else 'sub-expl' if use_sub_expl == True else 'grade'
ylbl = 'Explanation' if use_sub_expl == False else 'Sub-Explanation' if use_sub_expl == True else 'Gleason Pattern'

In [None]:
#%%capture
if not (in_path / f"{prefix}_rater_agreement.csv").exists():
     ca_dict = (agreement_occurrence_per_class(dataset.df, use_sub_expl, True))
     plot_three_rater_agreement_occurrence(ca_dict, dataset.df,  in_path, use_sub_expl)


In [None]:
rdf = pd.read_csv(in_path / f"{prefix}_rater_agreement.csv")
fig_size = (7, 15) if use_sub_expl == True else (10,9) if use_sub_expl == False else (6.3, 3)
fig, ax = plt.subplots(figsize=fig_size)
ax = sns.heatmap(rdf.T[1:], annot=True, fmt='g', cmap='Blues', square=False, cbar=False, annot_kws={"size": 14 if use_sub_expl == True else 22 }, linewidths=.5)#cbar_kws={"shrink": 0.2})
#plt.title('co-occurence of label decisions between three raters\n(per label)')
plt.xlabel('Number of Annotators', fontsize=22)
plt.ylabel(ylbl, fontsize=22)
plt.xticks(fontsize=22)
plt.yticks(fontsize=22)
ax.set_yticklabels(ax.get_yticklabels(), rotation='horizontal')
plt.tight_layout()
plt.savefig(fig_dir / f"{prefix}_agreement_occurences.png", dpi=1000)
# plt.cla()
# plt.clf()
# plt.close()
plt.show()


### Image vs Annotator Grade

In [None]:
from calculate_dataset_characteristics import compare_tma_and_explanation_grade

In [None]:
%%capture
if not (in_path / f"image_vs_annotation_grade.csv").exists():
    compare_tma_and_explanation_grade(dataset.df, in_path)

In [None]:
cm_df = pd.read_csv(in_path / f"image_vs_annotation_grade.csv").iloc[:,1:]
plt.figure(figsize=(3.5,3))
sns.heatmap(cm_df, annot=True, fmt='g', cmap='Blues', cbar=False, square=True, annot_kws={"size": 14}, linewidths=.5)
# plt.title("Confusion Matrix: TMA grade to Explanation Grade")
plt.xlabel('Annotation Gleason pattern', fontsize=14)
plt.ylabel('Image score', fontsize=14)
plt.xticks(fontsize=14)
plt.yticks(fontsize=14)
plt.tight_layout()
plt.savefig(fig_dir / "cm_class_comparison.png", dpi=1000)

plt.clf()
plt.cla()
plt.close()

### Class distribution

In [None]:
from calculate_dataset_characteristics import get_class_cooccurrence_between_labelers

In [None]:
use_sub_expl = True
prefix = 'expl' if use_sub_expl == False else 'sub-expl' if use_sub_expl == True else 'grade'
ylbl = 'Explanation' if use_sub_expl == False else 'Sub-Explanation' if use_sub_expl == True else 'Gleason Pattern'

In [None]:
%%capture
distr_in_path = in_path / "class_dist"
if not (distr_in_path / "sub-expl_label.npy").exists():
    for usage in [True, False, None]:
        get_class_cooccurrence_between_labelers(dataset.df, usage, True, in_path)
    

In [None]:
distr_in_path = in_path / "class_dist"
assert os.path.exists(distr_in_path)

grade_lbl = np.load(distr_in_path / "grade_label.npy")
grade_val = np.load(distr_in_path / "grade_value.npy")

grade_lbl = ["Gleason Pattern " + c_lbl for c_lbl in grade_lbl]

expl_lbl = np.load(distr_in_path / "sub-expl_label.npy")
expl_val = np.load(distr_in_path / "sub-expl_value.npy")

grouped_lbl =np.load(distr_in_path / "expl_label.npy")
grouped_val = np.load(distr_in_path / "expl_value.npy")

# cmap = 'Blues'#{'3':'forestgreen', '4': 'royalblue', '5': 'firebrick'}
# cmap_raw = 'Blues'#{i: "forestgreen" if i < 8 else "royalblue" if i < 25 else 'firebrick' for i in range(0, len(expl_lbl))} 
# cmap_grouped = 'Blues'#{i: "forestgreen" if i < 2 else "royalblue" if i < 5 else 'firebrick' for i in range(0, len(grouped_lbl))} 

sns.set_theme(style="ticks")

ybar_fontsize = 19
single_fontsize = 22

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(18, 5), width_ratios=[0.09, 0.91])
ax1.grid(axis='y', linestyle='--', linewidth=0.5, color='lightgray')
ax1.tick_params(axis='y', labelsize=single_fontsize)
ax1.set_ylim((0,900))

ax2.grid(axis='y', linestyle='--', linewidth=0.5, color='lightgray')
ax2.tick_params(axis='y', labelsize=single_fontsize)
ax2.set_ylim((0,900))

bar_width = 0.5
bar_spacing = 0.1
positions_grade = np.arange(len(grade_lbl)) * (bar_width + bar_spacing)
positions_expl = np.arange(len(expl_lbl)) * (bar_width + bar_spacing)

bars1 = ax1.bar(positions_grade, grade_val, width=0.3, color='#77b5d9')#[cmap[cat] for cat in grade_lbl])
ax1.set_xticks(positions_grade)
ax1.set_xticklabels(grade_lbl, fontsize=single_fontsize, rotation=45)

bars2 = ax2.bar(positions_expl, expl_val, width=0.3, color='#77b5d9')#[cmap_raw[i] for i in range(len(expl_lbl))])
ax2.set_xticks(positions_expl)
ax2.set_xticklabels(expl_lbl, fontsize=single_fontsize, rotation=45)


for bar, value in zip(bars1, grade_val):
    ax1.text(bar.get_x() + bar.get_width() / 2, value + 0.5, str(value), ha='center', va='bottom', color='black', fontsize=ybar_fontsize)
for bar, value in zip(bars2, expl_val):
    ax2.text(bar.get_x() + bar.get_width() / 2, value + 0.5, str(value), ha='center', va='bottom', color='black', fontsize=ybar_fontsize)


ax1.spines['top'].set_visible(False)
ax1.spines['right'].set_visible(False)
ax2.spines['top'].set_visible(False)
ax2.spines['right'].set_visible(False)

plt.tight_layout()
plt.savefig(fig_dir / "class_dist_grade_expl.svg", dpi=1000)
plt.cla()
plt.close()

plt.rcParams['ytick.labelsize'] = single_fontsize

plt.figure(figsize=(8, 7))
plt.ylim((0,900))
plt.grid(axis='y', linestyle='--', linewidth=0.5, color='lightgray')

positions_grouped = np.arange(len(grouped_lbl)) * (bar_width + bar_spacing)

bars = plt.bar(positions_grouped, grouped_val, width=0.3, color='#77b5d9')#[cmap_grouped[i] for i in range(len(grouped_lbl))])

plt.xticks(positions_grouped, grouped_lbl, rotation=45, ha='right', fontsize=single_fontsize)
plt.yticks(fontsize=single_fontsize)
plt.grid(axis='y', linestyle='--', linewidth=0.5, color='lightgray')

# Add y value labels at the bottom of each bar
for bar, value in zip(bars, grouped_val):
    plt.text(bar.get_x() + bar.get_width() / 2, value + 0.5, str(value), ha='center', va='bottom', color='black', fontsize=ybar_fontsize)

ax = plt.gca()
ax.spines['top'].set_visible(False)
ax.spines['right'].set_visible(False)
plt.tight_layout()
plt.savefig(fig_dir / "class_dist_explanation.svg", dpi=1000)
# %%


# Debug Metrics

In [None]:
d_mac = Dice(num_classes=num_classes, average="macro")
d_mic = Dice(num_classes=num_classes, average="micro")
l = JDTLoss(alpha=0.5, beta=0.5)

# Compute SoftDice C,D,I
l_C = JDTLoss(mIoUC=1.0, mIoUD=0.0, mIoUI=0.0, alpha=0.5, beta=0.5, active_classes_mode_soft="ALL")
l_D = JDTLoss(mIoUC=0.0, mIoUD=1.0, mIoUI=0.0, alpha=0.5, beta=0.5, active_classes_mode_soft="ALL")
l_I = JDTLoss(mIoUC=0.0, mIoUD=0.9, mIoUI=1.0, alpha=0.5, beta=0.5, active_classes_mode_soft="ALL")

In [None]:
emd = []
for i in range(len(preds_ensemble)):#preds_ensemble.shape[0]):
    rel_pred = preds_ensemble[i]
    rel_bg = bgs[i]
    rel_label = labels[i]

    emd.append(((rel_pred[:, ~rel_bg] - rel_label[:, ~rel_bg]).abs().sum(dim=1)/2)/(~rel_bg).sum())

emd = torch.stack(emd).mean(dim=0)
print(emd)
emd.sum()

In [None]:
kl = torchmetrics.KLDivergence()
kls = []

for i in range(len(preds_ensemble)):#preds_ensemble.shape[0]):
    rel_pred = preds_ensemble[i]
    rel_bg = bgs[i]
    rel_label = labels[i]

    rel_pred = rel_pred[:, ~rel_bg]
    rel_label = rel_label[:, ~rel_bg]

    kls.append(torch.nn.functional.kl_div(torch.log(rel_pred).unsqueeze(0), rel_label.unsqueeze(0)))

kls
torch.mean(torch.stack(kls))

In [None]:
tverskys = []
active_classes = []
per_image_dice_scores = []

for i in range(len(labels)):
    bg = bgs[i]
    label = labels[i]

    tversky, ac = l.get_image_class_matrix(preds_ensemble[i][:, ~bg].unsqueeze(0), labels[i][:, ~bg].unsqueeze(0), prob_predictions=True)
    tverskys.append(tversky)
    active_classes.append(ac)

    # Find parts of label with non-unique maximum
    label_max = torch.max(label, dim=0)[0]
    duplicated_max = torch.sum(label == label_max.unsqueeze(0), dim=0) > 1

    unique_max = ~duplicated_max
    fg = ~bg

    unique_max_fg = torch.logical_and(fg, unique_max)
    per_image_dice_scores.append(torchmetrics.functional.dice(preds_ensemble[i][:, unique_max_fg].unsqueeze(0).argmax(dim=1), labels[i][:, unique_max_fg].unsqueeze(0).argmax(dim=1), average="none", num_classes=num_classes))
    d_mac.update(preds_ensemble[i][:, unique_max_fg].unsqueeze(0).argmax(dim=1), labels[i][:, unique_max_fg].unsqueeze(0).argmax(dim=1))
    d_mic.update(preds_ensemble[i][:, unique_max_fg].unsqueeze(0).argmax(dim=1), labels[i][:, unique_max_fg].unsqueeze(0).argmax(dim=1))
tverskys = torch.stack(tverskys).squeeze()
active_classes = torch.stack(active_classes).squeeze()
per_image_class_dice_scores = torch.stack(per_image_dice_scores).squeeze()

In [None]:
(per_image_class_dice_scores.nanmean(dim=0).mean(),d_mac.compute(), d_mic.compute())

In [None]:
IoUCs = []
for c in range(num_classes):
    if active_classes[:, c].sum() > 0:
        IoUCs.append(tverskys[:, c][active_classes[:, c]].mean())

mIoUC = torch.sum(torch.stack(IoUCs)) / (active_classes.sum(dim=0) > 0).sum()
mIoUC

In [None]:
IoUIs = []
for i in range(tverskys.shape[0]):
    IoUIs.append(tverskys[i, :][active_classes[i, :]].mean())

mIoUI = torch.sum(torch.stack(IoUIs)) / tverskys.shape[0]
mIoUI

In [None]:
my_mac_dice = SoftDICEMetric(average="macro")
other_soft_dice = SoftCorrectDICEMetric(average=None)

for i in range(len(preds_ensemble)):#preds_ensemble.shape[0]):
    rel_pred = preds_ensemble[i]
    rel_bg = bgs[i]
    rel_label = labels[i]

    rel_pred = rel_pred[:, ~rel_bg]
    rel_label = rel_label[:, ~rel_bg]

    my_mac_dice.update(rel_pred, rel_label)
    other_soft_dice.update(rel_pred.unsqueeze(0), rel_label.unsqueeze(0))

print(my_mac_dice.compute())
print(other_soft_dice.compute())

# Visualizations

In [None]:
confm_ml = conf_matrix.compute()
confm_ml_reshape = confm_ml.reshape(confm_ml.shape[0], -1)
num_pixels = confm_ml_reshape.sum(dim=1)
confm_ml_reshape_normed = confm_ml_reshape / num_pixels.reshape(-1, 1)

TN, FP, FN, TP = confm_ml_reshape[:, 0], confm_ml_reshape[:, 1], confm_ml_reshape[:, 2], confm_ml_reshape[:, 3]

P = (TP+FN)
N = (TN+FP)

POPU = (TN+FP+FN+TP)
ACC = (TN+TP)/POPU
TPR = TP/(TP+FN)
TNR = (TN/(TN+FP))
PREC = TP/(FP+TP)
BACC = (TPR+TNR) / 2
DICE = (2*TP)/(2*TP+FP+FN)

PREL = P / POPU

confm_ml_reshape_normed = torch.cat([PREL.reshape(num_classes, 1), confm_ml_reshape_normed[:, [3,0,1,2]], ACC.reshape(
    num_classes, 1), PREC.reshape(num_classes, 1), BACC.reshape(num_classes, 1), DICE.reshape(num_classes,1)], dim=1)

im3 = plt.matshow(confm_ml_reshape_normed)
plt.colorbar(im3)
plt.xticks([0, 1, 2, 3, 4, 5, 6, 7,8], ["PREL", "TP", "TN", "FP", "FN", "ACC", "PREC", "BA", "DICE"], rotation=45)
_ = plt.yticks(range(len(classes_named)), classes_named)
_ = plt.xlabel("Metrics in %")
for (i, j), val in np.ndenumerate(confm_ml_reshape_normed):
    plt.text(j, i, f"{val*1000:0.0f}",c="red", ha="center", va="center")

In [None]:
@wid.interact(idx=wid.IntSlider(min=0, max=len(data)-1, value=0))
def show_worst_results(idx):
    _ = create_single_class_acti_maps(predictions=preds_ensemble, dataset=data, idx=idx, plot_mode="heatmap", strip_background=True)
    # _ = create_ensemble_plot(dataset=data, ensemble_predictions=preds_ensemble, individual_predictions=preds, idx=idcs[idx], show_ensemble_preds=0)

In [None]:
good_idcs = [4, 16, 22, 25, 29, 30, 35, 52, 53, 54, 84, 90, 113, 114, 142]
interest_cases = [ 85, 79, 105]
bad_cases =[66,83]
57,108, 133, 151
final_preds = [57,108,54, 4, 84, 85, 105, 66, 83]

create_multi_seg_anno_plot(predictions=preds, dataset=data, idcs=final_preds, strip_background=True, legend=True, class_names=class_names)
plt.savefig(fig_dir / "imgs_final_paper.png", dpi=100)

In [None]:
vals, idcs = torch.sort(torch.stack(IoUIs), descending=True)

@wid.interact(idx=wid.IntSlider(min=0, max=len(data)-1, value=0))
def show_worst_results(idx):
    _ = create_single_class_acti_maps(predictions=preds_ensemble, dataset=data, idx=idcs[idx], plot_mode="heatmap", strip_background=True)
    #_ = create_ensemble_plot(dataset=data, ensemble_predictions=preds_ensemble, individual_predictions=preds, idx=idcs[idx], show_ensemble_preds=0)


In [None]:
vals, idcs = torch.sort(torch.stack(IoUIs), descending=True)

import ipywidgets as wid

def show_worst_results(idx):
    f, ax = create_single_class_acti_maps(predictions=preds_ensemble, dataset=data, idx=idcs[20], plot_mode="heatmap", strip_background=True, plot=False)
    plt.subplots_adjust
    plt.savefig("./figures/test.png", dpi=2000)
    #_ = create_ensemble_plot(dataset=data, ensemble_predictions=preds_ensemble, individual_predictions=preds, idx=idcs[idx], show_ensemble_preds=0)

show_worst_results(0)

In [None]:
vals, idcs = torch.sort(torch.stack(IoUIs), descending=True)
imgs_good = np.array([3,5,6,9,10,11,14,15,16,18,20,21,23,24,34,36,45,46,92])
sm_plot = np.array([25,46])

sieht_anders = np.array([49, 56, 58, 70, 63, 91, 17])

ganz_anders =np.array([62,68,73, 26])

#fs = []
#axess =[]
#for idx in imgs_good:
#    create_simple_seg_anno_plot(predictions=preds_ensemble, dataset=data, idx=idcs[idx], plot_mode="heatmap", strip_background=True, plot=True)

create_multi_seg_anno_plot(predictions=preds_ensemble, dataset=data, idcs=[idcs[i] for i in sm_plot], strip_background=True)
plt.savefig("./figures/imgs_good.png", dpi=100)

In [None]:
to_vis = np.concatenate([imgs_good[[2,5,7,8,16]], sieht_anders[[2,6]], ganz_anders[[0,3]]], axis=0)
create_multi_seg_anno_plot(predictions=preds_ensemble, dataset=data, idcs=[idcs[i] for i in to_vis], strip_background=True)
plt.savefig("./figures/imgs_good_selected.png", dpi=100)

In [None]:
imgs_good[[2,5,7,8,16]]

In [None]:
cm_pathos = ConfusionMatrix(task="multiclass", num_classes=10)
cm_net = ConfusionMatrix(task="multiclass", num_classes=10)

from itertools import combinations
for i in range(len(data)):
    (_, patho_preds, bg) = data.__getitem__(i, False)

    pred = preds[i].squeeze(0).argmax(dim=0)[~bg]



    for a,b in combinations(patho_preds, r=2):
        a = torch.tensor(a[~bg])
        b = torch.tensor(b[~bg])
        cm_pathos.update(a,b)
        cm_pathos.update(b,a)

    for patho_pred in patho_preds:
        patho_pred = torch.tensor(patho_pred[~bg])
        cm_net.update(pred,patho_pred)
        cm_net.update(patho_pred,pred)




In [None]:
cm_results1 = cm_pathos.compute().numpy()
cm_results1 = cm_results1 / cm_results1.sum(axis=1).reshape(-1, 1)

sns.heatmap(cm_results1*1000, xticklabels=class_names, yticklabels=class_names, fmt=".0f", cmap="Blues", annot=True)

In [None]:
cm_results2 = cm_net.compute().numpy()
cm_results2 = cm_results2 / cm_results2.sum(axis=1).reshape(-1, 1)

sns.heatmap(cm_results2*1000, xticklabels=class_names, yticklabels=class_names, fmt=".0f", cmap="Blues", annot=True)

In [None]:
sns.heatmap((cm_results2-cm_results1)*1000, xticklabels=class_names, yticklabels=class_names, fmt=".0f", cmap="bwr", annot=True, vmin=-200, vmax=200)

In [None]:
cm_pathos.compute().numpy()

In [None]:

cm_net_results = cm_pathos.compute().numpy()

cm_net_results = cm_net_results / cm_net_results.sum(axis=1).reshape(-1,1)
plt.matshow(cm_pathos.compute().numpy())
plt.xticks(range(10), class_names, rotation=45, ha="left")
plt.yticks(range(10), class_names, rotation=45)