In [1]:
from src.jdt_losses import SoftCorrectDICEMetric
import os
import numpy as np
import torch
import matplotlib.pyplot as plt

from PIL import Image

from src.gleason_data import GleasonX
from pathlib import Path
import matplotlib.pyplot as plt
import torch

from torchmetrics import Dice

%load_ext autoreload
data_path = Path(os.environ["DATASET_LOCATION"] / "GleasonXAI")
assert data_path.exists()
from src.augmentations import basic_transforms_val_test_colorpreserving, normalize_only_transform

In [None]:
data_test = GleasonX(data_path, split="test", scaling="MicronsCalibrated", transforms=normalize_only_transform, label_level=1, create_seg_masks=True,
                     tissue_mask_kwargs={"open": False, "close": False, "flood": False}, drawing_order="grade_frame_order", explanation_file="final_filtered_explanations_df.csv", data_split=(0.7, 0.15, 0.15))
data_train = GleasonX(data_path, split="train", scaling="MicronsCalibrated", transforms=normalize_only_transform, label_level=1, create_seg_masks=True,
                      tissue_mask_kwargs={"open": False, "close": False, "flood": False}, drawing_order="grade_frame_order", explanation_file="final_filtered_explanations_df.csv", data_split=(0.7, 0.15, 0.15))
data_val = GleasonX(data_path, split="val", scaling="MicronsCalibrated", transforms=normalize_only_transform, label_level=1, create_seg_masks=True,
                    tissue_mask_kwargs={"open": False, "close": False, "flood": False}, drawing_order="grade_frame_order", explanation_file="final_filtered_explanations_df.csv", data_split=(0.7, 0.15, 0.15))

data_all = GleasonX(data_path, split="all", scaling="MicronsCalibrated", transforms=normalize_only_transform, label_level=1, create_seg_masks=True,
                    tissue_mask_kwargs={"open": False, "close": False, "flood": False}, drawing_order="grade_frame_order", explanation_file="final_filtered_explanations_df.csv", data_split=(0.7, 0.15, 0.15))

Compute the number of explanations per split

In [None]:
train = list(data_all.train_slides)
val = list(data_all.val_slides)
test = list(data_all.test_slides)


grouped = data_all.df[data_all.df["TMA_identifier"].isin(test)].groupby(["TMA_identifier", "explanations"]).size() > 0
grouped.groupby(["explanations"]).sum()

# Compute new dataset seed

In [None]:
from tqdm import tqdm

def get_label_counts_per_slide(data):

    num_classes = data[0][1].shape[0]

    labels_all = []
    labels_fg_only = []
    labels_unique_max_only = []

    for i in tqdm(range(len(data))):
        _, label, background = data[i]

        label_max = torch.max(label, dim=0)[0]
        duplicated_max = torch.sum(label == label_max.unsqueeze(0), dim=0) > 1
        unique_max = ~duplicated_max

        labels_all.append(label.sum(dim=(1, 2)))

        labels_fg_only.append(torch.sum(label[:, ~background].sum(dim=1)))
        labels_unique_max_only.append(torch.sum(label[:, torch.logical_and(~background, unique_max)].sum(dim=1)))
    return labels_all, labels_fg_only, labels_unique_max_only

labels, labels_fg, labels_unique_max = get_label_counts_per_slide(data_all)

In [None]:
from torch.utils.data import random_split

num_trials = 100000
slides = torch.arange(len(labels))
split = (0.7, 0.15, 0.15)
best_loss = 1000
best_seed = -1
for i in tqdm(range(100000)):
    train, val, test = random_split(
        slides, split, torch.Generator().manual_seed(i))


    train = torch.stack(list(train))
    val = torch.stack(list(val))
    test = torch.stack(list(test))

    def dist_comp(idcs):
        t =  torch.stack([labels[i] for i in idcs]).sum(dim=0)
        return t/t.sum()

    train_dist = dist_comp(train)[1:]
    val_dist = dist_comp(val)[1:]
    test_dist = dist_comp(test)[1:]

    def l1(a,b): return ((a-b).abs()/2).sum()

    loss = l1(train_dist, val_dist)+l1(val_dist, test_dist)+l1(train_dist, test_dist)

    if loss < best_loss:
        print(loss, i)
        best_loss = loss
        best_seed = i

