**Jupyter enviornment to test the NIX-Nets**<br>
Modify the script to suit your dataset and model requirements. To employ an alternative model, import the appropriate nix model with the correct set of parameters. Additionally, make sure to update the model's file path.

The second dataset, Dataset 2 with 'realfake' is preset. To utilize 'fakefake', you need to replace "realfake" with "fakefake" in the call to CombinedDataset. For Dataset 1 or the inpainting dataset, invoke ImageDataset with the relevant data path and chosen method. (For reference, these datasets are also used in various NIX-Net training instances.)

In [None]:
from sklearn.metrics import jaccard_score
from torch import no_grad, sigmoid, load, float32
from torchvision.ops.focal_loss import sigmoid_focal_loss
from torch.utils.data import DataLoader
from model_definitions.nix_85 import NIX #Import other NIX-Net here
from dataset import ImageDataset, collate_fn, CombinedDataset
from tqdm import tqdm
from torch import sigmoid
import torch
import matplotlib.pyplot as plt
import numpy as np

In [None]:
def load_model(path, device):
    model = NIX(512, 512)
    model.load_state_dict(load(path, map_location=device))
    model = model.to(device)
    model.eval()
    return model


def val(model, dataloader, device):
    totalvalloss = 0
    total_iou = 0
    with no_grad():
        for x, r, y in tqdm(dataloader):
            x, r, y = x.to(device, dtype=float32), r.to(device, dtype=float32), y.to(device, dtype=float32)
            output = model(x, r)
            totalvalloss += sigmoid_focal_loss(output, y, reduction="mean").item()
            pred = (sigmoid(output) > 0.5).int()
            y = (sigmoid(y) > 0.5).int()
            total_iou += jaccard_score(y.flatten().cpu().numpy(), pred.flatten().cpu().numpy())
    totalvalloss = totalvalloss/len(dataloader)
    total_iou = total_iou / len(dataloader)
    print('Val Loss: %.3f | IoU: %.3f' % (totalvalloss, total_iou))
    return totalvalloss, total_iou


device = "cuda:0"
device = torch.device(device)

In [None]:
#Paths for data and model

PATH_TEST1 = "/home/adlerpqt/data/test"
PATH_TEST2 = "/home/guenthhx/MPDL2/data/test"
PATH_MODEL = "/home/adlerpqt/MPDL_Project_2/models/dataset_2/nix_85_large_dataset_realfake/nix_85_0807_large_dataset_realfake.pth"

In [None]:
#Change here for using another dataset

test_data = CombinedDataset(PATH_TEST1, PATH_TEST2, "realfake")
test_dataloader = DataLoader(test_data, batch_size=8, num_workers=4, shuffle=False, collate_fn=collate_fn)

In [None]:
nix = load_model(PATH_MODEL, device)
val(nix, test_dataloader, device)

In [None]:
test_dataloader = DataLoader(test_data, batch_size=1, num_workers=0, shuffle=True, collate_fn=collate_fn)

for i in range(0, 25):
    x, r, y = next(iter(test_dataloader))
    x, r, y = x.to(device, dtype=float32), r.to(device, dtype=float32), y.to(device, dtype=float32)
    with no_grad():
        pred = nix(x, r)

    x = np.squeeze(x.detach().cpu().numpy())
    x = np.einsum('jkl->klj', x)

    y = np.squeeze(y.detach().cpu().numpy())

    pred = (sigmoid(pred) > 0.5).int()
    pred = np.squeeze(pred.detach().cpu().numpy())

    fig, axes = plt.subplots(1, 3, figsize=(20, 20))
    axes[0].imshow(x)
    axes[0].set_title('Image')
    axes[1].imshow(y)
    axes[1].set_title('Mask')
    axes[2].imshow(pred)
    axes[2].set_title('Prediction')
    fig.tight_layout()