In [None]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torchvision
from torch.autograd import Variable
import torch.nn as nn
import pickle
from random import randint
import sys
from tqdm import tqdm_notebook as tqdm
import cv2

In [None]:
from util.data_utils import get_SALICON_datasets
from util.data_utils import get_direct_datasets

# train_data, val_data, test_data, mean_image = get_SALICON_datasets('Dataset/Transformed') # 128x96
dataset_root_dir = 'Dataset/Raw_Dataset'
mean_image_name = 'mean_image.npy'
img_size = (480, 640) # height, width - original: 480, 640, reimplementation: 96, 128
train_data, val_data, test_data, mean_image = get_direct_datasets(dataset_root_dir, mean_image_name, img_size)

from util.loss_functions import NSS_loss, NSS_loss_2, PCCLoss_torch

models = []
model_names = []

In [None]:
train = False
from models.DSCLRCN_PyTorch import DSCLRCN
from util.solver import Solver

batchsize = 20 # Recommended: 20. Determines how many images are processed before backpropagation is done
minibatchsize = 2 # Recommended: 2 for 480x640. Determines how many images are processed in parallel on the GPU at once
epoch_number = 10 # Recommended: 10 (epoch_number =~ batchsize/2)
optim_str = 'SGD' # 'SGD' or 'Adam' Recommended: Adam
optim_args = {'lr': 1e-2} # 1e-2 if SGD, 1e-4 if Adam
loss_func = NSS_loss # NSS_loss or torch.nn.KLDivLoss() Recommended: NSS_loss

if batchsize % minibatchsize:
    print("Error, batchsize % minibatchsize must equal 0 ({} % {} != 0).".format(batchsize, minibatchsize))
num_minibatches = batchsize/minibatchsize
optim_args['lr'] /= num_minibatches # Scale the lr down as smaller minibatches are used

    optim = torch.optim.SGD if optim_str == 'SGD' else torch.optim.Adam

    train_loader = torch.utils.data.DataLoader(train_data, batch_size=minibatchsize, shuffle=True, num_workers=8, pin_memory=True)

    val_loader = torch.utils.data.DataLoader(val_data, batch_size=minibatchsize, shuffle=True, num_workers=8, pin_memory=True)

    if train:
        # Attempt to train a model using the original image sizes
        model = DSCLRCN(input_dim=img_size, local_feats_net='Seg')
        # Set solver as torch.optim.SGD and lr as 1e-2, or torch.optim.Adam and lr 1e-4
        solver = Solver(optim=optim, optim_args=optim_args, loss_func=loss_func, location='jupyter')
        solver.train(model, train_loader, val_loader, num_epochs=epoch_number, num_minibatches=num_minibatches, log_nth=50, 
            filename_args={'batchsize' : batchsize, 'epoch_number' : epoch_number, 'optim' : optim_str}
        )

        models.append(model)
        model_names.append('model_{}_lr2_batch{}_epoch{}'.format(optim_str, batchsize, epoch_number))

In [None]:
#Saving the model:
if train:
    model.save('trained_models/model_{}_lr2_batch{}_epoch{}'.format(optim_str, batchsize, epoch_number))
    with open('trained_models/solver_{}_lr2_batch{}_epoch{}.pkl'.format(optim_str, batchsize, epoch_number), 'wb') as outf:
        pickle.dump(solver, outf, pickle.HIGHEST_PROTOCOL)

In [None]:
# Plotting training and validation loss over iterations:
if train:
    plt.subplot(2,1,1)
    plt.plot(solver.train_loss_history, 'o')
    plt.title('Train Loss')
    plt.subplot(2,1,2)
    plt.plot(solver.val_loss_history, '-o')
    plt.title('Val Loss')
    plt.show()

In [None]:
# Loading a model from the saved state that produced the lowest validation loss during training:

from models.DSCLRCN_PyTorch import DSCLRCN # Requires the model class be loaded

# Assumes the model uses models.DSCLRCN_PyTorch2 architecture. If not, this method will fail
def load_model_from_checkpoint(model_name):
    filename = "trained_models/" + model_name + ".pth"
    if torch.cuda.is_available():
        checkpoint = torch.load(filename)
    else:
        # Load GPU model on CPU
        checkpoint = torch.load(filename, map_location='cpu')
    start_epoch = checkpoint['epoch']
    best_accuracy = checkpoint['best_accuracy']
    
    model = DSCLRCN(input_dim=img_size, local_feats_net='Seg')
    model.load_state_dict(checkpoint['state_dict'], strict=False) # Ignore extra parameters ('.num_batches_tracked' that are added on NCC due to different pytorch version)

            
    print("=> loaded model checkpoint '{}' (trained for {} epochs)".format(model_name, checkpoint['epoch']))

    if torch.cuda.is_available():
        model = model.cuda()
    return model

