In [None]:
%load_ext autoreload
%autoreload 2


import torch
from src.jdt_losses import JDTLoss
from torchmetrics import Dice

from pathlib import Path
from src.gleason_data import GleasonX
from src.augmentations import basic_transforms_val_test_scaling512, normalize_only_transform
import torchmetrics
from src.model_utils import SoftDICEMetric
from src.jdt_losses import SoftCorrectDICEMetric
from src.jdt_losses import SoftDICECorrectAccuSemiMetric


In [None]:
base_path = Path("/home/experiments/Gleason")
data_path = Path("/home/datasets/GleasonXAI/")

In [None]:
available_models = ["ll1/CE", "ll1/Tree", "ll1/SDB", "ll0/CE", "ll0/SDB", "ens/ll1/Tree", "ens/ll1/SDB", "final/ll0/CE"]

selected_model = "final/ll1/CE"
remap_ll0 = False


label_level = 1 if "ll1" in selected_model else 0
eval_on = "test"
save_mem = True

num_classes = 10 if (label_level == 1 and not remap_ll0) else 4

if remap_ll0:
    assert "ll1" in selected_model
    label_level = 0

ens_mode = "ens" in selected_model

restructured_data = "final" if "final" in selected_model else "org"


if selected_model == "ens/ll1/SDB":
    model_paths = [Path(f"GleasonBackgroundMasking/label_level1/HoleMask/SoftDICEBalancedNoZoomCont-{i}/version_0/") for i in [1,2,3,4]]
    model_paths +=  [Path(f"GleasonBackgroundMasking/label_level1/HoleMask/SoftDICEBalancedNoZoom-{i}/version_0/") for i in [1, 2,3]]

elif selected_model == "ens/ll1/Tree":
    model_paths = [Path(f"GleasonBackgroundMasking/label_level1/HoleMask/FinalTreeLossNoZoomCont-{i}/version_0/") for i in [1, 2, 3, 4]]
    model_paths +=  [Path(f"GleasonBackgroundMasking/label_level1/HoleMask/NoZoomFinalTreeLoss-{i}/version_0/") for i in [1, 2,3]]

elif selected_model == "ll1/CE":
    model_paths = [Path(f"GleasonBackgroundMasking/label_level1/HoleMask/NoZoomFinalCE-{i}/version_0/") for i in [1, 2, 3]]

elif selected_model == "ll1/Tree":
    model_paths = [Path(f"GleasonBackgroundMasking/label_level1/HoleMask/NoZoomFinalTreeLoss-{i}/version_0/") for i in [1, 2, 3]]

elif selected_model == "ll1/SDB":
    model_paths = [Path(f"GleasonBackgroundMasking/label_level1/HoleMask/SoftDICEBalancedNoZoom-{i}/version_0/") for i in [1, 2, 3]]

elif selected_model == "ll0/CE":
    model_paths = [Path(f"GleasonBackgroundMasking/label_level0/HoleMask/NoZoomFinalCE-{i}/version_0/") for i in [1, 2, 3]]

elif selected_model == "ll0/SDB":
    model_paths = [Path(f"GleasonBackgroundMasking/label_level0/HoleMask/SoftDICEBalancedNoZoom-{i}/version_0/") for i in [1, 2, 3]]

elif selected_model == "final/ll0/CE":
    model_paths = [Path(f"GleasonFinal/label_level0/CE-{i}/version_0/") for i in [1,2,3]]

elif selected_model == "final/ll1/CE":
    model_paths = [Path(f"GleasonFinal/label_level1/CE-{i}/version_0/") for i in [1,2,3]]

else:
    raise RuntimeError()


for path in model_paths:
    assert (base_path/path).exists()

In [None]:

if not ens_mode or (ens_mode and not save_mem):

    preds = [torch.load(base_path/m_path/"preds"/f"pred_{eval_on}.pt") for m_path in model_paths]
    preds = [[torch.nn.functional.softmax(img_pred.float().squeeze(0), dim=0) for img_pred in pred] for pred in preds]
    

else:
    preds_ensemble = torch.nn.functional.softmax(torch.load(base_path/model_paths[0]/"preds"/f"pred_{eval_on}.pt").float().detach(), dim=1)
    preds = None
    for m_path in model_paths[1:4]:
        preds_ensemble += torch.nn.functional.softmax(torch.load(base_path/m_path/"preds"/f"pred_{eval_on}.pt").float().detach(), dim=1)

    preds_ensemble = preds_ensemble/len(model_paths)

