In [None]:
#%load_ext blackcellmagic

In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import pickle
from random import randint, randrange
import sys
from tqdm import tqdm_notebook as tqdm
import cv2
print("CUDA available: {}".format(torch.cuda.is_available()))

In [None]:
# Import model architectures
from models.DSCLRCN_OldContext import DSCLRCN
from models.CoSADUV import CoSADUV
from models.CoSADUV_NoTemporal import CoSADUV_NoTemporal

# Prepare settings and get the datasets
from util.data_utils import get_SALICON_datasets, get_video_datasets

### Data options ###
dataset_root_dir = "Dataset/UAV123"  # Dataset/[SALICON, UAV123]
mean_image_name = (
    "mean_image.npy"
)  # Must be located at dataset_root_dir/mean_image_name
img_size = (480, 640)  # height, width - original: 480, 640, reimplementation: 96, 128
duration = 300  # Length of sequences loaded from each video, if a video dataset is used

from util import loss_functions

from util.solver import Solver

### Training options ###

# Batchsize: Determines how many images are processed before backpropagation is done
batchsize = 20  # Recommended: 20.
# Minibatchsize: Determines how many images are processed at a time on the GPU
minibatchsize = 1  # Recommended: 4 for 480x640 for >12GB mem, 2 for <12GB mem.
epoch_number = 5  # Recommended: 10 (epoch_number =~ batchsize/2)
optim_str = "SGD"  # 'SGD' or 'Adam' Recommended: Adam
optim_args = {"lr": 1e-2}  # 1e-2 if SGD, 1e-4 if Adam
# Loss functions:
# From loss_functions (use loss_functions.LOSS_FUNCTION_NAME)
# NSS_loss
# CE_MAE_loss
# PCC_loss
# KLDiv_loss
loss_func = loss_functions.NSS_alt  # Recommended: NSS_loss
test_loss_func = loss_functions.CE_MAE_loss

### Prepare optimiser ###
if batchsize % minibatchsize:
    print(
        "Error, batchsize % minibatchsize must equal 0 ({} % {} != 0).".format(
            batchsize, minibatchsize
        )
    )
    exit()
num_minibatches = batchsize / minibatchsize

# Scale the lr down as smaller minibatches are used
optim_args["lr"] /= num_minibatches

optim = torch.optim.SGD if optim_str == "SGD" else torch.optim.Adam

### Prepare datasets and loaders ###

if "SALICON" in dataset_root_dir:
    train_data, val_data, test_data, mean_image = get_SALICON_datasets(
        dataset_root_dir, mean_image_name, img_size
    )
    train_loader = [
        torch.utils.data.DataLoader(
            train_data,
            batch_size=minibatchsize,
            shuffle=True,
            num_workers=8,
            pin_memory=True,
        )
    ]
    val_loader = [
        torch.utils.data.DataLoader(
            val_data,
            batch_size=minibatchsize,
            shuffle=True,
            num_workers=8,
            pin_memory=True,
        )
    ]
    # Load test loader using val_data as SALICON does not give GT for its test set
    test_loader = [
        torch.utils.data.DataLoader(
            val_data,
            batch_size=minibatchsize,
            shuffle=True,
            num_workers=8,
            pin_memory=True,
        )
    ]
elif "UAV123" in dataset_root_dir:
    train_loader, val_loader, test_loader, mean_image = get_video_datasets(
        dataset_root_dir,
        mean_image_name,
        duration=duration,
        img_size=img_size,
        shuffle=False,
        loader_settings={
            "batch_size": minibatchsize,
            "num_workers": 8,
            "pin_memory": False,
        },
    )

In [None]:
# Train a model
train = False
models = []
model_names = []
if train:
    # Attempt to train a model using the original image sizes
    model = CoSADUV(input_dim=img_size, local_feats_net="Seg")

    print("Starting train on model with settings:")
    print("### Dataset settings ###")
    print("Dataset: {}".format(dataset_root_dir.split("/")[-1]))
    print("Image size: ({}h, {}w)".format(img_size[0], img_size[1]))
    print("Sequence duration: {}".format(duration))
    print("")
    print("### Training settings ###")
    print("Batch size: {}".format(batchsize))
    print("Minibatch size: {}".format(minibatchsize))
    print("Epochs: {}".format(epoch_number))
    print("")
    print("### Optimiser settings ###")
    print("Optimiser: {}".format(optim_str))
    print("Effective lr: {}".format(str(optim_args["lr"] * num_minibatches)))
    print("Actual lr: {}".format(str(optim_args["lr"])))
    print("Loss function: {}".format(loss_func.__name__))
    print("\n")

    # Create a solver with the options given above and appropriate location
    solver = Solver(
        optim=optim, optim_args=optim_args, loss_func=loss_func, location=location
    )
    # Start training
    solver.train(
        model,
        train_loader,
        val_loader,
        num_epochs=epoch_number,
        num_minibatches=num_minibatches,
        log_nth=50,
        filename_args={
            "batchsize": batchsize,
            "epoch_number": epoch_number,
            "optim": optim_str,
        },
    )

    models.append(model)
    model_names.append(
        "model_{}_{}_batch{}_epoch{}".format(type(model).__name__, loss_func.__name__, batchsize, epoch_number)
    )
    model.eval()