def load_model(model_name):
    model = torch.load("trained_models/" + model_name, map_location='cpu')
    print("=> loaded model_1 '{}'".format(model_name))
    if torch.cuda.is_available():
        model = model.cuda()
    return model

In [None]:
#Loading some pretrained models to test them on the images:
if train:
    model_names = [model_names[0]]
    models = [models[0]]
else:
    model_names = []
    models = []

# This model's best checkpoint
model_names.append('best_model_SGD_lr2_batch20_epoch10')

# model_0: The best model thusfar
# model_names.append('model_SGD_lr2_batch20_epoch10')

# model_1: Best contender for model_0
# model_names.append('best_model_SGD_lr2_batch20_epoch10')

# model_2 and on: Others



# other models

max_name_len = max([len(name) for name in model_names])
# Load the models specified above
iterable = model_names[1:] if train else model_names

for i, name in enumerate(iterable):
    if "best_model" in name:
        models.append(load_model_from_checkpoint(name))
    else:
        models.append(load_model(name))

print()
print("Loaded all specified models")

In [None]:
# Testing the different models on a random image from the val set:

# Choose which two models to test
m_index_1 = 0
m_index_2 = 0

# Pick a random test image and validation image
test_image_id = randint(0, len(test_data)-1)
val_image_id  = randint(0, len(val_data)-1)
print("Test image: {}, Val image: {}".format(test_image_id, val_image_id))

# Load the images
x,y = test_data.__getitem__(test_image_id)
x_val, y_val = val_data.__getitem__(val_image_id)

# Get the original (before pre-processing) images to be displayed
original = x.transpose(0,1).transpose(1,2) + torch.from_numpy(mean_image)
original_val = x_val.transpose(0,1).transpose(1,2) + torch.from_numpy(mean_image)

# Create copies of the images to pass through each model
x = x.contiguous().view(1, *x.size())
x_2 = x[:]
x_val = x_val.contiguous().view(1, *x_val.size())
x_2_val = x_val[:]
if torch.cuda.is_available():
    x = x.cuda()
    x_val = x_val.cuda()
    x_2 = x_2.cuda()
    x_2_val = x_2_val.cuda()
y = y.numpy()
y_val = y_val.numpy()



##### First model #####

x_sal = models[m_index_1](Variable(x))
if torch.cuda.is_available():
    x_sal = x_sal.cpu()
x_sal_nmp = x_sal.squeeze().data.numpy()
# Blur the saliency map

# Sigma value used by all models, as all inputs are of same shape
sigma = 0.035*min(x_sal_nmp.shape) # Define a sigma to be used for Gaussian blurring
kernel_size = int(4*sigma)
kernel_size += 1 if kernel_size % 2 == 0 else 0 # Make sure the kernel size is odd

x_sal_nmp = cv2.GaussianBlur(x_sal_nmp, (kernel_size, kernel_size), sigma)

x_val_sal = models[m_index_1](Variable(x_val))
if torch.cuda.is_available():
    x_val_sal = x_val_sal.cpu()
x_val_sal_nmp = x_val_sal.squeeze().data.numpy()
# Blur the saliency map
# x_val_sal_nmp = cv2.GaussianBlur(x_val_sal_nmp, (kernel_size, kernel_size), sigma)

if m_index_2 != m_index_1:
    ##### Second model #####
    x_2_sal = models[m_index_2](Variable(x_2))
    if torch.cuda.is_available():
        x_2_sal = x_2_sal.cpu()
    x_2_sal_nmp = x_2_sal.squeeze().data.numpy()
    # Blur the saliency map
    x_2_sal_nmp = cv2.GaussianBlur(x_2_sal_nmp, (kernel_size, kernel_size), sigma)


    x_2_val_sal = models[m_index_2](Variable(x_2_val))
    if torch.cuda.is_available():
        x_2_val_sal = x_2_val_sal.cpu()
    x_2_val_sal_nmp = x_2_val_sal.squeeze().data.numpy()
    # Blur the saliency map
    x_2_val_sal_nmp = cv2.GaussianBlur(x_2_val_sal_nmp, (kernel_size, kernel_size), sigma)


# Plot the output
plt.figure(figsize=(24,16))

##### Testing set image #####
plt.subplot(3,4,1); plt.title('Original')
plt.imshow(original, vmin=0, vmax=1)
plt.subplot(3,4,2); plt.title('Ground Truth')
plt.imshow(y, cmap='gray', vmin=0, vmax=1)