In [None]:
from tqdm import tqdm
def get_label_counts(data):

    num_classes = data[0][1].shape[0]

    labels_all = torch.zeros(num_classes)
    labels_fg_only = torch.zeros(num_classes)
    labels_unique_max_only = torch.zeros(num_classes)
    
    for i in tqdm(range(len(data))):
        _, label, background = data[i]


        label_max = torch.max(label, dim=0)[0]
        duplicated_max = torch.sum(label == label_max.unsqueeze(0), dim=0) > 1
        unique_max = ~duplicated_max
        

        labels_all += label.sum(dim=(1,2))

        labels_fg_only += torch.sum(label[:,~background].sum(dim=1))
        labels_unique_max_only += torch.sum(label[:, torch.logical_and(~background, unique_max)].sum(dim=1))
    return labels_all, labels_fg_only, labels_unique_max_only



labels_train, label_fg_only_train, labels_unique_max_only_train = get_label_counts(data_train)
labels_val, label_fg_only_val, labels_unique_max_only_val = get_label_counts(data_val)
labels_test, label_fg_only_test, labels_unique_max_only_test = get_label_counts(data_test)
num_classes = 10


def get_freq_max(labels, bgs, unique_maxs):

    label_freq = labels.sum(dim=(0, 2, 3))
    label_freq /= label_freq.sum()

    max_freq = torch.zeros_like(label_freq)
    for label, bg, um in zip(labels, bgs, unique_maxs):
        fg_mask = ~bg
        unique_forground = torch.logical_and(fg_mask, um)
        max_freq += torch.bincount(label[:, unique_forground].argmax(dim=0).reshape(-1), minlength=num_classes)

    return label_freq, max_freq


#label_freq_train, max_freq_train = get_freq_max(labels_train, bgs_train, unique_max_train)
#label_freq_val, max_freq_val = get_freq_max(labels_val, bgs_val, unique_max_val)
#label_freq_test, max_freq_test = get_freq_max(labels_test, bgs_test, unique_max_test)

In [None]:
classes_named = data_train.classes_named
bar_width = 0.1
plt.bar(np.arange(num_classes)-2.5*bar_width, labels_train/labels_train.sum(), label="Train", width=bar_width)
# plt.bar(np.arange(num_classes)-1.5*bar_width, max_freq/torch.sum(max_freq), label="Current Anno.: Majority Vote", width=bar_width)

plt.bar(np.arange(num_classes)-0.5*bar_width, labels_val/labels_val.sum(), label="Val", width=bar_width)
# plt.bar(np.arange(num_classes)+0.5*bar_width, max_freq2/torch.sum(max_freq2), label="Custom Order Anno.: Majority Vote", width=bar_width)

plt.bar(np.arange(num_classes)+1.5*bar_width, labels_test/labels_test.sum(), label="Test", width=bar_width)
# plt.bar(np.arange(num_classes)+2.5*bar_width, max_freq3/torch.sum(max_freq3), label="Ordered Anno.: Majority Vote", width=bar_width)


# _ = plt.yscale("log")
_ = plt.legend()
_ = plt.xticks(np.arange(num_classes), list(map(lambda x: x[:20], classes_named)), rotation=45, ha="right", rotation_mode="anchor")
_ = plt.xlabel("Grouped Explanation")
_ = plt.ylabel("Proportion")

# Compare multiple drawing orders

This was used when discussing which drawing order we should use. 
The outputs will not be the same as before, as we changed other things (new dataframe, MicronsCalibrated resizing and the new seed from above etc.)

## IMPORTANT: This will crash the cells, as I throw an error in the Gleason data class if you take a drawing_order other than "grade_frame_order", as this is the setting we settled on. You will need to change that if you want to rerun this for some reason.

In [None]:

def get_labels_bg(data, drawing_order=None):
    labels = []
    bgs = []
    unique_maxs = []
    for i in tqdm(range(len(data))):
        _, label, background = data.__getitem__(i, drawing_order=drawing_order)
        labels.append(label)
        bgs.append(background)

        label_max = torch.max(label, dim=0)[0]
        duplicated_max = torch.sum(label == label_max.unsqueeze(0), dim=0) > 1
        unique_max = ~duplicated_max
        unique_maxs.append(unique_max)

    return labels, bgs, unique_maxs

# IMPORTANT: This is older. Back then I used the old data to make these comparisons. It does not matter to much if you just want to visualize them.
data = GleasonX(data_path, split="test", scaling="MicronsCalibrated", transforms=normalize_only_transform, label_level=1, create_seg_masks=True,
                    tissue_mask_kwargs={"open": False, "close": False, "flood": False}, drawing_order="grade_frame_order", explanation_file="final_filtered_explanations_df.csv", data_split=(0.7, 0.15, 0.15))

