In [29]:
from monkey.data.data_utils import extract_dotmaps
import numpy as np
import matplotlib.pyplot as plt
from monkey.model.efficientunetb0.architecture import (
    get_efficientunet_b0_MBConv,
)
import skimage
import cv2
import math
import torch
from monkey.config import TrainingIOConfig
from monkey.data.dataset import get_dataloaders
from monkey.data.data_utils import (
    imagenet_denormalise,
    extract_dotmaps,
    load_json_annotation,
)
from evaluation.evaluation import mask_based_evaluate
from tqdm import tqdm

In [30]:
def get_cell_centers(cell_mask):
    mask_label = skimage.measure.label(cell_mask)
    stats = skimage.measure.regionprops(mask_label)
    centers = []
    for region in stats:
        centroid = region["centroid"]
        centers.append(centroid)
    return centers


def point_to_box(x, y, size):
    """Convert centerpoint to bounding box of fixed size"""
    return np.array([x - size, y - size, x + size, y + size])


def check_point_in_box(x, y, box):
    cond_1 = x >= box[0] and x <= box[2]
    cond_2 = y >= box[1] and y <= box[3]
    return cond_1 and cond_2


def check_point_in_circle(x, y, center_x, center_y, radius):
    dist = math.sqrt((center_x - x) ** 2 + (center_y - y) ** 2)
    # print(dist)
    return dist <= radius


def evaluate_cell_predictions(gt_centers, pred_centers):
    tp_x_coords = []
    tp_y_coords = []

    # print(f"Total gt cells: {len(gt_centers)}")
    # print(f"Total pred cells: {len(pred_centers)}")
    tp = 0
    fp = 0
    fn = 0

    for j, pred_center in enumerate(pred_centers):
        for i, true_center in enumerate(gt_centers):
            # true_box = point_to_box(true_center[0], true_center[1], 8)
            # if check_point_in_box(pred_center[0], pred_center[1], true_box):
            if check_point_in_circle(
                pred_center[0],
                pred_center[1],
                true_center[0],
                true_center[1],
                8,
            ):
                tp_x_coords.append(pred_center[0])
                tp_y_coords.append(pred_center[1])
                tp += 1
                del gt_centers[i]
                del pred_centers[j]
                break

    for i, true_center in enumerate(gt_centers):
        for j, pred_center in enumerate(pred_centers):
            # true_box = point_to_box(true_center[0], true_center[1], 8)
            # if check_point_in_box(pred_center[0], pred_center[1], true_box):
            if check_point_in_circle(
                pred_center[0],
                pred_center[1],
                true_center[0],
                true_center[1],
                8,
            ):
                tp_x_coords.append(pred_center[0])
                tp_y_coords.append(pred_center[1])
                tp += 1
                del gt_centers[i]
                del pred_centers[j]
                break

    fn = len(gt_centers)
    fp = len(pred_centers)

    try:
        precision = tp / (tp + fp)
    except ZeroDivisionError:
        precision = 1
    try:
        recall = tp / (tp + fn)
    except ZeroDivisionError:
        recall = 1
    try:
        f1 = (2 * precision * recall) / (precision + recall)
    except ZeroDivisionError:
        f1 = 1

    # print(f"True Positives: {tp}")
    # print(f"False Positives: {fp}")
    # print(f"False Negatives: {fn}")
    # print(f"f1 = {f1}")
    # print(f"precision = {precision}")
    # print(f"recall = {recall}")

    return tp_x_coords, tp_y_coords, f1, precision, recall


def erode_mask(mask):
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    mask = cv2.erode(mask, kernel, iterations=1)
    return mask

In [None]:
model = get_efficientunet_b0_MBConv(pretrained=False)
val_fold = 1

checkpoint_path = (
    "/home/u1910100/Documents/Monkey/runs/base/fold_1/epoch_100.pth"
)
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model"])
model.to("cuda")

In [24]:
IOconfig = TrainingIOConfig(
    dataset_dir="/home/u1910100/Documents/Monkey/patches_256",
    save_dir=f"./",
)

# Get dataloaders for task
train_loader, val_loader = get_dataloaders(
    IOconfig,
    val_fold=1,
    task=1,
    batch_size=1,
    disk_radius=13,
    do_augmentation=False,
)

[3951, 7731]


In [31]:
counter = 0
sum_F1 = []
sum_precison = []
sum_recall = []
for data in tqdm(val_loader, leave=False):
    file_ids = data["id"]

    annotation_dict = load_json_annotation(file_ids[0], IOconfig)
    images = data["image"].cuda().float()
    gt_masks = data["mask"].cuda().float()

    image_np = images[0]
    image_np = image_np.cpu().numpy()
    gt_mask_np = gt_masks[0]
    gt_mask_np = gt_mask_np.cpu().numpy()

    gt_mask_np = erode_mask(gt_mask_np)

    image_np = np.moveaxis(image_np, 0, 2)
    image_np = imagenet_denormalise(image_np)
    # fig, axs = plt.subplots(1, 4, figsize=(10,10))
    # axs[0].imshow(image_np)
    # axs[0].title.set_text("Image")
    # axs[1].imshow(gt_mask_np[0], cmap='gray')
    # axs[1].title.set_text("Ground Truth")
    # plt.show()

    with torch.no_grad():
        out = model(images)
        out = torch.sigmoid(out)

    out = out.cpu().detach().numpy()[0][0]
    out_mask = np.where(out >= 0.9, 1, 0).astype(np.uint8)

    pred_centers = get_cell_centers(out_mask)
    true_centers = get_cell_centers(gt_mask_np[0])
    xs, ys, f1, precision, recall = evaluate_cell_predictions(
        true_centers, pred_centers
    )

    # axs[2].imshow(out_mask, cmap="gray")
    # axs[2].title.set_text("Prediction")

    # axs[3].imshow(image_np)

    # axs[0].scatter(ys, xs, alpha=0.7)
    # axs[1].scatter(ys, xs, alpha=0.5)
    # axs[2].scatter(ys, xs, alpha=0.5)

    # for ax in fig.axes:
    #     ax.axis("off")
    # plt.show()

    sum_F1.append(f1)
    sum_precison.append(precision)
    sum_recall.append(recall)

    # counter +=1
    # if counter > 20:
    #     break

print("Avg F1 ", np.mean(sum_F1))
print("Avg Precision ", np.mean(sum_precison))
print("Avg Recall ", np.mean(sum_recall))

                                                   

Avg F1  0.38525558154748907
Avg Precision  0.3569386047369017
Avg Recall  0.6754249875600906


