In [1]:
import numpy as np
import matplotlib.pyplot as plt
from monkey.model.efficientunetb0.architecture import (
    get_efficientunet_b0_MBConv,
)
import skimage
import cv2
import torch
from monkey.config import TrainingIOConfig
from monkey.data.dataset import get_dataloaders
from monkey.data.data_utils import (
    imagenet_denormalise,
    load_json_annotation,
)
from tqdm import tqdm
import segmentation_models_pytorch as smp
from monkey.model.utils import get_patch_F1_score

  warn(
  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()



In [2]:
def erode_mask(mask):
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    mask = cv2.erode(mask, kernel, iterations=1)
    return mask

In [3]:
model = get_efficientunet_b0_MBConv(pretrained=False)
# model = smp.Unet(
#     encoder_name="mit_b5",
#     encoder_weights=None,
#     decoder_attention_type="scse",
#     in_channels=3,
#     classes=1,
# )

val_fold = 1

checkpoint_path = f"/home/u1910100/Documents/Monkey/runs/efficientunetb0/fold_{val_fold}/epoch_100.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model"])
model.to("cuda")

IOconfig = TrainingIOConfig(
    dataset_dir="/home/u1910100/Documents/Monkey/patches_256",
    save_dir=f"./",
)

# Get dataloaders for task
train_loader, val_loader = get_dataloaders(
    IOconfig,
    val_fold=val_fold,
    task=1,
    batch_size=1,
    disk_radius=11,
    do_augmentation=False,
)

[4889, 10051]
train patches: 14940
test patches: 5215


In [4]:
thresholds = [0.3, 0.5, 0.7, 0.9]
# thresholds = [0.9]
best_thresh = thresholds[0]
best_F1 = 0

for thresh in thresholds:
    print(f"threshold {thresh}")
    counter = 0
    sum_F1 = []
    sum_precison = []
    sum_recall = []
    for data in tqdm(val_loader):
        file_ids = data["id"]
        # print(file_ids)
        annotation_dict = load_json_annotation(file_ids[0], IOconfig)
        lymphocyte_coords = annotation_dict["lymphocytes"]
        monocyte_coords = annotation_dict["monocytes"]
        overall_coords = []
        overall_coords.extend(lymphocyte_coords)
        overall_coords.extend(monocyte_coords)

        images = data["image"].cuda().float()
        gt_masks = data["mask"]

        image_np = images[0]
        image_np = image_np.cpu().numpy()
        gt_mask_np = gt_masks[0]
        gt_mask_np = gt_mask_np.cpu().numpy()

        image_np = np.moveaxis(image_np, 0, 2)
        image_np = imagenet_denormalise(image_np)

        # fig, axs = plt.subplots(1, 4, figsize=(10,10))
        # axs[0].imshow(image_np)
        # axs[0].title.set_text("Image")
        # axs[1].imshow(gt_mask_np[0], cmap='gray')
        # axs[1].title.set_text("Ground Truth")

        with torch.no_grad():
            out = model(images)
            out = torch.sigmoid(out)
            # out = torch.relu(out)

        out = out.cpu().detach().numpy()[0][0]

        # axs[3].imshow(out, cmap='jet')

        out_mask = np.where(out >= thresh, 1, 0)
        # # out_mask = erode_mask(out_mask)

        out_mask = skimage.morphology.remove_small_objects(
            ar=out_mask, min_size=32
        )
        out_mask = out_mask.astype(np.uint8)

        metrics = get_patch_F1_score(out_mask, gt_mask_np[0], out)
        f1, precision, recall = (
            metrics["F1"],
            metrics["Precision"],
            metrics["Recall"],
        )

        # axs[2].imshow(out_mask, cmap="gray")
        # axs[2].title.set_text("Prediction")

        # for ax in fig.axes:
        #     ax.axis("off")
        # plt.show()

        sum_F1.append(f1)
        sum_precison.append(precision)
        sum_recall.append(recall)

        # counter +=1
        # if counter > 20:
        #     break

    sum_F1 = [x for x in sum_F1 if x is not None]
    sum_precison = [x for x in sum_precison if x is not None]
    sum_recall = [x for x in sum_recall if x is not None]

    print("Avg F1 ", np.mean(sum_F1))
    print("Median F1 ", np.median(sum_F1))
    print("Avg Precision ", np.mean(sum_precison))
    print("Avg Recall ", np.mean(sum_recall))

    if np.mean(sum_F1) > best_F1:
        best_F1 = np.mean(sum_F1)
        best_thresh = thresh

print(f"best threshold: {best_thresh}")
print(f"best F1: {best_F1}")

threshold 0.3


  out_mask = skimage.morphology.remove_small_objects(

100%|██████████| 5215/5215 [01:08<00:00, 76.36it/s]


Avg F1  0.38854861171068533
Median F1  0.46153846153846156
Avg Precision  0.4318193696339317
Avg Recall  0.4268606708233473
threshold 0.5


100%|██████████| 5215/5215 [01:06<00:00, 78.11it/s]


Avg F1  0.383871991152216
Median F1  0.45161290322580644
Avg Precision  0.43562698485630075
Avg Recall  0.41355695917167185
threshold 0.7


100%|██████████| 5215/5215 [01:07<00:00, 77.25it/s]


Avg F1  0.376449108022764
Median F1  0.4444444444444444
Avg Precision  0.43699779074247375
Avg Recall  0.3978741489193381
threshold 0.9


100%|██████████| 5215/5215 [01:07<00:00, 77.44it/s]

Avg F1  0.3605555780744407
Median F1  0.4
Avg Precision  0.4432272956782587
Avg Recall  0.3664550338757106
best threshold: 0.3
best F1: 0.38854861171068533



