In [71]:
import numpy as np
import matplotlib.pyplot as plt
from monkey.model.efficientunetb0.architecture import (
    get_efficientunet_b0_MBConv,
)
import skimage
import cv2
import math
import torch
from monkey.config import TrainingIOConfig
from monkey.data.dataset import get_dataloaders
from monkey.data.data_utils import (
    imagenet_denormalise,
    load_json_annotation,
)
from tqdm import tqdm
import segmentation_models_pytorch as smp
import skimage
from evaluation.evaluate import match_coordinates

In [73]:
def get_cell_centers(cell_mask):
    mask_label = skimage.measure.label(cell_mask)
    stats = skimage.measure.regionprops(mask_label)
    centers = []
    for region in stats:
        centroid = region["centroid"]
        centers.append(centroid)
    return centers


def point_to_box(x, y, size):
    """Convert centerpoint to bounding box of fixed size"""
    return np.array([x - size, y - size, x + size, y + size])


def check_point_in_box(x, y, box):
    cond_1 = x >= box[0] and x <= box[2]
    cond_2 = y >= box[1] and y <= box[3]
    return cond_1 and cond_2


def check_point_in_circle(x, y, center_x, center_y, radius):
    dist = math.sqrt((center_x - x) ** 2 + (center_y - y) ** 2)
    # print(dist)
    return dist <= radius


SPACING_LEVEL0 = 0.24199951445730394


def evaluate_cell_predictions(gt_centers, pred_centers):
    tp_x_coords = []
    tp_y_coords = []

    # print(f"Total gt cells: {len(gt_centers)}")
    # print(f"Total pred cells: {len(pred_centers)}")
    tp = 0
    fp = 0
    fn = 0

    result_prob = [1.0 for i in range(len(pred_centers))]
    (
        tp,
        fn,
        fp,
        tp_probs,
        fp_probs,
    ) = match_coordinates(
        gt_centers,
        pred_centers,
        result_prob,
        int(7.5 / SPACING_LEVEL0),
    )

    try:
        precision = tp / (tp + fp)
    except ZeroDivisionError:
        precision = 1
    try:
        recall = tp / (tp + fn)
    except ZeroDivisionError:
        recall = 1
    try:
        f1 = (2 * precision * recall) / (precision + recall)
    except ZeroDivisionError:
        f1 = 1

    # print(f"True Positives: {tp}")
    # print(f"False Positives: {fp}")
    # print(f"False Negatives: {fn}")
    # print(f"f1 = {f1}")
    # print(f"precision = {precision}")
    # print(f"recall = {recall}")

    return tp_x_coords, tp_y_coords, f1, precision, recall


def erode_mask(mask):
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (3, 3))
    mask = cv2.erode(mask, kernel, iterations=1)
    return mask

In [83]:
model = get_efficientunet_b0_MBConv(pretrained=False)
# model = smp.Unet(
#     encoder_name="mit_b0",
#     encoder_weights=None,
#     decoder_attention_type="scse",
#     in_channels=3,
#     classes=1,
# )

val_fold = 4

checkpoint_path = f"/home/u1910100/Documents/Monkey/runs/efficientunetb0/fold_{val_fold}/epoch_100.pth"
checkpoint = torch.load(checkpoint_path)
model.load_state_dict(checkpoint["model"])
model.to("cuda")

EfficientUnet_MBConv(
  (encoder): Sequential(
    (0): Conv2dNormActivation(
      (0): Conv2d(3, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), bias=False)
      (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): SiLU(inplace=True)
    )
    (1): Sequential(
      (0): MBConv(
        (block): Sequential(
          (0): Conv2dNormActivation(
            (0): Conv2d(32, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), groups=32, bias=False)
            (1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
            (2): SiLU(inplace=True)
          )
          (1): SqueezeExcitation(
            (avgpool): AdaptiveAvgPool2d(output_size=1)
            (fc1): Conv2d(32, 8, kernel_size=(1, 1), stride=(1, 1))
            (fc2): Conv2d(8, 32, kernel_size=(1, 1), stride=(1, 1))
            (activation): SiLU(inplace=True)
            (scale_activation): Sigmoid()
          )
          (2): Conv2dNorm

In [84]:
IOconfig = TrainingIOConfig(
    dataset_dir="/home/u1910100/Documents/Monkey/patches_256",
    save_dir=f"./",
)

# Get dataloaders for task
train_loader, val_loader = get_dataloaders(
    IOconfig,
    val_fold=val_fold,
    task=1,
    batch_size=1,
    disk_radius=11,
    do_augmentation=False,
)

[4512, 10745]
train patches: 15257
test patches: 4898


In [85]:
# thresholds = [0.3, 0.5, 0.7, 0.9]
thresholds = [0.9]
best_thresh = thresholds[0]
best_F1 = 0

for thresh in thresholds:
    print(f"threshold {thresh}")
    counter = 0
    sum_F1 = []
    sum_precison = []
    sum_recall = []
    for data in tqdm(val_loader, leave=False):
        file_ids = data["id"]

        annotation_dict = load_json_annotation(file_ids[0], IOconfig)
        images = data["image"].cuda().float()
        gt_masks = data["mask"].cuda().float()

        image_np = images[0]
        image_np = image_np.cpu().numpy()
        gt_mask_np = gt_masks[0]
        gt_mask_np = gt_mask_np.cpu().numpy()

        image_np = np.moveaxis(image_np, 0, 2)
        image_np = imagenet_denormalise(image_np)

        # fig, axs = plt.subplots(1, 4, figsize=(10,10))
        # axs[0].imshow(image_np)
        # axs[0].title.set_text("Image")
        # axs[1].imshow(gt_mask_np[0], cmap='gray')
        # axs[1].title.set_text("Ground Truth")

        with torch.no_grad():
            out = model(images)
            out = torch.sigmoid(out)

        out = out.cpu().detach().numpy()[0][0]

        # axs[3].imshow(out, cmap='jet')

        out_mask = np.where(out >= thresh, 1, 0).astype(np.uint8)
        # out_mask = erode_mask(out_mask)

        out_mask = skimage.morphology.remove_small_objects(
            ar=out_mask, min_size=32
        )

        pred_centers = get_cell_centers(out_mask)
        true_centers = get_cell_centers(gt_mask_np[0])
        xs, ys, f1, precision, recall = evaluate_cell_predictions(
            true_centers, pred_centers
        )

        # axs[2].imshow(out_mask, cmap="gray")
        # axs[2].title.set_text("Prediction")

        # axs[0].scatter(ys, xs, alpha=0.7)
        # axs[1].scatter(ys, xs, alpha=0.5)
        # axs[2].scatter(ys, xs, alpha=0.5)

        # for ax in fig.axes:
        #     ax.axis("off")
        # plt.show()

        sum_F1.append(f1)
        sum_precison.append(precision)
        sum_recall.append(recall)

        # counter +=1
        # if counter > 20:
        #     break

    print("Avg F1 ", np.mean(sum_F1))
    print("Avg Precision ", np.mean(sum_precison))
    print("Avg Recall ", np.mean(sum_recall))

    if np.mean(sum_F1) > best_F1:
        best_F1 = np.mean(sum_F1)
        best_thresh = thresh

print(f"best threshold: {best_thresh}")
print(f"best F1: {best_F1}")

threshold 0.9


  out_mask = skimage.morphology.remove_small_objects(

                                                   

Avg F1  0.8012652902950135
Avg Precision  0.8355989993061247
Avg Recall  0.7644101807592207
best threshold: 0.9
best F1: 0.8012652902950135


