In [None]:
import numpy as np
import matplotlib.pyplot as plt
from monkey.model.efficientunetb0.architecture import (
    get_efficientunet_b0_MBConv,
)
import skimage
import cv2
import torch
from monkey.config import TrainingIOConfig
from monkey.data.dataset import get_detection_dataloaders
from monkey.data.data_utils import (
    imagenet_denormalise,
    load_json_annotation,
)
from tqdm import tqdm
from monkey.model.utils import get_patch_F1_score

In [None]:
def erode_mask(mask, size=3):
    kernel = cv2.getStructuringElement(
        cv2.MORPH_ELLIPSE, (size, size)
    )
    mask = cv2.erode(mask, kernel, iterations=1)
    return mask

In [None]:
model = get_efficientunet_b0_MBConv(pretrained=False)
# model = smp.Unet(
#     encoder_name="mit_b5",
#     encoder_weights=None,
#     decoder_attention_type="scse",
#     in_channels=3,
#     classes=1,
# )

val_fold = 1

checkpoint_path = f"/home/u1910100/Documents/Monkey/runs/detection/efficientunetb0_seg_bm/fold_{val_fold}/epoch_16.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model"])
model.to("cuda")

IOconfig = TrainingIOConfig(
    dataset_dir="/home/u1910100/Documents/Monkey/patches_256",
    save_dir=f"./",
)

# Get dataloaders for task
train_loader, val_loader = get_detection_dataloaders(
    IOconfig,
    val_fold=val_fold,
    task=1,
    batch_size=1,
    disk_radius=11,
    do_augmentation=False,
)

In [None]:
# thresholds = [0.1]
thresholds = [0.5, 0.7, 0.9]
best_thresh = thresholds[0]
best_F1 = 0

visualization = False

for thresh in thresholds:
    print(f"threshold {thresh}")
    counter = 0
    sum_F1 = []
    sum_precison = []
    sum_recall = []
    for data in tqdm(val_loader):
        file_ids = data["id"]
        # print(file_ids)
        annotation_dict = load_json_annotation(file_ids[0], IOconfig)
        lymphocyte_coords = annotation_dict["lymphocytes"]
        monocyte_coords = annotation_dict["monocytes"]

        images = data["image"].cuda().float()
        gt_masks = data["mask"]

        image_np = images[0]
        image_np = image_np.cpu().numpy()
        gt_mask_np = gt_masks[0]
        gt_mask_np = gt_mask_np.cpu().numpy()

        image_np = np.moveaxis(image_np, 0, 2)
        image_np = imagenet_denormalise(image_np)

        with torch.no_grad():
            out = model(images)
            out = torch.sigmoid(out)

        out = out.cpu().detach().numpy()[0][0]
        out_mask = np.where(out >= thresh, 1, 0)
        out_mask = out_mask.astype(np.uint8)
        # out_mask = erode_mask(out_mask, 3)
        out_mask = out_mask.astype(bool)
        out_mask = skimage.morphology.remove_small_objects(
            ar=out_mask, min_size=15
        )
        out_mask = out_mask.astype(np.uint8)

        metrics = get_patch_F1_score(out_mask, gt_mask_np[0], out)
        f1, precision, recall = (
            metrics["F1"],
            metrics["Precision"],
            metrics["Recall"],
        )

        if visualization:
            fig, axs = plt.subplots(1, 4, figsize=(10, 10))
            axs[0].imshow(image_np)
            axs[0].title.set_text("Image")
            axs[1].imshow(gt_mask_np[0], cmap="gray")
            axs[1].title.set_text("Ground Truth")
            axs[2].imshow(out_mask, cmap="gray", alpha=0.6)
            axs[2].imshow(image_np, alpha=0.4)
            axs[2].title.set_text("Prediction")
            axs[3].imshow(out, cmap="jet")
            for ax in fig.axes:
                ax.axis("off")
            plt.show()

        sum_F1.append(f1)
        sum_precison.append(precision)
        sum_recall.append(recall)

        # counter +=1
        # if counter > 20:
        #     break

    sum_F1 = [x for x in sum_F1 if x is not None]
    sum_precison = [x for x in sum_precison if x is not None]
    sum_recall = [x for x in sum_recall if x is not None]

    print("Avg F1 ", np.mean(sum_F1))
    print("Median F1 ", np.median(sum_F1))
    print("Avg Precision ", np.mean(sum_precison))
    print("Avg Recall ", np.mean(sum_recall))

    if np.mean(sum_F1) > best_F1:
        best_F1 = np.mean(sum_F1)
        best_thresh = thresh

