In [None]:
from matplotlib import pyplot as plt
from dataset import CamyleonDataset
from pathlib import Path
import random
import numpy as np
# from augmentations import Augmentations
# from masking import MaskingGenerator
import torch
from model import Model, init_model
from run_inference import infer_slide, remove_noise
from monai.data import WSIReader
import cv2
%matplotlib inline
%load_ext autoreload
%autoreload 2

In [None]:
preprocessed_data_file = Path("/home/espenbfo/Documents/projects/dinov2_wsi/camelyon.hdf5")

sizes = (96, 768, 1024)

dataset = CamyleonDataset(preprocessed_data_file, is_train=False, sizes=sizes)
print(dataset.files)

In [None]:
(image1, image2, _), label = dataset.__getitem__(2)

fig, (ax1, ax2) = plt.subplots(1,2, figsize=(40, 20))
image1 = np.moveaxis(image1.numpy(), 0, 2)
#image1[np.arange(224)%16==0]=0
#image1[:,np.arange(224)%16==0]=0
image2 = np.moveaxis(image2.numpy(), 0, 2)
image2[16*5]=0
image2[16*9]=0
image2[:,16*5]=0
image2[:,16*9]=0
image3 = image2[16*5:16*9, 16*5:16*9]

print(label)
print(image1.shape)
ax1.imshow(image1)
ax2.imshow(image2)


In [None]:
label = 2
index, key = dataset.label_to_index[label][0]
images = dataset.retrieve_patch_with_label(2, index, key, sizes=sizes)

fig, axes = plt.subplots(1,len(sizes), figsize=(10*len(sizes), 10))
print(len(images))
for i, ax in enumerate(axes):
    ax.imshow(np.moveaxis(images[i].numpy(), 0, 2))

In [None]:
SIZES_AND_BACKBONES = (
    (96, "phikon", None),
    (288, "normal", "weights/a100_full_87499.pth")
    )

FILENAME = f"weights{'-'.join(map(lambda x: str(x[0]), SIZES_AND_BACKBONES))}.pt"

model = init_model(2, SIZES_AND_BACKBONES)
SIZES = [x[0] for x in SIZES_AND_BACKBONES]

state_dict = torch.load(f"weights{'-'.join(map(lambda x: str(x), SIZES))}.pt")
model.load_state_dict(state_dict)

In [None]:
slide_id = 5 # 7
inference, ignored = infer_slide(dataset.files["images"][slide_id], 32, model.to("cuda"), distance_per_sample=96, sizes=SIZES)

In [None]:
noise_constant = 1

inference_no_noise  =remove_noise(inference, noise_constant)

fig, (ax1, ax2, ax3) = plt.subplots(1, 3, figsize=(20, 10))


ax1.imshow(inference, vmin=0, vmax=1)
ax2.imshow(inference_no_noise, vmin=0, vmax=1)
ax3.imshow(ignored)

In [None]:
masked_file = dataset.files["masks"][slide_id]
image_file = dataset.files["images"][slide_id]
print(masked_file)
print(dataset.files["images"][slide_id])

masked_reader = WSIReader(backend="tifffile")
masked_file = masked_reader.read(masked_file)
mask = masked_reader.get_data(masked_file, level=6, mode="Å")[0][0]


image_reader = WSIReader(backend="cucim")
image_file = image_reader.read(image_file)
image = np.moveaxis(image_reader.get_data(image_file, level=7, mode="RGB")[0], 0, 2)
fig, (ax1, ax2, ax3, ax4) = plt.subplots(1, 4, figsize=(20, 10))
print(np.unique(mask[0]))
ax1.imshow(inference+1-ignored)
ax1.set_title("Predictions")
ax2.imshow(inference_no_noise+1-ignored)
ax2.set_title("Predictions (noise removed)")
ax3.imshow(mask)
ax3.set_title("Target mask")
ax4.imshow(image)
ax4.set_title("Slide image")

In [None]:
def balanced_accuracy(tp, fn, tn, fp):
    return (tp.sum()/(tp.sum()+fn.sum())+tn.sum()/(tn.sum()+fp.sum()))/2
max_threshold = 0.99
N_THRESHOLDS = 300

thresholds = np.linspace(1-max_threshold, max_threshold, N_THRESHOLDS)

precs = []
recs = []
f1_scores = []

resized = cv2.resize(inference, mask.shape[::-1])

for threshold in thresholds:
    inference_resized = np.round((resized>threshold).astype(float)).astype(int)
    # inference_resized = np.round(cv2.resize((inference>threshold).astype(float), mask.shape[::-1])).astype(int)
    out = np.zeros((*inference_resized.shape, 3))
    out[:,:,1] = (mask==2)&(inference_resized==1)
    out[:,:,0] = ((mask!=2)&(inference_resized==1))
    #plt.imshow(out)

    true_positives = (mask==2)&(inference_resized==1)
    false_negative = (mask==2)&(inference_resized!=1)

    false_positive = (mask!=2)&(inference_resized==1)
    true_negative = (mask!=2)&(inference_resized!=1)

    n_positives = (mask==2).sum()
    n_negatives = (mask!=2).sum()
    n_predicted_positives = (inference_resized==1).sum()

    precision = true_positives.sum()/(true_positives.sum()+false_positive.sum())
    recall = true_positives.sum()/n_positives

    #print(f"Recall {recall}")
    #print(f"Precision {precision}")

    #print(f"Recall on negative predicions {true_negative.sum()/n_negatives}")

    #print(f"Balanced accuracy {balanced_accuracy(true_positives, false_negative, true_negative, false_positive)}")
    #print(f"f1 score {2/(1/recall+1/precision)}")

    precs.append(precision)
    recs.append(recall)
    f1_scores.append(2/(1/recall+1/precision))

max_f1_index = np.argmax(f1_scores)
max_f1 = f1_scores[max_f1_index]
print(f"Max f1: {max_f1:.3f} at threshold {thresholds[max_f1_index]}")
print(f"Precision and recall at max f1: {precs[max_f1_index]:.3f}, {recs[max_f1_index]:.3f}")
plt.plot(recs, precs, label="Precision Recall curve")
plt.plot(thresholds, f1_scores, label="f1_scores")
plt.legend()
plt.xlim(0, 1)
plt.ylim(0, 1)

In [None]:
threshold = (0.5, 0.8, 0.95, 0.98)

fig, axes = plt.subplots(1,len(threshold), figsize=(5*len(threshold), 10))
print(len(images))
for i, ax in enumerate(axes):
    ax.imshow(inference>threshold[i])


In [None]:
plt.imshow(inference)