In [None]:
import os
os.chdir('/users/scratch1/s189737/collaborative-learning-diabetic-retinopathy')

In [None]:
import torch
from segmentation.dataset import DRSegmentationDataset
from segmentation.unet import UNet
import matplotlib.pyplot as plt
from grading_model.grading_model import GradingModel
from matplotlib.lines import Line2D
import numpy as np

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
test_dataset = DRSegmentationDataset("/users/scratch1/s189737/collaborative-learning-diabetic-retinopathy/datasets/processed_segmentation_dataset/test_set")

In [None]:
test_dataloader = torch.utils.data.DataLoader(
                      test_dataset, 
                      batch_size=1)

In [None]:
segmentation_model = UNet(3, 5)
segmentation_model.to(device)
segmentation_model.load_state_dict(torch.load("/users/scratch1/s189737/collaborative-learning-diabetic-retinopathy/models/segmentation/pretrained/unet_pretrained.pth", weights_only=True, map_location=device))

grading_model = GradingModel()
grading_model.to(device)
grading_model.load_state_dict(torch.load("/users/scratch1/s189737/collaborative-learning-diabetic-retinopathy/models/classification/grading_model_with_masks_23-07-25_14-08.pth", weights_only=True, map_location=device))

grading_model_pretrained = GradingModel()
grading_model_pretrained.to(device)
grading_model_pretrained.load_state_dict(torch.load("/users/scratch1/s189737/collaborative-learning-diabetic-retinopathy/models/classification/grading_model_pretrained.pth", weights_only=False, map_location=device))


In [None]:
segmentation_model.eval()
loss = torch.nn.BCELoss()
test_loss = 0

with torch.no_grad():
    for test_batch_id, test_batch in enumerate(test_dataloader):                
        input_tensor = test_batch[0].to(device)
        target_tensor = test_batch[1].to(device)

        val_output = segmentation_model(input_tensor)

        loss_value = loss(val_output, target_tensor)
        test_loss += loss_value.item() 

mean_test_loss = test_loss / len(test_dataloader)
print("Mean test loss:", mean_test_loss)

In [None]:
lesion_index = 1

for test_batch_id, (input_batch, target_batch) in enumerate(test_dataloader):
    if test_batch_id >= 5:
        break

    input_batch = input_batch.to(device)
    target_batch = target_batch.to(device)

    masks = segmentation_model(input_batch)
    pretrained_logits, pretrained_f_low, pretrained_f_high, _ = grading_model_pretrained(input_batch)

    logits, _, _, attention_maps = grading_model(input_batch, masks, pretrained_f_low, pretrained_f_high)

    reference_image = input_batch[0].cpu().detach().numpy()
    reference_image = reference_image.transpose(1, 2, 0)

    target_image = target_batch[0, lesion_index].cpu().detach().numpy()

    # TODO: Add min max scaling
    reference_image = (reference_image - reference_image.min()) / (reference_image.max() - reference_image.min())
    reference_image = (reference_image * 255).astype('uint8')
    
    attention_map = attention_maps[0, lesion_index].cpu().detach().numpy()
    # attention_map = torch.nn.functional.sigmoid(attention_map)
    attention_map = (attention_map - attention_map.min()) / (attention_map.max() - attention_map.min())
    attention_map = np.where(attention_map > 0.6, 1, 0)
    attention_map = (attention_map * 255).astype('uint8')

    mask = masks[0][lesion_index].cpu().detach().numpy()
    mask = (mask*255).astype('uint8')

    fig = plt.figure(figsize=(20, 10))
    plt.subplots_adjust(bottom=0.1, right=0.8, top=2)

    ax = fig.add_subplot(1,4,1)
    ax.imshow(reference_image, cmap='gray')
    ax.set_title("Wejściowy obraz")
    ax.axis('off')
    pos1 = ax.get_position()

    ax = fig.add_subplot(1,4,2)
    ax.imshow(target_image, cmap='gray')
    ax.set_title("Maska poprawna")
    ax.axis('off')
    pos2 = ax.get_position()

    ax = fig.add_subplot(1,4,3)
    ax.imshow(attention_map, cmap='gray')
    ax.set_title("Mapa atencji")
    ax.axis('off')
    pos3 = ax.get_position()

    ax = fig.add_subplot(1,4,4)
    ax.imshow(mask, cmap='gray')
    ax.set_title("Maska generatora")
    ax.axis('off')
    pos4 = ax.get_position()

    fig.add_artist(Line2D([pos1.x1, pos1.x1], [pos1.y0, pos1.y1], color='white', linewidth=2))

    fig.add_artist(Line2D([pos2.x1, pos2.x1], [pos2.y0, pos2.y1], color='white', linewidth=2))
    fig.add_artist(Line2D([pos3.x1, pos3.x1], [pos3.y0, pos3.y1], color='white', linewidth=2))