In [None]:
# Saving the model:
if train:
    model.save(
        "trained_models/model_{}_{}_batch{}_epoch{}".format(
            type(model).__name__, loss_func.__name__, batchsize, epoch_number
        )
    )
    with open(
        "trained_models/solver_{}_{}_batch{}_epoch{}.pkl".format(
            type(model).__name__, loss_func.__name__, batchsize, epoch_number
        ),
        "wb",
    ) as outf:
        pickle.dump(solver, outf, pickle.HIGHEST_PROTOCOL)

In [None]:
# Plotting training and validation loss over iterations:
if train:
    plt.subplot(2, 1, 1)
    plt.plot(solver.train_loss_history, "o")
    plt.title("Train Loss")
    plt.subplot(2, 1, 2)
    plt.plot(solver.val_loss_history, "-o")
    plt.title("Val Loss")
    plt.show()

In [None]:
# Loading a model from the saved state that produced
# the lowest validation loss during training:

# Requires the model classes be loaded

# Assumes the model uses models.CoSADUV_NoTemporal architecture.
# If not, this method will fail
def load_model_from_checkpoint(model_name):
    #filename = "C:/Users/simon/Downloads/Project Models/" + model_name + ".pth"  # Run on own computer
    filename = "trained_models/" + model_name + ".pth"  # Run on NCC/Linux
    if torch.cuda.is_available():
        checkpoint = torch.load(filename)
    else:
        # Load GPU model on CPU
        checkpoint = torch.load(filename, map_location="cpu")
    start_epoch = checkpoint["epoch"]
    best_accuracy = checkpoint["best_accuracy"]
    
    if "DSCLRCN" in model_name:
        model = DSCLRCN(input_dim=img_size, local_feats_net="Seg")
    elif "CoSADUV_NoTemporal" in model_name:
        model = CoSADUV_NoTemporal(input_dim=img_size, local_feats_net="Seg")
    elif "CoSADUV" in model_name:
        model = CoSADUV(input_dim=img_size, local_feats_net="Seg")
    else:
        print("Error: no model name found in filename: {}".format(model_name))
        return
    # Ignore extra parameters ('.num_batches_tracked'
    # that are added on NCC due to different pytorch version)
    model.load_state_dict(
        checkpoint["state_dict"], strict=False
    )

    print(
        "=> loaded model checkpoint '{}' (trained for {} epochs)\n   with architecture {}".format(
            model_name, checkpoint["epoch"], type(model).__name__
        )
    )

    if torch.cuda.is_available():
        model = model.cuda()
        print("   loaded to cuda")
    model.eval()
    return model


def load_model(model_name):
    #model = torch.load("C:/Users/simon/Downloads/Project Models/" + model_name, map_location="cpu")  # run on own computer
    model = torch.load("trained_models/" + model_name, map_location="cpu")  # run on ncc/linux
    print("=> loaded model '{}'".format(model_name))
    if torch.cuda.is_available():
        model = model.cuda()
        print("   loaded to cuda")
    model.eval()
    return model

In [None]:
# Loading some pretrained models to test them on the images:
if train:
    # Keep the trained model
    while len(models) > 1:
        del models[1]
    model_names = [model_names[0]]
    models = [models[0]]
else:
    while len(models) > 0:
        del models[0]
    model_names = []
    models = []
    
# TODO
# CoSADUV_NoTemporal on UAV123 with DoM loss func, Adam 1e-2 lr
# CoSADUV_NoTemporal on UAV123 with NSS loss func, Adam 1e-2 lr
# Above, plus transfer learning on EyeTrackUAV, same settings

# TEST ALL MODELS (QUANTITATIVE with CE_MAE, DoM, NSS_alt (split into +ve and -ve images))
# COLLECT SOME QUALITATIVE RESULTS FROM best DSCLRCN, CoSADUV_NoTemporal, CoSADUV models
# Also collect some qualitative results of arbitrary good/bad results




# DSCLRCN, UAV123, DoM loss func - Old Context
# model_names.append("DSCLRCN/UAV123/Homebrew 3.58last 2.56best testing/model_DSCLRCN_homebrew_batch20_epoch5")

# DSCLRCN models
## Trained on SALICON
### NSS_loss
# model_names.append("DSCLRCN/SALICON NSS -1.62NSS val best and last/best_model_DSCLRCN_NSS_loss_batch20_epoch5")
## Trained on UAV123
### NSS_alt loss func
# model_names.append("DSCLRCN/UAV123 NSS_alt 1.38last 3.15best testing/best_model_DSCLRCN_NSS_alt_batch20_epoch5")