In [None]:
data_options = {"org":{"scaling":"1024", "transforms":basic_transforms_val_test_scaling512, "label_level":label_level, "create_seg_masks":True, "tissue_mask_kwargs":{"open": False, "close":False, "flood":False}},
                "final": {"scaling": "MicronsCalibrated", "transforms": normalize_only_transform, "label_level": label_level, "create_seg_masks": True, "tissue_mask_kwargs": {"open": False, "close": False, "flood": False}, "drawing_order":"grade_frame_order", "explanation_file":"final_filtered_explanations_df.csv"}}

#data = GleasonX(Path("/home/datasets/GleasonXAI/"), split="test", scaling="1024", transforms=basic_transforms_val_test_scaling512, label_level=label_level, create_seg_masks=True, tissue_mask_kwargs={"open": False, "close":False, "flood":False})
data = GleasonX(data_path, split="all", **data_options["final"])

In [None]:
labels = []
bgs = []
for i in range(len(data)):
    _, label, background = data[i]
    labels.append(label)
    bgs.append(background)

In [None]:
from src.tree_loss import generate_label_hierarchy

def remapping_function(out):

    out_remappings = generate_label_hierarchy(out, data.exp_numbered_lvl_remapping, start_level=1)

    return out_remappings[0]

if remap_ll0:
    print("REMAP!")
    preds_ensemble = remapping_function(preds_ensemble)
    #labels = remapping_function(labels)

    if preds is not None:
        preds = torch.stack([remapping_function(preds[i]) for i in range(preds.shape[0])])


## VISUALIZATIONS

In [None]:
from itertools import zip_longest
import math
from textwrap import wrap
from typing import Literal

from matplotlib import pyplot as plt
import src.augmentations as augmentations
import numpy as np
from matplotlib.colors import ListedColormap
from src.gleason_utils import create_composite_plot
from matplotlib import colormaps as cm

import pandas as pd

In [None]:
def create_ensemble_plot(dataset, idx, ensemble_predictions, individual_predictions, show_ensemble_preds = 0):
    
    img, masks, background_mask = dataset.__getitem__(idx, False)
    org_img = augmentations.basic_transforms_val_test_colorpreserving(image=np.array(dataset.get_raw_image(idx)))["image"]
    np_seg = np.array(ensemble_predictions[idx].argmax(dim=0)).astype(np.uint8)

    masks = {"segmentation": np_seg} | {f"Annotator {i}": mask.astype(np.uint8) for i,
                                        mask in enumerate(masks)}

    sub_ensemble_preds = []
    rnd_ensemble_subset = np.random.permutation(show_ensemble_preds)

    for i in rnd_ensemble_subset:
        np_seg = np.array(individual_predictions[i, idx].argmax(dim=0)).astype(np.uint8)
        sub_ensemble_preds.append(np_seg)
    

    
    if len(sub_ensemble_preds) > 0:
        masks = {f"ensemble_pred_{i}":sub_ensemble_pred for i, sub_ensemble_pred in enumerate(sub_ensemble_preds)} | masks


    f = create_composite_plot(dataset, org_img, masks, background_mask.astype(np.uint8), label_level=1, only_show_existing_annotation=True)


def composite_prediction_plot(predictions, dataset, indices, mask_background=False, full_legend=True):

    num_plots = len(indices)
    rows = int(np.ceil(num_plots / 3))  # Assuming 3 columns per row, adjust as needed

    for idx in indices:

        img, masks, background_mask = dataset.__getitem__(idx, False)
        org_img = augmentations.basic_transforms_val_test_colorpreserving(image=np.array(dataset.get_raw_image(idx)))["image"]

        background = background_mask if mask_background else None

        #out = generate_model_output(model, img, device=device, label_remapping=LABEL_REMAPPING, inferer=SLIDING_WINDOW_INFERER)
        out = predictions[idx]
        out = torch.nn.functional.softmax(out, 0)
        np_seg = np.array(out.argmax(dim=0)).astype(np.uint8)
        
        f = create_composite_plot(dataset, org_img,  {"segmentation": np_seg} | {f"Annotator {i}": mask for i,
                                    mask in enumerate(masks)}, background, only_show_existing_annotation=not full_legend)

        return f


