In [None]:
%load_ext autoreload
%autoreload 2

import sys
import random
from pathlib import Path
from typing import Literal
import matplotlib.pyplot as plt
import numpy as np
import torch
from torch.utils.data import DataLoader

import src.augmentations as augmentations
import src.model_utils
from src.gleason_data import GleasonX, prepare_torch_inputs
from src.gleason_utils import create_composite_plot
from src.lightning_modul import LitSegmenter
import src.tree_loss as tree_loss
from matplotlib import colormaps as cm
from matplotlib.colors import ListedColormap
from src.tree_loss import generate_label_hierarchy
from torchmetrics import Accuracy, ConfusionMatrix
from src.lightning_modul import LitClassifier
from src.gleason_data import GleasonXClassification
from itertools import zip_longest
import math
from textwrap import wrap
import torchvision.transforms as tt
from ipywidgets import widgets
from monai.inferers import *
from tqdm import tqdm
from itertools import chain
from torch.utils.data import Subset
import plotly.graph_objects as go
import ipywidgets.widgets as wid
from PIL import Image
import albumentations as alb
import ipywidgets

from src.gleason_data import get_class_colormaps

try:
    sys.modules["tree_loss"] = sys.modules["src.tree_loss"]
except KeyError:
    pass

try:
    sys.modules["robust_loss_functions"] = sys.modules["src.robust_loss_functions"]
except KeyError:
    pass

try:
    sys.modules["loss_functions"] = sys.modules["src.loss_functions"]
except KeyError:
    pass

try:
    sys.modules["augmentations"] = sys.modules["src.augmentations"]
except KeyError:
    pass


try:
    sys.modules["model_utils"] = sys.modules["src.model_utils"]
except KeyError:
    pass

In [None]:
DATA_PATH = "/home/datasets/GleasonXAI"
DATA_PATH = Path(DATA_PATH)

EXPERIMENT_PATH = "home/experiments/Gleason"

 ## Utility functions

In [None]:
# OUTPUT function

LABEL_REMAPPING = None
SLIDING_WINDOW_INFERER = None


def generate_model_output(model, img, device="cpu", label_remapping=LABEL_REMAPPING, inferer=SLIDING_WINDOW_INFERER):
    model.eval()

    if not isinstance(img, torch.Tensor):
        img = tt.functional.to_tensor(img)

    if len(img.size()) == 3:
        no_batch_input = True
        img = img.unsqueeze(0)
    else:
        no_batch_input = False

    img = img.to(device)

    with torch.no_grad():
        if inferer is not None:
            out = inferer(img, model)
        else:
            out = model(img)

    # Move back and strip batch_dim
    out = out.cpu()

    if label_remapping is not None:
        out = label_remapping(out)

    if no_batch_input:
        out = out[0, ...]

    return out


In [None]:
def composite_prediction_plot(model, dataset, indices, mask_background=False, full_legend=True):

    num_plots = len(indices)
    rows = int(np.ceil(num_plots / 3))  # Assuming 3 columns per row, adjust as needed

    # f, axs = plt.subplots(rows, 3, figsize=(12, 4 * rows))

    # for idx, ax in zip(indices, axs.flatten()):
    for idx in indices:

        img, masks, background_mask = dataset.__getitem__(idx, False)
        org_img = augmentations.basic_transforms_val_test_colorpreserving(image=np.array(dataset.get_raw_image(idx)))["image"]

        background = background_mask if mask_background else None

        if model is not None:
            out = generate_model_output(model, img, device=device, label_remapping=LABEL_REMAPPING, inferer=SLIDING_WINDOW_INFERER)
            out = torch.nn.functional.softmax(out, 0)
            np_seg = np.array(out.argmax(dim=0)).astype(np.uint8)
            #out = generate_model_output(model, img, masks, dataset.num_classes)
            #np_seg = np.array(out.squeeze(0).argmax(dim=0)).astype(np.uint8)
            f = create_composite_plot(dataset, org_img,  {"segmentation": np_seg} | {f"Annotator {i}": mask for i, mask in enumerate(masks)}, background, only_show_existing_annotation=not full_legend)
        else:
            f = create_composite_plot(dataset, org_img, {f"Annotator {i}": mask for i, mask in enumerate(masks)}, background, only_show_existing_annotation=not full_legend)

        return f
    # f.show()


def create_single_class_acti_maps(model, dataset, idx, plot_mode: Literal["heatmap", "contourf", "contour", "thresholded"] = "contourf", thresholds: list[float] = None, strip_background=False):

    img, masks, background_mask = dataset.__getitem__(idx, False)
    # org_img = augmentations.basic_transforms_val_test_colorpreserving(image=np.array(dataset.get_raw_image(idx)))["image"]

    out = generate_model_output(model, img, device=device, label_remapping=LABEL_REMAPPING, inferer=SLIDING_WINDOW_INFERER)

    out = torch.nn.functional.softmax(out, 0)

    np_seg = np.array(out.argmax(dim=0)).astype(np.uint8)

    org_img = np.array(dataset.get_raw_image(idx).resize(np_seg.shape))

    colormap = ListedColormap(dataset.colormap.colors)
    num_class_to_vis = dataset.num_classes

    if strip_background:

        for mask in masks:
            mask += 1
            mask[background_mask] = 0

        np_seg += 1
        np_seg[background_mask] = 0

        out[:, torch.tensor(background_mask).bool()] = 0.0
        
        colormap = ListedColormap(np.concatenate([np.array([[0., 0., 0., 1.]]), dataset.colormap.colors]))
        num_class_to_vis = dataset.num_classes + 1

    f, axes = plt.subplots(2, 3+math.ceil(dataset.num_classes/2), sharex=False, sharey=False, constrained_layout=False, figsize=(12, 4))
    f.tight_layout()

    axes[0, 0].imshow(org_img)
    axes[0, 0].set_title("Image", size=7)
    axes[0, 0].set_axis_off()

    # axes[1, 0].imshow(img)
    axes[1, 0].imshow(np_seg.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1, interpolation_stage="rgba")
    axes[1, 0].set_title("Segmentation", size=7)
    axes[1, 0].set_axis_off()

    for sub_ax, mask in zip_longest(list(axes[:, 1:3].flatten()), masks):
        sub_ax.set_axis_off()

        if mask is not None:
            sub_ax.imshow(mask.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1,  interpolation_stage="rgba")
            sub_ax.set_title("Annotation", size=7)

    for i in range(dataset.num_classes):
        active_axis = axes[:, 3:].flatten()[i]

        class_out = out[i, :].detach().numpy()

        if strip_background:
            class_out[background_mask] = 0

        temp_colormap = ListedColormap(np.concatenate([np.array([[0., 0., 0., 1.]]), dataset.colormap.colors[i].reshape(1,-1)]))

        match plot_mode:
            case "heatmap": active_axis.matshow(class_out, cmap=cm["Grays"].reversed(), vmin=0.0, vmax=out.max())
            case "multilabel": active_axis.matshow(class_out >= 0.32 , cmap=temp_colormap)
            case "contour": active_axis.contour(np.flipud(class_out), cmap=cm["Grays"].reversed(), vmin=0.0, vmax=out.max())
            case "contourf": active_axis.contourf(np.flipud(class_out), cmap=cm["Grays"].reversed(), vmin=0.0, vmax=out.max())
            case "thresholded": active_axis.matshow(class_out > thresholds[i], cmap=cm["Grays"].reversed())

        if i == 0:
            title = "Benign"
        else:
            exp = dataset.explanations[i-1]
            title = "\n".join(wrap(str(dataset.exp_grade_mapping[exp]) + ": " + exp, width=20))

        active_axis.set_title(title, size=7)
        active_axis.set_axis_off()

    #from matplotlib.colors import Normalize
    #from matplotlib.cm import ScalarMappable
    #cbar_ax = f.add_axes([0.93, 0.15, 0.02, 0.7])  #
    #norm = Normalize(vmin=0, vmax=out.max())  # Assuming out.max() represents the maximum value in your data
    #sm = ScalarMappable(norm=norm, cmap=cm["Grays"].reversed())
    #sm.set_array([])  # Required for the colorbar to work properly

    ## Add the colorbar to the axis
    #cbar = plt.colorbar(sm, ax=active_axis)


    plt.show()

 ## Model, Dataset and Option Selection