# CoSADUV_NoTemporal models
## Trained on UAV123
### DoM loss func
model_names.append("CoSADUV_NoTemporal/DoM SGD 0.01lr - 3.16 NSS_alt/best_model_CoSADUV_NoTemporal_DoM_batch20_epoch6")
### NSS_alt loss func
# model_names.append("CoSADUV_NoTemporal/NSS_alt Adam lr 1e-4 - 1.36/best_model_CoSADUV_NoTemporal_NSS_alt_batch20_epoch5")
### CE_MAE loss func
# model_names.append("CoSADUV_NoTemporal/best_model_CoSADUV_NoTemporal_CE_MAE_loss_batch20_epoch10")


# CoSADUV models (CoSADUV2)
## Trained on UAV123
### NSS_alt loss func
#### 1 Frame backpropagation
#### Kernel size 1
# model_names.append("CoSADUV/NSS_alt Adam 0.001lr 1frame backprop size1 kernel -2train -0.7val 1epoch/best_model_CoSADUV_NSS_alt_batch20_epoch5")
#### Kernel size 3
# model_names.append("CoSADUV/NSS_alt Adam 0.01lr 1frame backprop size3 kernel/best_model_CoSADUV_NSS_alt_batch20_epoch5")
#### 2 Frame backpropagation
#### Kernel size 3
# model_names.append("CoSADUV/NSS_alt Adam 0.01lr 2frame backprop size3 kernel - 6.56 NSS_alt val/best_model_CoSADUV_NSS_alt_batch20_epoch5")
### DoM loss func
# Only very poor results achieved
### CE_MAE loss func
# Only very poor results achieved

# CoSADUV_NoTemporal models with transfer learning on EyeTrackUAV
# model_names.append("best_model_CoSADUV_NoTemporal_DoM_batch20_epoch5")
# model_names.append("best_model_CoSADUV_NoTemporal_NSS_alt_batch20_epoch5")

max_name_len = max([len(name) for name in model_names])
# Load the models specified above
iterable = model_names[1:] if train else model_names

for i, name in enumerate(iterable):
    if "best_model" in name:
        models.append(load_model_from_checkpoint(name))
    else:
        models.append(load_model(name))

print()
print("Loaded all specified models")

In [None]:
# Testing the different models on a random image from the val set:

# Choose which two models to test
m_index_1 = 0
m_index_2 = 0

# Pick a random test image and validation image
test_video_id = 2 #test_video_id = randrange(0, len(test_loader))
val_video_id = 7 #val_video_id  = randrange(0, len(val_loader))
test_frame_id = 48 #test_frame_id = randrange(0, len(test_loader.__getitem__(test_video_id)))
val_frame_id = 95 #val_frame_id  = randrange(0, len(val_loader.__getitem__(val_video_id)))

print("Test video index: {}, val video index: {}".format(test_video_id, val_video_id))    
print("Test frame index: {}, val frame index: {}".format(test_frame_id, val_frame_id))


# Load the images
for i, data in enumerate(test_loader.__getitem__(test_video_id)):
    if i == test_frame_id//minibatchsize:
        x, y = data
        x = x[test_frame_id % minibatchsize]
        y = y[test_frame_id % minibatchsize]
        break
for i, data in enumerate(val_loader.__getitem__(val_video_id)):
    if i == val_frame_id//minibatchsize:
        x_val, y_val = data
        x_val = x_val[val_frame_id % minibatchsize]
        y_val = y_val[val_frame_id % minibatchsize]
        break


# Get the original (before pre-processing) images to be displayed
original = x.transpose(0,1).transpose(1,2) + torch.from_numpy(mean_image)
original_val = x_val.transpose(0,1).transpose(1,2) + torch.from_numpy(mean_image)

# Create copies of the images to pass through each model
x = x.contiguous().view(1, *x.size())
x_2 = x[:]
x_val = x_val.contiguous().view(1, *x_val.size())
x_2_val = x_val[:]
if torch.cuda.is_available():
    x = x.cuda()
    x_val = x_val.cuda()
    x_2 = x_2.cuda()
    x_2_val = x_2_val.cuda()
y = y.numpy()
y_val = y_val.numpy()



##### First model #####

x_sal = models[m_index_1](Variable(x))
if torch.cuda.is_available():
    x_sal = x_sal.cpu()
x_sal_nmp = x_sal.squeeze().data.numpy()

# If model is temporal, reset the stored state
if models[m_index_1].temporal:
    models[m_index_1].clear_temporal_state()
    models[m_index_1].detach_temporal_state()

x_val_sal = models[m_index_1](Variable(x_val))
if torch.cuda.is_available():
    x_val_sal = x_val_sal.cpu()