def create_single_class_acti_maps(predictions, dataset, idx, plot_mode: Literal["heatmap", "contourf", "contour", "thresholded"] = "contourf", thresholds: list[float] = None, strip_background=False, plot=True):

    img, masks, background_mask = dataset.__getitem__(idx, False)

    out = predictions[idx] #generate_model_output(model, img, device=device, label_remapping=LABEL_REMAPPING, inferer=SLIDING_WINDOW_INFERER)

    out = torch.nn.functional.softmax(out, 0)

    np_seg = np.array(out.argmax(dim=0)).astype(np.uint8)

    org_img = np.array(dataset.get_raw_image(idx).resize(np_seg.shape))

    colormap = ListedColormap(dataset.colormap.colors)
    num_class_to_vis = dataset.num_classes

    if strip_background:

        for mask in masks:
            mask += 1
            mask[background_mask] = 0

        np_seg += 1
        np_seg[background_mask] = 0

        out[:, torch.tensor(background_mask).bool()] = 0.0

        colormap = ListedColormap(np.concatenate([np.array([[0., 0., 0., 1.]]), dataset.colormap.colors]))
        num_class_to_vis = dataset.num_classes + 1

    f, axes = plt.subplots(2, 3+math.ceil(dataset.num_classes/2), sharex=False, sharey=False, constrained_layout=False, figsize=(12, 4))
    f.tight_layout()

    axes[0, 0].imshow(org_img)
    axes[0, 0].set_title("Image", size=7)
    axes[0, 0].set_axis_off()

    # axes[1, 0].imshow(img)
    axes[1, 0].imshow(np_seg.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1, interpolation_stage="rgba")
    axes[1, 0].set_title("Segmentation", size=7)
    axes[1, 0].set_axis_off()

    for sub_ax, mask in zip_longest(list(axes[:, 1:3].flatten()), masks):
        sub_ax.set_axis_off()

        if mask is not None:
            sub_ax.imshow(mask.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1,  interpolation_stage="rgba")
            sub_ax.set_title("Annotation", size=7)

    for i in range(dataset.num_classes):
        active_axis = axes[:, 3:].flatten()[i]

        class_out = out[i, :].detach().numpy()

        if strip_background:
            class_out[background_mask] = 0

        temp_colormap = ListedColormap(np.concatenate([np.array([[0., 0., 0., 1.]]), dataset.colormap.colors[i].reshape(1, -1)]))

        match plot_mode:
            case "heatmap": active_axis.matshow(class_out, cmap=cm["Grays"].reversed(), vmin=out[out != 0.0].min(), vmax=out.max())
            case "multilabel": active_axis.matshow(class_out >= 0.32, cmap=temp_colormap)
            case "contour": active_axis.contour(np.flipud(class_out), cmap=cm["Grays"].reversed(), vmin=0.0, vmax=out.max())
            case "contourf": active_axis.contourf(np.flipud(class_out), cmap=cm["Grays"].reversed(), vmin=0.0, vmax=out.max())
            case "thresholded": active_axis.matshow(class_out > thresholds[i], cmap=cm["Grays"].reversed())

        if i == 0:
            title = "Benign"
        else:
            exp = dataset.explanations[i-1]
            classes_named = ["benign tissue", "individual glands", "compressed glands", "poorly formed glands",
                             "cribriform glands", "glomeruloid glands", "group of tumor cells", "single cells", "cords", "comedenocrosis"]
            exp_short = classes_named[i]
            title = "\n".join(wrap(str(dataset.exp_grade_mapping[exp]) + ": " + exp_short, width=20))

        active_axis.set_title(title, size=7)
        active_axis.set_axis_off()
    
    plt.subplots_adjust(wspace=0.1, hspace=0.3)

    if plot:
        plt.show()
    else:
        return f, axes


def create_simple_seg_anno_plot(predictions, dataset, idx, plot_mode: Literal["heatmap", "contourf", "contour", "thresholded"] = "contourf", thresholds: list[float] = None, strip_background=False, plot=True):

    img, masks, background_mask = dataset.__getitem__(idx, False)

    out = predictions[idx] #generate_model_output(model, img, device=device, label_remapping=LABEL_REMAPPING, inferer=SLIDING_WINDOW_INFERER)

    out = torch.nn.functional.softmax(out, 0)

    np_seg = np.array(out.argmax(dim=0)).astype(np.uint8)

    org_img = np.array(dataset.get_raw_image(idx).resize(np_seg.shape))

    colormap = ListedColormap(dataset.colormap.colors)
    num_class_to_vis = dataset.num_classes

    if strip_background:

        for mask in masks:
            mask += 1
            mask[background_mask] = 0

        np_seg += 1
        np_seg[background_mask] = 0

        out[:, torch.tensor(background_mask).bool()] = 0.0

        colormap = ListedColormap(np.concatenate([np.array([[0., 0., 0., 1.]]), dataset.colormap.colors]))
        num_class_to_vis = dataset.num_classes + 1

    f, axes = plt.subplots(1, 5, sharex=False, sharey=False, constrained_layout=False, figsize=(12, 4))
    f.tight_layout()
    axes[0].imshow(org_img)
    axes[0].set_title("Image", size=7)
    axes[0].set_axis_off()

    axes[1].imshow(np_seg.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1, interpolation_stage="rgba")
    axes[1].set_title("Segmentation", size=7)
    axes[1].set_axis_off()

    for sub_ax, mask in zip_longest(axes[2:], masks):
        
        if sub_ax is None:
            continue
        sub_ax.set_axis_off()

        if mask is not None:
            sub_ax.imshow(mask.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1,  interpolation_stage="rgba")
            sub_ax.set_title("Annotation", size=7)
    plt.tight_layout()