In [None]:
# DATASET CONSTANTS

# What changed: SegmentationModelsTest V0: Normalization, RandomResizedCrops, CenterCrop statt RandomCrop in val (marginaler unterschied), smpUNet statt unet.py (pretrained), adamW statt adam, größere LR, größerer WeightDecay, precision=16
#               SegmentationModelsTest V3: Batchsize to 32 from 8.

# Use 16-mixed instead. Testing different loss functions. Testing color and distortion transforms.


BATCH_SIZE = 4
NUM_WORKERS = 8
# SCALING = "1024"
# LABEL_LEVEL = 1

# dataset = GleasonX(DATA_PATH, "train", scaling=SCALING, transforms=transforms_val_test, label_level=LABEL_LEVEL)
# dataset_val = GleasonX(DATA_PATH, "val", scaling=SCALING, transforms=transforms_val_test, label_level=LABEL_LEVEL)
# dataset_test = GleasonX(DATA_PATH, "test", scaling=SCALING, transforms=transforms_val_test, label_level=LABEL_LEVEL)
# dataset_all = GleasonX(DATA_PATH, "all", scaling=SCALING, transforms=transforms_val_test, label_level=LABEL_LEVEL)

# dataset_Gleason = GleasonX(DATA_PATH, "test", scaling=SCALING, transforms=transforms_val_test, label_level=0)


# NUM_CLASSES = dataset.num_classes
# CLASSES_NAMED = ["Background"] + dataset_val.explanations

# dataloader_train = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
# dataloader_val = DataLoader(dataset=dataset_val, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=True)
# dataloader_test = DataLoader(dataset=dataset_test, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=False)#

# dataloader_Gleason = DataLoader(dataset=dataset_Gleason, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=False)


In [None]:
# MODEL SELECTION WIDGET

# MODEL_BASE_PATHS = [Path(p) for p in  ["/home/experiments/GleasonImprovements",
#                    "/home/experiments/GleasonBackgroundMasking", "/home/experiments/GleasonClassification"]]
MODEL_BASE_PATHS = [Path(p) for p in [EXPERIMENT_PATH]]


model = None
device = None
import os

def create_model():

    def find_best_models(base_paths):
        exp_dict = {}
        for base_path in base_paths:
            for model_path in Path(base_path).glob("**/*.ckpt"):
                experiment_path = Path(model_path).relative_to(base_path.parent)
                *experiment_super_folders, experiment_name, version_number, checkpoints, model_name = experiment_path.parts

                try:
                    experiment_super_folder = str(Path(os.path.join(*experiment_super_folders[1:])))
                    exp_name = f"{experiment_super_folder}/{experiment_name}/{version_number}"

                except:
                    exp_name = f"{experiment_super_folder}/{experiment_name}/{version_number}"

                if exp_name in exp_dict and not "best_model" in model_name:
                    continue
                exp_dict[exp_name] = model_path
        return exp_dict

    # Define base path

    # Find best model checkpoints and extract experiment names
    best_models = find_best_models(MODEL_BASE_PATHS)
    device_opts = {"cpu": torch.device("cpu"), "gpu": torch.device("cuda:0")}

    MODEL_PATH = None

    def handle_dropdown_change(*args):
        nonlocal MODEL_PATH
        global device, model

        selected_exp_name = model_dropdown.value
        MODEL_PATH = best_models[selected_exp_name]
        selected_device = device_dropdown.value
        device = device_opts[selected_device]
        model = LitSegmenter.load_from_checkpoint(str(MODEL_PATH), map_location=device)
        print("Model Loaded!")

    print(best_models)
    # Create dropdown widget
    model_dropdown = widgets.Dropdown(
        options=sorted(best_models.keys()),
        description='Select Model:',
        disabled=False,
    )

    device_dropdown = widgets.Dropdown(
        options=device_opts.keys(),
        description='Select Device:',
        disabled=False,
        value="gpu",
    )

    model_menu = widgets.VBox([model_dropdown, device_dropdown])

    handle_dropdown_change()
    # Display dropdown
    model_dropdown.observe(handle_dropdown_change)
    device_dropdown.observe(handle_dropdown_change)

    return model_menu


_model_menu = create_model()

In [None]:
# DATSET SELECTION WIDGET
STATISTICS_DATASET = None  # dataset_Gleason
DATALOADER = None
num_classes = None
classes_named = None
LABEL_LEVEL = None
# transforms_train = augmentations.tellez_transforms_train
data_transform = augmentations.basic_transforms_val_test_scaling2048

STRIP_BACKGROUND = True


