In [1]:
from collections import OrderedDict
import os
from pathlib import Path
import shutil

from imageio.v3 import imread, imwrite
from PIL import Image
import pysaliency
from pysaliency.baseline_utils import BaselineModel, CrossvalidatedBaselineModel
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo

from tqdm import tqdm


from deepgaze_pytorch.layers import (
    Conv2dMultiInput,
    LayerNorm,
    LayerNormMultiInput,
    Bias,
    FlexibleScanpathHistoryEncoding
)

from deepgaze_pytorch.modules import DeepGazeIII, FeatureExtractor
from deepgaze_pytorch.features.densenet import RGBDenseNet201
from deepgaze_pytorch.data import ImageDataset, ImageDatasetSampler, FixationDataset, FixationMaskTransform
from deepgaze_pytorch.training import _train


In [2]:
def build_saliency_network(input_channels):
    return nn.Sequential(OrderedDict([
        ('layernorm0', LayerNorm(input_channels)),
        ('conv0', nn.Conv2d(input_channels, 8, (1, 1), bias=False)),
        ('bias0', Bias(8)),
        ('softplus0', nn.Softplus()),

        ('layernorm1', LayerNorm(8)),
        ('conv1', nn.Conv2d(8, 16, (1, 1), bias=False)),
        ('bias1', Bias(16)),
        ('softplus1', nn.Softplus()),

        ('layernorm2', LayerNorm(16)),
        ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
        ('bias2', Bias(1)),
        ('softplus2', nn.Softplus()),
    ]))


def build_scanpath_network():
    return nn.Sequential(OrderedDict([
        ('encoding0', FlexibleScanpathHistoryEncoding(in_fixations=4, channels_per_fixation=3, out_channels=128, kernel_size=[1, 1], bias=True)),
        ('softplus0', nn.Softplus()),

        ('layernorm1', LayerNorm(128)),
        ('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)),
        ('bias1', Bias(16)),
        ('softplus1', nn.Softplus()),
    ]))


def build_fixation_selection_network(scanpath_features=16):
    return nn.Sequential(OrderedDict([
        ('layernorm0', LayerNormMultiInput([1, scanpath_features])),
        ('conv0', Conv2dMultiInput([1, scanpath_features], 128, (1, 1), bias=False)),
        ('bias0', Bias(128)),
        ('softplus0', nn.Softplus()),

        ('layernorm1', LayerNorm(128)),
        ('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)),
        ('bias1', Bias(16)),
        ('softplus1', nn.Softplus()),

        ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
    ]))

In [3]:
def prepare_spatial_dataset(stimuli, fixations, centerbias, batch_size, path=None):
    if path is not None:
        path.mkdir(parents=True, exist_ok=True)
        lmdb_path = str(path)
    else:
        lmdb_path = None

    dataset = ImageDataset(
        stimuli=stimuli,
        fixations=fixations,
        centerbias_model=centerbias,
        transform=FixationMaskTransform(sparse=False),
        average='image',
        lmdb_path=lmdb_path,
    )

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=ImageDatasetSampler(dataset, batch_size=batch_size),
        pin_memory=False,
        num_workers=0,
    )

    return loader

In [4]:
def prepare_scanpath_dataset(stimuli, fixations, centerbias, batch_size, path=None):
    if path is not None:
        path.mkdir(parents=True, exist_ok=True)
        lmdb_path = str(path)
    else:
        lmdb_path = None

    dataset = FixationDataset(
        stimuli=stimuli,
        fixations=fixations,
        centerbias_model=centerbias,
        included_fixations=[-1, -2, -3, -4],
        allow_missing_fixations=True,
        transform=FixationMaskTransform(sparse=False),
        average='image',
        lmdb_path=lmdb_path,
    )

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=ImageDatasetSampler(dataset, batch_size=batch_size),
        pin_memory=False,
        num_workers=0,
    )

    return loader

In [5]:
dataset_directory = Path('pysaliency_datasets')
train_directory = Path('train_deepgaze3')

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    print('Using GPU')

Using GPU


# Pretraining on SALICON

In [7]:
import os
import pickle
# Assume dataset_directory is defined appropriately, e.g.:
# dataset_directory = '/path/to/your/datasets'
# Ensure the directory exists or create it if it doesn't
# os.makedirs(dataset_directory, exist_ok=True) # Might be needed if dataset_directory is just for caching

# --- Load Data ---
print(f"Loading SALICON data from {dataset_directory}...")
SALICON_train_stimuli, SALICON_train_fixations = pysaliency.get_SALICON_train(location=dataset_directory)
SALICON_val_stimuli, SALICON_val_fixations = pysaliency.get_SALICON_val(location=dataset_directory)
print("SALICON data loaded.")

# --- Define Model ---
# parameters taken from an early fit for MIT1003. Since SALICON has many more fixations, the bandwidth won't be too small
print("Initializing BaselineModel...")
SALICON_centerbias = BaselineModel(stimuli=SALICON_train_stimuli, fixations=SALICON_train_fixations, bandwidth=0.0217, eps=2e-13, caching=False)
print("BaselineModel initialized.")

# --- Define cache file paths ---
train_ll_cache_file = os.path.join(dataset_directory, 'salicon_baseline_train_ll.pkl')
val_ll_cache_file = os.path.join(dataset_directory, 'salicon_baseline_val_ll.pkl')