def create_multi_seg_anno_plot(predictions, dataset, idcs, strip_background=False):

    num_plots = len(idcs)
    
    f, top_axes = plt.subplots(num_plots, 5, sharex=False, sharey=False, constrained_layout=False, figsize=(10, 2*num_plots))
    f.tight_layout()



    if strip_background:
        colormap = ListedColormap(np.concatenate([np.array([[0., 0., 0., 1.]]), dataset.colormap.colors]))
        num_class_to_vis = dataset.num_classes + 1
    else:
        colormap = ListedColormap(dataset.colormap.colors)
        num_class_to_vis = dataset.num_classes
    
    for i,idx in enumerate(idcs):

        axes = top_axes[i]
        
        _, masks, background_mask = dataset.__getitem__(idx, False)

        out = predictions[idx]  # generate_model_output(model, img, device=device, label_remapping=LABEL_REMAPPING, inferer=SLIDING_WINDOW_INFERER)

        out = torch.nn.functional.softmax(out, 0)

        np_seg = np.array(out.argmax(dim=0)).astype(np.uint8)

        org_img = np.array(dataset.get_raw_image(idx).resize(np_seg.shape))


        if strip_background:

            for mask in masks:
                mask += 1
                mask[background_mask] = 0

            np_seg += 1
            np_seg[background_mask] = 0

            out[:, torch.tensor(background_mask).bool()] = 0.0

        # f, axes = plt.subplots(2, 3+math.ceil(dataset.num_classes/2), sharex=False, sharey=False, constrained_layout=False, figsize=(12, 4))
        axes[0].imshow(org_img)
        #axes[0].set_title(f"Image {i}", size=7)
        axes[0].set_axis_off()

        # axes[1, 0].imshow(img)
        axes[1].imshow(np_seg.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1, interpolation_stage="rgba")
        #axes[1].set_title("Segmentation", size=7)
        axes[1].set_axis_off()

        for sub_ax, mask in zip_longest(axes[2:], masks):

            if sub_ax is None:
                continue
            sub_ax.set_axis_off()

            if mask is not None:
                sub_ax.imshow(mask.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1,  interpolation_stage="rgba")
                #sub_ax.set_title("Annotation", size=7)
    plt.tight_layout()
    plt.subplots_adjust(wspace=0.03, hspace=0.05)

In [None]:
def get_all_metrics(preds, labels, bgs):

    d_mac = Dice(num_classes=num_classes, average="macro")
    d_mic = Dice(num_classes=num_classes, average="micro")
    emd = []
    mDICED = SoftDICECorrectAccuSemiMetric()

    for i in range(preds.shape[0]):
        pred = preds[i]
        bg = bgs[i]
        label = labels[i]
        fg = ~bg

        emd.append(((pred[:, fg] - label[:, fg]).abs().sum(dim=1)/2)/fg.sum())

        label_max = torch.max(label, dim=0)[0]
        duplicated_max = torch.sum(label == label_max.unsqueeze(0), dim=0) > 1

        unique_max = ~duplicated_max

        unique_max_fg = torch.logical_and(fg, unique_max)

        d_mac.update(pred[:, unique_max_fg].unsqueeze(0).argmax(dim=1), label[:, unique_max_fg].unsqueeze(0).argmax(dim=1))
        d_mic.update(pred[:, unique_max_fg].unsqueeze(0).argmax(dim=1), label[:, unique_max_fg].unsqueeze(0).argmax(dim=1))

        mDICED.update(pred[:,fg].unsqueeze(0), label[:,fg].unsqueeze(0))


    emd = torch.stack(emd).mean(dim=0)
    emd = emd.sum()

    d_mac = d_mac.compute()
    d_mic = d_mic.compute()
    mDICED = mDICED.compute()

    return {"mDICED":mDICED.item(), "L1":emd.item(), "DICEmacro":d_mac.item(), "DICE":d_mic.item()}

print("Model:", selected_model, "Ens; ", ens_mode, "Label Level: ", label_level, "Remaped:", remap_ll0, "num_classes: ", num_classes)


eval_ens = True
if eval_ens:
    mets = [get_all_metrics(preds_ensemble, labels, bgs)]
    print(mets)
else:
    mets = []
    for i in range(preds.shape[0]):
        mets.append(get_all_metrics(preds[i], labels, bgs))
    print(mets)


In [None]:
pd.DataFrame(mets).aggregate(["mean", "std"])