def create_statistics_datset():
    split_opts = ["all", "train", "val", "test"]
    label_level_opts = [0, 1, 2, 3]
    scaling_opts = ["original", "1024", "2048"]
    transform_opts = ["resize_512", "resize_1024", "resize_2048"]
    background_mask_opts = ["all", "without_holes"]
    annotation_file_opts = ["fine_only_explanations.csv", "explanations_df.csv"]
    label_remapping_files_opts = ["label_remapping.json", "label_remapping_coarser.json"]

    split_dropdown = widgets.Dropdown(options=split_opts, description='Split:', value="val")
    label_level_dropdown = widgets.Dropdown(options=label_level_opts, description='Label Level:', value=0)
    scaling_dropdown = widgets.Dropdown(options=scaling_opts, description='Scaling:', value="original")
    transform_dropdown = widgets.Dropdown(options=transform_opts, description='Transform:', value="resize_1024")
    explanation_file_dropdown = widgets.Dropdown(options=annotation_file_opts, description='Annotation file:', value="explanations_df.csv")
    label_remapping_file_dropdown = widgets.Dropdown(options=label_remapping_files_opts, description='Remapping file:', value="label_remapping.json")

    background_mask_toggle = widgets.ToggleButtons(options=background_mask_opts, description='Background Masking:', value="without_holes")

    statistics_dataset_menu = widgets.VBox([split_dropdown, label_level_dropdown, scaling_dropdown, transform_dropdown,
                                           explanation_file_dropdown, label_remapping_file_dropdown, background_mask_toggle])

    strip_background_checkbox = widgets.Checkbox(description="Strip Background", value=True)

    sliding_window_averaging_opts = ["constant", "gaussian"]

    sliding_window_checkbox = widgets.Checkbox(description='Use Sliding Window:', value=True)
    sliding_window_averaging_dropdown = widgets.Dropdown(options=sliding_window_averaging_opts, description='SlidingWindowAveraging:', value="gaussian")
    sliding_window_overlap_slider = widgets.FloatSlider(value=0.5, min=0.0, max=1.0, step=0.05, description='SlidingWindowOverlap:')

    sliding_window_menu = widgets.VBox([sliding_window_checkbox, strip_background_checkbox, sliding_window_averaging_dropdown, sliding_window_overlap_slider])

    # Define dataset creation function

    def create_dataset(*args):
        global STATISTICS_DATASET, DATALOADER, LABEL_LEVEL, num_classes, classes_named

        CHOSEN_SPLIT = split_dropdown.value
        CHOSEN_LABEL_LEVEL = label_level_dropdown.value
        CHOSEN_SCALING = scaling_dropdown.value
        LABEL_LEVEL = CHOSEN_LABEL_LEVEL
        CHOSEN_TRANSFORM = transform_dropdown.value
        CHOSEN_ANNOTATION_FILE = explanation_file_dropdown.value
        CHOSEN_REMAPPING_FILE = label_remapping_file_dropdown.value

        CHOSEN_BACKGROUND_OPT = background_mask_toggle.value

        CHOSEN_TRANSFORM = {"resize_1024": augmentations.basic_transforms_val_test_scaling1024,
                            "resize_2048": augmentations.basic_transforms_val_test_scaling2048,
                            "resize_512": augmentations.basic_transforms_val_test_scaling512}[CHOSEN_TRANSFORM]

        CHOSEN_BACKGROUND_OPT = {"all": {}, "without_holes": {"open": False, "close": False, "flood": False}}[CHOSEN_BACKGROUND_OPT]

        STATISTICS_DATASET = GleasonX(DATA_PATH, CHOSEN_SPLIT, scaling=CHOSEN_SCALING, transforms=CHOSEN_TRANSFORM,
                                      label_level=CHOSEN_LABEL_LEVEL, create_seg_masks=True, tissue_mask_kwargs=CHOSEN_BACKGROUND_OPT, explanation_file=CHOSEN_ANNOTATION_FILE, label_hierarchy_file=CHOSEN_REMAPPING_FILE)
        DATALOADER = DataLoader(dataset=STATISTICS_DATASET, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS, pin_memory=True)
        num_classes = STATISTICS_DATASET.num_classes
        classes_named = STATISTICS_DATASET.classes_named
        print("Called with", str(CHOSEN_SPLIT), str(CHOSEN_LABEL_LEVEL), str(CHOSEN_SCALING), str(CHOSEN_TRANSFORM), str(CHOSEN_BACKGROUND_OPT), str(CHOSEN_ANNOTATION_FILE), str(CHOSEN_REMAPPING_FILE))

    create_dataset()
    split_dropdown.observe(create_dataset, names='value')
    label_level_dropdown.observe(create_dataset, names='value')
    scaling_dropdown.observe(create_dataset, names='value')
    transform_dropdown.observe(create_dataset, names='value')
    explanation_file_dropdown.observe(create_dataset, names="value")
    label_remapping_file_dropdown.observe(create_dataset, names="value")

    background_mask_toggle.observe(create_dataset, names="values")

    statistics_dataset_menu.observe(create_dataset)

    def create_label_remapper(*args):
        global LABEL_REMAPPING

        MODEL_LEVEL = model_label_lvl.value
        DATA_LABEL_LEVEL = label_level_dropdown.value

        if MODEL_LEVEL == DATA_LABEL_LEVEL:
            LABEL_REMAPPING = None
            print(f"Created remapping from {MODEL_LEVEL} to {DATA_LABEL_LEVEL}. No mapping")
        else:
            dataset_number_remappings = STATISTICS_DATASET.exp_numbered_lvl_remapping

            def remapping_function(out):

                out_remappings = generate_label_hierarchy(out, dataset_number_remappings, start_level=MODEL_LEVEL)

                return out_remappings[DATA_LABEL_LEVEL]

            LABEL_REMAPPING = remapping_function
            print(f"Created remapping from {MODEL_LEVEL} to {DATA_LABEL_LEVEL}.")

    model_label_lvl = widgets.IntSlider(min=0, max=2, description='Output Lvl Network:', value=0)
    model_label_lvl.observe(create_label_remapper)
    create_label_remapper()

    def create_sliding_window_inferer(*args):
        global SLIDING_WINDOW_INFERER
        global STRIP_BACKGROUND

        STRIP_BACKGROUND = strip_background_checkbox.value
        print(f"Strip Background: {STRIP_BACKGROUND}")

        use_sw = sliding_window_checkbox.value

        if use_sw is None:
            SLIDING_WINDOW_INFERER = None
            print("Not using a sliding window inferer!")
            return

        sw_avg = sliding_window_averaging_dropdown.value
        sw_overlap = sliding_window_overlap_slider.value

        SLIDING_WINDOW_INFERER = SlidingWindowInferer(roi_size=(512, 512), sw_batch_size=1, overlap=sw_overlap, mode=sw_avg)
        print(f"Created SLIDING_WINDOW_INFERER with overlap: {sw_overlap} and averaging: {sw_avg}.")

    create_sliding_window_inferer()
    strip_background_checkbox.observe(create_sliding_window_inferer)
    sliding_window_checkbox.observe(create_sliding_window_inferer)
    sliding_window_averaging_dropdown.observe(create_sliding_window_inferer)
    sliding_window_overlap_slider.observe(create_sliding_window_inferer)
    sliding_window_menu.observe(create_sliding_window_inferer)

    return statistics_dataset_menu, model_label_lvl, sliding_window_menu


_dataset_menu, _label_remapper_menu, _sliding_window_menu = create_statistics_datset()

 ## Menu

In [None]:
# Display Menus
display(_model_menu)
print("---------------")
display(_dataset_menu)
print("---------------")
display(_label_remapper_menu)
print("---------------")
display(_sliding_window_menu)


In [None]:
print("Model name:", _model_menu.children[0].value)
print("Data split:", STATISTICS_DATASET.split)
print("Label Level:", STATISTICS_DATASET.label_level)
print("Transform: ", _dataset_menu.children[3].value, "Mask: ", _dataset_menu.children[4].value)
print("SlidingWindow:", _sliding_window_menu.children[0].value, "Background Strip:", _sliding_window_menu.children[1].value,
      "Combine: ", _sliding_window_menu.children[2].value, "Overlap: ", _sliding_window_menu.children[3].value)
print(type(model), type(model.model))

 ##  Forward pass


In [None]:
# Compute outputs

SUBSET = 1.0
rand_subset = True

dataloader = DataLoader(dataset=Subset(STATISTICS_DATASET, torch.randperm(len(STATISTICS_DATASET))[:int(len(STATISTICS_DATASET)*SUBSET)]),
                        batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=False)

pix_freq = torch.zeros(num_classes, dtype=torch.int)
max_freq = torch.zeros(num_classes, dtype=torch.int)
pred_freq = torch.zeros(num_classes, dtype=torch.int)
one_annotator_pred_freq = torch.zeros(num_classes, dtype=torch.int)

activation_values_positive = [[] for _ in range(num_classes)]
activation_values_negative = [[] for _ in range(num_classes)]

activation_values_per_class = [[[] for _ in range(num_classes)] for _ in range(num_classes)]


conf_matrix = ConfusionMatrix(task="multiclass", num_classes=num_classes)
conf_matrix_ml = ConfusionMatrix(task="multilabel", num_labels=num_classes)