x_val_sal_nmp = x_val_sal.squeeze().data.numpy()

if m_index_2 != m_index_1:
    ##### Second model #####
    x_2_sal = models[m_index_2](Variable(x_2))
    if torch.cuda.is_available():
        x_2_sal = x_2_sal.cpu()
    x_2_sal_nmp = x_2_sal.squeeze().data.numpy()

    if models[m_index_2].temporal:
        models[m_index_2].clear_temporal_state()
        models[m_index_2].detach_temporal_state()

    x_2_val_sal = models[m_index_2](Variable(x_2_val))
    if torch.cuda.is_available():
        x_2_val_sal = x_2_val_sal.cpu()
    x_2_val_sal_nmp = x_2_val_sal.squeeze().data.numpy()

    
# Normalise the output to [0, 1] range
# x_val_sal_nmp -= x_val_sal_nmp.min()
# x_val_sal_nmp /= x_val_sal_nmp.max()
# x_sal_nmp -= x_sal_nmp.min()
# x_sal_nmp /= x_sal_nmp.max()

# Plot the output
plt.figure(figsize=(24,16))

##### Testing set image #####
plt.subplot(3,4,1); plt.title('Original')
plt.imshow(original, vmin=0, vmax=1)
plt.subplot(3,4,2); plt.title('Ground Truth')
plt.imshow(y, cmap='gray', vmin=0, vmax=1)

# First model
plt.subplot(3,4,3)
plt.imshow(x_sal_nmp, cmap='gray', vmin=0, vmax=1); plt.title(model_names[m_index_1])
# Second model
if m_index_2 != m_index_1:
    plt.subplot(3,4,4)
    plt.imshow(x_2_sal_nmp, cmap='gray', vmin=0, vmax=1); plt.title(model_names[m_index_2])

##### Validation set image #####
plt.subplot(3,4,5); plt.title('Original Val')
plt.imshow(original_val, vmin=0, vmax=1)
plt.subplot(3,4,6); plt.title('Ground Truth Val')
plt.imshow(y_val, cmap='gray', vmin=0, vmax=1)


# First model
plt.subplot(3,4,7)
plt.imshow(x_val_sal_nmp, cmap='gray', vmin=0, vmax=1); plt.title(model_names[m_index_1])
# Second model
if m_index_2 != m_index_1:
    plt.subplot(3,4,8)
    plt.imshow(x_2_val_sal_nmp, cmap='gray', vmin=0, vmax=1); plt.title(model_names[m_index_2])
    
# Get the test_loss_func score of the val images
print(model_names[m_index_1])
x_val_sal_torch = torch.from_numpy(x_val_sal_nmp).unsqueeze(0)
x_sal_torch = torch.from_numpy(x_sal_nmp).unsqueeze(0)
y_val_torch = torch.from_numpy(y_val).unsqueeze(0)
y_torch = torch.from_numpy(y).unsqueeze(0)


print("[{}] NSS_loss Test:".format(m_index_1), loss_functions.NSS_loss(x_sal_torch, y_torch).item())
print("[{}] CE_MAE   Test:".format(m_index_1), loss_functions.CE_MAE_loss(x_sal_torch, y_torch).item())
print("[{}] CE       Test:".format(m_index_1), loss_functions.CE_loss(x_sal_torch, y_torch).item())
print("[{}] MAE      Test:".format(m_index_1), loss_functions.MAE_loss(x_sal_torch, y_torch).item())
print("[{}] DoM      Test:".format(m_index_1), loss_functions.DoM(x_sal_torch, y_torch).item())

print()

print("[{}] NSS_loss Validation:".format(m_index_1), loss_functions.NSS_loss(x_val_sal_torch, y_val_torch).item())
print("[{}] CE_MAE   Validation:".format(m_index_1), loss_functions.CE_MAE_loss(x_val_sal_torch, y_val_torch).item())
print("[{}] CE       Validation:".format(m_index_1), loss_functions.CE_loss(x_val_sal_torch, y_val_torch).item())
print("[{}] MAE      Validation:".format(m_index_1), loss_functions.MAE_loss(x_val_sal_torch, y_val_torch).item())
print("[{}] DoM      Validation:".format(m_index_1), loss_functions.DoM(x_val_sal_torch, y_val_torch).item())

if m_index_2 != m_index_1:
    print(model_names[m_index_2])
    x_2_val_sal_torch = torch.from_numpy(x_2_val_sal_nmp).unsqueeze(0)
    print("[{}] CE_MAE Validation:".format(m_index_2), loss_functions.CE_MAE_loss(x_2_val_sal_torch, y_val_torch).item())
    print("[{}] CE Validation:".format(m_index_2), loss_functions.CE_loss(x_2_val_sal_torch, y_val_torch).item())
    print("[{}] MAE Validation:".format(m_index_2), loss_functions.MAE_loss(x_2_val_sal_torch, y_val_torch).item())
    print("[{}] NSS_loss Validation:".format(m_index_2), loss_functions.NSS_loss(x_2_val_sal_torch, y_val_torch).item())
    print("[{}] DoM Validation:".format(m_index_2), loss_functions.DoM(x_2_val_sal_torch, y_val_torch).item())
    print("[{}] NSS_alt Validation:".format(m_index_2), loss_functions.NSS_alt(x_2_val_sal_torch, y_val_torch).item())
    
