In [None]:
%load_ext autoreload
%autoreload 2

import random
from pathlib import Path
from typing import Literal
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch.utils.data import DataLoader

import src.augmentations as augmentations
import src.model_utils
from src.gleason_data import GleasonX, prepare_torch_inputs
from src.gleason_utils import create_composite_plot
from src.lightning_modul import LitSegmenter
import src.tree_loss as tree_loss
from matplotlib import colormaps as cm
from matplotlib.colors import ListedColormap
from src.tree_loss import generate_label_hierarchy
from torchmetrics import Accuracy, ConfusionMatrix
from src.lightning_modul import LitClassifier
from src.gleason_data import GleasonXClassification
from itertools import zip_longest
import math
from textwrap import wrap
import torchvision.transforms as tt
from ipywidgets import widgets
from monai.inferers import *
from tqdm import tqdm
from itertools import chain
from torch.utils.data import Subset
import plotly.graph_objects as go
import os
import ipywidgets.widgets as wid
from PIL import Image
import albumentations as alb
import sklearn
import ipywidgets as wid

import sys
import sklearn.metrics
from monai.inferers import sliding_window_inference
import torchvision

In [None]:
try:
    sys.modules["tree_loss"] = sys.modules["src.tree_loss"]
except KeyError:
    pass

try:
    sys.modules["robust_loss_functions"] = sys.modules["src.robust_loss_functions"]
except KeyError:
    pass

try:
    sys.modules["loss_functions"] = sys.modules["src.loss_functions"]
except KeyError:
    pass

try:
    sys.modules["augmentations"] = sys.modules["src.augmentations"]
except KeyError:
    pass


try:
    sys.modules["model_utils"] = sys.modules["src.model_utils"]
except KeyError:
    pass

In [None]:
DATA_PATH = "/home/datasets/GleasonXAI"
DATA_PATH = Path(DATA_PATH)

model_base_path = [Path("~/experiments/Gleason/GleasonClassification").expanduser()]

In [None]:
BATCH_SIZE = 16
NUM_WORKERS = 8

label_level = 1

def find_best_models(base_paths):
    exp_dict = {}
    for base_path in base_paths:
        for model_path in Path(base_path).glob("**/*.ckpt"):
            experiment_path = Path(model_path).relative_to(base_path.parent)
            *experiment_super_folders, experiment_name, version_number, checkpoints, model_name = experiment_path.parts

            if len(experiment_super_folders) > 1:
                experiment_super_folder = str(Path(os.path.join(*experiment_super_folders[1:])))
            elif len(experiment_super_folders) == 0:
                experiment_super_folder = ""
            else:
                experiment_super_folder = experiment_super_folders[0]


            # experiment_path = model_path.parent.parent  # Get experiment directory
            # experiment_name = experiment_path.name
            # version_number = experiment_path.parent.name
            exp_name = f"{experiment_super_folder}/{experiment_name}/{version_number}"

            if exp_name in exp_dict and not "best_model" in model_name:
                continue
            exp_dict[exp_name] = model_path
    return exp_dict

model_dict = find_best_models(model_base_path)


transforms_dict = {"1024": augmentations.basic_transforms_val_test_scaling1024,
                      "2048": augmentations.basic_transforms_val_test_scaling2048, "512": augmentations.basic_transforms_val_test_scaling512, "center": augmentations.basic_transforms_val_test, "random": augmentations.basic_random_crop_val_test, "efficient_net_random":augmentations.effb4_random_crop_val_test}


classification_model = None
dataset = None
dataset_seg = None
dataloader = None
dataloader_seg = None
num_classes  = None
split = None
transforms_scaling = None
model_path = None
label_file = None
scaling = None
transform = None

device = torch.device("cuda:0")