for batch in tqdm(dataloader):

    imgs, masks, background_masks = batch  # STATISTICS_DATASET[i]

    outs = generate_model_output(model, imgs, device=device, label_remapping=LABEL_REMAPPING, inferer=SLIDING_WINDOW_INFERER)        

    for img, mask, background_mask, out in zip(imgs, masks, background_masks, outs):

        out = torch.nn.functional.softmax(out, dim=0)

        if STRIP_BACKGROUND:
            foreground_mask = ~background_mask.bool()
        else:
            foreground_mask = torch.ones_like(background_mask).bool()

        mask = mask[:, foreground_mask].flatten(start_dim=1)
        out = out[:, foreground_mask].flatten(start_dim=1)

        pix_freq += torch.sum((mask > 0), dim=(1))
        max_freq += torch.bincount(torch.argmax(mask, dim=0).reshape(-1), minlength=num_classes)
        pred_freq += torch.bincount(torch.argmax(out,  dim=0).reshape(-1), minlength=num_classes)
        one_annotator_pred_freq += torch.sum(out >= 0.3333, dim=(1)) #torch.bincount(torch.argmax(out, dim=0).reshape(-1), minlength=num_classes)

        conf_matrix(out.argmax(dim=0), mask.argmax(dim=0))
        conf_matrix_ml(out.T >= 0.32, mask.T >= 0.32)
        for cl in range(num_classes):

            # aggr_mask_class = np.sum(np.stack([mask == cl for mask in masks]), axis=0) > 0

            aggr_mask_class_pos = mask[cl, :] > 0
            aggr_mask_class_neg = mask[cl, :] <= 0

            rel_pixels = out[cl, aggr_mask_class_pos].detach().numpy()
            sampled_pixels = np.random.choice(rel_pixels, int(rel_pixels.size/100))
            activation_values_positive[cl].append(sampled_pixels)

            rel_pixels = out[cl, aggr_mask_class_neg].detach().numpy()
            sampled_pixels = np.random.choice(rel_pixels, int(rel_pixels.size/100))
            activation_values_negative[cl].append(sampled_pixels)

            sample_pixel_indices = np.random.permutation(int(aggr_mask_class_pos.sum()))
            sample_pixel_indices = sample_pixel_indices[:int(len(sample_pixel_indices)/100)]
            for cl2 in range(num_classes):

                rel_pixels = out[cl2, aggr_mask_class_pos].detach().numpy().flatten()
                sampled_pixels = rel_pixels[sample_pixel_indices]
                activation_values_per_class[cl][cl2].append(sampled_pixels)


for i in range(num_classes):
    activation_values_positive[i] = np.concatenate(activation_values_positive[i])
    activation_values_negative[i] = np.concatenate(activation_values_negative[i])

    for j in range(num_classes):
        activation_values_per_class[i][j] = np.concatenate(activation_values_per_class[i][j])


num_activations = 2000

activation_values_positive_resampled = [[] for _ in range(num_classes)]
activation_values_negative_resampled = [[] for _ in range(num_classes)]
activation_values_per_class_resampled = [[[] for _ in range(num_classes)] for _ in range(num_classes)]


for i in range(num_classes):
    activation_values_positive_resampled[i] = activation_values_positive[i]
    if len(activation_values_positive_resampled[i]) > num_activations:
        activation_values_positive_resampled[i] = np.random.choice(activation_values_positive_resampled[i], size=num_activations, replace=False)
    activation_values_negative_resampled[i] = activation_values_negative[i]

    if len(activation_values_negative_resampled[i]) > num_activations:
        activation_values_negative_resampled[i] = np.random.choice(activation_values_negative_resampled[i], size=num_activations, replace=False)

    for j in range(num_classes):
        activation_values_per_class_resampled[i][j] = activation_values_per_class[i][j]
        if len(activation_values_per_class_resampled[i][j]) > num_activations:
            activation_values_per_class_resampled[i][j] = np.random.choice(activation_values_per_class_resampled[i][j], size=num_activations, replace=False)

 ## Plots

In [None]:
# Confusion Matrix
torch.set_printoptions(sci_mode=False)

confm = conf_matrix.compute()

confm_normed = confm / confm.sum(dim=1).reshape(-1,1)
confm_normed = (confm_normed *1000).round().long()

confm_expanded = torch.zeros(confm.shape[0]+1, confm.shape[1]+1, dtype=int)
confm_expanded[:confm.shape[0], :confm.shape[1]] = confm_normed

proportion_gt = (torch.sum(confm, dim=1) / torch.sum(confm)) * 1000
proportion_predicted = (torch.sum(confm, dim=0) / torch.sum(confm)) * 1000


# Calculate proportions

confm_expanded[-1, :-1] = proportion_predicted.round().long()
confm_expanded[:-1, -1] = proportion_gt.round().long()

im = plt.matshow(confm_expanded)

for (i, j), val in np.ndenumerate(confm_expanded):
    plt.text(j, i, f'{val}', ha='center', va='center', color='red')

plt.colorbar(im)
plt.xticks(range(num_classes+1), list(map(lambda x: x[:20], classes_named)) + ["Proportion"],
           rotation=45, rotation_mode="anchor", ha="right", va="center_baseline")

plt.yticks(range(num_classes+1), list(map(lambda x: x[:20], classes_named)) + ["Proportion"], rotation=45)
plt.tick_params(axis="x", bottom=True, top=False, labelbottom=True, labeltop=False)


plt.xlabel("Predictions")
plt.ylabel("GT")



per_class_acc = ((confm.diagonal() / confm.sum(dim=1)) * 1000).round().long()

r = torch.stack([per_class_acc, proportion_predicted.round().long(), proportion_gt.round().long()])
im2 = plt.matshow(r)
plt.colorbar(im2)
for (i,j), val in np.ndenumerate(r):
    plt.text(j,i, f"{val}", c="red")




In [None]:
confm_ml = conf_matrix_ml.compute()
confm_ml_reshape = confm_ml.reshape(confm_ml.shape[0], -1)
num_pixels = confm_ml_reshape.sum(dim=1)
confm_ml_reshape_normed = confm_ml_reshape / num_pixels.reshape(-1, 1)

TN, FP, FN, TP = confm_ml_reshape[:, 0], confm_ml_reshape[:, 1], confm_ml_reshape[:, 2] ,confm_ml_reshape[:, 3] 

POPU = (TN+FP+FN+TP)
ACC = (TN+TP)/POPU
TPR = TP/(TP+FN)
TNR = (TN/(TN+FP))
PREC = TP/(FP+TP)
BACC = (TPR+TNR) /2

PREL = (TP+FN)/POPU

confm_ml_reshape_normed = torch.cat([PREL.reshape(num_classes,1), confm_ml_reshape_normed, ACC.reshape(num_classes, 1), PREC.reshape(num_classes,1), BACC.reshape(num_classes,1)], dim=1)

im3 = plt.matshow(confm_ml_reshape_normed)
plt.colorbar(im3)
plt.xticks([0,1,2,3,4,5,6,7], ["PREL", "TN", "FP", "FN", "TP", "ACC", "PREC", "BA"], rotation=45)
_ = plt.yticks(range(len(classes_named)),classes_named)
_ = plt.xlabel("Metrics in %")
for (i,j), val in np.ndenumerate(confm_ml_reshape_normed):
    plt.text(j,i, f"{val*100:0.0f}",c="red", ha="center", va="center")

In [None]:
# Class annotation and prediction frequency
bar_width = 0.2
plt.bar(np.arange(num_classes)-1.5*bar_width, pix_freq/torch.sum(max_freq), label="at least one annotaor", width=bar_width)
plt.bar(np.arange(num_classes)-0.5*bar_width, max_freq/torch.sum(max_freq), label="majority vote", width=bar_width)
#plt.bar(np.arange(num_classes)+1.5*bar_width, one_annotator_pred_freq/torch.sum(max_freq), label="one_annotator_pred_freq", width=bar_width)
plt.bar(np.arange(num_classes)+0.5*bar_width, pred_freq/torch.sum(pred_freq), label="pred_freq", width=bar_width)