##### Heat map of GT and prediction on Validation set image (first model) #####
heat_map = original_val.cpu().data.numpy()
gt = y_val
pred = x_val_sal_nmp[:]
# Normalize pred so it's in range 0->1
pred /= np.max(pred)

alpha = 0.5
heat_map = np.array([[[alpha*r, alpha*g, alpha*b] for [r, g, b] in row] for row in heat_map])
gt       = np.array([[[(1-alpha)*x, 0, 0] for x in row] for row in gt])
pred     = np.array([[[0, 0, (1-alpha)*x] for x in row] for row in pred])

heat_map = heat_map + gt + pred

plt.subplot(3, 4, 11)
plt.imshow(heat_map, vmin=0, vmax=1); plt.title('Heat map 1st model: gt=red, pred=blue')

if m_index_1 != m_index_2:
    ##### Heat map of GT and prediction on Validation set image (second model) #####
    heat_map = original_val.cpu().data.numpy()
    gt = y_val
    pred_2 = x_2_val_sal_nmp
    # Normalize pred so it's in range 0->1
    pred_2 /= np.max(pred_2)

    alpha = 0.5
    heat_map_2 = np.array([[[alpha*r, alpha*g, alpha*b] for [r, g, b] in row] for row in heat_map])
    gt       = np.array([[[(1-alpha)*x, 0, 0] for x in row] for row in gt])
    pred_2     = np.array([[[0, 0, (1-alpha)*x] for x in row] for row in pred_2])

    heat_map_2 = heat_map_2 + gt + pred_2

    plt.subplot(3, 4, 12)
    plt.imshow(heat_map_2, vmin=0, vmax=1); plt.title('Heat map 2nd model: gt=red, pred=blue')

# plt.savefig('ResExamples/example_test_'+str(test_image_id)+'_val_'+str(val_image_id)+'.png')
plt.show()

if models[m_index_1].temporal:
    models[m_index_1].clear_temporal_state()
    models[m_index_1].detach_temporal_state()
if m_index_1 != m_index_2 and models[m_index_2].temporal:
    models[m_index_2].clear_temporal_state()
    models[m_index_2].detach_temporal_state()


In [None]:
cv2.imwrite("C:/Users/simon/Downloads/Project Models/Results Data/CoSADUV_NoTemporal_NSS_alt_artefact.jpg", x_sal_nmp*255)


cv2.imwrite("C:/Users/simon/Downloads/Project Models/Results Data/CoSADUV_NoTemporal_NSS_alt_artefact.jpg", x_val_sal_nmp*255)


In [None]:
# cv2.imwrite("/tmp/pbqk24_tmp/image0.jpg", x_val_sal_nmp*255)
cv2.imwrite("C:/Users/simon/Downloads/Project Models/Results Data/Input_image_1.jpg", cv2.cvtColor(original.data.numpy()*255, cv2.COLOR_BGR2RGB))
cv2.imwrite("C:/Users/simon/Downloads/Project Models/Results Data/GT_1.jpg", y*255)


cv2.imwrite("C:/Users/simon/Downloads/Project Models/Results Data/Input_image_2.jpg", cv2.cvtColor(original_val.data.numpy()*255, cv2.COLOR_BGR2RGB))
cv2.imwrite("C:/Users/simon/Downloads/Project Models/Results Data/GT_2.jpg", y_val*255)

In [None]:
# Testing the different models on a random image from the val set:

# Choose which two models to test
m_index_1 = 0
m_index_2 = 0

# Pick a random test image and validation image
test_video_id = 16 #test_video_id = randrange(0, len(test_loader))
val_video_id = 11 #val_video_id  = randrange(0, len(val_loader))
test_frame_id = 48 #test_frame_id = randrange(0, len(test_loader.__getitem__(test_video_id)))
val_frame_id = 95 #val_frame_id  = randrange(0, len(val_loader.__getitem__(val_video_id)))

print("Test video index: {}, val video index: {}".format(test_video_id, val_video_id))    
print("Test frame index: {}, val frame index: {}".format(test_frame_id, val_frame_id))

# Load specific images to use as examples