In [None]:
data.df["dataset"] = data.df["group"].apply(lambda x: 0 if x < 3 else 1 if x < 10 else 2)
for dataset, frame in data.df.groupby("dataset"):
    slides = set(frame["TMA_identifier"].unique())

    inter = set(data.used_slides).intersection(slides)

    idcs = sorted([data.used_slides.index(slide_name) for slide_name in inter])

    mets = []
    for i in range(preds.shape[0]):
        preds_split = preds[i]

        preds_split = preds_split[idcs]
        labels_split = labels[idcs]
        bgs_split = bgs[idcs]
        mets.append(get_all_metrics(preds_split, labels_split, bgs_split))
    print(pd.DataFrame(mets).aggregate(["mean", "std"]))

# Debug Metrics

In [None]:
d_mac = Dice(num_classes=num_classes, average="macro")
d_mic = Dice(num_classes=num_classes, average="micro")
l = JDTLoss(alpha=0.5, beta=0.5)

# Compute SoftDice C,D,I
l_C = JDTLoss(mIoUC=1.0, mIoUD=0.0, mIoUI=0.0, alpha=0.5, beta=0.5, active_classes_mode_soft="ALL")
l_D = JDTLoss(mIoUC=0.0, mIoUD=1.0, mIoUI=0.0, alpha=0.5, beta=0.5, active_classes_mode_soft="ALL")
l_I = JDTLoss(mIoUC=0.0, mIoUD=0.9, mIoUI=1.0, alpha=0.5, beta=0.5, active_classes_mode_soft="ALL")

In [None]:
emd = []
for i in range(preds_ensemble.shape[0]):
    rel_pred = preds_ensemble[i]
    rel_bg = bgs[i]
    rel_label = labels[i]

    emd.append(((rel_pred[:, ~rel_bg] - rel_label[:, ~rel_bg]).abs().sum(dim=1)/2)/(~rel_bg).sum())

emd = torch.stack(emd).mean(dim=0)
print(emd)
emd.sum()

In [None]:
kl = torchmetrics.KLDivergence()
kls = []

for i in range(preds_ensemble.shape[0]):
    rel_pred = preds_ensemble[i]
    rel_bg = bgs[i]
    rel_label = labels[i]

    rel_pred = rel_pred[:, ~rel_bg]
    rel_label = rel_label[:, ~rel_bg]

    kls.append(torch.nn.functional.kl_div(torch.log(rel_pred).unsqueeze(0), rel_label.unsqueeze(0)))

kls
torch.mean(torch.stack(kls))

In [None]:
tverskys = []
active_classes = []
per_image_dice_scores = []

for i in range(len(labels)):
    bg = bgs[i]
    label = labels[i]

    tversky, ac = l.get_image_class_matrix(preds_ensemble[i][:, ~bg].unsqueeze(0), labels[i][:, ~bg].unsqueeze(0), prob_predictions=True)
    tverskys.append(tversky)
    active_classes.append(ac)

    # Find parts of label with non-unique maximum
    label_max = torch.max(label, dim=0)[0]
    duplicated_max = torch.sum(label == label_max.unsqueeze(0), dim=0) > 1

    unique_max = ~duplicated_max
    fg = ~bg

    unique_max_fg = torch.logical_and(fg, unique_max)
    per_image_dice_scores.append(torchmetrics.functional.dice(preds_ensemble[i][:, unique_max_fg].unsqueeze(0).argmax(dim=1), labels[i][:, unique_max_fg].unsqueeze(0).argmax(dim=1), average="none", num_classes=num_classes))
    d_mac.update(preds_ensemble[i][:, unique_max_fg].unsqueeze(0).argmax(dim=1), labels[i][:, unique_max_fg].unsqueeze(0).argmax(dim=1))
    d_mic.update(preds_ensemble[i][:, unique_max_fg].unsqueeze(0).argmax(dim=1), labels[i][:, unique_max_fg].unsqueeze(0).argmax(dim=1))
tverskys = torch.stack(tverskys).squeeze()
active_classes = torch.stack(active_classes).squeeze()
per_image_class_dice_scores = torch.stack(per_image_dice_scores).squeeze()

In [None]:
(per_image_class_dice_scores.nanmean(dim=0).mean(),d_mac.compute(), d_mic.compute())

In [None]:
IoUCs = []
for c in range(num_classes):
    if active_classes[:, c].sum() > 0:
        IoUCs.append(tverskys[:, c][active_classes[:, c]].mean())

mIoUC = torch.sum(torch.stack(IoUCs)) / (active_classes.sum(dim=0) > 0).sum()
mIoUC

In [None]:
IoUIs = []
for i in range(tverskys.shape[0]):
    IoUIs.append(tverskys[i, :][active_classes[i, :]].mean())