plt.yscale("log")
_ = plt.legend()
_ = plt.xticks(np.arange(num_classes), list(map(lambda x: x[:20], classes_named)), rotation=90)

# inv_weights = 1/(max_freq/torch.sum(max_freq))
# inv_weights /= torch.sum(inv_weights)
# print(inv_weights.tolist())

In [None]:
# Distribution of Softmax Values per Class
plt.hist(np.concatenate(activation_values_positive), label="all", histtype="step", density=False, bins=100, color=(0, 0, 0))
plt.hist(activation_values_positive, label=classes_named, histtype="step", density=False, bins=100)
_ = plt.legend()


In [None]:
# Distribution of softmax values over all classes per GT class

def exl_idx(arr, idx):
    return arr[:idx]+arr[idx+1:]


fig, axes = plt.subplots(2, 2, figsize=(12, 6))
fig.set_facecolor((0.0, 0., 0., 0.))
axes = axes.flatten()

for cl_idx, ax in enumerate(axes):

    ax.hist(exl_idx(activation_values_per_class[cl_idx], cl_idx)[::-1]+[activation_values_per_class[cl_idx][cl_idx]], label=exl_idx(classes_named, cl_idx)
            [::-1]+[f"TrueClass"], bins=100, histtype="step", density=True, color=exl_idx(list(STATISTICS_DATASET.colormap.colors), cl_idx)[::-1]+[(0.8, 0.2, 0.7)])
    ax.set_title(classes_named[cl_idx])
    ax.set_facecolor((.0, 0., 0.))
    ax.legend()


In [None]:
# Tetrahedron Plot


def plot_tetrahedron(dataset, softmaxes, correct_class=0):
    # Define tetrahedron vertices

    # vertices = np.array([[0, 0, 0],
    #                     [1, 0, 0],
    #                     [0, 1, 0],
    #                     [0, 0, 1],
    #                     [1/3, 1/3, 1/3]])

    # inv_sq2 = 1/math.sqrt(2)
    # vertices = np.array([[1, 0, -inv_sq2],
    #                     [-1, 0, -inv_sq2],
    #                     [0, 1, inv_sq2],
    #                     [0, -1, inv_sq2],
    #                     [1/3, 1/3, 1/3]])

    vertices = np.array([[1, 1, 1],
                         [1, -1, -1],
                         [-1, 1, -1],
                         [-1, -1, 1],
                         [0, 0, 0]])

    vertices[-1] = np.mean(vertices[:-1])

    # Define edges
    edges = [[0, 1], [0, 2], [0, 3], [1, 2], [1, 3], [2, 3], [0, 4], [1, 4], [2, 4], [3, 4]]

    # Define trace for vertices
    trace_vertices = go.Scatter3d(
        x=vertices[:, 0],
        y=vertices[:, 1],
        z=vertices[:, 2],
        mode='markers',
        marker=dict(size=8, color='red'),
        text=classes_named[1:]
    )

    corr_vertex = go.Scatter3d(
        x=[vertices[correct_class, 0]],
        y=[vertices[correct_class, 1]],
        z=[vertices[correct_class, 2]],
        mode='markers',
        marker=dict(size=8, color='green'),
        text=classes_named[1:][correct_class]
    )

    # Define trace for edges
    traces_edges = []
    for edge in edges:
        trace_edge = go.Scatter3d(
            x=[vertices[edge[0], 0], vertices[edge[1], 0]],
            y=[vertices[edge[0], 1], vertices[edge[1], 1]],
            z=[vertices[edge[0], 2], vertices[edge[1], 2]],
            mode='lines',
            line=dict(color='black', width=2))
        traces_edges.append(trace_edge)

    softmaxes = np.stack(softmaxes, axis=1)

    rnd_idcs = np.random.permutation(softmaxes.shape[0])[:5000]

    softmaxes = softmaxes[rnd_idcs, :]

    car_coords = softmaxes @ vertices[:-1]

    preds = go.Scatter3d(
        x=car_coords[:, 0],
        y=car_coords[:, 1],
        z=car_coords[:, 2],
        mode='markers',
        # marker=dict(size=2, color=np.array([0.5, 0.5, 0.0])),   #[dict(size=2, color=(1-sm[correct_class])*np.array([1., 0., 0.]) + sm[correct_class]*np.array([0., 1., 0.])) for sm in softmaxes],
        text=[f"({sm[0]:.2f},{sm[1]:.2f},{sm[2]:.2f},{sm[3]:.2f})" for sm in softmaxes],
        marker_size=2,
        marker_color=[np.array([0., 1., 0.]) if sm.argmax() == correct_class else np.array([1., 0., 0.]) for sm in softmaxes]
    )

    # Create figure
    fig = go.Figure(data=[trace_vertices, corr_vertex, preds, *traces_edges])

    # Update layout
    fig.update_layout(
        scene=dict(
            xaxis=dict(title='GP 3'),
            yaxis=dict(title='GP 4'),
            zaxis=dict(title='GP 5')
        ),
        title='Tetrahedron Plot'
    )

    # Show plot
    fig.show()


tetrahedron_class_selecter = widgets.ToggleButtons(options=classes_named, index=0)
plot_tetrahedron(STATISTICS_DATASET, activation_values_per_class[0], correct_class=0)
tetrahedron_class_selecter.observe(lambda x: plot_tetrahedron(
    STATISTICS_DATASET, activation_values_per_class[tetrahedron_class_selecter.index], correct_class=tetrahedron_class_selecter.index))


display(tetrahedron_class_selecter)

# Example usage
# Replace classes_probs with your predictive probabilities for each class
# plot_tetrahedron(STATISTICS_DATASET, activation_values_per_class[class_idx], correct_class=class_idx)


 ## Prediction Visualization

In [None]:
def plot_class_dist_with_thres(dataset, activation_values, thresholds, norm_hist=True, thresholds_are_values=True):

    def resample(arr, num):

        if len(arr) > num:
            arr = np.random.choice(arr, size=num, replace=False)
        return arr

    activation_values_positive, activation_values_negative = activation_values

    print(list(map(lambda x: x is not None, activation_values_positive)))
    if not thresholds_are_values:
        thresholds = [np.percentile(activation_values_positive[i], thresholds[i]*100) for i in range(len(thresholds))]

    # Create a 2x5 grid of histograms
    fig, axes = plt.subplots(nrows=2, ncols=math.ceil(len(thresholds)/2), figsize=(15, 6))

    # Flatten the 2D array of axes for easier iteration
    axes = axes.flatten()

    for i in range(0, len(thresholds)):
        # Plot histogram on the corresponding subplot

        if norm_hist:

            if activation_values_positive[i] is not None:
                axes[i].hist([activation_values_positive[i],], bins=100,
                             histtype="step", density=norm_hist, label="positive", color="g")  # , range=(0.0, 1.0))
            if activation_values_negative[i] is not None:
                axes[i].hist([activation_values_negative[i],], bins=100,
                             histtype="step", density=norm_hist, label="negative", color="r")  # , range=(0.0, 1.0))

        else:
            if activation_values_positive[i] is not None:
                axes[i].hist([activation_values_positive[i],], bins=100,
                             histtype="step", density=norm_hist, label="positive", color="g")  # , range=(0.0, 1.0))
            if activation_values_negative[i] is not None:
                axes[i].hist([activation_values_negative[i],], bins=100,
                             histtype="step", density=norm_hist, label="negative", color="r")  # , range=(0.0, 1.0))

        # if activation_values_positive[i] is not None and activation_values_negative[i] is not None:
        #    axes[i].text(0.95, 0.95, "ration: " +str(len(activation_values_negative[i])/len(activation_values_positive[i])), transform=axes[i].transAxes, ha='right', va='top')

        axes[i].axvline(x=thresholds[i], color='red', linestyle='--', label='Quantile Threshold')

        title = "benign" if i == 0 else dataset.explanations[i-1]
        axes[i].set_title(title[:20])

    plt.tight_layout()
    plt.legend()
    plt.show()