def create_model_dataset():
    global model, dataset, dataset_seg, num_classes, split, transforms_scaling

    # Create drowdown
    def c_d(options, value=None, desc=""):
        if value is None:
            value = options[0]
        return wid.Dropdown(options=options, value=value, description=desc)
    

    model_dropdown = c_d(sorted(list(model_dict.keys())), desc="model")
    split_dropdown = c_d(["train", "val", "test", "all"], "val", "split")
    scaling_dropdown = c_d(["original", "1024", "2048"],  desc="Scaling")
    transform_drowdown = c_d(list(transforms_dict.keys()), desc="Transform")
    label_file_dropdown = c_d(["label_remapping.json", "label_remapping_coarser.json"], desc="label file")

    status_label = wid.Label(value="Status: up-to-date", color="green")

    def create_model(*args):

        global classification_model, model_path

        SELECTED_MODEL = model_dropdown.value
        
        model_path = model_dict[SELECTED_MODEL]
        classification_model = LitClassifier.load_from_checkpoint(model_path)

    
    def create_dataset(*args):

        global dataset, dataset_seg, dataloader, dataloader_seg, split, num_classes, transforms_scaling, scaling, label_file

        SELECTED_SPLIT = split_dropdown.value
        SELECTED_SCALING = scaling_dropdown.value
        SELECTED_TRANSFORM = transform_drowdown.value
        SELECTED_LABEL_FILE = label_file_dropdown.value

        transform = SELECTED_TRANSFORM
        scaling = SELECTED_SCALING
        label_file = SELECTED_LABEL_FILE
        split = SELECTED_SPLIT
        transform_scaling = transforms_dict[SELECTED_TRANSFORM]

        dataset = GleasonXClassification(DATA_PATH, SELECTED_SPLIT, scaling=SELECTED_SCALING, transforms=transform_scaling,
                                         label_level=label_level, label_hierarchy_file=SELECTED_LABEL_FILE)

        dataset_seg = GleasonX(DATA_PATH, split=SELECTED_SPLIT, scaling=SELECTED_SCALING,
                               transforms=transform_scaling, label_level=label_level, label_hierarchy_file=SELECTED_LABEL_FILE)
        
        dataloader = DataLoader(dataset=dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=False)
        dataloader_seg = DataLoader(dataset=dataset_seg, batch_size=BATCH_SIZE, shuffle=False, num_workers=NUM_WORKERS, pin_memory=False)

        num_classes = dataset.num_classes

    #create_model()
    #create_dataset()

    def obs_wid(wids, observable):
        for w in wids:
            w.observe(observable)
    


    def on_button_click(*args):

        status_label.value = "Status: Running"
        status_label.style = {"color": "yellow"}
        create_model()
        create_dataset()

        status_label.value = "Status: up-to-date"
        status_label.style = {"color": "green"}


    def on_change(change):
        status_label.value = "Status: needs update"
        status_label.style = {"color": "red"}
    

    button = wid.Button(description="Update")
    button.on_click(on_button_click)

    obs_wid([model_dropdown, split_dropdown, scaling_dropdown, transform_drowdown, label_file_dropdown], on_change)

    menu = wid.VBox([model_dropdown, split_dropdown, scaling_dropdown, transform_drowdown, label_file_dropdown, status_label, button])

    return menu


menu = create_model_dataset()
menu

In [None]:
preds = []
labels = []

iterations = 5

for it in range(iterations):
    for batch in tqdm(dataloader):

        imgs, masks, background_masks = batch  # STATISTICS_DATASET[i]

        with torch.no_grad():
            outs = classification_model(imgs.to(device))
            preds.append(outs)
            labels.append(masks)

preds = torch.cat(preds, dim=0).cpu()
sigmoids = torch.nn.functional.sigmoid(preds)
labels = torch.cat(labels).cpu()

In [None]:
classes = dataset.classes_named #["Benign", "3", "4", "5"]

f, ax = plt.subplots(2,2)
f.tight_layout()
ax = ax.T
thresholds = torch.linspace(0,1,100)
at_least_once = labels > 0.0
majority = labels >= 0.5


def get_accs_at_level(inputs, targets, threshold):
    return ((inputs >= threshold) == targets).float().mean(dim=0)


for a, targets, title in zip(ax,[at_least_once, majority], ["once", "majority"]):


    accs_at_level = torch.stack([get_accs_at_level(sigmoids, targets, t) for t in thresholds])

    for i, (cl, c) in enumerate(zip(dataset.classes_named, dataset.colormap.colors)):
        fpr, tpr, _ = sklearn.metrics.roc_curve(targets.numpy()[:, i], sigmoids.numpy()[:, i])
        precision, recall , _ =  sklearn.metrics.precision_recall_curve(targets.numpy()[:, i], sigmoids.numpy()[:, i])

        a[0].plot(fpr, tpr, label=cl + f" AUROC:{float(sklearn.metrics.auc(fpr, tpr)):.2f}, AUPRC:{float(sklearn.metrics.auc(recall, precision)):.2f}", color=c)
        a[0].set_xlabel("FPR")
        a[0].set_ylabel("TPR")
        a[0].set_title(title)

        a[1].plot(recall, precision, label=cl + f" AUROC:{float(sklearn.metrics.auc(fpr, tpr)):.2f}, AUPRC:{float(sklearn.metrics.auc(recall, precision)):.2f}", color=c)
        a[1].set_xlabel("Recall")
        a[1].set_ylabel("Precision")

        a[1].set_title(title)

    a[0].plot([0,1], [0,1], color="k", linestyle="--", label="Random guess")
    a[1].plot([0,1], [1,0], color="k", linestyle="--")