mIoUI = torch.sum(torch.stack(IoUIs)) / tverskys.shape[0]
mIoUI

In [None]:
my_mac_dice = SoftDICEMetric(average="macro")
other_soft_dice = SoftCorrectDICEMetric(average=None)

for i in range(preds_ensemble.shape[0]):
    rel_pred = preds_ensemble[i]
    rel_bg = bgs[i]
    rel_label = labels[i]

    rel_pred = rel_pred[:, ~rel_bg]
    rel_label = rel_label[:, ~rel_bg]

    my_mac_dice.update(rel_pred, rel_label)
    other_soft_dice.update(rel_pred.unsqueeze(0), rel_label.unsqueeze(0))

print(my_mac_dice.compute())
print(other_soft_dice.compute())

In [None]:
from src.jdt_losses import SoftDICECorrectAccuSemiMetric
striped_mIoUD = SoftCorrectDICEMetric(average="mIoUD")
masked_mIoUD = []
corrected_striped_mIoUD = SoftDICECorrectAccuSemiMetric()
corrected_masked_mIoUD = SoftDICECorrectAccuSemiMetric()

for i in range(preds_ensemble.shape[0]):
    rel_pred = preds_ensemble[i]
    rel_bg = bgs[i]
    rel_label = labels[i]

    rel_pred = rel_pred[:, ~rel_bg]
    rel_label = rel_label[:, ~rel_bg]

    striped_mIoUD.update(rel_pred.unsqueeze(0), rel_label.unsqueeze(0))
    masked_mIoUD.append(1 - JDTLoss(mIoUD=1.0, mIoUC=0.0, mIoUI=0.0, active_classes_mode_soft="ALL",
                        alpha=0.5, beta=0.5)(preds_ensemble[i].unsqueeze(0), labels[i].unsqueeze(0), keep_mask=~rel_bg, prob_predictions=True))

    corrected_striped_mIoUD.update(rel_pred.unsqueeze(0), rel_label.unsqueeze(0))
    corrected_masked_mIoUD.update(preds_ensemble[i].unsqueeze(0), labels[i].unsqueeze(0), keep_mask=~rel_bg)
masked_mIoUD = torch.tensor(masked_mIoUD).mean()


print(striped_mIoUD.compute())
print(masked_mIoUD)
print(corrected_striped_mIoUD.compute())
print(corrected_masked_mIoUD.compute())
print(1 - JDTLoss(mIoUD=1.0, mIoUC=0.0, mIoUI=0.0, active_classes_mode_soft="ALL", alpha=0.5, beta=0.5)(preds_ensemble, labels, keep_mask=~bgs, prob_predictions=True))

In [None]:
super_l_label = []
super_l_pred = []
for i in range(preds_ensemble.shape[0]):
    rel_pred = preds_ensemble[i]
    rel_bg = bgs[i]
    rel_label = labels[i]

    rel_pred = rel_pred[:, ~rel_bg]
    rel_label = rel_label[:, ~rel_bg]

    super_l_label.append(rel_label)
    super_l_pred.append(rel_pred)

super_l_pred = torch.cat(super_l_pred, dim=1)
super_l_label = torch.cat(super_l_label, dim=1)

In [None]:
other_soft_dice = SoftCorrectDICEMetric(average=None)
other_soft_dice(super_l_pred.unsqueeze(0), super_l_label.unsqueeze(0))

In [None]:
1 - JDTLoss(mIoUD=1.0, mIoUC=0.0, mIoUI=0.0, active_classes_mode_soft="ALL", alpha=0.5, beta=0.5)(super_l_pred.unsqueeze(0), super_l_label.unsqueeze(0), prob_predictions=True)

In [None]:
from src.jdt_losses import JDTLoss

1- JDTLoss(mIoUD=1.0, mIoUC=0.0, mIoUI=0.0, active_classes_mode_soft="ALL", alpha=0.5, beta=0.5)(preds_ensemble, labels, keep_mask=~bgs, prob_predictions=True)

# Visualizations

In [None]:
from torchmetrics import ConfusionMatrix
from tqdm import tqdm
from matplotlib.colors import LinearSegmentedColormap

import ipywidgets as wid

from src.jdt_losses import SoftCorrectDICEMetric

In [None]:
img_save_path = Path("./figures")
img_save_path.mkdir(exist_ok=True, parents=True)

In [None]:
# Compute outputs

STRIP_BACKGROUND = True

num_classes = 10

pix_freq = torch.zeros(num_classes, dtype=torch.int)
max_freq = torch.zeros(num_classes, dtype=torch.int)
pred_freq = torch.zeros(num_classes, dtype=torch.int)
ml_pred_freq = torch.zeros(num_classes, dtype=torch.int)

one_annotator_pred_freq = torch.zeros(num_classes, dtype=torch.int)