# Load the images
for i, data in enumerate(test_loader.__getitem__(test_video_id)):
    print(i)
    x, y = data
    if torch.cuda.is_available():
        x = x.cuda()
        y = y.cuda()
    x = x[0].contiguous().view(1, *x[0].size())
    y = y[0].contiguous().view(1, *y[0].size())
    x_sal = models[m_index_1](Variable(x))
    if i == test_frame_id:
        break
    
# If model is temporal, reset the stored state
if models[m_index_1].temporal:
    models[m_index_1].clear_temporal_state()
    models[m_index_1].detach_temporal_state()
    
for i, data in enumerate(val_loader.__getitem__(val_video_id)):
    print(i)
    x_val, y_val = data
    if torch.cuda.is_available():
        x_val = x_val.cuda()
        y_val = y_val.cuda()
    x_val = x_val[0].contiguous().view(1, *x_val[0].size())
    y_val = y_val[0].contiguous().view(1, *y_val[0].size())
    x_val_sal = models[m_index_1](Variable(x_val))
    if i % 10 == 0:
        out = x_val_sal.squeeze().cpu().data.numpy()
        out -= out.min()
        out /= out.max()
        x_val = x_val.cpu()
        cv2.imwrite("/tmp/pbqk24_tmp/Person_Example (person7, frame95)/CoSADUV_NoTemporal_DoM_video_frame{}.jpg".format(i), out*255)
    if i == val_frame_id:
        break

# Get the original (before pre-processing) images to be displayed
original = x.cpu().squeeze().transpose(0,1).transpose(1,2) + torch.from_numpy(mean_image)
original_val = x_val.cpu().squeeze().transpose(0,1).transpose(1,2) + torch.from_numpy(mean_image)

# Create copies of the images to pass through each model
x = x.contiguous().view(1, *x.size())
x_2 = x[:]
x_val = x_val.contiguous().view(1, *x_val.size())
x_2_val = x_val[:]
if torch.cuda.is_available():
    x = x.cuda()
    x_val = x_val.cuda()
    x_2 = x_2.cuda()
    x_2_val = x_2_val.cuda()
y = y.squeeze().cpu().numpy()
y_val = y_val.squeeze().cpu().numpy()



##### First model #####

# x_sal = models[m_index_1](Variable(x))
if torch.cuda.is_available():
    x_sal = x_sal.cpu()
x_sal_nmp = x_sal.squeeze().data.numpy()

# If model is temporal, reset the stored state
if models[m_index_1].temporal:
    models[m_index_1].clear_temporal_state()
    models[m_index_1].detach_temporal_state()

# x_val_sal = models[m_index_1](Variable(x_val))
if torch.cuda.is_available():
    x_val_sal = x_val_sal.cpu()
x_val_sal_nmp = x_val_sal.squeeze().data.numpy()

if m_index_2 != m_index_1:
    ##### Second model #####
    x_2_sal = models[m_index_2](Variable(x_2))
    if torch.cuda.is_available():
        x_2_sal = x_2_sal.cpu()
    x_2_sal_nmp = x_2_sal.squeeze().data.numpy()

    if models[m_index_2].temporal:
        models[m_index_2].clear_temporal_state()
        models[m_index_2].detach_temporal_state()

    x_2_val_sal = models[m_index_2](Variable(x_2_val))
    if torch.cuda.is_available():
        x_2_val_sal = x_2_val_sal.cpu()
    x_2_val_sal_nmp = x_2_val_sal.squeeze().data.numpy()

# Normalise the output to [0, 1] range
x_val_sal_nmp -= x_val_sal_nmp.min()
x_val_sal_nmp /= x_val_sal_nmp.max()
x_sal_nmp -= x_sal_nmp.min()
x_sal_nmp /= x_sal_nmp.max()

# Plot the output
plt.figure(figsize=(24,16))

# ##### Testing set image #####
plt.subplot(3,4,1); plt.title('Original')
plt.imshow(original, vmin=0, vmax=1)
plt.subplot(3,4,2); plt.title('Ground Truth')
plt.imshow(y, cmap='gray', vmin=0, vmax=1)

# # First model
plt.subplot(3,4,3)
plt.imshow(x_sal_nmp, cmap='gray', vmin=0, vmax=1); plt.title(model_names[m_index_1])
# Second model
if m_index_2 != m_index_1:
    plt.subplot(3,4,4)
    plt.imshow(x_2_sal_nmp, cmap='gray', vmin=0, vmax=1); plt.title(model_names[m_index_2])

##### Validation set image #####
plt.subplot(3,4,5); plt.title('Original Val')
plt.imshow(original_val, vmin=0, vmax=1)
plt.subplot(3,4,6); plt.title('Ground Truth Val')
plt.imshow(y_val, cmap='gray', vmin=0, vmax=1)


# First model
plt.subplot(3,4,7)
plt.imshow(x_val_sal_nmp, cmap='gray', vmin=0, vmax=1); plt.title(model_names[m_index_1])
# Second model
if m_index_2 != m_index_1:
    plt.subplot(3,4,8)
    plt.imshow(x_2_val_sal_nmp, cmap='gray', vmin=0, vmax=1); plt.title(model_names[m_index_2])
    
