In [None]:
import sys

sys.path.append("../")
import numpy as np
import matplotlib.pyplot as plt
from monkey.model.efficientunetb0.architecture import (
    get_efficientunet_b0_MBConv,
)
import skimage
import cv2
import torch
from monkey.config import TrainingIOConfig
from monkey.data.dataset import get_detection_dataloaders
from monkey.data.data_utils import (
    imagenet_denormalise,
    load_json_annotation,
)
from tqdm import tqdm
from monkey.model.utils import (
    get_patch_F1_score,
    get_patch_F1_score_batch,
    get_multiclass_patch_F1_score_batch,
)
from skimage.morphology import remove_small_objects

In [None]:
def erode_mask(mask, size=3):
    kernel = cv2.getStructuringElement(
        cv2.MORPH_ELLIPSE, (size, size)
    )
    if mask.ndim == 4:
        for i in range(mask.shape[0]):
            for j in range(mask.shape[1]):
                mask[i, j, :, :] = cv2.erode(
                    mask[i, j, :, :], kernel, iterations=1
                )
    else:
        mask = cv2.erode(mask, kernel, iterations=1)

    return mask


def morphological_post_processing(mask, size=3):
    kernel = cv2.getStructuringElement(
        cv2.MORPH_ELLIPSE, (size, size)
    )
    if mask.ndim == 4:
        for i in range(mask.shape[0]):
            for j in range(mask.shape[1]):
                mask[i, j, :, :] = cv2.morphologyEx(
                    mask[i, j, :, :], cv2.MORPH_OPEN, kernel
                )
                mask[i, j, :, :] = cv2.morphologyEx(
                    mask[i, j, :, :], cv2.MORPH_CLOSE, kernel
                )
    else:
        mask = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel)
        mask = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    return mask


def filter_objects_by_size(label_image, min_size=0, max_size=None):
    small_removed = remove_small_objects(label_image, min_size)
    if max_size is not None:
        mid_removed = remove_small_objects(small_removed, max_size)
        large_removed = label_image - mid_removed
        return large_removed
    else:
        return small_removed

Overall Detection (binary)

In [None]:
model = get_efficientunet_b0_MBConv(pretrained=False)
# model = smp.Unet(
#     encoder_name="mit_b5",
#     encoder_weights=None,
#     decoder_attention_type="scse",
#     in_channels=3,
#     classes=1,
# )

val_fold = 5

checkpoint_path = f"/home/u1910100/Documents/Monkey/runs/detection/efficientunetb0_seg/fold_{val_fold}/epoch_75.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model"])
model.eval()
model.to("cuda")

nuclick_mask = False

IOconfig = TrainingIOConfig(
    dataset_dir="/home/u1910100/Documents/Monkey/patches_256",
    save_dir=f"./",
)
if nuclick_mask:
    IOconfig.set_mask_dir(
        "/home/u1910100/Documents/Monkey/patches_256/annotations/nuclick_hovernext"
    )

# Get dataloaders for task
train_loader, val_loader = get_detection_dataloaders(
    IOconfig,
    val_fold=val_fold,
    task=1,
    batch_size=32,
    disk_radius=11,
    do_augmentation=False,
    use_nuclick_masks=nuclick_mask,
)

In [None]:
# thresholds = [0.1]
thresholds = [0, 3, 0.5, 0.7]
# thresholds = [0.5]
best_thresh = thresholds[0]
best_F1 = 0

# visualization = False

for thresh in thresholds:
    print(f"threshold {thresh}")
    counter = 0
    sum_F1 = []
    sum_precision = []
    sum_recall = []
    for data in tqdm(val_loader):
        file_ids = data["id"]

        images = data["image"].cuda().float()
        gt_masks = data["mask"].cuda().float()

        with torch.no_grad():
            out = model(images)
            out = torch.sigmoid(out)

        out_mask = (out > 0.5).float()
        out_mask = out_mask.numpy(force=True).astype(np.uint8)
        out_mask = erode_mask(out_mask, 3)
        out_mask = morphological_post_processing(out_mask, 3)

        metrics = get_multiclass_patch_F1_score_batch(
            out_mask, gt_masks, out
        )
        f1, precision, recall = (
            metrics["F1"],
            metrics["Precision"],
            metrics["Recall"],
        )

        sum_F1.append(f1)
        sum_precision.append(precision)
        sum_recall.append(recall)

    sum_F1 = [x for x in sum_F1 if x is not None]
    sum_precision = [x for x in sum_precision if x is not None]
    sum_recall = [x for x in sum_recall if x is not None]

    print("Avg F1 ", np.mean(sum_F1))
    print("Median F1 ", np.median(sum_F1))
    print("Avg Precision ", np.mean(sum_precision))
    print("Avg Recall ", np.mean(sum_recall))

    if np.mean(sum_F1) > best_F1:
        best_F1 = np.mean(sum_F1)
        best_thresh = thresh