# --- Compute or Load Train Baseline Log Likelihood ---
try:
    # Attempt to load from cache
    with open(train_ll_cache_file, 'rb') as f:
        train_baseline_log_likelihood = pickle.load(f)
    print(f"Loaded cached train baseline log likelihood from: {train_ll_cache_file}")
except (FileNotFoundError, EOFError, pickle.UnpicklingError) as e:
    # Compute if cache doesn't exist or is invalid
    print(f"Cache not found or invalid ({e}). Computing train baseline log likelihood...")
    train_baseline_log_likelihood = SALICON_centerbias.information_gain(
        SALICON_train_stimuli,
        SALICON_train_fixations,
        verbose=True,
        average='image'
    )
    print(f"Computation finished. Train LL = {train_baseline_log_likelihood}")
    # Save the result
    try:
        os.makedirs(os.path.dirname(train_ll_cache_file), exist_ok=True) # Ensure directory exists
        with open(train_ll_cache_file, 'wb') as f:
            pickle.dump(train_baseline_log_likelihood, f)
        print(f"Saved train baseline log likelihood to: {train_ll_cache_file}")
    except Exception as save_e:
        print(f"Error saving cache file {train_ll_cache_file}: {save_e}")


# --- Compute or Load Validation Baseline Log Likelihood ---
try:
    # Attempt to load from cache
    with open(val_ll_cache_file, 'rb') as f:
        val_baseline_log_likelihood = pickle.load(f)
    print(f"Loaded cached validation baseline log likelihood from: {val_ll_cache_file}")
except (FileNotFoundError, EOFError, pickle.UnpicklingError) as e:
    # Compute if cache doesn't exist or is invalid
    print(f"Cache not found or invalid ({e}). Computing validation baseline log likelihood...")
    val_baseline_log_likelihood = SALICON_centerbias.information_gain(
        SALICON_val_stimuli,
        SALICON_val_fixations,
        verbose=True,
        average='image'
    )
    print(f"Computation finished. Validation LL = {val_baseline_log_likelihood}")
    # Save the result
    try:
        os.makedirs(os.path.dirname(val_ll_cache_file), exist_ok=True) # Ensure directory exists
        with open(val_ll_cache_file, 'wb') as f:
            pickle.dump(val_baseline_log_likelihood, f)
        print(f"Saved validation baseline log likelihood to: {val_ll_cache_file}")
    except Exception as save_e:
        print(f"Error saving cache file {val_ll_cache_file}: {save_e}")


# --- Final Output ---
print("-" * 30)
print(f"Final Train Baseline Log Likelihood: {train_baseline_log_likelihood}")
print(f"Final Validation Baseline Log Likelihood: {val_baseline_log_likelihood}")
print("-" * 30)

Loading SALICON data from pysaliency_datasets...
SALICON data loaded.
Initializing BaselineModel...
BaselineModel initialized.
Loaded cached train baseline log likelihood from: pysaliency_datasets/salicon_baseline_train_ll.pkl
Loaded cached validation baseline log likelihood from: pysaliency_datasets/salicon_baseline_val_ll.pkl
------------------------------
Final Train Baseline Log Likelihood: 0.46408017115279737
Final Validation Baseline Log Likelihood: 0.4291592320821603
------------------------------


In [8]:
model = DeepGazeIII(
    features=FeatureExtractor(RGBDenseNet201(), [
            '1.features.denseblock4.denselayer32.norm1',
            '1.features.denseblock4.denselayer32.conv1',
            '1.features.denseblock4.denselayer31.conv2',
        ]),
    saliency_network=build_saliency_network(2048),
    scanpath_network=None,
    fixation_selection_network=build_fixation_selection_network(scanpath_features=0),
    downsample=1.5,
    readout_factor=4,
    saliency_map_factor=4,
    included_fixations=[],
)

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30, 45, 60, 75, 90, 105, 120])

Using cache found in /home/mmorello/.cache/torch/hub/pytorch_vision_v0.6.0


In [9]:
train_loader = prepare_spatial_dataset(SALICON_train_stimuli, SALICON_train_fixations, SALICON_centerbias, batch_size=32, path=train_directory / 'lmdb_cache' / 'SALICON_train')
validation_loader = prepare_spatial_dataset(SALICON_val_stimuli, SALICON_val_fixations, SALICON_centerbias, batch_size=32, path=train_directory / 'lmdb_cache' / 'SALICON_val')

Valid LMDB found at train_deepgaze3/lmdb_cache/SALICON_train with 10000 items. Skipping generation.
Populating fixations cache


100%|█████████▉| 68992354/68992355 [00:45<00:00, 1527942.02it/s]


Valid LMDB found at train_deepgaze3/lmdb_cache/SALICON_val with 5000 items. Skipping generation.
Populating fixations cache


100%|█████████▉| 38846997/38846998 [00:24<00:00, 1590809.29it/s]


In [None]:
_train(train_directory / 'pretraining',
    model,
    train_loader, train_baseline_log_likelihood,
    validation_loader, val_baseline_log_likelihood,
    optimizer, lr_scheduler,
    minimum_learning_rate=1e-7,
    device=device,
) 


