In [1]:
import numpy as np
import matplotlib.pyplot as plt
from monkey.model.efficientunetb0.architecture import (
    get_efficientunet_b0_MBConv,
)
import skimage
import cv2
import math
import torch
from monkey.config import TrainingIOConfig
from monkey.data.dataset import get_dataloaders
from monkey.data.data_utils import (
    imagenet_denormalise,
    load_json_annotation,
)
from tqdm import tqdm
import segmentation_models_pytorch as smp
import skimage
from evaluation.evaluate import match_coordinates

  warn(
  @numba.jit()

  @numba.jit()

  @numba.jit()

  @numba.jit()



In [2]:
def get_cell_centers(cell_mask):
    mask_label = skimage.measure.label(cell_mask)
    stats = skimage.measure.regionprops(mask_label)
    centers = []
    for region in stats:
        centroid = region["centroid"]
        centers.append(centroid)
    return centers


SPACING_LEVEL0 = 0.24199951445730394


def evaluate_cell_predictions(gt_centers, pred_centers):
    tp_x_coords = []
    tp_y_coords = []

    # print(f"Total gt cells: {len(gt_centers)}")
    # print(f"Total pred cells: {len(pred_centers)}")
    tp = 0
    fp = 0
    fn = 0

    result_prob = [1.0 for i in range(len(pred_centers))]
    (
        tp,
        fn,
        fp,
        tp_probs,
        fp_probs,
    ) = match_coordinates(
        gt_centers,
        pred_centers,
        result_prob,
        int(7.5 / SPACING_LEVEL0),
    )

    try:
        precision = tp / (tp + fp)
    except ZeroDivisionError:
        precision = 1
    try:
        recall = tp / (tp + fn)
    except ZeroDivisionError:
        recall = 1
    try:
        f1 = (2 * precision * recall) / (precision + recall)
    except ZeroDivisionError:
        f1 = 1

    # print(f"True Positives: {tp}")
    # print(f"False Positives: {fp}")
    # print(f"False Negatives: {fn}")
    # print(f"f1 = {f1}")
    # print(f"precision = {precision}")
    # print(f"recall = {recall}")

    return tp_x_coords, tp_y_coords, f1, precision, recall


def erode_mask(mask):
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    mask = cv2.erode(mask, kernel, iterations=1)
    return mask

In [5]:
# model = get_efficientunet_b0_MBConv(pretrained=False)
model = smp.Unet(
    encoder_name="mit_b5",
    encoder_weights=None,
    decoder_attention_type="scse",
    in_channels=3,
    classes=1,
)

val_fold = 3

checkpoint_path = f"/home/u1910100/Documents/Monkey/runs/MiTB5Unet/fold_{val_fold}/epoch_75.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model"])
model.to("cuda")

IOconfig = TrainingIOConfig(
    dataset_dir="/home/u1910100/Documents/Monkey/patches_256",
    save_dir=f"./",
)

# Get dataloaders for task
train_loader, val_loader = get_dataloaders(
    IOconfig,
    val_fold=val_fold,
    task=1,
    batch_size=1,
    disk_radius=11,
    do_augmentation=False,
)

[5425, 10368]
train patches: 15793
test patches: 4362


In [6]:
thresholds = [0.3, 0.5, 0.7, 0.9]
# thresholds = [0.9]
best_thresh = thresholds[0]
best_F1 = 0

for thresh in thresholds:
    print(f"threshold {thresh}")
    counter = 0
    sum_F1 = []
    sum_precison = []
    sum_recall = []
    for data in tqdm(val_loader):
        file_ids = data["id"]

        annotation_dict = load_json_annotation(file_ids[0], IOconfig)
        images = data["image"].cuda().float()
        gt_masks = data["mask"].cuda().float()

        image_np = images[0]
        image_np = image_np.cpu().numpy()
        gt_mask_np = gt_masks[0]
        gt_mask_np = gt_mask_np.cpu().numpy()

        image_np = np.moveaxis(image_np, 0, 2)
        image_np = imagenet_denormalise(image_np)

        # fig, axs = plt.subplots(1, 4, figsize=(10,10))
        # axs[0].imshow(image_np)
        # axs[0].title.set_text("Image")
        # axs[1].imshow(gt_mask_np[0], cmap='gray')
        # axs[1].title.set_text("Ground Truth")

        with torch.no_grad():
            out = model(images)
            out = torch.sigmoid(out)

        out = out.cpu().detach().numpy()[0][0]

        # axs[3].imshow(out, cmap='jet')

        out_mask = np.where(out >= thresh, 1, 0).astype(np.uint8)
        # out_mask = erode_mask(out_mask)

        out_mask = skimage.morphology.remove_small_objects(
            ar=out_mask, min_size=32
        )

        pred_centers = get_cell_centers(out_mask)
        true_centers = get_cell_centers(gt_mask_np[0])
        xs, ys, f1, precision, recall = evaluate_cell_predictions(
            true_centers, pred_centers
        )

        # axs[2].imshow(out_mask, cmap="gray")
        # axs[2].title.set_text("Prediction")

        # axs[0].scatter(ys, xs, alpha=0.7)
        # axs[1].scatter(ys, xs, alpha=0.5)
        # axs[2].scatter(ys, xs, alpha=0.5)

        # for ax in fig.axes:
        #     ax.axis("off")
        # plt.show()

        sum_F1.append(f1)
        sum_precison.append(precision)
        sum_recall.append(recall)

        # counter +=1
        # if counter > 20:
        #     break

    print("Avg F1 ", np.mean(sum_F1))
    print("Avg Precision ", np.mean(sum_precison))
    print("Avg Recall ", np.mean(sum_recall))

    if np.mean(sum_F1) > best_F1:
        best_F1 = np.mean(sum_F1)
        best_thresh = thresh

print(f"best threshold: {best_thresh}")
print(f"best F1: {best_F1}")

threshold 0.3


  out_mask = skimage.morphology.remove_small_objects(

100%|██████████| 4362/4362 [01:43<00:00, 41.95it/s]


Avg F1  0.7980072834525814
Avg Precision  0.7950164731238576
Avg Recall  0.8170303902353496
threshold 0.5


100%|██████████| 4362/4362 [01:43<00:00, 42.10it/s]


Avg F1  0.7998625105586832
Avg Precision  0.8040467246045137
Avg Recall  0.8111176252940094
threshold 0.7


100%|██████████| 4362/4362 [01:43<00:00, 42.00it/s]


Avg F1  0.8027994492485725
Avg Precision  0.814152642183286
Avg Recall  0.8055671614262554
threshold 0.9


100%|██████████| 4362/4362 [01:44<00:00, 41.80it/s]

Avg F1  0.8076487814893856
Avg Precision  0.8411933154144587
Avg Recall  0.7894643897558292
best threshold: 0.9
best F1: 0.8076487814893856