In [None]:
threshold_sliders = {}
threshold_sliders["Benign"] = ipywidgets.FloatSlider(min=0, max=1, step=0.01, value=0.1, description=f'Benign', continuous_update=False)
for cl, idx in STATISTICS_DATASET.exp_number_mapping.items():
    threshold_sliders[cl] = ipywidgets.FloatSlider(min=0, max=1, step=0.01, value=0.1, description=f'{cl}', continuous_update=False)

vis_mode_widget = ipywidgets.widgets.ToggleButtons(
    options=["contour", "contourf", "heatmap", "multilabel", "thresholded"],
    value="contourf",
    description="Visualization mode",
    disabled=False
)

strip_background_button = ipywidgets.widgets.Checkbox(
    value=False,
    description='Strip Background',
    disabled=False,
    indent=False
)

idx_slider = ipywidgets.widgets.IntSlider(min=0, max=len(STATISTICS_DATASET), step=1, value=0)


normed_hist_check = ipywidgets.widgets.Checkbox(
    value=True,
    description='Norm Hist',
    disabled=False,
    indent=False
)

show_hist_check = ipywidgets.widgets.Checkbox(
    value=True,
    description='Show Hist',
    disabled=False,
    indent=False
)


show_images_check = ipywidgets.widgets.Checkbox(
    value=True,
    description='Show Img',
    disabled=False,
    indent=False
)

use_per_image_thres_check = ipywidgets.widgets.Checkbox(
    value=False,
    description='Use per img thres',
    disabled=False,
    indent=False
)

dummy_checkbox = ipywidgets.widgets.Checkbox(
    value=True,
    description='Hidden',
    disabled=True,
    indent=False
)

dummy_checkbox.layout.visibility = "hidden"

stack = ipywidgets.widgets.Stack([dummy_checkbox, dummy_checkbox, dummy_checkbox, dummy_checkbox,wid.HBox([wid.VBox(
    [show_hist_check, normed_hist_check, show_images_check, use_per_image_thres_check]), wid.VBox([slider for slider in threshold_sliders.values()])])], selected_index=0)
ipywidgets.widgets.jslink((vis_mode_widget, 'index'), (stack, 'selected_index'))
vis_mode_selection = ipywidgets.widgets.VBox([vis_mode_widget, stack])

interface = wid.VBox([idx_slider, strip_background_button, vis_mode_selection])

# Predictions
# @ipywidgets.interact(idx=idx_slider, vis_mode=vis_mode_widget, norm_hist=normed_hist_check, show_img=show_images_check, ** threshold_sliders)


def vis_thresholds_interactive(idx, vis_mode, norm_hist, show_img, show_hist, use_per_image_thres, strip_background, **thresholds):

    dataset = STATISTICS_DATASET

    if vis_mode == "thresholded":

        img, mask, background_mask = dataset[idx]
        out = model(img.unsqueeze(0)).squeeze(0)
        out = torch.nn.functional.softmax(out, dim=0).detach()
        org_img = augmentations.basic_transforms_val_test_colorpreserving(image=np.array(dataset.get_raw_image(idx)))["image"]

        if use_per_image_thres:
            vis_activation_values_pos = []
            vis_activation_values_neg = []

        for cl in range(dataset.num_classes):

            aggr_mask_class_pos = mask[cl, :, :] > 0
            aggr_mask_class_neg = mask[cl, :, :] <= 0

            if aggr_mask_class_pos.sum() > 0:
                rel_pixels = out[cl, aggr_mask_class_pos].detach().numpy().flatten()
            else:
                rel_pixels = None

            vis_activation_values_pos.append(rel_pixels)

            if aggr_mask_class_neg.sum() > 0:
                rel_pixels = out[cl, aggr_mask_class_neg].detach().numpy().flatten()
            else:
                rel_pixels = None

            vis_activation_values_neg.append(rel_pixels)

        else:
            vis_activation_values_pos = activation_values_positive
            vis_activation_values_neg = activation_values_negative

        thresholds_list = list(thresholds.values())

        value_thresholds = [(np.percentile(vis_activation_values_pos[i], thresholds_list[i]*100)
                             if vis_activation_values_pos[i] is not None else 0.0) for i in range(dataset.num_classes)]
        if show_hist:
            plot_class_dist_with_thres(dataset, activation_values=(vis_activation_values_pos, vis_activation_values_neg),
                                       thresholds=value_thresholds, norm_hist=norm_hist)
        if show_img:
            create_single_class_acti_maps(dataset=dataset, idx=idx, model=model, plot_mode=vis_mode,
                                          strip_background=strip_background, thresholds=value_thresholds)
    else:
        create_single_class_acti_maps(dataset=dataset, idx=idx, model=model, plot_mode=vis_mode,
                                      strip_background=strip_background, thresholds=None)


out = wid.interactive_output(vis_thresholds_interactive, {"idx": idx_slider, "vis_mode": vis_mode_widget, "norm_hist": normed_hist_check, "show_img": show_images_check,
                             "show_hist": show_hist_check, "use_per_image_thres": use_per_image_thres_check, "strip_background": strip_background_button, **threshold_sliders})

show_widget = wid.VBox([interface, out])
show_widget


In [None]:
@widgets.interact(idx=widgets.IntSlider(min=0, max=len(STATISTICS_DATASET), value=0), strip_background=True, class_activation_map=False, norm_by_class_freq=False, create_segmentation=True)
def interactive_composite_plot(idx, strip_background, class_activation_map, norm_by_class_freq, create_segmentation):

    dataset = STATISTICS_DATASET
    if not class_activation_map:
        #composite_prediction_plot(dataset=dataset, model=model, indices=[idx], mask_background=strip_background)

        img, masks, background_mask = dataset.__getitem__(idx, False)
        org_img = augmentations.basic_transforms_val_test_colorpreserving(image=np.array(dataset.get_raw_image(idx)))["image"]

        background = background_mask if strip_background else None

        if model is not None and create_segmentation:
            out = generate_model_output(model, img, device=device, label_remapping=LABEL_REMAPPING, inferer=SLIDING_WINDOW_INFERER)


            out = torch.nn.functional.softmax(out, 0)
            if norm_by_class_freq:
                out /= torch.tensor([0.38, .21, .25, .17]).reshape(-1, 1, 1)
            np_seg = np.array(out.argmax(dim=0)).astype(np.uint8)
            f = create_composite_plot(dataset, org_img,  {"segmentation": np_seg} | {f"Annotator {i}": mask for i,
                                    mask in enumerate(masks)}, background, only_show_existing_annotation=True)
        else:
            f = create_composite_plot(dataset, org_img, {f"Annotator {i}": mask for i, mask in enumerate(masks)},
                                    background, only_show_existing_annotation=True)


    else:
        create_single_class_acti_maps(dataset=dataset, idx=idx, model=model, plot_mode="heatmap",
                                      strip_background=strip_background, thresholds=None)

        

In [None]:
interactive_composite_plot(209, True, False, False)

 ## Other

In [None]:
@widgets.interact(idx=widgets.IntSlider(0,0, len(STATISTICS_DATASET)))
def bbb(idx):
    f = composite_prediction_plot(None, STATISTICS_DATASET, [idx], mask_background=False)