labels, bgs, unique_maxs = get_labels_bg(data)
labels2, bgs2, unique_maxs2 = get_labels_bg(data, "custom_order")
labels3, bgs3, unique_maxs3 = get_labels_bg(data, "frame_order")

In [20]:
# run if other drawing orders not available
labels2 = labels
labels3 = labels
bgs2 = bgs
bgs3 = bgs
unique_maxs2 = unique_maxs
unique_maxs3 = unique_maxs

In [None]:
import matplotlib.pyplot as plt
import numpy as np

num_classes=10
def get_freq_max(labels, bgs, unique_maxs):
    
    label_freq = torch.sum(torch.stack([label.sum(dim=(1, 2)) for label in labels]), dim=0)  # labels.sum(dim=(0, 2, 3)) but we dont a same sized first dimension, instead we have a list.
    label_freq /= label_freq.sum()

    max_freq = torch.zeros_like(label_freq)
    for label,bg,um in zip(labels, bgs, unique_maxs):
        fg_mask = ~bg
        unique_forground = torch.logical_and(fg_mask, um)
        max_freq+= torch.bincount(label[:, unique_forground].argmax(dim=0).reshape(-1),minlength=num_classes)

    return label_freq, max_freq


label_freq, max_freq = get_freq_max(labels, bgs, unique_maxs)
label_freq2, max_freq2 = get_freq_max(labels2, bgs2, unique_maxs2)
label_freq3, max_freq3 = get_freq_max(labels3, bgs3, unique_maxs3)

classes_named = data.classes_named

bar_width = 0.1
plt.bar(np.arange(num_classes)-2.5*bar_width, label_freq, label="Current: Probability Mass", width=bar_width)
#plt.bar(np.arange(num_classes)-1.5*bar_width, max_freq/torch.sum(max_freq), label="Current Anno.: Majority Vote", width=bar_width)

plt.bar(np.arange(num_classes)-0.5*bar_width, label_freq2, label="Custom Order: Soft-Label Probability Mass", width=bar_width)
#plt.bar(np.arange(num_classes)+0.5*bar_width, max_freq2/torch.sum(max_freq2), label="Custom Order Anno.: Majority Vote", width=bar_width)

plt.bar(np.arange(num_classes)+1.5*bar_width, label_freq3, label="Creation ordered: Soft-Label Probability Mass", width=bar_width)
#plt.bar(np.arange(num_classes)+2.5*bar_width, max_freq3/torch.sum(max_freq3), label="Ordered Anno.: Majority Vote", width=bar_width)


#_ = plt.yscale("log")
_ = plt.legend()
_ = plt.xticks(np.arange(num_classes), list(map(lambda x: x[:20], classes_named)), rotation=45, ha="right", rotation_mode="anchor")
_ = plt.xlabel("Grouped Explanation")
_ = plt.ylabel("Proportion")

# inv_weights = 1/(max_freq/torch.sum(max_freq))
# inv_weights /= torch.sum(inv_weights)
# print(inv_weights.tolist())

In [None]:
data.df["dataset"] = data.df["group"].apply(lambda x: 0 if x < 3 else 1 if x < 10 else 2)
for dataset, frame in data.df.groupby("dataset"):
    slides = set(frame["TMA_identifier"].unique())

    inter = set(data.used_slides).intersection(slides)

    idcs = sorted([data.used_slides.index(slide_name) for slide_name in inter])
    print(idcs)


In [23]:
DICE = Dice(num_classes=10, average="micro")
DICE_mac = Dice(num_classes=10, average="macro")

In [None]:
def get_DICE_aggreement(data, drawing_order=False):

    def get_labels_with_drawing_order(data, drawing_order=False):
        labels = []
        bgs = []
        for i in range(len(data)):
            _, label, background = data.__getitem__(i, drawing_order=drawing_order, prepare_torch=False)
            labels.append(label)
            bgs.append(background)

        return labels, bgs


    labels, bgs = get_labels_with_drawing_order(data, drawing_order)

    DICE = Dice(num_classes=10, average="micro")
    DICE_mac = Dice(num_classes=10, average="macro")

    for img_idx, (imgs, bg) in enumerate(zip(labels, bgs)):
        num_annos = len(imgs)

        l = torch.tensor(imgs)

        for i in range(num_annos):
            for j in range(num_annos):
                if j > i:
                    DICE.update(l[i],l[j])
                    DICE_mac.update(l[i], l[j])
                else:
                    continue

    print(f"{drawing_order}: ", DICE.compute(), DICE_mac.compute())

