In [22]:
import sys

sys.path.insert(0, "..")

from pathlib import Path

import matplotlib.pyplot as plt
import torch
from torch.utils.data import DataLoader
from torchvision.io.image import ImageReadMode, read_image, write_jpeg
from torchvision.transforms.functional import resize
from torchvision.transforms.v2 import Compose
from torchvision.utils import draw_keypoints, draw_segmentation_masks

from src.dataloader.transform import ImagenetNormalize, ToNormalized
from src.model.model import UNETNetwork, UNETNetworkModi
from src.utils.dirichlet_utils import (
    combined_dirichlet,
    convert_belief_mass_to_prediction,
)

In [49]:
MODEL_BASELINE_PATH = "/mnt/storage/projects/semi-he/data/model/20241221_185129_tiger_baseline_500_bs16_pseudo/499_tiger_baseline_500_bs16_pseudo_model.pt"
MODEL_PROPOSED_PATH = "/mnt/storage/projects/semi-he/data/model/20250109_184230_ukmtils_proposed_epoch500_bs16_pseudo/499_ukmtils_proposed_epoch500_bs16_pseudo_model.pt"

baseline_model = UNETNetwork(number_class=3)
baseline_model.load_state_dict(torch.load(MODEL_BASELINE_PATH))
baseline_model.cuda()
baseline_model.eval()
proposed_model = UNETNetworkModi(number_class=3)
proposed_model.load_state_dict(torch.load(MODEL_PROPOSED_PATH))
proposed_model.cuda()
proposed_model.eval()
preprocessor = Compose([ToNormalized(), ImagenetNormalize()])

In [65]:
def generate_baseline(
    prediction: torch.Tensor,
    original_size,
):
    final_prediction = prediction.sigmoid()
    resized_prediction = resize(final_prediction, original_size)
    colored_prediction = resized_prediction.argmax(1)
    return colored_prediction


def generate_dirichlet(
    prediction: torch.Tensor,
    original_size,
):
    output_belief, output_uncertainty = combined_dirichlet(
        prediction[0].relu(),
        prediction[1].relu(),
    )
    final_prediction = convert_belief_mass_to_prediction(
        output_belief,
        output_uncertainty,
    )
    resized_prediction = resize(final_prediction, original_size)
    colored_prediction = resized_prediction.argmax(1)
    return colored_prediction


def convert_to_standard_color(
    argmax_prediction: torch.Tensor,
    original_image,
    mapping=None,
    alpha=0.2,
):
    if mapping is None:
        masks = torch.stack(
            [
                argmax_prediction[0] == 0,
                argmax_prediction[0] == 1,
                argmax_prediction[0] == 2,
            ],
        )
    else:
        masks = torch.stack(
            [
                argmax_prediction[0] == mapping[0],
                argmax_prediction[0] == mapping[1],
                argmax_prediction[0] == mapping[2],
            ],
        )

    colors = [(0, 0, 0), (255, 0, 0), (0, 255, 0)]
    return draw_segmentation_masks(
        original_image,
        masks,
        alpha=alpha,
        colors=colors,
    )


def write_images(
    images,
    directory: Path,
    index,
):
    image_paths = [
        directory.joinpath(f"{index}_image.jpg"),
        directory.joinpath(f"{index}_proposed.jpg"),
        directory.joinpath(f"{index}_baseline.jpg"),
        directory.joinpath(f"{index}_label.jpg"),
    ]
    for image, image_path in zip(images, image_paths):
        write_jpeg(image, str(image_path))

# Image Visualisation for UKMTIls dataset

In [51]:
MODEL_BASELINE_PATH = "/mnt/storage/projects/semi-he/data/model/20241221_185129_tiger_baseline_500_bs16_pseudo/499_tiger_baseline_500_bs16_pseudo_model.pt"
MODEL_PROPOSED_PATH = "/mnt/storage/projects/semi-he/data/model/20250109_184230_ukmtils_proposed_epoch500_bs16_pseudo/499_ukmtils_proposed_epoch500_bs16_pseudo_model.pt"

baseline_model = UNETNetwork(number_class=3)
baseline_model.load_state_dict(torch.load(MODEL_BASELINE_PATH))
baseline_model.cuda()
baseline_model.eval()
proposed_model = UNETNetworkModi(number_class=3)
proposed_model.load_state_dict(torch.load(MODEL_PROPOSED_PATH))
proposed_model.cuda()
proposed_model.eval()
preprocessor = Compose([ToNormalized(), ImagenetNormalize()])

# Dataset Definition
test_image = Path("/mnt/storage/Dataset130_ukmtils/imagesTs/")
images = [x for x in test_image.glob("*.png")]

DIRECTORY_NAME = "UKMTILS_SAMPLES"
num_samples = 20
directory = Path(DIRECTORY_NAME)
if not directory.exists():
    directory.mkdir()
# plt.figure(figsize=(16, 5 * num_samples), dpi=300)
# plt.tight_layout()