conf_matrix = ConfusionMatrix(task="multiclass", num_classes=num_classes)
conf_matrix_ml = ConfusionMatrix(task="multilabel", num_labels=num_classes)

for i in tqdm(range(len(data))):

    out = preds_ensemble[i]

    background_mask = bgs[i]
    mask = labels[i]
    if STRIP_BACKGROUND:
        foreground_mask = ~background_mask.bool()
    else:
        foreground_mask = torch.ones_like(background_mask).bool()

    label_max = torch.max(mask, dim=0)[0]
    duplicated_max = torch.sum(mask == label_max.unsqueeze(0), dim=0) > 1

    unique_max = ~duplicated_max

    unique_max_fg = torch.logical_and(foreground_mask, unique_max)



    mask_unique = mask[:, unique_max_fg]
    out_unique = out[:, unique_max_fg]

    mask = mask[:, foreground_mask].flatten(start_dim=1)
    out = out[:, foreground_mask].flatten(start_dim=1)

    pix_freq += torch.sum((mask > 0), dim=(1))
    max_freq += torch.bincount(torch.argmax(mask_unique, dim=0).reshape(-1), minlength=num_classes)
    pred_freq += torch.bincount(torch.argmax(out,  dim=0).reshape(-1), minlength=num_classes)
    one_annotator_pred_freq += torch.sum(out >= 0.33, dim=(1))

    conf_matrix(out_unique.argmax(dim=0), mask_unique.argmax(dim=0))
    conf_matrix_ml(out.T >= 0.33, mask.T >= 0.33)

In [None]:
torch.set_printoptions(sci_mode=False)
classes_named = data.classes_named
classes_named = ["benign tissue", "individual glands", "compressed glands", "poorly formed glands", "cribriform glands", "glomeruloid glands", "group of tumor cells", "single cells", "cords", "comedenocrosis"]
confm = conf_matrix.compute()

confm_normed = confm / confm.sum(dim=1).reshape(-1, 1)
confm_normed = (confm_normed * 1000).round().long()

# confm = confm / torch.sum(confm, dim=1, keepdim=True)

#confm_expanded = torch.zeros(confm.shape[0]+1, confm.shape[1]+1, dtype=int)
#confm_expanded[:confm.shape[0], :confm.shape[1]] = confm_normed

#proportion_gt = (torch.sum(confm, dim=1) / torch.sum(confm)) * 1000
#proportion_predicted = (torch.sum(confm, dim=0) / torch.sum(confm)) * 1000


# Calculate proportions

#confm_expanded[-1, :-1] = proportion_predicted.round().long()
#confm_expanded[:-1, -1] = proportion_gt.round().long()

confm_expanded = confm_normed
original_cmap = cm.get_cmap("Blues")
im = plt.matshow(confm_expanded, cmap=LinearSegmentedColormap.from_list("Blues_80", original_cmap(np.linspace(0, 0.7, 256))))

for (i, j), val in np.ndenumerate(confm_expanded):
    plt.text(j, i, f'{val}', ha='center', va='center', color='black')

#plt.colorbar(im)
plt.xticks(range(num_classes), list(map(lambda x: x, classes_named)), #+ ["Proportion"],
           rotation=45, rotation_mode="anchor", ha="right", va="center_baseline")

plt.yticks(range(num_classes), list(map(lambda x: x, classes_named)) #+ ["Proportion"]
           , rotation=45)
plt.tick_params(axis="x", bottom=True, top=False, labelbottom=True, labeltop=False)

plt.ylabel("Annotations")
plt.xlabel("Predictions")

# plt.xlabel("Predictions")
# plt.ylabel("GT")
plt.savefig(img_save_path / "conf_mat.png", dpi=1000)


# per_class_acc = ((confm.diagonal() / confm.sum(dim=1)) * 1000).round().long()

# r = torch.stack([per_class_acc, proportion_predicted.round().long(), proportion_gt.round().long()])
# im2 = plt.matshow(r)
# plt.colorbar(im2)
# for (i, j), val in np.ndenumerate(r):
#     plt.text(j, i, f"{val}", c="red")

In [None]:
# Class annotation and prediction frequency

a = preds_ensemble.sum(dim=(0, 2, 3))
a /= a.sum()

b = labels.sum(dim=(0, 2, 3))
b /= b.sum()


bar_width = 0.15
plt.bar(np.arange(num_classes)-1.7*bar_width, b, label="Soft-Label Probability Mass", width=bar_width)
plt.bar(np.arange(num_classes)-0.7*bar_width, a, label="Predictive Probability Mass", width=bar_width)