# Get the test_loss_func score of the val images
print(model_names[m_index_1])
x_val_sal_torch = torch.from_numpy(x_val_sal_nmp).unsqueeze(0)
x_sal_torch = torch.from_numpy(x_sal_nmp).unsqueeze(0)
y_val_torch = torch.from_numpy(y_val).unsqueeze(0)
y_torch = torch.from_numpy(y).unsqueeze(0)


print("[{}] NSS_loss Test:".format(m_index_1), loss_functions.NSS_loss(x_sal_torch, y_torch).item())
print("[{}] CE_MAE   Test:".format(m_index_1), loss_functions.CE_MAE_loss(x_sal_torch, y_torch).item())
print("[{}] CE       Test:".format(m_index_1), loss_functions.CE_loss(x_sal_torch, y_torch).item())
print("[{}] MAE      Test:".format(m_index_1), loss_functions.MAE_loss(x_sal_torch, y_torch).item())
print("[{}] DoM      Test:".format(m_index_1), loss_functions.DoM(x_sal_torch, y_torch).item())

print()

# print("[{}] Validation:".format(m_index_1), test_loss_func(x_val_sal_torch, y_val_torch).item())
print("[{}] NSS_loss Validation:".format(m_index_1), loss_functions.NSS_loss(x_val_sal_torch, y_val_torch).item())
print("[{}] CE_MAE   Validation:".format(m_index_1), loss_functions.CE_MAE_loss(x_val_sal_torch, y_val_torch).item())
print("[{}] CE       Validation:".format(m_index_1), loss_functions.CE_loss(x_val_sal_torch, y_val_torch).item())
print("[{}] MAE      Validation:".format(m_index_1), loss_functions.MAE_loss(x_val_sal_torch, y_val_torch).item())
print("[{}] DoM      Validation:".format(m_index_1), loss_functions.DoM(x_val_sal_torch, y_val_torch).item())
#print("[{}] NSS_alt Validation:".format(m_index_1), loss_functions.NSS_alt(x_val_sal_torch, y_val_torch).item())

if m_index_2 != m_index_1:
    print(model_names[m_index_2])
    x_2_val_sal_torch = torch.from_numpy(x_2_val_sal_nmp).unsqueeze(0)
    print("[{}] CE_MAE Validation:".format(m_index_2), loss_functions.CE_MAE_loss(x_2_val_sal_torch, y_val_torch).item())
    print("[{}] CE Validation:".format(m_index_2), loss_functions.CE_loss(x_2_val_sal_torch, y_val_torch).item())
    print("[{}] MAE Validation:".format(m_index_2), loss_functions.MAE_loss(x_2_val_sal_torch, y_val_torch).item())
    print("[{}] NSS_loss Validation:".format(m_index_2), loss_functions.NSS_loss(x_2_val_sal_torch, y_val_torch).item())
    print("[{}] DoM Validation:".format(m_index_2), loss_functions.DoM(x_2_val_sal_torch, y_val_torch).item())
    print("[{}] NSS_alt Validation:".format(m_index_2), loss_functions.NSS_alt(x_2_val_sal_torch, y_val_torch).item())
    
##### Heat map of GT and prediction on Validation set image (first model) #####
heat_map = original_val.cpu().data.numpy()
gt = y_val
pred = x_val_sal_nmp[:]
# Normalize pred so it's in range 0->1
pred /= np.max(pred)

alpha = 0.5
heat_map = np.array([[[alpha*r, alpha*g, alpha*b] for [r, g, b] in row] for row in heat_map])
gt       = np.array([[[(1-alpha)*x, 0, 0] for x in row] for row in gt])
pred     = np.array([[[0, 0, (1-alpha)*x] for x in row] for row in pred])

heat_map = heat_map + gt + pred

plt.subplot(3, 4, 11)
plt.imshow(heat_map, vmin=0, vmax=1); plt.title('Heat map 1st model: gt=red, pred=blue')

if m_index_1 != m_index_2:
    ##### Heat map of GT and prediction on Validation set image (second model) #####
    heat_map = original_val.cpu().data.numpy()
    gt = y_val
    pred_2 = x_2_val_sal_nmp
    # Normalize pred so it's in range 0->1
    pred_2 /= np.max(pred_2)

    alpha = 0.5
    heat_map_2 = np.array([[[alpha*r, alpha*g, alpha*b] for [r, g, b] in row] for row in heat_map])
    gt       = np.array([[[(1-alpha)*x, 0, 0] for x in row] for row in gt])
    pred_2     = np.array([[[0, 0, (1-alpha)*x] for x in row] for row in pred_2])

    heat_map_2 = heat_map_2 + gt + pred_2

    plt.subplot(3, 4, 12)
    plt.imshow(heat_map_2, vmin=0, vmax=1); plt.title('Heat map 2nd model: gt=red, pred=blue')