for i, image_path in enumerate(images):
    label_path = str(image_path).replace("_0000", "").replace("images", "labels")
    image = read_image(str(image_path)).unsqueeze(0)
    label = read_image(label_path, ImageReadMode.GRAY).unsqueeze(0)

    processed_image, mask = preprocessor(image.cuda(), label.cuda())

    processed_image = resize(processed_image, [1024, 1024])
    with torch.no_grad():
        prediction_1 = baseline_model(processed_image)
        prediction_2 = proposed_model(processed_image)

    baseline_prediction = generate_baseline(prediction_1, image.shape[2:])
    proposed_prediction = generate_dirichlet(prediction_2, image.shape[2:])
    baseline_prediction = convert_to_standard_color(baseline_prediction, image[0])
    proposed_prediction = convert_to_standard_color(proposed_prediction, image[0])
    label = convert_to_standard_color(label[0], image[0])

    images = [
        image[0],
        proposed_prediction.cpu(),
        baseline_prediction.cpu(),
        label.cpu(),
    ]

    write_images(images, directory, i)

    # plt.subplot(num_samples, 3, 1 + i * 3)
    # plt.title("Image")
    # plt.imshow(image[0].permute([1, 2, 0]))
    # plt.subplot(num_samples, 3, 2 + i * 3)
    # plt.title("Baseline")
    # plt.imshow(image[0].permute([1, 2, 0]), alpha=1.0)
    # plt.imshow(model_1_color_prediction.cpu(), alpha=0.2)
    # plt.subplot(num_samples, 3, 3 + i * 3)
    # plt.title("Proposed")
    # plt.imshow(image[0].permute([1, 2, 0]), alpha=1.0)
    # plt.imshow(model_2_color_prediction.cpu(), alpha=0.2)
    # plt.subplot(num_samples, 4, 4 + i * 3)
    # plt.title("Label")
    # plt.imshow(image[0].permute([1, 2, 0]), alpha=1.0)
    # plt.imshow(label.cpu(), alpha=0.2)

    if i == (num_samples - 1):
        break

# Image Visualisation for Ocelot dataset

In [66]:
MODEL_BASELINE_PATH = "/mnt/storage/projects/semi-he/data/model/20250110_012619_ocelot_baseline_epoch500_bs16_pseudo/499_ocelot_baseline_epoch500_bs16_pseudo_model.pt"
MODEL_PROPOSED_PATH = "/mnt/storage/projects/semi-he/data/model/20250110_225854_ocelot_proposed_epoch500_bs16_pseudo/499_ocelot_proposed_epoch500_bs16_pseudo_model.pt"

baseline_model = UNETNetwork(number_class=3)
baseline_model.load_state_dict(torch.load(MODEL_BASELINE_PATH))
baseline_model.cuda()
baseline_model.eval()
proposed_model = UNETNetworkModi(number_class=3)
proposed_model.load_state_dict(torch.load(MODEL_PROPOSED_PATH))
proposed_model.cuda()
proposed_model.eval()
preprocessor = Compose([ToNormalized(), ImagenetNormalize()])

# Dataset Definition
test_image = Path("/mnt/storage/ocelot2023_v1.0.1/images/test/tissue/")
images = [x for x in test_image.glob("*.jpg")]

DIRECTORY_NAME = "OCELOT"
num_samples = 20
directory = Path(DIRECTORY_NAME)
if not directory.exists():
    directory.mkdir()
# plt.figure(figsize=(16, 5 * num_samples), dpi=300)
# plt.tight_layout()

for i, image_path in enumerate(images):
    label_path = str(image_path).replace("images", "annotations").replace("jpg", "png")
    image = read_image(str(image_path)).unsqueeze(0)
    label = read_image(label_path, ImageReadMode.GRAY).unsqueeze(0)

    processed_image, mask = preprocessor(image.cuda(), label.cuda())

    processed_image = resize(processed_image, [1024, 1024])
    with torch.no_grad():
        prediction_1 = baseline_model(processed_image)
        prediction_2 = proposed_model(processed_image)

    baseline_prediction = generate_baseline(prediction_1, image.shape[2:])
    proposed_prediction = generate_dirichlet(prediction_2, image.shape[2:])
    baseline_prediction = convert_to_standard_color(
        baseline_prediction, image[0], alpha=0.3
    )
    proposed_prediction = convert_to_standard_color(
        proposed_prediction, image[0], alpha=0.3
    )
    label = convert_to_standard_color(label[0], image[0], [1, 2, 255], alpha=0.3)
    images = [
        image[0],
        proposed_prediction.cpu(),
        baseline_prediction.cpu(),
        label.cpu(),
    ]

    write_images(images, directory, i)

    # plt.subplot(num_samples, 3, 1 + i * 3)
    # plt.title("Image")
    # plt.imshow(image[0].permute([1, 2, 0]))
    # plt.subplot(num_samples, 3, 2 + i * 3)
    # plt.title("Baseline")
    # plt.imshow(image[0].permute([1, 2, 0]), alpha=1.0)
    # plt.imshow(model_1_color_prediction.cpu(), alpha=0.2)
    # plt.subplot(num_samples, 3, 3 + i * 3)
    # plt.title("Proposed")
    # plt.imshow(image[0].permute([1, 2, 0]), alpha=1.0)
    # plt.imshow(model_2_color_prediction.cpu(), alpha=0.2)
    # plt.subplot(num_samples, 4, 4 + i * 3)
    # plt.title("Label")
    # plt.imshow(image[0].permute([1, 2, 0]), alpha=1.0)
    # plt.imshow(label.cpu(), alpha=0.2)

    if i == (num_samples - 1):
        break