Using device cuda
validation metrics ['IG', 'LL', 'AUC', 'NSS']
Found old checkpoint train_deepgaze3/pretraining/step-0013.pth
Restoring from train_deepgaze3/pretraining/step-0013.pth
Setting step to 13
Continuing from step 13
    epoch                   timestamp  learning_rate      loss  validation_IG  \
0       0  2025-04-11 02:34:02.322526          0.001       NaN       0.006263   
1       1  2025-04-11 02:40:35.007449          0.001 -0.890663       0.342832   
2       2  2025-04-11 02:47:00.624549          0.001 -0.925385       0.347606   
3       3  2025-04-11 02:53:25.953809          0.001 -0.935251       0.347895   
4       4  2025-04-11 02:59:53.322699          0.001 -0.940744       0.341106   
5       5  2025-04-11 03:06:21.230717          0.001 -0.945974       0.357141   
6       6  2025-04-11 03:12:48.991831          0.001 -0.949205       0.347955   
7       7  2025-04-11 03:19:15.283275          0.001 -0.951454       0.356708   
8       8  2025-04-11 03:25:43.769879       

  mask = torch.sparse.IntTensor(torch.tensor(inds), torch.tensor(values), shape)
-0.96172: 100%|██████████| 313/313 [02:29<00:00,  2.10it/s]
LL 0.79163: 100%|██████████| 157/157 [04:06<00:00,  1.57s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
13     13  2025-04-11 03:57:58.876248          0.001 -0.960608       0.365961   
14     14  2025-04-11 17:50:48.727639          0.001 -0.961724       0.362471   

    validation_LL  validation_AUC  validation_NSS  
13        0.79512        0.769521        1.506053  
14        0.79163        0.769253        1.488613  
validation_IG     13
validation_LL     13
validation_AUC    13
validation_NSS    10
dtype: int64


-0.96185: 100%|██████████| 313/313 [02:27<00:00,  2.12it/s]
LL 0.78885: 100%|██████████| 157/157 [04:05<00:00,  1.56s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
14     14  2025-04-11 17:50:48.727639          0.001 -0.961724       0.362471   
15     15  2025-04-11 17:57:22.490430          0.001 -0.961850       0.359691   

    validation_LL  validation_AUC  validation_NSS  
14        0.79163        0.769253        1.488613  
15        0.78885        0.769353        1.512500  
validation_IG     13
validation_LL     13
validation_AUC    13
validation_NSS    10
dtype: int64
removing train_deepgaze3/pretraining/step-0014.pth


-0.96732: 100%|██████████| 313/313 [02:24<00:00,  2.17it/s]
LL 0.79201: 100%|██████████| 157/157 [04:06<00:00,  1.57s/it]


    epoch                   timestamp  learning_rate     loss  validation_IG  \
15     15  2025-04-11 17:57:22.490430         0.0010 -0.96185       0.359691   
16     16  2025-04-11 18:03:54.095652         0.0001 -0.96732       0.362849   

    validation_LL  validation_AUC  validation_NSS  
15       0.788850        0.769353        1.512500  
16       0.792008        0.769541        1.521637  
validation_IG     13
validation_LL     13
validation_AUC    16
validation_NSS    10
dtype: int64
removing train_deepgaze3/pretraining/step-0015.pth


-0.96834: 100%|██████████| 313/313 [02:26<00:00,  2.14it/s]
LL 0.79198: 100%|██████████| 157/157 [04:07<00:00,  1.58s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
16     16  2025-04-11 18:03:54.095652         0.0001 -0.967320       0.362849   
17     17  2025-04-11 18:10:29.177884         0.0001 -0.968336       0.362816   

    validation_LL  validation_AUC  validation_NSS  
16       0.792008        0.769541        1.521637  
17       0.791976        0.769509        1.527234  
validation_IG     13
validation_LL     13
validation_AUC    16
validation_NSS    10
dtype: int64
removing train_deepgaze3/pretraining/step-0016.pth


-0.96859: 100%|██████████| 313/313 [02:25<00:00,  2.15it/s]
LL 0.79225: 100%|██████████| 157/157 [04:08<00:00,  1.58s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
17     17  2025-04-11 18:10:29.177884         0.0001 -0.968336       0.362816   
18     18  2025-04-11 18:17:03.771528         0.0001 -0.968585       0.363092   

    validation_LL  validation_AUC  validation_NSS  
17       0.791976        0.769509        1.527234  
18       0.792251        0.769639        1.531810  
validation_IG     13
validation_LL     13
validation_AUC    18
validation_NSS    10
dtype: int64
removing train_deepgaze3/pretraining/step-0017.pth


-0.96871: 100%|██████████| 313/313 [02:25<00:00,  2.15it/s]
LL 0.79030: 100%|██████████| 157/157 [04:07<00:00,  1.57s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
18     18  2025-04-11 18:17:03.771528         0.0001 -0.968585       0.363092   
19     19  2025-04-11 18:23:37.024975         0.0001 -0.968709       0.361139   

    validation_LL  validation_AUC  validation_NSS  
18       0.792251        0.769639        1.531810  
19       0.790298        0.769576        1.530436  
validation_IG     13
validation_LL     13
validation_AUC    18
validation_NSS    10
dtype: int64
removing train_deepgaze3/pretraining/step-0018.pth


-0.96889: 100%|██████████| 313/313 [02:27<00:00,  2.13it/s]
LL 0.79198: 100%|██████████| 157/157 [04:06<00:00,  1.57s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
19     19  2025-04-11 18:23:37.024975         0.0001 -0.968709       0.361139   
20     20  2025-04-11 18:30:11.705705         0.0001 -0.968887       0.362823   

    validation_LL  validation_AUC  validation_NSS  
19       0.790298        0.769576        1.530436  
20       0.791983        0.769503        1.523016  
validation_IG     13
validation_LL     13
validation_AUC    18
validation_NSS    10
dtype: int64
removing train_deepgaze3/pretraining/step-0019.pth


-0.96894: 100%|██████████| 313/313 [02:26<00:00,  2.14it/s]
LL 0.79097: 100%|██████████| 157/157 [04:08<00:00,  1.58s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
20     20  2025-04-11 18:30:11.705705         0.0001 -0.968887       0.362823   
21     21  2025-04-11 18:36:47.504161         0.0001 -0.968940       0.361812   

    validation_LL  validation_AUC  validation_NSS  
20       0.791983        0.769503        1.523016  
21       0.790971        0.769561        1.525275  
validation_IG     13
validation_LL     13
validation_AUC    18
validation_NSS    10
dtype: int64
removing train_deepgaze3/pretraining/step-0020.pth


-0.96905: 100%|██████████| 313/313 [02:25<00:00,  2.15it/s]
LL 0.79158: 100%|██████████| 157/157 [04:06<00:00,  1.57s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
21     21  2025-04-11 18:36:47.504161         0.0001 -0.968940       0.361812   
22     22  2025-04-11 18:43:20.463037         0.0001 -0.969054       0.362425   

    validation_LL  validation_AUC  validation_NSS  
21       0.790971        0.769561        1.525275  
22       0.791584        0.769655        1.530648  
validation_IG     13
validation_LL     13
validation_AUC    22
validation_NSS    10
dtype: int64
removing train_deepgaze3/pretraining/step-0021.pth


-0.96921: 100%|██████████| 313/313 [02:26<00:00,  2.13it/s]
LL 0.79170: 100%|██████████| 157/157 [04:08<00:00,  1.58s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
22     22  2025-04-11 18:43:20.463037         0.0001 -0.969054       0.362425   
23     23  2025-04-11 18:49:56.019169         0.0001 -0.969210       0.362544   

    validation_LL  validation_AUC  validation_NSS  
22       0.791584        0.769655        1.530648  
23       0.791703        0.769597        1.527686  
validation_IG     13
validation_LL     13
validation_AUC    22
validation_NSS    10
dtype: int64
removing train_deepgaze3/pretraining/step-0022.pth


-0.96934: 100%|██████████| 313/313 [02:27<00:00,  2.12it/s]
LL 0.79215: 100%|██████████| 157/157 [04:10<00:00,  1.59s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
23     23  2025-04-11 18:49:56.019169         0.0001 -0.969210       0.362544   
24     24  2025-04-11 18:56:34.344944         0.0001 -0.969336       0.362989   

    validation_LL  validation_AUC  validation_NSS  
23       0.791703        0.769597        1.527686  
24       0.792148        0.769613        1.529305  
validation_IG     13
validation_LL     13
validation_AUC    22
validation_NSS    10
dtype: int64
removing train_deepgaze3/pretraining/step-0023.pth


-0.96939: 100%|██████████| 313/313 [02:27<00:00,  2.13it/s]
LL 0.79052: 100%|██████████| 157/157 [04:08<00:00,  1.58s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
24     24  2025-04-11 18:56:34.344944         0.0001 -0.969336       0.362989   
25     25  2025-04-11 19:03:10.693880         0.0001 -0.969386       0.361363   

    validation_LL  validation_AUC  validation_NSS  
24       0.792148        0.769613        1.529305  
25       0.790522        0.769560        1.533791  
validation_IG     13
validation_LL     13
validation_AUC    22
validation_NSS    25
dtype: int64
removing train_deepgaze3/pretraining/step-0024.pth


-0.96947: 100%|██████████| 313/313 [02:27<00:00,  2.12it/s]
LL 0.79044: 100%|██████████| 157/157 [04:07<00:00,  1.58s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
25     25  2025-04-11 19:03:10.693880         0.0001 -0.969386       0.361363   
26     26  2025-04-11 19:09:46.329325         0.0001 -0.969465       0.361281   

    validation_LL  validation_AUC  validation_NSS  
25       0.790522        0.769560        1.533791  
26       0.790440        0.769632        1.524122  
validation_IG     13
validation_LL     13
validation_AUC    22
validation_NSS    25
dtype: int64
removing train_deepgaze3/pretraining/step-0025.pth


-0.96960: 100%|██████████| 313/313 [02:29<00:00,  2.10it/s]
LL 0.79330: 100%|██████████| 157/157 [04:07<00:00,  1.58s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
26     26  2025-04-11 19:09:46.329325         0.0001 -0.969465       0.361281   
27     27  2025-04-11 19:16:23.491679         0.0001 -0.969600       0.364141   

    validation_LL  validation_AUC  validation_NSS  
26        0.79044        0.769632        1.524122  
27        0.79330        0.769687        1.517840  
validation_IG     13
validation_LL     13
validation_AUC    27
validation_NSS    25
dtype: int64
removing train_deepgaze3/pretraining/step-0026.pth


-0.96968: 100%|██████████| 313/313 [02:27<00:00,  2.12it/s]
LL 0.79249: 100%|██████████| 157/157 [04:07<00:00,  1.58s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
27     27  2025-04-11 19:16:23.491679         0.0001 -0.969600       0.364141   
28     28  2025-04-11 19:22:59.416409         0.0001 -0.969684       0.363335   

    validation_LL  validation_AUC  validation_NSS  
27       0.793300        0.769687        1.517840  
28       0.792494        0.769711        1.521255  
validation_IG     13
validation_LL     13
validation_AUC    28
validation_NSS    25
dtype: int64
removing train_deepgaze3/pretraining/step-0027.pth


-0.96975: 100%|██████████| 313/313 [02:24<00:00,  2.16it/s]
LL 0.78964: 100%|██████████| 157/157 [04:03<00:00,  1.55s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
28     28  2025-04-11 19:22:59.416409         0.0001 -0.969684       0.363335   
29     29  2025-04-11 19:29:28.593544         0.0001 -0.969753       0.360484   

    validation_LL  validation_AUC  validation_NSS  
28       0.792494        0.769711        1.521255  
29       0.789643        0.769538        1.534606  
validation_IG     13
validation_LL     13
validation_AUC    28
validation_NSS    29
dtype: int64
removing train_deepgaze3/pretraining/step-0028.pth


-0.96984: 100%|██████████| 313/313 [02:26<00:00,  2.14it/s]
LL 0.79269: 100%|██████████| 157/157 [04:03<00:00,  1.55s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
29     29  2025-04-11 19:29:28.593544         0.0001 -0.969753       0.360484   
30     30  2025-04-11 19:35:59.781376         0.0001 -0.969838       0.363535   

    validation_LL  validation_AUC  validation_NSS  
29       0.789643        0.769538        1.534606  
30       0.792694        0.769566        1.528520  
validation_IG     13
validation_LL     13
validation_AUC    28
validation_NSS    29
dtype: int64
removing train_deepgaze3/pretraining/step-0029.pth


-0.97063: 100%|██████████| 313/313 [02:24<00:00,  2.16it/s]
LL 0.79176: 100%|██████████| 157/157 [04:03<00:00,  1.55s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
30     30  2025-04-11 19:35:59.781376        0.00010 -0.969838       0.363535   
31     31  2025-04-11 19:42:28.936428        0.00001 -0.970632       0.362601   

    validation_LL  validation_AUC  validation_NSS  
30       0.792694        0.769566        1.528520  
31       0.791760        0.769585        1.526607  
validation_IG     13
validation_LL     13
validation_AUC    28
validation_NSS    29
dtype: int64
removing train_deepgaze3/pretraining/step-0030.pth


-0.97074: 100%|██████████| 313/313 [02:24<00:00,  2.17it/s]
LL 0.79193: 100%|██████████| 157/157 [04:02<00:00,  1.54s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
31     31  2025-04-11 19:42:28.936428        0.00001 -0.970632       0.362601   
32     32  2025-04-11 19:48:55.896010        0.00001 -0.970739       0.362766   

    validation_LL  validation_AUC  validation_NSS  
31       0.791760        0.769585        1.526607  
32       0.791925        0.769618        1.526016  
validation_IG     13
validation_LL     13
validation_AUC    28
validation_NSS    29
dtype: int64
removing train_deepgaze3/pretraining/step-0031.pth


-0.97076: 100%|██████████| 313/313 [02:25<00:00,  2.16it/s]
LL 0.79314: 100%|██████████| 157/157 [04:02<00:00,  1.54s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
32     32  2025-04-11 19:48:55.896010        0.00001 -0.970739       0.362766   
33     33  2025-04-11 19:55:24.394128        0.00001 -0.970763       0.363981   

    validation_LL  validation_AUC  validation_NSS  
32       0.791925        0.769618        1.526016  
33       0.793140        0.769641        1.520833  
validation_IG     13
validation_LL     13
validation_AUC    28
validation_NSS    29
dtype: int64
removing train_deepgaze3/pretraining/step-0032.pth


-0.97077: 100%|██████████| 313/313 [02:24<00:00,  2.17it/s]
LL 0.79242: 100%|██████████| 157/157 [04:03<00:00,  1.55s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
33     33  2025-04-11 19:55:24.394128        0.00001 -0.970763       0.363981   
34     34  2025-04-11 20:01:53.080133        0.00001 -0.970773       0.363260   

    validation_LL  validation_AUC  validation_NSS  
33        0.79314        0.769641        1.520833  
34        0.79242        0.769619        1.522817  
validation_IG     13
validation_LL     13
validation_AUC    28
validation_NSS    29
dtype: int64
removing train_deepgaze3/pretraining/step-0033.pth


-0.97079: 100%|██████████| 313/313 [02:24<00:00,  2.16it/s]
LL 0.79187: 100%|██████████| 157/157 [04:04<00:00,  1.56s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
34     34  2025-04-11 20:01:53.080133        0.00001 -0.970773       0.363260   
35     35  2025-04-11 20:08:23.324112        0.00001 -0.970792       0.362715   

    validation_LL  validation_AUC  validation_NSS  
34       0.792420        0.769619        1.522817  
35       0.791874        0.769643        1.529418  
validation_IG     13
validation_LL     13
validation_AUC    28
validation_NSS    29
dtype: int64
removing train_deepgaze3/pretraining/step-0034.pth


-0.97080: 100%|██████████| 313/313 [02:25<00:00,  2.14it/s]
LL 0.79192: 100%|██████████| 157/157 [04:03<00:00,  1.55s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
35     35  2025-04-11 20:08:23.324112        0.00001 -0.970792       0.362715   
36     36  2025-04-11 20:14:53.453881        0.00001 -0.970803       0.362763   

    validation_LL  validation_AUC  validation_NSS  
35       0.791874        0.769643        1.529418  
36       0.791922        0.769612        1.527990  
validation_IG     13
validation_LL     13
validation_AUC    28
validation_NSS    29
dtype: int64
removing train_deepgaze3/pretraining/step-0035.pth


-0.97081: 100%|██████████| 313/313 [02:24<00:00,  2.17it/s]
LL 0.79220: 100%|██████████| 157/157 [04:03<00:00,  1.55s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
36     36  2025-04-11 20:14:53.453881        0.00001 -0.970803       0.362763   
37     37  2025-04-11 20:21:22.199116        0.00001 -0.970814       0.363039   

    validation_LL  validation_AUC  validation_NSS  
36       0.791922        0.769612        1.527990  
37       0.792198        0.769657        1.527284  
validation_IG     13
validation_LL     13
validation_AUC    28
validation_NSS    29
dtype: int64
removing train_deepgaze3/pretraining/step-0036.pth


-0.97082: 100%|██████████| 313/313 [02:25<00:00,  2.15it/s]
LL 0.79203: 100%|██████████| 157/157 [04:02<00:00,  1.55s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
37     37  2025-04-11 20:21:22.199116        0.00001 -0.970814       0.363039   
38     38  2025-04-11 20:27:51.283426        0.00001 -0.970820       0.362870   

    validation_LL  validation_AUC  validation_NSS  
37       0.792198        0.769657        1.527284  
38       0.792030        0.769637        1.527977  
validation_IG     13
validation_LL     13
validation_AUC    28
validation_NSS    29
dtype: int64
removing train_deepgaze3/pretraining/step-0037.pth


-0.97083: 100%|██████████| 313/313 [02:24<00:00,  2.16it/s]
LL 0.79161: 100%|██████████| 157/157 [04:03<00:00,  1.55s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
38     38  2025-04-11 20:27:51.283426        0.00001 -0.970820       0.362870   
39     39  2025-04-11 20:34:20.504102        0.00001 -0.970828       0.362451   

    validation_LL  validation_AUC  validation_NSS  
38        0.79203        0.769637        1.527977  
39        0.79161        0.769616        1.529185  
validation_IG     13
validation_LL     13
validation_AUC    28
validation_NSS    29
dtype: int64
removing train_deepgaze3/pretraining/step-0038.pth


-0.97084: 100%|██████████| 313/313 [02:25<00:00,  2.16it/s]
LL 0.79227: 100%|██████████| 157/157 [04:04<00:00,  1.56s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
39     39  2025-04-11 20:34:20.504102        0.00001 -0.970828       0.362451   
40     40  2025-04-11 20:40:50.807925        0.00001 -0.970839       0.363109   

    validation_LL  validation_AUC  validation_NSS  
39       0.791610        0.769616        1.529185  
40       0.792268        0.769642        1.527596  
validation_IG     13
validation_LL     13
validation_AUC    28
validation_NSS    29
dtype: int64
removing train_deepgaze3/pretraining/step-0039.pth


-0.97085: 100%|██████████| 313/313 [02:25<00:00,  2.15it/s]
LL 0.79235: 100%|██████████| 157/157 [04:02<00:00,  1.55s/it]


    epoch                   timestamp  learning_rate      loss  validation_IG  \
40     40  2025-04-11 20:40:50.807925        0.00001 -0.970839       0.363109   
41     41  2025-04-11 20:47:20.315317        0.00001 -0.970850       0.363191   

    validation_LL  validation_AUC  validation_NSS  
40       0.792268        0.769642        1.527596  
41       0.792350        0.769648        1.524828  
validation_IG     13
validation_LL     13
validation_AUC    28
validation_NSS    29
dtype: int64
removing train_deepgaze3/pretraining/step-0040.pth


-0.97086: 100%|██████████| 313/313 [02:24<00:00,  2.16it/s]
LL 0.78712:  14%|█▍        | 22/157 [00:33<03:32,  1.57s/it]

# Preparing the MIT1003 dataset

In [None]:
mit_stimuli_orig, mit_scanpaths_orig = pysaliency.external_datasets.mit.get_mit1003_with_initial_fixation(location=dataset_directory, replace_initial_invalid_fixations=True)

In [None]:
def convert_stimulus(input_image):
    size = input_image.shape[0], input_image.shape[1]
    if size[0] < size[1]:
        new_size = 768, 1024
    else:
        new_size = 1024,768
    
    # pillow uses width, height
    new_size = tuple(list(new_size)[::-1])
    
    new_stimulus = np.array(Image.fromarray(input_image).resize(new_size, Image.BILINEAR))
    return new_stimulus

def convert_fixations(stimuli, fixations):
    new_fixations = fixations.copy()
    for n in tqdm(list(range(len(stimuli)))):
        stimulus = stimuli.stimuli[n]
        size = stimulus.shape[0], stimulus.shape[1]
        if size[0] < size[1]:
            new_size = 768, 1024
        else:
            new_size = 1024,768
        x_factor = new_size[1] / size[1]
        y_factor = new_size[0] / size[0]
        
        inds = new_fixations.n == n
        new_fixations.x[inds] *= x_factor
        new_fixations.y[inds] *= y_factor
        new_fixations.x_hist[inds] *= x_factor
        new_fixations.y_hist[inds] *= y_factor
    
    return new_fixations

def convert_fixation_trains(stimuli, fixations):
    train_xs = fixations.train_xs.copy()
    train_ys = fixations.train_ys.copy()
    
    for i in tqdm(range(len(train_xs))):
        n = fixations.train_ns[i]
        
        size = stimuli.shapes[n][0], stimuli.shapes[n][1]
        
        if size[0] < size[1]:
            new_size = 768, 1024
        else:
            new_size = 1024,768
        
        x_factor = new_size[1] / size[1]
        y_factor = new_size[0] / size[0]
        
        train_xs[i] *= x_factor
        train_ys[i] *= y_factor
        
    new_fixations = pysaliency.FixationTrains(
        train_xs = train_xs,
        train_ys = train_ys,
        train_ts = fixations.train_ts.copy(),
        train_ns = fixations.train_ns.copy(),
        train_subjects = fixations.train_subjects.copy(),
        attributes={key: getattr(fixations, key).copy() for key in fixations.__attributes__ if key not in ['subjects', 'scanpath_index']},
    )
    return new_fixations



def convert_stimuli(stimuli, new_location: Path):
    assert isinstance(stimuli, pysaliency.FileStimuli)
    new_stimuli_location = new_location / 'stimuli'
    new_stimuli_location.mkdir(parents=True, exist_ok=True)
    new_filenames = []
    for filename in tqdm(stimuli.filenames):
        stimulus = imread(filename)
        new_stimulus = convert_stimulus(stimulus)
        
        basename = os.path.basename(filename)
        new_filename = new_stimuli_location / basename
        if new_stimulus.size != stimulus.size:
            imwrite(new_filename, new_stimulus)
        else:
            #print("Keeping")
            shutil.copy(filename, new_filename)
        new_filenames.append(new_filename)
    return pysaliency.FileStimuli(new_filenames)

mit_scanpaths_twosize = convert_fixation_trains(mit_stimuli_orig, mit_scanpaths_orig)
mit_stimuli_twosize = convert_stimuli(mit_stimuli_orig, train_directory / 'MIT1003_twosize')

  0%|          | 0/15045 [00:00<?, ?it/s]

100%|██████████| 15045/15045 [00:00<00:00, 172014.85it/s]
  new_fixations = pysaliency.FixationTrains(
100%|██████████| 1003/1003 [00:17<00:00, 57.52it/s]


In [None]:
# remove the initial forced fixation from the training data, it's only used for conditioning
mit_fixations_twosize = mit_scanpaths_twosize[mit_scanpaths_twosize.lengths > 0]

In [None]:
# parameters optimized on MIT1003 for maximum leave-one-image-out crossvalidation log-likelihood
MIT1003_centerbias = CrossvalidatedBaselineModel(
    mit_stimuli_twosize,
    mit_fixations_twosize,
    bandwidth=10**-1.6667673342543432,
    eps=10**-14.884189168516073,
    caching=False,
)

In [None]:
for crossval_fold in range(10):
    MIT1003_stimuli_train, MIT1003_fixations_train = pysaliency.dataset_config.train_split(mit_stimuli_twosize, mit_fixations_twosize, crossval_folds=10, fold_no=crossval_fold)
    MIT1003_stimuli_val, MIT1003_fixations_val = pysaliency.dataset_config.validation_split(mit_stimuli_twosize, mit_fixations_twosize, crossval_folds=10, fold_no=crossval_fold)

    train_baseline_log_likelihood = MIT1003_centerbias.information_gain(MIT1003_stimuli_train, MIT1003_fixations_train, verbose=True, average='image')
    val_baseline_log_likelihood = MIT1003_centerbias.information_gain(MIT1003_stimuli_val, MIT1003_fixations_val, verbose=True, average='image')

    # finetune spatial model on MIT1003

    model = DeepGazeIII(
        features=FeatureExtractor(RGBDenseNet201(), [
                '1.features.denseblock4.denselayer32.norm1',
                '1.features.denseblock4.denselayer32.conv1',
                '1.features.denseblock4.denselayer31.conv2',
            ]),
        saliency_network=build_saliency_network(2048),
        scanpath_network=None,
        fixation_selection_network=build_fixation_selection_network(scanpath_features=0),
        downsample=2,
        readout_factor=4,
        saliency_map_factor=4,
        included_fixations=[],
    )

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9, 12, 15, 18, 21, 24])

    train_loader = prepare_spatial_dataset(MIT1003_stimuli_train, MIT1003_fixations_train, MIT1003_centerbias, batch_size=4, path=train_directory / 'lmdb_cache' / f'MIT1003_train_spatial_{crossval_fold}')
    validation_loader = prepare_spatial_dataset(MIT1003_stimuli_val, MIT1003_fixations_val, MIT1003_centerbias, batch_size=4, path=train_directory / 'lmdb_cache' / f'MIT1003_val_spatial_{crossval_fold}')

    _train(train_directory / 'MIT1003_spatial' / f'crossval-10-{crossval_fold}',
        model,
        train_loader, train_baseline_log_likelihood,
        validation_loader, val_baseline_log_likelihood,
        optimizer, lr_scheduler,
        minimum_learning_rate=1e-7,
        device=device,
        startwith=train_directory / 'pretraining' / 'final.pth',
    )


    # Train scanpath model

    train_loader = prepare_scanpath_dataset(MIT1003_stimuli_train, MIT1003_fixations_train, MIT1003_centerbias, batch_size=4, path=train_directory / 'lmdb_cache' / f'MIT1003_train_scanpath_{crossval_fold}')
    validation_loader = prepare_scanpath_dataset(MIT1003_stimuli_val, MIT1003_fixations_val, MIT1003_centerbias, batch_size=4, path=train_directory / 'lmdb_cache' / f'MIT1003_val_scanpath_{crossval_fold}')

    # first train with partially frozen saliency network


    model = DeepGazeIII(
        features=FeatureExtractor(RGBDenseNet201(), [
                '1.features.denseblock4.denselayer32.norm1',
                '1.features.denseblock4.denselayer32.conv1',
                '1.features.denseblock4.denselayer31.conv2',
            ]),
        saliency_network=build_saliency_network(2048),
        scanpath_network=build_scanpath_network(),
        fixation_selection_network=build_fixation_selection_network(scanpath_features=16),
        downsample=2,
        readout_factor=4,
        saliency_map_factor=4,
        included_fixations=[-1, -2, -3, -4],
    )
    model = model.to(device)

    frozen_scopes = [
        "saliency_network.layernorm0",
        "saliency_network.conv0",
        "saliency_network.bias0",
        "saliency_network.layernorm1",
        "saliency_network.conv1",
        "saliency_network.bias1",
    ]

    for scope in frozen_scopes:
        for parameter_name, parameter in model.named_parameters():
            if parameter_name.startswith(scope):
                print("Fixating parameter", parameter_name)
                parameter.requires_grad = False


    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30, 31, 32, 33, 34, 35])

    _train(train_directory / 'MIT1003_scanpath_partially_frozen_saliency_network' / f'crossval-10-{crossval_fold}',
        model,
        train_loader, train_baseline_log_likelihood,
        validation_loader, val_baseline_log_likelihood,
        optimizer, lr_scheduler,
        minimum_learning_rate=1e-7,
        device=device,
        startwith=train_directory / 'MIT1003_spatial' /  f'crossval-10-{crossval_fold}' / 'final.pth'
    )

    # Now finetune full scanpath model

    model = DeepGazeIII(
        features=FeatureExtractor(RGBDenseNet201(), [
                '1.features.denseblock4.denselayer32.norm1',
                '1.features.denseblock4.denselayer32.conv1',
                '1.features.denseblock4.denselayer31.conv2',
            ]),
        saliency_network=build_saliency_network(2048),
        scanpath_network=build_scanpath_network(),
        fixation_selection_network=build_fixation_selection_network(scanpath_features=16),
        downsample=2,
        readout_factor=4,
        saliency_map_factor=4,
        included_fixations=[-1, -2, -3, -4],
    )

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9, 12, 15, 18, 21, 24])

    _train(train_directory / 'MIT1003_scanpath' / f'crossval-10-{crossval_fold}',
        model,
        train_loader, train_baseline_log_likelihood,
        validation_loader, val_baseline_log_likelihood,
        optimizer, lr_scheduler,
        minimum_learning_rate=1e-7,
        device=device,
        startwith=train_directory / 'MIT1003_scanpath_partially_frozen_saliency_network' / f'crossval-10-{crossval_fold}' / 'final.pth'
    )


Using random shuffles for crossvalidation
Using random shuffles for crossvalidation


100%|██████████| 808/808 [02:27<00:00,  5.49it/s]
100%|██████████| 94/94 [00:16<00:00,  5.84it/s]
Using cache found in /home/mmorello/.cache/torch/hub/pytorch_vision_v0.6.0


Generate LMDB to train_deepgaze3/lmdb_cache/MIT1003_train_spatial_0


100%|██████████| 808/808 [03:04<00:00,  4.39it/s]


Flushing database ...
Populating fixations cache


100%|█████████▉| 83717/83718 [00:00<00:00, 819458.17it/s]


Generate LMDB to train_deepgaze3/lmdb_cache/MIT1003_val_spatial_0


100%|██████████| 94/94 [00:15<00:00,  6.05it/s]


Flushing database ...
Populating fixations cache


100%|█████████▉| 9925/9926 [00:00<00:00, 600093.23it/s]


Using device cuda
Restoring from train_deepgaze3/pretraining/final.pth


FileNotFoundError: [Errno 2] No such file or directory: 'train_deepgaze3/pretraining/final.pth'