In [1]:
from data import TiagerDataset, TrainDataset, TestDataset
from torch.utils.data import DataLoader
import torch
from torch.utils.data import random_split, ConcatDataset
import segmentation_models_pytorch as smp
import numpy as np
import os
from train import train_model
from loss_functions import BCE_Dice_Loss
import math
from tqdm import tqdm

import matplotlib.pyplot as plt
import cv2

In [2]:
BATCH_SIZE = 1


FOLD_DIR = "/media/u1910100/Extreme SSD/data/tiger/tissue_segmentation/patches/512"
test_fold = 1
fold_dir = os.path.join(FOLD_DIR, f"fold_{test_fold}")
dataset = TiagerDataset(fold_dir)
# datasets = []
# for i in range(0, 5):
#     fold_dir = os.path.join(FOLD_DIR, f"fold_{i+1}")
#     dataset = TiagerDataset(fold_dir)
#     datasets.append(dataset)


test_dataset = TestDataset(dataset)
validation_loader = DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=True)

print(len(test_dataset))

3525


In [3]:
def imagenet_denormalise(img):
    mean = np.array([0.485, 0.456, 0.406])
    std = np.array([0.229, 0.224, 0.225])
    img_normal = (img * std) + mean
    return img_normal


def mask_to_rgb(mask, classes="all"):
    rgb_mask = np.zeros(shape=(mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
    if classes == "all":
        rgb_mask[mask == 1] = [255, 0, 0]
        rgb_mask[mask == 2] = [0, 255, 0]
    elif classes == "stroma":
        rgb_mask[mask == 2] = [0, 255, 0]
    elif classes == "tumor":
        rgb_mask[mask == 1] = [255, 0, 0]
    return rgb_mask


def smooth_mask(mask):
    kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (8, 8))
    closing = cv2.morphologyEx(mask, cv2.MORPH_CLOSE, kernel)
    opening = cv2.morphologyEx(closing, cv2.MORPH_OPEN, kernel)
    return opening

In [None]:
# batch = next(iter(validation_loader))
# img, mask = batch["img"][0], batch["mask"][0][0]

# fig, axes = plt.subplots(1, 2, figsize=(10, 10))

# img = batch["img"][0].numpy(force=True)
# img = np.moveaxis(img, 0, 2)
# img = imagenet_denormalise(img)

# axes[0].imshow(img)

# rgb_mask = np.zeros(shape=(mask.shape[0], mask.shape[1], 3), dtype=np.uint8)
# rgb_mask[mask == 1] = [255, 0, 0]
# rgb_mask[mask == 2] = [0, 255, 0]

# axes[1].imshow(rgb_mask)
# plt.show()

In [4]:
model = smp.Unet(
    encoder_name="efficientnet-b0",  # choose encoder, e.g. mobilenet_v2 or efficientnet-b7
    encoder_weights=None,  # use `imagenet` pre-trained weights for encoder initialization
    in_channels=3,  # model input channels (1 for gray-scale images, 3 for RGB, etc.)
    classes=3,  # model output channels (number of classes in your dataset)
)

model.load_state_dict(
    torch.load("/home/u1910100/GitHub/TIAger-Torch/runs/tissue/weights/tissue_1.pth")
)

model.to("cuda")
model.eval()

Unet(
  (encoder): EfficientNetEncoder(
    (_conv_stem): Conv2dStaticSamePadding(
      3, 32, kernel_size=(3, 3), stride=(2, 2), bias=False
      (static_padding): ZeroPad2d((0, 1, 0, 1))
    )
    (_bn0): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
    (_blocks): ModuleList(
      (0): MBConvBlock(
        (_depthwise_conv): Conv2dStaticSamePadding(
          32, 32, kernel_size=(3, 3), stride=[1, 1], groups=32, bias=False
          (static_padding): ZeroPad2d((1, 1, 1, 1))
        )
        (_bn1): BatchNorm2d(32, eps=0.001, momentum=0.010000000000000009, affine=True, track_running_stats=True)
        (_se_reduce): Conv2dStaticSamePadding(
          32, 8, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_se_expand): Conv2dStaticSamePadding(
          8, 32, kernel_size=(1, 1), stride=(1, 1)
          (static_padding): Identity()
        )
        (_project_conv): Conv2dStaticSamePadding

In [13]:
sum_IOU = []
sum_F1 = []

# jaccard_loss_fn = smp.losses.JaccardLoss(mode="multiclass", from_logits=False)
# dice_loss_fn = smp.losses.DiceLoss(mode="multiclass", from_logits=False)
# ce_loss_fn = smp.losses.SoftCrossEntropyLoss(smooth_factor=0.1)
threshold = 0.9

# for i in range(0, 10):
for batch in tqdm(validation_loader):
    batch = next(iter(validation_loader))
    imgs, masks = batch["img"], batch["mask"]

    # img = batch["img"][0].numpy(force=True)
    # img = np.moveaxis(img, 0, 2)
    # img_nromal = imagenet_denormalise(img)
    # rgb_mask = mask_to_rgb(masks[0][0])
    # fig, axs = plt.subplots(1, 3, figsize=(10, 10))
    # axs[0].imshow(img_nromal)
    # axs[0].title.set_text("Image")
    # axs[1].imshow(rgb_mask)
    # axs[1].title.set_text("Ground Truth")

    imgs = imgs.to("cuda").float()
    with torch.no_grad():
        out = model(imgs)
        out = torch.nn.functional.softmax(out, dim=1)

    out = out.cpu().detach().numpy()
    stroma_mask = out[0][2]
    stroma_mask[stroma_mask < threshold] = 0
    stroma_mask[stroma_mask >= threshold] = 1
    # stroma_mask = smooth_mask(stroma_mask)
    tumor_mask = out[0][1]
    tumor_mask[tumor_mask < threshold] = 0
    tumor_mask[tumor_mask >= threshold] = 1
    # tumor_mask = smooth_mask(tumor_mask)

    combined_mask = np.zeros((512, 512), dtype=np.uint8)
    combined_mask[np.where(stroma_mask == 1)] = 2
    combined_mask[np.where(tumor_mask == 1)] = 1
    combined_mask = smooth_mask(combined_mask)
    # rgb_labels = mask_to_rgb(combined_mask)

    combined_mask = combined_mask[
        np.newaxis,
        np.newaxis,
        :,
        :,
    ]
    combined_mask = torch.from_numpy(combined_mask)

    tp, fp, fn, tn = smp.metrics.get_stats(
        combined_mask, masks, mode="multiclass", num_classes=3
    )
    iou_score = smp.metrics.iou_score(
        tp, fp, fn, tn, reduction="macro", zero_division=1
    )
    f1_score = smp.metrics.f1_score(tp, fp, fn, tn, reduction="macro", zero_division=1)

    iou_score = iou_score.cpu().detach().item()
    f1_score = f1_score.cpu().detach().item()

    if math.isnan(iou_score):
        iou_score = 1
    if math.isnan(f1_score):
        f1_score = 1

    # axs[2].imshow(rgb_labels)
    # axs[2].title.set_text("Prediction")

    # for ax in fig.axes:
    #     ax.axis("off")
    # plt.show()

    # print(f"IOU (Jaccard): {iou_score}")
    # print(f"F1 (Dice): {f1_score}")
    # print(f"Jaccard loss {jaccard_loss}")
    # print(f"Dice loss {dice_loss}")
    # print(f"CE loss {ce_loss}")

    sum_IOU.append(iou_score)
    sum_F1.append(f1_score)


# sum_IOU = np.array(sum_IOU)
# sum_F1 = np.array(sum_F1)
print("-------------")
print("Avg Jaccard ", np.mean(sum_IOU), np.std(sum_IOU))
print("Avg Dice ", np.mean(sum_F1), np.std(sum_F1))

100%|██████████| 3525/3525 [03:05<00:00, 18.96it/s]

-------------
Avg Jaccard  0.682157618155082 0.1951810801760356
Avg Dice  0.742592840441998 0.19388481980979286





In [None]:
# Evaluate stroma dice

sum_IOU = []
sum_F1 = []
# thresholds = [0.3, 0.5, 0.7, 0.9]
thresholds = [0.9]

for threshold in thresholds:
    print(f"threshold = {threshold}")
    for i in range(0, 10):
        batch = next(iter(validation_loader))
        # for batch in tqdm(validation_loader):
        imgs, masks = batch["img"], batch["mask"]

        img = batch["img"][0].numpy(force=True)
        img = np.moveaxis(img, 0, 2)
        img_nromal = imagenet_denormalise(img)
        rgb_mask = mask_to_rgb(masks[0][0], classes="stroma")
        fig, axs = plt.subplots(1, 3, figsize=(10, 10))
        axs[0].imshow(img_nromal)
        axs[0].title.set_text("Image")
        axs[1].imshow(rgb_mask)
        axs[1].title.set_text("Ground Truth")

        imgs = imgs.to("cuda").float()
        with torch.no_grad():
            out = model(imgs)
            out = torch.nn.functional.softmax(out, dim=1)

        out = out.cpu().detach().numpy()
        stroma_mask = out[0][2]
        stroma_mask[stroma_mask < threshold] = 0
        stroma_mask[stroma_mask >= threshold] = 1
        stroma_mask = smooth_mask(stroma_mask)

        axs[2].imshow(stroma_mask)
        axs[2].title.set_text("Prediction")
        for ax in fig.axes:
            ax.axis("off")
        plt.show()

        stroma_mask = stroma_mask[
            np.newaxis,
            np.newaxis,
            :,
            :,
        ]
        stroma_mask = stroma_mask.astype(np.uint8)
        stroma_mask = torch.from_numpy(stroma_mask)
        masks[masks != 2] = 0
        masks[masks == 2] = 1
        masks = masks.cpu().detach().numpy()
        masks[0][0] = smooth_mask(masks[0][0])
        masks = torch.from_numpy(masks)

        tp, fp, fn, tn = smp.metrics.get_stats(stroma_mask, masks, mode="binary")
        iou_score = smp.metrics.iou_score(
            tp, fp, fn, tn, reduction="macro", zero_division=1
        )
        f1_score = smp.metrics.f1_score(
            tp, fp, fn, tn, reduction="macro", zero_division=1
        )

        iou_score = iou_score.cpu().detach().item()
        f1_score = f1_score.cpu().detach().item()

        if math.isnan(iou_score):
            iou_score = 1
        if math.isnan(f1_score):
            f1_score = 1

        print(f"IOU (Jaccard): {iou_score}")
        print(f"F1 (Dice): {f1_score}")

        sum_IOU.append(iou_score)
        sum_F1.append(f1_score)

    # sum_IOU = np.array(sum_IOU)
    # sum_F1 = np.array(sum_F1)
    print("-------------")
    print("Avg Jaccard ", np.mean(sum_IOU), np.std(sum_IOU))
    print("Avg Dice ", np.mean(sum_F1), np.std(sum_F1))

In [None]:
# Evaluate tumor dice

sum_IOU = []
sum_F1 = []
# thresholds = [0.3, 0.5, 0.7, 0.9]
thresholds = [0.9]


for threshold in thresholds:
    print(f"threshold = {threshold}")

    for i in range(0, 1):
        batch = next(iter(validation_loader))
        # for batch in tqdm(validation_loader):
        imgs, masks = batch["img"], batch["mask"]

        masks[masks != 1] = 0
        masks[masks == 1] = 1
        masks = masks.cpu().detach().numpy()

        masks[0][0] = smooth_mask(masks[0][0])
        masks = torch.from_numpy(masks)

        img = batch["img"][0].numpy(force=True)
        img = np.moveaxis(img, 0, 2)
        img_nromal = imagenet_denormalise(img)
        rgb_mask = mask_to_rgb(masks[0][0], classes="tumor")
        fig, axs = plt.subplots(1, 3, figsize=(10, 10))
        axs[0].imshow(img_nromal)
        axs[0].title.set_text("Image")
        axs[1].imshow(rgb_mask)
        axs[1].title.set_text("Ground Truth")

        imgs = imgs.to("cuda").float()

        with torch.no_grad():
            out = model(imgs)
            out = torch.nn.functional.softmax(out, dim=1)

        out = out.cpu().detach().numpy()
        tumor_mask = out[0][1]
        tumor_mask[tumor_mask < threshold] = 0
        tumor_mask[tumor_mask >= threshold] = 1
        tumor_mask = smooth_mask(tumor_mask)

        axs[2].imshow(tumor_mask)
        axs[2].title.set_text("Prediction")
        for ax in fig.axes:
            ax.axis("off")
        plt.show()

        tumor_mask = tumor_mask[
            np.newaxis,
            np.newaxis,
            :,
            :,
        ]
        tumor_mask = tumor_mask.astype(np.uint8)
        tumor_mask = torch.from_numpy(tumor_mask)

        # print(np.count_nonzero(tumor_mask))
        # print(torch.count_nonzero(masks))
        tp, fp, fn, tn = smp.metrics.get_stats(tumor_mask, masks, mode="binary")
        iou_score = smp.metrics.iou_score(
            tp, fp, fn, tn, reduction="macro", zero_division=1
        )
        f1_score = smp.metrics.f1_score(
            tp, fp, fn, tn, reduction="macro", zero_division=1
        )

        iou_score = iou_score.cpu().detach().item()
        f1_score = f1_score.cpu().detach().item()

        if math.isnan(iou_score):
            iou_score = 1
        if math.isnan(f1_score):
            f1_score = 1

        print(f"IOU (Jaccard): {iou_score}")
        print(f"F1 (Dice): {f1_score}")

        sum_IOU.append(iou_score)
        sum_F1.append(f1_score)

    # sum_IOU = np.array(sum_IOU)
    # sum_F1 = np.array(sum_F1)
    print("-------------")
    print("Avg Jaccard ", np.mean(sum_IOU), np.std(sum_IOU))
    print("Avg Dice ", np.mean(sum_F1), np.std(sum_F1))