# First model
plt.subplot(3,4,3)
plt.imshow(x_sal_nmp, cmap='gray'); plt.title(model_names[m_index_1])
# Second model
if m_index_2 != m_index_1:
    plt.subplot(3,4,4)
    plt.imshow(x_2_sal_nmp, cmap='gray'); plt.title(model_names[m_index_2])

##### Validation set image #####
plt.subplot(3,4,5); plt.title('Original Val')
plt.imshow(original_val, vmin=0, vmax=1)
plt.subplot(3,4,6); plt.title('Ground Truth Val')
plt.imshow(y_val, cmap='gray', vmin=0, vmax=1)

# First model
plt.subplot(3,4,7)
plt.imshow(x_val_sal_nmp, cmap='gray'); plt.title(model_names[m_index_1])
# Second model
if m_index_2 != m_index_1:
    plt.subplot(3,4,8)
    plt.imshow(x_2_val_sal_nmp, cmap='gray'); plt.title(model_names[m_index_2])
    
# Get the NSS score of the val images
print("[{}] Validation:".format(m_index_1), NSS_loss(x_val_sal_nmp, y_val))
if m_index_2 != m_index_1:
    print(model_names[m_index_2])
    print("[{}] Validation:".format(m_index_2), NSS_loss(x_2_val_sal_nmp, y_val))
    
##### Heat map of GT and prediction on Validation set image (first model) #####
heat_map = original_val.cpu().data.numpy()
gt = y_val
pred = x_val_sal_nmp
# Normalize pred so it's in range 0->1
pred /= np.max(pred)

alpha = 0.2
heat_map = np.array([[[alpha*r, alpha*g, alpha*b] for [r, g, b] in row] for row in heat_map])
gt       = np.array([[[(1-alpha)*x, 0, 0]   for x         in row] for row in gt])
pred     = np.array([[[0, 0, (1-alpha)*x]   for x         in row] for row in pred])

heat_map = heat_map + gt + pred

plt.subplot(3, 4, 10)
plt.imshow(heat_map, vmin=0, vmax=1); plt.title('Heat map: gt=red, pred=blue')

# plt.savefig('ResExamples/example_test_'+str(test_image_id)+'_val_'+str(val_image_id)+'.png')
plt.show()


In [None]:
# Define a function for testing a model
# Output is resized to the size of the data_source
def test_model(model, data_source, loss_fn=NSS_loss):
    test_loader = torch.utils.data.DataLoader(data_source, batch_size=minibatchsize, shuffle=True, num_workers=8)
    loss_sum = 0
    
    for data in tqdm(test_loader):
        inputs, labels = data
        if torch.cuda.is_available():
            inputs = inputs.cuda()
            labels = labels.cuda()

        # Produce the output
        outputs = model(inputs).squeeze(1)
        # Move the output to the CPU so we can process it using numpy
        outputs = outputs.cpu().data.numpy()
        
        # If outputs contains a single image, make it shape (1, x, y) instead of (x, y)
        if len(outputs.shape) == 2:
            outputs = np.expand_dims(outputs, 0)

        # Resize the images to input size
        outputs = np.array([cv2.resize(output, (labels.shape[2], labels.shape[1])) for output in outputs])

        # Apply a Gaussian filter to blur the saliency maps
        sigma = 0.035*min(labels.shape[1], labels.shape[2])
        kernel_size = int(4*sigma)
        # make sure the kernel size is odd
        kernel_size += 1 if kernel_size % 2 == 0 else 0
        
        outputs = np.array([cv2.GaussianBlur(output, (kernel_size, kernel_size), sigma) for output in outputs])
        
        outputs = torch.from_numpy(outputs)
        
        if torch.cuda.is_available():
            outputs = outputs.cuda()
            labels  = labels.cuda()
        
        loss_sum += loss_fn(outputs, labels).item()
        
    return loss_sum

In [None]:
# Obtaining NSS Loss values on the test set for different models:
loss_fn = NSS_loss_2

# Create new validation dataset with different size
# img_size_2 = (480, 640) # height, width - original: 480, 640, reimplementation: 96, 128
# _, val_data_2, _ = get_direct_datasets(dataset_root_dir, mean_image_name, img_size_2)

# test on validation data as we don't have ground truths for the test data (this was also done in original DSCLRCN paper)

test_losses = []
for model in tqdm(models):
    test_losses.append(test_model(model, val_data, loss_fn=loss_fn))

# Print out the result
print('Normalized Scanpath Saliency on Validation set:')

for i, loss in enumerate(test_losses):
    print(('[{}] {:' + str(max_name_len) + '} : {:6f}').format(i, model_names[i], -1*np.sum(loss)/len(val_data)))
