In [1]:
from data import get_cell_dataloaders
import torch
from torch.utils.data import random_split, ConcatDataset
import segmentation_models_pytorch as smp
import numpy as np
import os
from train import train_model
from loss_functions import BCE_Dice_Loss
import math
from tqdm import tqdm
import matplotlib.pyplot as plt

In [17]:
FOLD_DIR = "/home/u1910100/Documents/Tiger_Data/cell_detection/dilation/patches/128/"

test_fold = 5
print(f"Testing using fold {test_fold}")
train_loader, test_loader = get_cell_dataloaders(
    FOLD_DIR, fold_num=test_fold, batch_size=1, phase="Test"
)

print(len(test_loader.sampler))

Testing using fold 5
Train files: 17004
Test files: 3692
3692


In [11]:
def imagenet_denormalise(img):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_normal = (img * std) + mean
    return img_normal

In [18]:
model = smp.Unet(
    encoder_name="efficientnet-b0",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights=None,  # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=1,  # model output channels (number of classes in your dataset)
)

model.load_state_dict(
    torch.load("/home/u1910100/GitHub/TIAger-Torch/runs/cell/fold_5/model_22.pth")
)

model.to("cuda")
model.eval()

Unet(
  (encoder): EfficientNetEncoder(
    (_conv_stem): Conv2dStaticSamePadding(
      3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d((0, 1, 0, 1))
    )
    (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          32, 8, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          8, 32, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding

In [19]:
thresholds = [0.5]

for threshold in thresholds:
    sum_IOU = []
    sum_F1 = []
    print(f"threshold = {threshold}")
    for batch in tqdm(test_loader):
        # batch = next(iter(validation_loader))
        imgs, masks = batch["img"], batch["mask"]

        img = batch["img"][0].numpy(force=True)
        img = np.moveaxis(img, 0, 2)
        img_nromal = imagenet_denormalise(img)

        # rgb_mask = mask_to_rgb(masks[0][0])

        # fig, axs = plt.subplots(1, 3, figsize=(10, 10))
        # axs[0].imshow(img_nromal)
        # axs[0].title.set_text("Image")

        # axs[1].imshow(masks[0][0], cmap='gray')
        # axs[1].title.set_text("Ground Truth")

        imgs = imgs.to("cuda").float()
        masks = masks.to("cuda").long()
        with torch.no_grad():
            out = model(imgs)
            out = torch.sigmoid(out)

        tp, fp, fn, tn = smp.metrics.get_stats(
            out, masks, mode="binary", threshold=threshold
        )
        iou_score = smp.metrics.iou_score(
            tp, fp, fn, tn, reduction="micro", zero_division=0
        )
        f1_score = smp.metrics.f1_score(
            tp, fp, fn, tn, reduction="micro", zero_division=0
        )

        iou_score = iou_score.cpu().detach().item()
        f1_score = f1_score.cpu().detach().item()

        if math.isnan(iou_score):
            iou_score = 1
        if math.isnan(f1_score):
            f1_score = 1

        # out = out.cpu().detach()
        # out_mask = out[0]

        # axs[2].imshow(out_mask[0], cmap="plasma")
        # axs[2].title.set_text("Prediction")

        # for ax in fig.axes:
        #     ax.axis("off")
        # plt.show()

        # print(f"IOU (Jaccard): {iou_score}")
        # print(f"F1 (Dice): {f1_score}")

        sum_IOU.append(iou_score)
        sum_F1.append(f1_score)

    # sum_IOU = np.array(sum_IOU)
    # sum_F1 = np.array(sum_F1)
    print("-------------")
    print("Avg Jaccard ", np.mean(sum_IOU))
    print("Avg Dice ", np.mean(sum_F1))

threshold = 0.5


100%|██████████| 3692/3692 [00:18<00:00, 200.52it/s]

-------------
Avg Jaccard  0.5073793981009126
Avg Dice  0.5426765745500646



