In [None]:
import torch
from segmentation.dataset import DRSegmentationDataset
from segmentation.unet import UNet
from sklearn.model_selection import KFold
from torch.utils.tensorboard import SummaryWriter
import numpy as np
import matplotlib.pyplot as plt
from segmentation.discriminator import Discriminator
from torcheval.metrics import BinaryAccuracy, BinaryAUROC, BinaryF1Score, BinaryAUPRC

In [None]:
NUM_CLASSES = 4

In [None]:
test_dataset_path = ''
model_path = ''

In [None]:
test_dataset = DRSegmentationDataset(test_dataset_path, include_optic_disc=False)

In [None]:
test_dataloader = torch.utils.data.DataLoader(
                      test_dataset, 
                      batch_size=1)

In [None]:
loaded_model = UNet(3, NUM_CLASSES)
loaded_model.load_state_dict(torch.load(model_path, weights_only=True, map_location='cpu'))


In [None]:
loaded_model.eval()
loss = torch.nn.BCELoss()
test_loss = 0

with torch.no_grad():
    for test_batch_id, test_batch in enumerate(test_dataloader):                
        input_tensor = test_batch[0]
        target_tensor = test_batch[1]

        val_output = loaded_model(input_tensor)

        loss_value = loss(val_output, target_tensor)
        test_loss += loss_value.item() 

mean_test_loss = test_loss / len(test_dataloader)
print("Mean test loss:", mean_test_loss)

In [None]:
targets = []
outputs = []
for test_batch_id, test_batch in enumerate(test_dataloader):        
    if test_batch_id < 6:
        continue        
    input_tensor = test_batch[0]
    target_tensor = test_batch[1]

    test_output = loaded_model(input_tensor)
    outputs += test_output.flatten().detach().cpu().tolist()
    targets += target_tensor.flatten().detach().cpu().tolist()

targets = list(map(int, targets))

In [None]:
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve

fpr, tpr, thresholds = roc_curve(targets, outputs)

plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color="darkorange", label="ROC curve")
plt.plot([0, 1], [0, 1], color="navy", linestyle="--", label="Chance")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")

# Optionally, add selected thresholds as annotations
for i in range(0, len(thresholds), max(1, len(thresholds)//10)):
    plt.annotate(f"{thresholds[i]:.2f}", (fpr[i], tpr[i]))

plt.legend()
plt.show()

In [None]:
for test_batch_id, test_batch in enumerate(test_dataloader):        
    if test_batch_id < 6:
        continue        
    input_tensor = test_batch[0]
    target_tensor = test_batch[1]

    test_output = loaded_model(input_tensor)

    target_tensor = target_tensor.squeeze()
    test_output = test_output.squeeze()

    lesion_index = 0

    fig = plt.figure(figsize=(20, 10))
    plt.subplots_adjust(bottom=0.1, right=0.8, top=2)

    ax = fig.add_subplot(1,3,1)
    ax.imshow(input_tensor[0, ...].cpu().permute(1, 2, 0), cmap='gray')
    ax.set_title("Wejściowy obraz")
    ax.axis('off')

    ax = fig.add_subplot(1,3,2)
    ax.imshow(target_tensor[lesion_index, ...].cpu(), cmap='gray')
    ax.set_title("Poprawna maska")
    ax.axis('off')

    ax = fig.add_subplot(1,3,3)
    ax.imshow(test_output[lesion_index, ...].cpu().detach(), cmap='gray')
    ax.set_title("Maska predykcji")
    ax.axis('off')