print(f"best threshold: {best_thresh}")
print(f"best F1: {best_F1}")

Visualize Prediciton

In [None]:
train_loader, val_loader = get_detection_dataloaders(
    IOconfig,
    val_fold=val_fold,
    task=1,
    batch_size=1,
    disk_radius=11,
    do_augmentation=False,
    module="detection",
    use_nuclick_masks=nuclick_mask,
)

thresh = 0.5

counter = 0
for data in val_loader:
    file_ids = data["id"]
    images = data["image"].cuda().float()
    gt_masks = data["mask"]

    with torch.no_grad():
        out = model(images)
        out = torch.sigmoid(out)

    out_mask = (out > 0.5).float()
    out_mask = out_mask.numpy(force=True).astype(np.uint8)
    out_mask = erode_mask(out_mask, 3)
    out_mask = morphological_post_processing(out_mask, 3)
    # prob = out[0, 0, :, :]

    lymphocyte_pred = out_mask[0, 0, :, :]
    # pred = np.where(
    #     prob > thresh, 1, 0
    # ).astype(np.uint8)
    # pred = erode_mask(pred, 7)
    # pred = filter_objects_by_size(pred, 300, 60000)

    image_np = images.numpy(force=True)[0]
    gt_mask_np = gt_masks.numpy(force=True)[0, 0]
    image_np = np.moveaxis(image_np, 0, 2)
    image_np = imagenet_denormalise(image_np)

    fig, axs = plt.subplots(1, 3, figsize=(8, 18))
    axs[0].imshow(image_np)
    axs[0].title.set_text("Image")

    axs[1].imshow(gt_mask_np, cmap="gray")
    axs[1].title.set_text("Ground Truth")

    axs[2].imshow(lymphocyte_pred, cmap="gray")
    axs[2].title.set_text("Pred")

    # axs[3].imshow(prob, cmap='jet')
    # axs[3].title.set_text("Probs")

    for ax in fig.axes:
        ax.axis("off")
    plt.show()
    counter += 1
    if counter > 20:
        break

Multiclass Detection

In [None]:
model = get_efficientunet_b0_MBConv(pretrained=False, out_channels=3)
val_fold = 1

has_background_channel = True

checkpoint_path = f"/home/u1910100/Documents/Monkey/runs/cell_multiclass_det/efficientunetb0_seg_3_channel/fold_{val_fold}/epoch_10.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model"])
model.eval()
model.to("cuda")

IOconfig = TrainingIOConfig(
    dataset_dir="/home/u1910100/Documents/Monkey/patches_256",
    save_dir=f"./",
)

# Get dataloaders for task
train_loader, val_loader = get_detection_dataloaders(
    IOconfig,
    val_fold=val_fold,
    task=1,
    batch_size=32,
    disk_radius=11,
    do_augmentation=False,
    module="multiclass_detection",
    include_background_channel=has_background_channel
)

In [None]:
# thresholds = [0.3, 0.5, 0.7]
thresholds = [0.3]
best_thresh_lymph = thresholds[0]
best_thresh_mono = thresholds[0]
best_F1_lymph = 0.0
best_F1_mono = 0.0

visualization = False

