# Testing the trained U-Net models

## Module imports

In [None]:
#import matplotlib.pyplot as plt
import torch
from torch.utils.data import Dataset, DataLoader
import torch.nn as nn
from tqdm import tqdm # progress bar
from torch.optim import Adam
import numpy as np
from skimage import metrics

# custom modules
from unet import StentDataset, UNet

## Loading the Model

In [None]:
model_save_path = "./weights/iter_24000.pt"
model = UNet(in_channels=1, out_channels=1)
model.double()
#model.load_state_dict(torch.load(model_save_path))
model.eval()  # Switching to evaluation mode

## Loading test data

In [None]:
n_test_images = 10
batch_size = 2 
delta = 94


test_dataset = StentDataset(input_path="data/dataset/test/x",
                            target_path="data/dataset/test/y")
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=True)

## Testing model

In [None]:
PSNR = 0
SSIM = 0
for i, data in enumerate(tqdm(test_loader)):
    # Get inputs and targets
    inputs, targets = data
    # Calculate outputs
    outputs = model(inputs)
    # Preprocess output tensors and targets for scikit-image
    outputs = outputs.detach().numpy()[:, 0, delta:-delta, delta:-delta]
    targets = targets.detach().numpy()[:, 0, delta:-delta, delta:-delta]
    # Calculate metrics
    for j in range(batch_size):
        output = outputs[j, ...]
        target = targets[j, ...]
        PSNR = PSNR + metrics.peak_signal_noise_ratio(target, output)/n_test_images
        SSIM = SSIM + metrics.structural_similarity(target, output)/n_test_images