In [None]:
%load_ext autoreload
%autoreload 2

# Saliency Prediction

The task is to create a model that effectively predicts fixation points on a visual scene.
(This should not be confused with the saliency detection task, which is an extraction of areas on a visual scene, that are relevant to a given task.)

We use the CAT2000 dataset, which, as the authors say, is composed to create better models for saliency prediction. It diversified the then datasets with better resolution data, various categories, and of course more data. This collection consists of 4000 photos, 2000/2000 training/test, which are divided into 20 equinumerous categories such as Action, Social, Line drawings, Objects, etc. Along with the photos are attached, maps of fixation, which represent what parts of the presented images these people watched. A description of the aforementioned dataset can be found under  https://arxiv.org/pdf/1505.03581.pdf

![alt text](cat2000.png "Cat2000 Sample")

As a model, we decided to implement the model presented in the paper https://arxiv.org/pdf/1609.01064.pdf
which is a model of neural networks, with fully convolutional layers, using the features of different layers, in order to efficiently predict saliency maps. This model also uses the concept of prior, which is a way to define regularities in visual perception. In their model, it is fully learned and integrated, at the last stage, with the features of the image extracted.

![alt text](mlnet.png "MLNet Scheme")

In our experiment, the input is (480, 640) and the output is (60, 80), so we could set batch_size = 16 using a GeForce GTX 1060 Ti 6GB graphics card. We set the training to 100 epoch. We decided to draw one photo from each category, excluding it from the training and validation process, to check after each epoch what the saliency map looks like for this photo. The course of training and the final comparison can be viewed below.

In [None]:
from data import Cat2000Loader
from model import MLNet
from loss import ModMSELoss
from trainer import Trainer

In [None]:
# Create loaders
import torch
from torchvision.transforms import Compose, Resize, ToTensor

# ratio = 1080/1920
# width = 320
# image_size = (int(width*ratio) , width)
image_size = (480, 640)
fix_map_size = (image_size[0] // 8, image_size[1] // 8)
prior_size = (fix_map_size[0] // 10, fix_map_size[1] // 10)

transform1 = Compose([Resize(image_size), ToTensor()])
transform2 = Compose([Resize(fix_map_size), ToTensor()])

loaders = Cat2000Loader('cat2000', batch_size=16, transform=transform1, target_transform=transform2)

In [None]:
# Prepare model
model = MLNet(prior_size)

criterion = ModMSELoss(*fix_map_size)

optimizer = torch.optim.SGD(model.parameters(), lr=1e-3,weight_decay=0.0005,momentum=0.9,nesterov=True)

In [None]:
path = 'models/2021-09-16 05:52:44.712130_.basic_model'
model.load_state_dict(torch.load(path))

In [None]:
# Train
trainer = Trainer(model, criterion, optimizer, loaders)

trainer.run_trainer(100)

The graph shows that the measure on the validation (Loss) set has a downward trend, so further training (with better computing resources) could give a better final result. The loss function is transferred directly from the one of the author of the paper who posted their implementations, using the Keras library, at https://github.com/marcellacornia/mlnet.
The plot starts around the fifth epoch because the model initially returned a zero-filled tensor for some examples, which was a problem when computing said loss function.

In [None]:
import datetime

full_path = 'models/' + str(datetime.datetime.now()) + '_' + '.basic_model'
torch.save(model.state_dict(), full_path)

### Display

We noticed that the model initially returns the specified image features, similar to the saliency detection task. Over time, the result starts to blur as it goes towards the fixation points, and sometimes focuses entirely on listing important features in the image.

In [None]:
import os
import glob
import ipyplot
from PIL import Image

In [None]:
root = lambda x: f'cat2000/selected/{x}/'
for category in os.listdir(root('images')):
    original_image = Image.open(glob.glob(os.path.join(root('images'), category, '*.jpg'))[0])
    fix_map = Image.open(glob.glob(os.path.join(root('maps'), category, '*.jpg'))[0])
    print(f'CATEGORY: {category}')
    ipyplot.plot_images([original_image, fix_map], ['original image', 'map fixation'], img_width=500)
    pred = [Image.open(path) for path in glob.glob(os.path.join(root('maps'), category, 'outputs', '*.jpeg'))]
    ipyplot.plot_images([img for i, img in enumerate(pred) if i%10 == 9], 10*np.arange(10)+9, img_width=300)

In [None]:
from utils import display_evolution

In [None]:
display_evolution()