for thresh in thresholds:
    print(f"threshold {thresh}")
    counter = 0
    sum_F1_lymph = []
    sum_precision_lymph = []
    sum_recall_lymph = []
    sum_F1_mono = []
    sum_precision_mono = []
    sum_recall_mono = []
    for data in tqdm(val_loader):
        file_ids = data["id"]
        images = data["image"].cuda().float()
        gt_masks = data["mask"].cuda().float()

        with torch.no_grad():
            out = model(images)
        
            if has_background_channel:
                out = torch.softmax(out, dim=1)
                lymphocyte_prob = out[:, 1, :, :]
                monocyte_prob = out[:, 2, :, :]
                out_pred = torch.argmax(out, dim=1)
                mask_pred_binary = torch.zeros_like(
                    out
                ).scatter_(1, out_pred.unsqueeze(1), 1.)
                lymphocyte_pred = mask_pred_binary[:,1,:,:]
                monocyte_pred = mask_pred_binary[:,2,:,:]
            else:
                out = torch.sigmoid(out)
                lymphocyte_prob = out[:, 0, :, :]
                monocyte_prob = out[:, 1, :, :]
                lymphocyte_pred = (lymphocyte_prob > thresh).float()
                monocyte_pred = (monocyte_prob >= thresh).float()

        if has_background_channel:
            lymph_metrics = get_patch_F1_score_batch(
                lymphocyte_pred, gt_masks[:, 1, :, :], out[:, 1, :, :]
            )
            mono_metrics = get_patch_F1_score_batch(
                monocyte_pred, gt_masks[:, 2, :, :], out[:, 2, :, :]
            )
        else:
            lymph_metrics = get_patch_F1_score_batch(
                lymphocyte_pred, gt_masks[:, 0, :, :], out[:, 0, :, :]
            )
            mono_metrics = get_patch_F1_score_batch(
                monocyte_pred, gt_masks[:, 1, :, :], out[:, 1, :, :]
            )

        sum_F1_lymph.append(lymph_metrics["F1"])
        sum_precision_lymph.append(lymph_metrics["Precision"])
        sum_recall_lymph.append(lymph_metrics["Recall"])

        
        sum_F1_mono.append(mono_metrics["F1"])
        sum_precision_mono.append(mono_metrics["Precision"])
        sum_recall_mono.append(mono_metrics["Recall"])

    sum_F1_lymph = [x for x in sum_F1_lymph if x is not None]
    sum_precision_lymph = [
        x for x in sum_precision_lymph if x is not None
    ]
    sum_recall_lymph = [x for x in sum_recall_lymph if x is not None]

    sum_F1_mono = [x for x in sum_F1_mono if x is not None]
    sum_precision_mono = [
        x for x in sum_precision_mono if x is not None
    ]
    sum_recall_mono = [x for x in sum_recall_mono if x is not None]

    print("Lymph F1 ", np.mean(sum_F1_lymph))
    print("Lymph Precision ", np.mean(sum_precision_lymph))
    print("Lymph Recall ", np.mean(sum_recall_lymph))

    print("Mono F1 ", np.mean(sum_F1_mono))
    print("Mono Precision ", np.mean(sum_precision_mono))
    print("Mono Recall ", np.mean(sum_recall_mono))

    if np.mean(sum_F1_lymph) > best_F1_lymph:
        best_F1_lymph = np.mean(sum_F1_lymph)
        best_thresh_lymph = thresh
    if np.mean(sum_F1_mono) > best_F1_mono:
        best_F1_mono = np.mean(sum_F1_mono)
        best_thresh_mono = thresh

print(f"best lymph threshold: {best_thresh_lymph}")
print(f"best mono threshold: {best_thresh_mono}")
# print(f"best F1: {best_F1}")

Visualize prediction

In [None]:
train_loader, val_loader = get_detection_dataloaders(
    IOconfig,
    val_fold=val_fold,
    task=1,
    batch_size=1,
    disk_radius=11,
    do_augmentation=False,
    module="multiclass_detection",
    include_background_channel=has_background_channel
)


lymph_thresh = 0.3
mono_thresh = 0.3
counter = 0
for data in val_loader:
    file_ids = data["id"]
    images = data["image"].cuda().float()
    gt_masks = data["mask"]

    with torch.no_grad():
        out = model(images)
        if has_background_channel:
            out = torch.softmax(out, dim=1)
            lymphocyte_prob = out[:, 1, :, :]
            monocyte_prob = out[:, 2, :, :]
            out_pred = torch.argmax(out, dim=1)
            mask_pred_binary = torch.zeros_like(
                out
            ).scatter_(1, out_pred.unsqueeze(1), 1.)
            lymphocyte_pred = mask_pred_binary[:,1,:,:]
            monocyte_pred = mask_pred_binary[:,2,:,:]
        else:
            out = torch.sigmoid(out)
            lymphocyte_prob = out[:, 0, :, :]
            monocyte_prob = out[:, 1, :, :]
            lymphocyte_pred = (lymphocyte_prob > thresh).float()
            monocyte_pred = (monocyte_prob >= thresh).float()

    lymphocyte_pred = lymphocyte_pred.numpy(force=True)[0]
    monocyte_pred = monocyte_pred.numpy(force=True)[0]

    image_np = images.numpy(force=True)[0]
    gt_mask_np = gt_masks.numpy(force=True)[0]
    if has_background_channel:
        gt_lymph = gt_mask_np[1]
        gt_mono = gt_mask_np[2]
    else:
        gt_lymph = gt_mask_np[0]
        gt_mono = gt_mask_np[1]
    image_np = np.moveaxis(image_np, 0, 2)
    image_np = imagenet_denormalise(image_np)

    fig, axs = plt.subplots(1, 5, figsize=(10, 10))
    axs[0].imshow(image_np)
    axs[0].title.set_text("Image")

    axs[1].imshow(gt_lymph, cmap="gray")
    axs[1].title.set_text("Ground Truth Lymph")

    axs[2].imshow(gt_mono, cmap="gray")
    axs[2].title.set_text("Ground Truth Mono")

    axs[3].imshow(lymphocyte_pred, cmap="gray")
    axs[3].title.set_text("Pred Lymph")

    axs[4].imshow(monocyte_pred, cmap="gray")
    axs[4].title.set_text("Pred Mono")

    for ax in fig.axes:
        ax.axis("off")
    plt.show()
    counter += 1
    if counter > 5:
        break