# plt.savefig('ResExamples/example_test_'+str(test_image_id)+'_val_'+str(val_image_id)+'.png')
plt.show()

if models[m_index_1].temporal:
    models[m_index_1].clear_temporal_state()
    models[m_index_1].detach_temporal_state()
if m_index_1 != m_index_2 and models[m_index_2].temporal:
    models[m_index_2].clear_temporal_state()
    models[m_index_2].detach_temporal_state()


In [None]:
cv2.imwrite("/tmp/pbqk24_tmp/Boat_Example (wakeboard8, frame48)/CoSADUV_1bp_1k_video.jpg", x_sal_nmp*255)


cv2.imwrite("/tmp/pbqk24_tmp/Person_Example (person7, frame95)/CoSADUV_1bp_1k_video.jpg", x_val_sal_nmp*255)


In [None]:
cv2.imwrite("/tmp/pbqk24_tmp/image0.jpg", x_val_sal_nmp*255)
# cv2.imwrite("/tmp/pbqk24_tmp/image0.jpg", cv2.cvtColor(original.data.numpy()*255, cv2.COLOR_BGR2RGB))

In [None]:
# Define a function for testing a model
# Output is resized to the size of the data_source
def test_model(model, data_loader, loss_fns=[loss_functions.MAE_loss]):
    loss_sums = []
    loss_counts = []
    for i, loss_fn in enumerate(loss_fns):
        if loss_fn != loss_functions.NSS_alt:
            loss_sums.append(0)
            loss_counts.append(0)
        else:
            loss_sums.append([0, 0])
            loss_counts.append([0, 0])
    for video_loader in tqdm(data_loader):
        # Reset temporal state if model is temporal
        if model.temporal:
            model.clear_temporal_state()
        for data in tqdm(video_loader):
            inputs, labels = data
            if torch.cuda.is_available():
                inputs = inputs.cuda()
                labels = labels.cuda()

            # Produce the output
            outputs = model(inputs).squeeze(1)
            # if model is temporal detach its state
            if model.temporal:
                model.detach_temporal_state()
            # Move the output to the CPU so we can process it using numpy
            outputs = outputs.cpu().data.numpy()

            # If outputs contains a single image, insert
            # a singleton batchsize dimension at index 0
            if len(outputs.shape) == 2:
                outputs = np.expand_dims(outputs, 0)

            # Resize the images to input size
            outputs = np.array(
                [
                    cv2.resize(output, (labels.shape[2], labels.shape[1]))
                    for output in outputs
                ]
            )

            outputs = torch.from_numpy(outputs)

            if torch.cuda.is_available():
                outputs = outputs.cuda()
                labels = labels.cuda()
            # Apply each loss function, add results to corresponding entry in loss_sums and loss_counts
            for i, loss_fn in enumerate(loss_fns):
                # If loss fn is NSS_alt, manually add std_dev() if the target is all-0
                if loss_fn == loss_functions.NSS_alt:
                    for i in range(len(labels)):
                        if labels[i].sum() == 0:
                            loss_sums[i][1] += outputs[i].std().item()
                            loss_counts[i][1] += 1
                        else:
                            loss_sums[i][0] += loss_fn(outputs[i], labels[i]).item()
                            loss_counts[i][0] += 1
                else:
                    loss_sums[i] += loss_fn(outputs, labels).item()
                    loss_counts[i] += 1

    return loss_sums, loss_counts

In [None]:
loss_fns = [loss_functions.NSS_alt, loss_functions.CE_MAE_loss, loss_functions.CE_loss, loss_functions.MAE_loss, loss_functions.DoM]


In [None]:
# Obtaining loss values on the test set for different models:
for i, model_name in enumerate(tqdm(model_names)):
    tqdm.write("model name: {}".format(model_name))
    if "best_model" in model_name:
        model = load_model_from_checkpoint(model_name)
    else:
        model = load_model(model_name)

    test_losses, test_counts = test_model(model, val_loader, loss_fns=loss_fns)

    # Print out the result
    
    tqdm.write("[{}] Model: ".format(i, model_names[i]))

    for i, func in enumerate(loss_fns):
        if func == loss_functions.NSS_alt:
            tqdm.write(
                ("{:25} : {:6f}").format(
                    'NSS_alt (+ve imgs)', test_losses[i][0] / test_counts[i][0]
                )
            )
            tqdm.write(
                ("{:25} : {:6f}").format(
                    'NSS_alt (-ve imgs)', test_losses[i][1] / test_counts[i][1]
                )
            )
        else:
            tqdm.write(
                ("{:25} : {:6f}").format(
                    func.__name__, test_losses[i] / test_counts[i]
                )
            )
    del model
