In [None]:
import torch
from dataset import DRSegmentationDataset
from unet import UNet
from sklearn.model_selection import KFold
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
from discriminator import Discriminator
from torcheval.metrics import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryAUPRC

In [None]:
train_dataset = DRSegmentationDataset("/home/wilk/diabetic_retinopathy/datasets/processed_segmentation_dataset/train_set")
test_dataset = DRSegmentationDataset("/home/wilk/diabetic_retinopathy/datasets/processed_segmentation_dataset/test_set")

In [None]:
test_dataloader = torch.utils.data.DataLoader(
                      test_dataset, 
                      batch_size=1)

In [None]:
loaded_model = UNet(3, 5)
loaded_model.load_state_dict(torch.load("segmentation_generator.pth"))


In [None]:
loaded_model.eval()
loss = torch.nn.BCELoss()
test_loss = 0

with torch.no_grad():
    for test_batch_id, test_batch in enumerate(test_dataloader):                
        input_tensor = test_batch[0]
        target_tensor = test_batch[1]

        val_output = loaded_model(input_tensor)

        loss_value = loss(val_output, target_tensor)
        test_loss += loss_value.item() 

mean_test_loss = test_loss / len(test_dataloader)
print("Mean test loss:", mean_test_loss)

In [None]:
for test_batch_id, test_batch in enumerate(test_dataloader):        
    if test_batch_id < 6:
        continue        
    input_tensor = test_batch[0]
    target_tensor = test_batch[1]

    test_output = loaded_model(input_tensor)
    test_output = test_output > 0.5

    target_tensor = target_tensor.squeeze()
    test_output = test_output.squeeze()

    lesion_index = 4

    fig = plt.figure(figsize=(20, 10))
    plt.subplots_adjust(bottom=0.1, right=0.8, top=2)

    ax = fig.add_subplot(1,3,1)
    ax.imshow(input_tensor[0, ...].cpu().permute(1, 2, 0), cmap='gray')
    ax.set_title("Wejściowy obraz")
    ax.axis('off')

    ax = fig.add_subplot(1,3,2)
    ax.imshow(target_tensor[lesion_index, ...].cpu(), cmap='gray')
    ax.set_title("Poprawna maska")
    ax.axis('off')

    ax = fig.add_subplot(1,3,3)
    ax.imshow(test_output[lesion_index, ...].cpu().detach(), cmap='gray')
    ax.set_title("Maska predykcji")
    ax.axis('off')