In [86]:
from data import TiagerDataset, TrainDataset, TestDataset
from torch.utils.data import DataLoader
import torch
from torch.utils.data import random_split, ConcatDataset
import segmentation_models_pytorch as smp
import numpy as np
import os
from train import train_model
from loss_functions import BCE_Dice_Loss
import math
from tqdm import tqdm

In [103]:
BATCH_SIZE = 1


FOLD_DIR = "/home/u1910100/Documents/Tiger_Data/tissue_segmentation/patches/512/"
test_fold = 5
fold_dir = os.path.join(FOLD_DIR, f"fold_{test_fold}")
dataset = TiagerDataset(fold_dir)
# datasets = []
# for i in range(0, 5):
#     fold_dir = os.path.join(FOLD_DIR, f"fold_{i+1}")
#     dataset = TiagerDataset(fold_dir)
#     datasets.append(dataset)


test_dataset = TestDataset(dataset)
validation_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

print(len(test_dataset))

3152


In [None]:
def imagenet_denormalise(img):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_normal = (img * std) + mean
    return img_normal


def mask_to_rgb(mask):
    rgb_mask = np.zeros(shape=(mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    rgb_mask[mask == 1] = [255, 0, 0]
    rgb_mask[mask == 2] = [0, 255, 0]
    return rgb_mask

In [None]:
batch = next(iter(validation_loader))
img, mask = batch["img"][0], batch["mask"][0][0]
# print(img.size())
# print(mask.size())

import matplotlib.pyplot as plt

img = batch["img"][0].numpy(force=True)
img = np.moveaxis(img, 0, 2)
img = imagenet_denormalise(img)

plt.imshow(img)
plt.show()

rgb_mask = np.zeros(shape=(mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
rgb_mask[mask == 1] = [255, 0, 0]
rgb_mask[mask == 2] = [0, 255, 0]

plt.imshow(rgb_mask)
plt.show()

In [104]:
model = smp.Unet(
    encoder_name="efficientnet-b0",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights=None,  # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,  # model output channels (number of classes in your dataset)
)

model.load_state_dict(
    torch.load("/home/u1910100/GitHub/TIAger-Torch/runs/tissue/fold_5/model_58.pth")
)

model.to("cuda")
model.eval()

Unet(
  (encoder): EfficientNetEncoder(
    (_conv_stem): Conv2dStaticSamePadding(
      3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d((0, 1, 0, 1))
    )
    (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          32, 8, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          8, 32, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding

In [105]:
sum_IOU = []
sum_F1 = []

for batch in tqdm(validation_loader):
    # batch = next(iter(validation_loader))
    imgs, masks = batch["img"], batch["mask"]

    # img = batch['img'][0].numpy(force=True)
    # img = np.moveaxis(img, 0, 2)
    # img_nromal = imagenet_denormalise(img)

    # rgb_mask = mask_to_rgb(masks[0][0])

    # fig, axs = plt.subplots(1, 3, figsize=(10, 10))
    # plt.subplot(1,3,1)
    # axs[0].imshow(img_nromal)
    # axs[0].title.set_text("Image")

    # plt.subplot(1,3,2)
    # axs[1].imshow(rgb_mask)
    # axs[1].title.set_text("Ground Truth")

    imgs = imgs.to("cuda").float()
    masks = masks.to("cuda").long()
    with torch.no_grad():
        out = model(imgs)
        out = torch.nn.functional.softmax(out, dim=1)
        out = torch.argmax(out, dim=1, keepdim=True)
    # print(out.size())

    # out_mask = out[0][0].cpu().detach()
    # print(out_mask.size())
    # labels = out_mask.numpy()

    # rgb_labels = mask_to_rgb(labels)

    tp, fp, fn, tn = smp.metrics.get_stats(out, masks, mode="multiclass", num_classes=3)
    iou_score = smp.metrics.iou_score(
        tp, fp, fn, tn, reduction="micro", zero_division=0
    )
    f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="micro", zero_division=0)

    iou_score = iou_score.cpu().detach().item()
    f1_score = f1_score.cpu().detach().item()

    if math.isnan(iou_score):
        iou_score = 1
    if math.isnan(f1_score):
        f1_score = 1

    # axs[2].imshow(rgb_labels)
    # axs[2].title.set_text("Prediction")

    # for ax in fig.axes:
    #     ax.axis("off")
    # plt.show()

    # print(f"IOU (Jaccard): {iou_score}")
    # print(f"F1 (Dice): {f1_score}")

    sum_IOU.append(iou_score)
    sum_F1.append(f1_score)


# sum_IOU = np.array(sum_IOU)
# sum_F1 = np.array(sum_F1)
print("-------------")
print("Avg Jaccard ", np.mean(sum_IOU), np.std(sum_IOU))
print("Avg Dice ", np.mean(sum_F1), np.std(sum_F1))

100%|██████████| 3152/3152 [01:03<00:00, 49.57it/s]

-------------
Avg Jaccard  0.6944902769794538 0.23533202326788236
Avg Dice  0.7929415944869143 0.196055166152175