In [None]:
@widgets.interact(idx=widgets.IntSlider(0, 0, len(STATISTICS_DATASET)))
def vis_topdown_diff(idx):
    img, mask, background = STATISTICS_DATASET[idx]
    preds = generate_model_output(model, img, device=device, inferer=SLIDING_WINDOW_INFERER)
    lvl0_cmap = get_class_colormaps({"3": 1,"4":1,"5":1})
    lvl1_cmap = get_class_colormaps({"3": 2,"4":3,"5":4})


    lvl_0_cmap_background = ListedColormap(np.concatenate([np.array([[0., 0., 0., 1.]]), lvl0_cmap.colors]))
    lvl_1_cmap_background = ListedColormap(np.concatenate([np.array([[0., 0., 0., 1.]]), lvl1_cmap.colors]))

    lvl_0_preds_torch, lvl_1_preds_torch = generate_label_hierarchy(preds.cpu().unsqueeze(0), STATISTICS_DATASET.exp_numbered_lvl_remapping, start_level=1)
    lvl_0_preds = lvl_0_preds_torch.squeeze(0).argmax(dim=0).numpy().astype(np.uint8)
    lvl_1_preds = lvl_1_preds_torch.squeeze(0).argmax(dim=0).numpy().astype(np.uint8)

    f, ax = plt.subplots(2,2)
    ax = ax.flatten()

    def mask_background(mask, background_mask):
        mask = mask + 1 
        mask[background_mask] = 0.0
        return mask

    lvl_0_preds_masked = mask_background(lvl_0_preds, background)
    lvl_1_preds_masked = mask_background(lvl_1_preds, background)
    for a in ax:
        a.set_axis_off()

    def get_topdown_classes(pred0, pred1):

        argmax_preds0 = pred0.argmax(dim=0)

        argmax_lvl_1 = torch.zeros_like(argmax_preds0)

        lvl_0_1_remapping = STATISTICS_DATASET.exp_numbered_lvl_remapping[0]

        for lvl_0_class, lvl_1_classes in lvl_0_1_remapping.items():

            where_in_lvl_0 = argmax_preds0 == lvl_0_class
            argmax_lvl1_selected_classes = pred1[lvl_1_classes, :, :].argmax(dim=0)
            
            lvl_1_classes = torch.tensor(lvl_1_classes).reshape(-1,1,1)
            #assert lvl_1_classes.dim() == argmax_lvl1_selected_classes.dim()
            #Remap through gather
            argmax_lvl1_remapped = torch.gather(lvl_1_classes.expand((lvl_1_classes.shape[0],*argmax_lvl1_selected_classes.shape)),
                                                0, argmax_lvl1_selected_classes.unsqueeze(0)).squeeze(0)
            argmax_lvl_1[where_in_lvl_0] = argmax_lvl1_remapped[where_in_lvl_0]
        
        return argmax_preds0, argmax_lvl_1

    _, lvl_1_preds_topdown = get_topdown_classes(lvl_0_preds_torch.squeeze(0), lvl_1_preds_torch.squeeze(0))
    lvl_1_preds_topdown = lvl_1_preds_topdown.numpy().astype(np.uint8)
    lvl_1_preds_topdown_masked = mask_background(lvl_1_preds_topdown, background)



    ax[0].imshow(lvl_0_preds_masked, cmap=lvl_0_cmap_background, vmin=0, vmax=len(lvl_0_cmap_background.colors)-1, interpolation_stage="rgba")
    ax[1].imshow(lvl_1_preds_masked, cmap=lvl_1_cmap_background, vmin=0, vmax=len(lvl_1_cmap_background.colors)-1, interpolation_stage="rgba")
    ax[2].imshow(lvl_1_preds_topdown_masked, cmap=lvl_1_cmap_background, vmin=0, vmax=len(lvl_1_cmap_background.colors)-1, interpolation_stage="rgba")
    ax[3].imshow(lvl_1_preds_topdown_masked != lvl_1_preds_masked, interpolation_stage="rgba")

In [None]:
vis_topdown_diff(209)

In [None]:
STATISTICS_DATASET.exp_numbered_lvl_remapping[0]

In [None]:
generate_label_hierarchy(preds[:, 0,0].unsqueeze(0),STATISTICS_DATASET.exp_numbered_lvl_remapping, start_level=1)

In [None]:
set(np.array([2,3,4]))

In [None]:
plt.figure(figsize=(30, 10))
composite_prediction_plot(None, STATISTICS_DATASET, [13, 25,26,28, 37], mask_background=False)

In [None]:


def compute_accuracy(model, dataloader, device):
    model.eval()

    num_classes = dataloader.dataset.num_classes

    accuracy_metric = Accuracy(task="multiclass", num_classes=num_classes, average="micro")
    per_class_accuracy_metric = Accuracy(task="multiclass", num_classes=num_classes, average="none")
    class_occurence = torch.zeros(num_classes)
    class_predictions = torch.zeros(num_classes)

    with torch.no_grad():
        for data in tqdm(dataloader):
            inputs, labels = data
            inputs = inputs.to(device)

            labels = torch.argmax(labels, dim=1)  # Convert soft labels to hard labels

            class_occurence += torch.bincount(labels.reshape(-1), minlength=num_classes)

            outputs = model(inputs)
            predictions = torch.argmax(outputs, dim=1).cpu()

            class_predictions += torch.bincount(predictions.reshape(-1), minlength=num_classes)

            accuracy_metric(predictions, labels)
            per_class_accuracy_metric(predictions, labels)

    overall_accuracy = accuracy_metric.compute()
    per_class_accuracy = per_class_accuracy_metric.compute()

    return overall_accuracy, per_class_accuracy, class_occurence, class_predictions


# device = torch.device("cuda")
# Example usage:
overall_accuracy, per_class_accuracy, class_freq, pred_freq = compute_accuracy(model, dataloader_Gleason, device)

print(f"Class frequency: {class_freq}")
print(f"Overall Accuracy: {overall_accuracy}")
print("Per-Class Accuracy:")
for class_idx, acc in enumerate(per_class_accuracy):
    print(f"Class {class_idx}: {acc.item()}")


In [None]:
bar_width = 0.25

plt.bar(np.arange(len(class_freq))-bar_width/2, class_freq/torch.sum(class_freq), width=bar_width, label="Label Frequency")
plt.bar(np.arange(len(pred_freq))+bar_width/2, pred_freq/torch.sum(pred_freq), width=bar_width, label="Prediction frequency")
plt.legend()
# plt.xticks(range(len(pred_freq)), list(map(lambda x: x[:20], CLASSES_NAMED)), rotation=90)
plt.yscale("linear")


In [None]:
dataset_org = GleasonX(DATA_PATH, "train", scaling="original", transforms=data_transform, label_level=LABEL_LEVEL)
max_pix = 0
min_pix = 1000000000
shape = None
m_shape = None
for slide in dataset_org.used_slides:

    path = dataset_org.tma_base_path/dataset_org.tma_paths[slide]

    img = Image.open(path)

    pix_size = img.size

    num_pixs = pix_size[0]*pix_size[1]

    if num_pixs > max_pix:
        max_pix = num_pixs
        shape = pix_size

    if num_pixs < min_pix:
        min_pix = num_pixs
        m_shape = pix_size
    print(slide)
print(max_pix)
print(shape)

print(min_pix)
print(m_shape)


In [None]:
resize_512_full_scale = alb.Compose([
    alb.augmentations.geometric.resize.SmallestMaxSize(max_size=512, interpolation=2, always_apply=False, p=1),
    alb.CenterCrop(width=512, height=512, p=1),
])