plt.tight_layout()
plt.legend(bbox_to_anchor=(1.05, 1), loc='center left')  # Move legend to the right

img_save_path = Path("./results"/model_path.relative_to(model_base_path[0]).parents[2])
img_save_path.mkdir(exist_ok=True, parents=True)
plt.savefig(img_save_path/"curves.png")

def find_optimal_thresholds(inputs, targets, thresholds):
    accs_at_level = torch.stack([get_accs_at_level(inputs, targets, t) for t in thresholds])

    best_accs_per_class, t_index = torch.max(accs_at_level, dim=0)

    best_thresholds = [float(thresholds[t]) for t in t_index]

    return best_accs_per_class, best_thresholds


bestaccsclass, class_threshold = find_optimal_thresholds(sigmoids, majority, thresholds)

In [None]:
roc_class_scores = {}
prrec_class_scores = {}
class_size = {}

for i, class_name in enumerate(dataset.classes_named[1:9], start=1):

    roc_class_scores[("once", class_name)] = sklearn.metrics.roc_auc_score(at_least_once.numpy()[:, i], sigmoids.numpy()[:, i])
    roc_class_scores[("majority", class_name)] = sklearn.metrics.roc_auc_score(majority.numpy()[:, i], sigmoids.numpy()[:, i])
    prrec_class_scores[("once", class_name)] = sklearn.metrics.average_precision_score(at_least_once.numpy()[:, i], sigmoids.numpy()[:, i])
    prrec_class_scores[("majority", class_name)] = sklearn.metrics.average_precision_score(majority.numpy()[:, i], sigmoids.numpy()[:, i])
    class_size[("once", class_name)] = at_least_once[:, i].sum(axis=0).item()
    class_size[("majority", class_name)] = majority[:,i].sum(axis=0).item()
    
result_df = pd.concat([pd.Series(roc_class_scores), pd.Series(prrec_class_scores), pd.Series(class_size)], axis="columns")
result_df.columns = ["AUROC", "AUCPRC", "class_size"]

In [None]:
result_df = result_df.sort_index()
result_df.index = result_df.index.set_names(["label_type", "class"])
result_df = result_df.reset_index()
result_df["run"] = model_path.relative_to(model_base_path[0])
result_df["split"] = split
result_df["scaling"] = scaling
result_df["label_file"] = label_file
result_df["label_level"] = label_level
result_df["label_file"] = label_file
result_df = result_df.set_index(keys=["run", "split", "scaling", "label_level", "label_file", "label_type", "class"])

if not (img_save_path/"result_frame.csv").exists():
    result_df.to_csv(img_save_path/"result_frame.csv")
else:
    old_frame = pd.read_csv(img_save_path/"result_frame.csv")
    old_frame = old_frame.set_index(keys=["run", "split", "scaling", "label_level", "label_file", "label_type", "class"])

    merged_frame = pd.merge(old_frame, result_df, how="outer", left_index=True, right_index=True, )
    
    old_frame = pd.concat([old_frame, result_df], axis="index", ignore_index=True)
    old_frame.to_csv(img_save_path/"result_frame.csv")

In [None]:
a = pd.DataFrame({"A":["a","b","c"], "B":[0,1,2] }, index=[0,1,2])
b = pd.DataFrame({"A":["d","e","f"], "B":[2,4,5] },index=[3,4,5])
pd.merge(a,b, how="outer", on=["B"])


In [None]:
class_freq_maj = [majority.numpy()[:,i].sum()/len(majority) for i in range(majority.shape[1])]
class_freq_ao = [at_least_once.numpy()[:, i].sum()/len(at_least_once) for i in range(at_least_once.shape[1])]
pred_freq = [(sigmoids.numpy()[:, i]>= thres).sum()/len(at_least_once) for i, thres in enumerate(class_threshold)]

num_plots = 3