print(f"best threshold: {best_thresh}")
print(f"best F1: {best_F1}")

Multiclass Detection

In [None]:
model = get_efficientunet_b0_MBConv(pretrained=False,out_channels=2)
val_fold = 1

checkpoint_path = f"/home/u1910100/Documents/Monkey/runs/cell_multiclass_det/efficientunetb0/fold_{val_fold}/epoch_50.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model"])
model.to("cuda")

IOconfig = TrainingIOConfig(
    dataset_dir="/home/u1910100/Documents/Monkey/patches_256",
    save_dir=f"./",
)

# Get dataloaders for task
train_loader, val_loader = get_detection_dataloaders(
    IOconfig,
    val_fold=val_fold,
    task=1,
    batch_size=1,
    disk_radius=11,
    do_augmentation=False,
    module='multiclass_detection'
)

In [None]:
thresholds = [0.5, 0.7, 0.9]
best_thresh = thresholds[0]
best_F1 = 0

visualization = True

for thresh in thresholds:
    print(f"threshold {thresh}")
    counter = 0
    sum_F1 = []
    sum_precison = []
    sum_recall = []
    for data in tqdm(val_loader):
        file_ids = data["id"]
        print(file_ids)

        images = data["image"].cuda().float()
        gt_masks = data["mask"]

        image_np = images[0]
        image_np = image_np.cpu().numpy()
        gt_mask_np = gt_masks[0]
        gt_mask_np = gt_mask_np.cpu().numpy()

        image_np = np.moveaxis(image_np, 0, 2)
        image_np = imagenet_denormalise(image_np)

        with torch.no_grad():
            out = model(images)
            out = torch.softmax(out, dim=1)
            out = out.cpu().detach().numpy()
        
        lymphocyte_prob = out[0,0,:,:]
        monocyte_prob = out[0,1,:,:]


        # out = out.cpu().detach().numpy()[0][0]
        # out_mask = np.where(out >= thresh, 1, 0)
        # out_mask = out_mask.astype(np.uint8)
        # # out_mask = erode_mask(out_mask, 3)
        # out_mask = out_mask.astype(bool)
        # out_mask = skimage.morphology.remove_small_objects(
        #     ar=out_mask, min_size=15
        # )
        # out_mask = out_mask.astype(np.uint8)

        # metrics = get_patch_F1_score(out_mask, gt_mask_np[0], out)
        # f1, precision, recall = (
        #     metrics["F1"],
        #     metrics["Precision"],
        #     metrics["Recall"],
        # )

        if visualization:
            fig, axs = plt.subplots(1, 5, figsize=(10, 10))
            axs[0].imshow(image_np)
            axs[0].title.set_text("Image")
            axs[1].imshow(gt_mask_np[0], cmap="gray")
            axs[1].title.set_text("GT Lymphocyte")
            axs[2].imshow(gt_mask_np[1], cmap="gray")
            axs[2].title.set_text("GT Monocyte")
            axs[3].imshow(lymphocyte_prob, cmap="gray")
            axs[3].title.set_text("Lymphocyte Prob")
            axs[4].imshow(monocyte_prob, cmap="gray")
            axs[4].title.set_text("Monocyte Prob")
            for ax in fig.axes:
                ax.axis("off")
            plt.show()

        # sum_F1.append(f1)
        # sum_precison.append(precision)
        # sum_recall.append(recall)

        counter +=1
        if counter > 20:
            break

    sum_F1 = [x for x in sum_F1 if x is not None]
    sum_precison = [x for x in sum_precison if x is not None]
    sum_recall = [x for x in sum_recall if x is not None]

    print("Avg F1 ", np.mean(sum_F1))
    print("Median F1 ", np.median(sum_F1))
    print("Avg Precision ", np.mean(sum_precison))
    print("Avg Recall ", np.mean(sum_recall))

    if np.mean(sum_F1) > best_F1:
        best_F1 = np.mean(sum_F1)
        best_thresh = thresh

print(f"best threshold: {best_thresh}")
print(f"best F1: {best_F1}")