In [None]:
from sklearn.metrics import jaccard_score
from torch import no_grad, sigmoid, load, float32
from torchvision.ops.focal_loss import sigmoid_focal_loss
from torch.utils.data import DataLoader
from nix import NIX
from dataset import ImageDataset, collate_fn
from tqdm import tqdm
import torch
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def load_model(path, device):
    model = NIX(512, 512)
    model.load_state_dict(load(path, map_location=device))
    model = model.to(device)
    model.eval()
    return model


def val(model, dataloader, device):
    totalvalloss = 0
    total_iou = 0
    with no_grad():
        for x, r, y in tqdm(dataloader):
            x, r, y = x.to(device, dtype=float32), r.to(device, dtype=float32), y.to(device, dtype=float32)
            output = model(x, r)
            print(output)
            totalvalloss += sigmoid_focal_loss(output, y, reduction="mean").item()
            pred = (output > 0.5).int()
            y = (y > 0.5).int()
            total_iou += jaccard_score(y.flatten().cpu().numpy(), pred.flatten().cpu().numpy())
    totalvalloss = totalvalloss/len(dataloader)
    total_iou = total_iou / len(dataloader)
    print('Val Loss: %.3f | IoU: %.3f' % (totalvalloss, total_iou))
    return totalvalloss, total_iou

# Set device
device = "cpu"
if torch.backends.mps.is_available():
    device = "mps"
elif torch.cuda_is_available():
    device = "cuda"

device = torch.device(device)

In [None]:
PATH_TEST = "/Users/pauladler/MPDL_Project_2/data/test"
PATH_MODEL = "/Users/pauladler/MPDL_Project_2/models/nix_1706.pth"

In [None]:
test_data = ImageDataset(PATH_TEST)
test_dataloader = DataLoader(test_data, batch_size=1, num_workers=0, shuffle=False, collate_fn=collate_fn)

In [None]:
nix = load_model(PATH_MODEL, device)
val(nix, test_dataloader, device)

In [None]:
test_dataloader = DataLoader(test_data, batch_size=1, num_workers=0, shuffle=True)

for i in range(0, 25):
    x, r, y = next(iter(test_dataloader))
    x, r, y = x.to(device, dtype=float32), r.to(device, dtype=float32), y.to(device, dtype=float32)
    with no_grad():
        pred = nix(x, r)

    x = np.squeeze(x.detach().cpu().numpy())
    x = np.einsum('jkl->klj', x)

    y = np.squeeze(y.detach().cpu().numpy())

    pred = (pred > 0.5).int()
    pred = np.squeeze(pred.detach().cpu().numpy())

    fig, axes = plt.subplots(1, 3, figsize=(20, 20))
    axes[0].imshow(x)
    axes[0].set_title('Image')
    axes[1].imshow(y)
    axes[1].set_title('Mask')
    axes[2].imshow(pred)
    axes[2].set_title('Prediction')
    fig.tight_layout()