_= plt.bar(np.arange(num_classes)-(1/(2*num_plots)),class_freq_maj, width=1/(2*num_plots))
_ = plt.bar(np.arange(num_classes), class_freq_ao, width=1/(2*num_plots))
_ = plt.bar(np.arange(num_classes)+(1/(2*num_plots)), pred_freq, width=1/(2*num_plots))

pred_freq

_=plt.xticks(range(num_classes), [cl if len(cl)<20 else cl[:20-3]+"..." for cl in classes], rotation=45, horizontalalignment="right")

In [None]:
device = torch.device("cuda:0")
class SplattingClassifier(LitClassifier):

    def forward(self, x, *args, **kwargs):
        out = super().forward(x, *args, **kwargs)

        b, _, h, w = x.shape
        out_large = torch.zeros((b, out.shape[1], h, w), device=out.device)

        for b_i in range(b):
            for c_i in range(out.shape[1]):
                out_large[b_i, c_i, :, :] = out[b_i, c_i]
        return out_large


sc = SplattingClassifier.load_from_checkpoint(model_path).to(device)
sc.model = sc.model.to(device)
segmentation_dataset = GleasonX(DATA_PATH, split=split, scaling="original", transforms=augmentations.basic_transforms_val_test_scaling1024, label_level=label_level)
sw_inf = SlidingWindowInferer(roi_size=(512, 512), sw_batch_size=4, overlap=0.9, mode="gaussian")

In [None]:
import torchvision
@widgets.interact(idx=widgets.IntSlider(0, 0,100), apply_background=True)
def plot_class(idx, apply_background):

    img, masks, background_mask = segmentation_dataset.get(idx, False)
    img_torch = torchvision.transforms.functional.to_tensor(img)
    
    #org_img = segmentation_dataset.get_raw_image(idx)
    with torch.no_grad():
        out = classification_model(img_torch.unsqueeze(0).to(device)).squeeze().detach().cpu().numpy()
        out_seg = sw_inf(img_torch.unsqueeze(0).to(device), sc).squeeze().detach().cpu().numpy()

    f, axes = plt.subplots(2,3 + math.ceil(segmentation_dataset.num_classes/2))
    axes = axes.T.flatten()

    axes[0].imshow(img)
    axes[0].set_axis_off()


    axes[1].set_axis_off()
    class_mask = np.zeros_like(masks[0])
    x_len = class_mask.shape[0]
    print(out >= class_threshold)
    for i in range(len(out)):
        if out[i] >= class_threshold[i]:
            class_mask[:, i*x_len//len(out):(i+1)*x_len//len(out)] = i+1


    class_cmap = ListedColormap(np.concatenate([np.array([[0,0,0,1]]), segmentation_dataset.colormap.colors]))
    axes[1].imshow(class_mask, cmap=class_cmap, vmin=0, vmax=segmentation_dataset.num_classes)


    for ax, mask in zip_longest(axes[2:6], masks):
        if ax is not None and mask is not None:

            if apply_background:
                mask += 1
                mask[background_mask] = 0.0
                used_cmap = class_cmap
                ax.imshow(mask.astype(int), used_cmap, vmin=0, vmax=segmentation_dataset.num_classes)

            else:
                used_cmap = segmentation_dataset.colormap
                ax.imshow(mask.astype(int), used_cmap, vmin=0, vmax=segmentation_dataset.num_classes-1)
        
        if ax is not None:
            ax.set_axis_off()

    for i,ax in enumerate(axes[6:6+segmentation_dataset.num_classes+1]):

        tmp_cmap = ListedColormap(np.concatenate([np.array([[0, 0, 0, 1]]), [segmentation_dataset.colormap.colors[i]]]))

        pred_class_mask = out_seg[i] >= class_threshold[i]

        if apply_background:
            pred_class_mask[background_mask] = 0.0

        ax.imshow(pred_class_mask, cmap=tmp_cmap, vmin=0, vmax=1)
        ax.set_axis_off()
        ax.set_title(segmentation_dataset.classes_named[i][:25], size=8, rotation=45, horizontalalignment="left")


    _ = plt.show()

In [None]:
img, _, _ = segmentation_dataset.get(0, False)
img = torchvision.transforms.functional.to_tensor(img).unsqueeze(0).to(device)

with torch.no_grad():
    out_splated = sw_inf(img, sc)

plt.imshow(out_splated[0, 1].cpu().detach().numpy(), cmap=cm["Greys"].reversed())