plt.bar(np.arange(num_classes)+0.7*bar_width, max_freq/torch.sum(max_freq), label="Anno.: Majority Vote", width=bar_width)
# plt.bar(np.arange(num_classes)+1.5*bar_width, one_annotator_pred_freq/torch.sum(max_freq), label="one_annotator_pred_freq", width=bar_width)
plt.bar(np.arange(num_classes)+1.7*bar_width, pred_freq/torch.sum(pred_freq), label="Pred.: Argmax", width=bar_width)

# plt.yscale("log")
_ = plt.legend()
_ = plt.xticks(np.arange(num_classes), list(map(lambda x: x[:20], classes_named)), rotation=45, ha="right", rotation_mode="anchor")
_ = plt.xlabel("Grouped Explanation")
_ = plt.ylabel("Proportion")

plt.savefig(img_save_path / "prob_mass.png", dpi=1000)


In [None]:
confm_ml = conf_matrix_ml.compute()
confm_ml_reshape = confm_ml.reshape(confm_ml.shape[0], -1)
num_pixels = confm_ml_reshape.sum(dim=1)
confm_ml_reshape_normed = confm_ml_reshape / num_pixels.reshape(-1, 1)

TN, FP, FN, TP = confm_ml_reshape[:, 0], confm_ml_reshape[:, 1], confm_ml_reshape[:, 2], confm_ml_reshape[:, 3]

P = (TP+FN)
N = (TN+FP)

POPU = (TN+FP+FN+TP)
ACC = (TN+TP)/POPU
TPR = TP/(TP+FN)
TNR = (TN/(TN+FP))
PREC = TP/(FP+TP)
BACC = (TPR+TNR) / 2
DICE = (2*TP)/(2*TP+FP+FN)

PREL = P / POPU

confm_ml_reshape_normed = torch.cat([PREL.reshape(num_classes, 1), confm_ml_reshape_normed[:, [3,0,1,2]], ACC.reshape(
    num_classes, 1), PREC.reshape(num_classes, 1), BACC.reshape(num_classes, 1), DICE.reshape(num_classes,1)], dim=1)

im3 = plt.matshow(confm_ml_reshape_normed)
plt.colorbar(im3)
plt.xticks([0, 1, 2, 3, 4, 5, 6, 7,8], ["PREL", "TP", "TN", "FP", "FN", "ACC", "PREC", "BA", "DICE"], rotation=45)
_ = plt.yticks(range(len(classes_named)), classes_named)
_ = plt.xlabel("Metrics in %")
for (i, j), val in np.ndenumerate(confm_ml_reshape_normed):
    plt.text(j, i, f"{val*1000:0.0f}",c="red", ha="center", va="center")

In [None]:
vals, idcs = torch.sort(torch.stack(IoUIs), descending=True)

@wid.interact(idx=wid.IntSlider(min=0, max=len(data)-1, value=0))
def show_worst_results(idx):
    _ = create_single_class_acti_maps(predictions=preds_ensemble, dataset=data, idx=idcs[idx], plot_mode="heatmap", strip_background=True)
    #_ = create_ensemble_plot(dataset=data, ensemble_predictions=preds_ensemble, individual_predictions=preds, idx=idcs[idx], show_ensemble_preds=0)


In [None]:
vals, idcs = torch.sort(torch.stack(IoUIs), descending=True)

def show_worst_results(idx):
    f, ax = create_single_class_acti_maps(predictions=preds_ensemble, dataset=data, idx=idcs[20], plot_mode="heatmap", strip_background=True, plot=False)
    plt.subplots_adjust
    plt.savefig(img_save_path / "test.png", dpi=2000)
    #_ = create_ensemble_plot(dataset=data, ensemble_predictions=preds_ensemble, individual_predictions=preds, idx=idcs[idx], show_ensemble_preds=0)

show_worst_results(0)

In [None]:
vals, idcs = torch.sort(torch.stack(IoUIs), descending=True)
imgs_good = np.array([3,5,6,9,10,11,14,15,16,18,20,21,23,24,34,36,45,46,92])
sm_plot = np.array([25,46])

sieht_anders = np.array([49, 56, 58, 70, 63, 91, 17])

ganz_anders =np.array([62,68,73, 26])

create_multi_seg_anno_plot(predictions=preds_ensemble, dataset=data, idcs=[idcs[i] for i in sm_plot], strip_background=True)
plt.savefig(img_save_path / "imgs_good.png", dpi=500)

In [None]:
to_vis = np.concatenate([imgs_good[[2,5,7,8,16]], sieht_anders[[2,6]], ganz_anders[[0,3]]], axis=0)
create_multi_seg_anno_plot(predictions=preds_ensemble, dataset=data, idcs=[idcs[i] for i in to_vis], strip_background=True)
plt.savefig(img_save_path / "imgs_good_selected.png", dpi=500)