get_DICE_aggreement(data, "grade_frame_order")
#get_DICE_aggreement(data, "frame_order")
#get_DICE_aggreement(data, "custom_order")

In [27]:
bla = data.df.groupby("TMA_identifier")["explanations"].unique()

ranks = np.zeros((10,10))

for img in bla:
    
    for rank, exp in enumerate(img):

        exp_num = data.exp_number_mapping[exp]

        ranks[exp_num, rank] += 1
    

In [None]:
from src.gleason_utils import create_composite_plot
from src import augmentations
import numpy as np

import ipywidgets as wid

# Again you can commnet our th rest
@wid.interact(idx=(0,len(data)))
def plot_compare_data(idx):
    #_, masks1, bg = data.__getitem__(idx, False, drawing_order="custom_order")
    #_, masks2, bg = data.__getitem__(idx, False, drawing_order="classic")
    #_, masks3, bg = data.__getitem__(idx, False, drawing_order="frame_order")
    _, masks4, bg = data.__getitem__(idx, False, drawing_order="grade_frame_order")


    ref_mask = masks4[0]
    img = augmentations.basic_transforms_val_test_colorpreserving(image=np.array(data.get_raw_image(idx)))["image"]
    masks = {"1":np.zeros_like(ref_mask)}
    #masks = masks|{f"Custom Order {i}":mask for i, mask in enumerate(masks1)}
    #if len(masks1) == 3:
    #    masks = masks | {"2": np.zeros_like(ref_mask)}

    #masks = masks | {f"Classic Order {i}": mask for i, mask in enumerate(masks2)}
    #if len(masks1) == 3:
    #    masks = masks | {"3": np.zeros_like(ref_mask)}
    #masks = masks | {f"Frame Order {i}": mask for i, mask in enumerate(masks3)}
    masks = masks | {f"Grade Frame Order {i}": mask for i, mask in enumerate(masks4)}

    _ = create_composite_plot(data, None, masks, None, label_level=1, only_show_existing_annotation=True)
#plot_compare_data(10)

In [None]:
import cv2
from matplotlib.colors import ListedColormap

def plot_mask(data, mask, label_level, background, save_path):
    if label_level is None:
            label_level = data.label_level

    colormap = data.colormap
    num_class_to_vis = data.num_classes

    if background is not None:

        mask += 1
        mask[cv2.resize(background.astype(np.uint8), mask.shape, interpolation=cv2.INTER_NEAREST_EXACT).astype(bool)] = 0

        colormap = ListedColormap(np.concatenate([np.array([[0., 0., 0., 1.]]), data.colormap.colors]))
        num_class_to_vis = data.num_classes + 1


    if save_path is not None:
        color_palette = (colormap.colors[:, :3] * 255).astype(np.uint8)

        pil_img = Image.fromarray(mask.astype(np.uint8), mode='P')
        pil_img.putpalette(color_palette)
        pil_img.save(save_path)

    else:
        encountered_classes = set()

        f, ax = plt.subplots(1,1)
        ax.set_axis_off()

        encountered_classes |= set(np.unique(mask))

        ax.imshow(mask.astype(int), alpha=0.8,  cmap=colormap, vmin=0, vmax=num_class_to_vis, interpolation_stage="rgba")
        ax.set_axis_off()
        return f, ax

In [None]:
save_path_base = Path("figures/background_graphic/")
save_path_img = save_path_base/"imgs"
save_path_nobg = save_path_base/"nobg"
save_path_wbg = save_path_base/"wbg"
save_path_bg = save_path_base/"bg"

save_path_img.mkdir(exist_ok=True, parents=True)
save_path_nobg.mkdir(exist_ok=True, parents=True)
save_path_wbg.mkdir(exist_ok=True, parents=True)
save_path_bg.mkdir(exist_ok=True, parents=True)

for idx in range(10):
    _, masks, bg = data.__getitem__(idx, False)
    img = augmentations.basic_transforms_val_test_colorpreserving(image=np.array(data.get_raw_image(idx)))["image"]
    img = Image.fromarray(img, mode="RGB")
    img.save(save_path_img/f"{idx}.png")

    bg_img = Image.fromarray(bg.astype(np.uint8)*255, mode='L').convert("1")
    bg_img.save(save_path_bg/f"{idx}.png")

    for a, m in enumerate(masks):
        plot_mask(data, m, 1, None, save_path_nobg/f"{idx}_{a}.png")
        plot_mask(data, m, 1, bg, save_path_wbg/f"{idx}_{a}.png")