resize_512_scale025 = alb.Compose([
    alb.augmentations.geometric.resize.SmallestMaxSize(max_size=2048, interpolation=2, always_apply=False, p=1),
    alb.RandomResizedCrop(height=512, width=512, scale=(.1, .1), interpolation=2),
    # alb.RandomRotate90(p=1),
])

resize_512_scale05 = alb.Compose([
    alb.augmentations.geometric.resize.SmallestMaxSize(max_size=2048, interpolation=2, always_apply=False, p=1),
    alb.RandomResizedCrop(height=512, width=512, scale=(0.5, 0.5), interpolation=2),
    # alb.RandomRotate90(p=1),
])

# TODO das stimmt nocht nicht. Zoome ich jetzt rein oder raus?
rnd_scale = alb.Compose([alb.augmentations.geometric.resize.SmallestMaxSize(max_size=2048, interpolation=2, always_apply=False, p=1),
                         alb.RandomScale(scale_limit=(1+-0.5, 1+-0.5), p=1.0),
                         alb.RandomCrop(512, 512, p=1.0)])

# dataset_for_vis = GleasonX(DATA_PATH, "val", scaling="original", transforms=rnd_scale, label_level=1)


In [None]:


def _convert_to_random_scale_tuple(input):

    if isinstance(input, (float, int)):
        input = (input, input)

    if not isinstance(input, np.ndarray):
        input = np.array(input)

    return input


def create_scaling_crop(scale_factor, image_resize=2048, patch_size=512, crop="random"):

    assert crop in ["random", "center", None]

    if crop == "random":
        crop = alb.RandomCrop(patch_size, patch_size, p=1.0)
    elif crop == "center":
        crop = alb.CenterCrop(patch_size, patch_size, p=1.0)
    else:
        crop = alb.Identity()

    scale_factor = _convert_to_random_scale_tuple(scale_factor)

    # alb.RandomScale uses a bias of 1. For whatever reason. So scaling the image by 0.25 actually requires the input -0.75 wtf. Oh and it wants tuples.
    scale_factor -= 1
    scale_factor = tuple(scale_factor)

    scale_only = alb.Compose([alb.augmentations.geometric.resize.SmallestMaxSize(max_size=image_resize, interpolation=2, always_apply=False, p=1.0),
                              alb.RandomScale(scale_limit=scale_factor, p=1.0),
                              crop], p=1.0)

    return scale_only


def create_zoom_crop(zoom_factor, image_resize=2048, patch_size=512, crop="random"):

    zoom_factor = _convert_to_random_scale_tuple(zoom_factor)

    scale_factor = (zoom_factor * patch_size)/image_resize

    aug = create_scaling_crop(scale_factor, image_resize, patch_size, crop)
    return aug


def create_fraction_of_border_crop(border_fraction, image_resize=2048, patch_size=512, crop="random"):

    border_fraction = _convert_to_random_scale_tuple(border_fraction)
    zoom_factor = 1/border_fraction

    return create_zoom_crop(zoom_factor=zoom_factor, image_resize=image_resize, patch_size=patch_size, crop=crop)


def create_fraction_of_image_crop(image_fraction, image_resize=2048, patch_size=512, crop="random"):

    image_fraction = _convert_to_random_scale_tuple(image_fraction)

    border_length_fraction = np.sqrt(image_fraction)

    return create_fraction_of_border_crop(border_fraction=border_length_fraction, image_resize=image_resize, patch_size=patch_size, crop=crop)


In [None]:
def embed_scale_in_train_transform(croping_transform):
    return alb.Compose([
        # alb.RandomBrightnessContrast(p=0.25),33

        # alb.RandomResizedCrop(height=512, width=512, scale=(0.25, 1), interpolation=2),
        croping_transform,
        # Basic
        alb.HorizontalFlip(p=0.5),
        alb.RandomRotate90(p=1),

        # Morphology (scaling is already included in RandomResizedCrop)
        alb.OneOf([
            alb.ElasticTransform(),
            alb.GridDistortion(),], p=0.25),

        # Blur or Noise
        alb.OneOf([
            alb.AdvancedBlur(p=.25),
            alb.GaussNoise()], p=0.25),

        # Includes Brightness, Contrast, Hue and Saturation
        alb.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.2, p=0.5),

        # Normalization
        alb.Normalize(),
        # alb.augmentations.geometric.resize.SmallestMaxSize(max_size=512, interpolation=2, always_apply=False, p=1),
        # alb.RandomCrop(width=512, height=512, p=1)

    ])



In [None]:


def create_random_resized_crop(s, image_resize=2048):

    return alb.Compose([alb.augmentations.geometric.resize.SmallestMaxSize(max_size=image_resize, interpolation=2, always_apply=False, p=1.0),
                        alb.RandomResizedCrop(height=512, width=512, scale=(s, s), interpolation=2),])


In [None]:
dataset = STATISTICS_DATASET
idx = 2
strip_background = True

transforms = dataset.tissue_mask_kwargs.copy()
dataset.tissue_mask_kwargs = {"open": False, "close": False, "flood": False}
img, masks, background_mask = dataset.__getitem__(idx, False)
# org_img = augmentations.basic_transforms_val_test_colorpreserving(image=np.array(dataset.get_raw_image(idx)))["image"]

out = generate_model_output(model, img, device=device, label_remapping=LABEL_REMAPPING, inferer=SLIDING_WINDOW_INFERER)

out = torch.nn.functional.softmax(out, 0)

np_seg = np.array(out.argmax(dim=0)).astype(np.uint8)

org_img = np.array(dataset.get_raw_image(idx).resize(np_seg.shape))

colormap = ListedColormap(dataset.colormap.colors)
num_class_to_vis = dataset.num_classes

if strip_background:

    for mask in masks:
        mask += 1
        mask[background_mask] = 0

    np_seg += 1
    np_seg[background_mask] = 0

    colormap = ListedColormap(np.concatenate([np.array([[0., 0., 0., 1.]]), dataset.colormap.colors]))
    num_class_to_vis = dataset.num_classes + 1

f, axes = plt.subplots(2, 3+math.ceil(dataset.num_classes/2), sharex=False, sharey=False, constrained_layout=False, figsize=(12, 4))
f.tight_layout()


axes[0, 0].imshow(org_img)
axes[0, 0].set_title("Image", size=7)
axes[0, 0].set_axis_off()

# axes[1, 0].imshow(img)
axes[1, 0].imshow(np_seg.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1, interpolation="nearest")
axes[1, 0].set_title("Segmentation", size=7)
axes[1, 0].set_axis_off()

for sub_ax, mask in zip_longest(list(axes[:, 1:3].flatten()), masks):
    sub_ax.set_axis_off()

    if mask is not None:
        sub_ax.imshow(mask.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1, interpolation="nearest")
        sub_ax.set_title("Annotation", size=7)
plt.show()


In [None]:
for mask in masks:
    print(np.unique(mask.astype(int)))

print(np.unique(np_seg))
print(num_class_to_vis-1)


In [None]:
@widgets.interact(size=widgets.FloatSlider(7, min=0.0, max=7, step=.4))
def show_mat_bug(size):
    mask = masks[0]
    plt.figure(figsize=(size, size))
    plt.imshow(mask.astype(int), cmap=colormap, vmin=0, vmax=num_class_to_vis - 1)


In [None]:
img = np.ones((100, 100))
img[0:50, :] = 2
plt.figure(figsize=(2, 1))
plt.imshow(img, cmap=colormap, vmin=0, vmax=4)
