In [1]:
from collections import OrderedDict
import os
from pathlib import Path
import shutil

from imageio.v3 import imread, imwrite
from PIL import Image
import pysaliency
from pysaliency.baseline_utils import BaselineModel, CrossvalidatedBaselineModel
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils import model_zoo

from tqdm import tqdm


from deepgaze_pytorch.layers import (
    Conv2dMultiInput,
    LayerNorm,
    LayerNormMultiInput,
    Bias,
    FlexibleScanpathHistoryEncoding
)

from deepgaze_pytorch.modules import DeepGazeIII, FeatureExtractor
from deepgaze_pytorch.features.densenet import RGBDenseNet201
from deepgaze_pytorch.data import ImageDataset, ImageDatasetSampler, FixationDataset, FixationMaskTransform
from deepgaze_pytorch.training import _train


In [2]:
def build_saliency_network(input_channels):
    return nn.Sequential(OrderedDict([
        ('layernorm0', LayerNorm(input_channels)),
        ('conv0', nn.Conv2d(input_channels, 8, (1, 1), bias=False)),
        ('bias0', Bias(8)),
        ('softplus0', nn.Softplus()),

        ('layernorm1', LayerNorm(8)),
        ('conv1', nn.Conv2d(8, 16, (1, 1), bias=False)),
        ('bias1', Bias(16)),
        ('softplus1', nn.Softplus()),

        ('layernorm2', LayerNorm(16)),
        ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
        ('bias2', Bias(1)),
        ('softplus2', nn.Softplus()),
    ]))


def build_scanpath_network():
    return nn.Sequential(OrderedDict([
        ('encoding0', FlexibleScanpathHistoryEncoding(in_fixations=4, channels_per_fixation=3, out_channels=128, kernel_size=[1, 1], bias=True)),
        ('softplus0', nn.Softplus()),

        ('layernorm1', LayerNorm(128)),
        ('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)),
        ('bias1', Bias(16)),
        ('softplus1', nn.Softplus()),
    ]))


def build_fixation_selection_network(scanpath_features=16):
    return nn.Sequential(OrderedDict([
        ('layernorm0', LayerNormMultiInput([1, scanpath_features])),
        ('conv0', Conv2dMultiInput([1, scanpath_features], 128, (1, 1), bias=False)),
        ('bias0', Bias(128)),
        ('softplus0', nn.Softplus()),

        ('layernorm1', LayerNorm(128)),
        ('conv1', nn.Conv2d(128, 16, (1, 1), bias=False)),
        ('bias1', Bias(16)),
        ('softplus1', nn.Softplus()),

        ('conv2', nn.Conv2d(16, 1, (1, 1), bias=False)),
    ]))

In [3]:
def prepare_spatial_dataset(stimuli, fixations, centerbias, batch_size, path=None):
    if path is not None:
        path.mkdir(parents=True, exist_ok=True)
        lmdb_path = str(path)
    else:
        lmdb_path = None

    dataset = ImageDataset(
        stimuli=stimuli,
        fixations=fixations,
        centerbias_model=centerbias,
        transform=FixationMaskTransform(sparse=False),
        average='image',
        lmdb_path=lmdb_path,
    )

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=ImageDatasetSampler(dataset, batch_size=batch_size),
        pin_memory=True,
        num_workers=os.cpu_count(),
        persistent_workers=True,
        prefetch_factor=2,
    )

    return loader

In [4]:
def prepare_scanpath_dataset(stimuli, fixations, centerbias, batch_size, path=None):
    if path is not None:
        path.mkdir(parents=True, exist_ok=True)
        lmdb_path = str(path)
    else:
        lmdb_path = None

    dataset = FixationDataset(
        stimuli=stimuli,
        fixations=fixations,
        centerbias_model=centerbias,
        included_fixations=[-1, -2, -3, -4],
        allow_missing_fixations=True,
        transform=FixationMaskTransform(sparse=False),
        average='image',
        lmdb_path=lmdb_path,
    )

    loader = torch.utils.data.DataLoader(
        dataset,
        batch_sampler=ImageDatasetSampler(dataset, batch_size=batch_size),
        pin_memory=True,
        num_workers=os.cpu_count(),
        persistent_workers=True,
        prefetch_factor=2,
    )

    return loader

In [5]:
dataset_directory = Path('pysaliency_datasets')
train_directory = Path('train_deepgaze3')

In [6]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if device.type == 'cuda':
    print('Using GPU')

Using GPU


# Pretraining on SALICON

In [7]:
import os
import pickle
# Assume dataset_directory is defined appropriately, e.g.:
# dataset_directory = '/path/to/your/datasets'
# Ensure the directory exists or create it if it doesn't
# os.makedirs(dataset_directory, exist_ok=True) # Might be needed if dataset_directory is just for caching

# --- Load Data ---
print(f"Loading SALICON data from {dataset_directory}...")
SALICON_train_stimuli, SALICON_train_fixations = pysaliency.get_SALICON_train(location=dataset_directory)
SALICON_val_stimuli, SALICON_val_fixations = pysaliency.get_SALICON_val(location=dataset_directory)
print("SALICON data loaded.")

# --- Define Model ---
# parameters taken from an early fit for MIT1003. Since SALICON has many more fixations, the bandwidth won't be too small
print("Initializing BaselineModel...")
SALICON_centerbias = BaselineModel(stimuli=SALICON_train_stimuli, fixations=SALICON_train_fixations, bandwidth=0.0217, eps=2e-13, caching=False)
print("BaselineModel initialized.")

# --- Define cache file paths ---
train_ll_cache_file = os.path.join(dataset_directory, 'salicon_baseline_train_ll.pkl')
val_ll_cache_file = os.path.join(dataset_directory, 'salicon_baseline_val_ll.pkl')

# --- Compute or Load Train Baseline Log Likelihood ---
try:
    # Attempt to load from cache
    with open(train_ll_cache_file, 'rb') as f:
        train_baseline_log_likelihood = pickle.load(f)
    print(f"Loaded cached train baseline log likelihood from: {train_ll_cache_file}")
except (FileNotFoundError, EOFError, pickle.UnpicklingError) as e:
    # Compute if cache doesn't exist or is invalid
    print(f"Cache not found or invalid ({e}). Computing train baseline log likelihood...")
    train_baseline_log_likelihood = SALICON_centerbias.information_gain(
        SALICON_train_stimuli,
        SALICON_train_fixations,
        verbose=True,
        average='image'
    )
    print(f"Computation finished. Train LL = {train_baseline_log_likelihood}")
    # Save the result
    try:
        os.makedirs(os.path.dirname(train_ll_cache_file), exist_ok=True) # Ensure directory exists
        with open(train_ll_cache_file, 'wb') as f:
            pickle.dump(train_baseline_log_likelihood, f)
        print(f"Saved train baseline log likelihood to: {train_ll_cache_file}")
    except Exception as save_e:
        print(f"Error saving cache file {train_ll_cache_file}: {save_e}")


# --- Compute or Load Validation Baseline Log Likelihood ---
try:
    # Attempt to load from cache
    with open(val_ll_cache_file, 'rb') as f:
        val_baseline_log_likelihood = pickle.load(f)
    print(f"Loaded cached validation baseline log likelihood from: {val_ll_cache_file}")
except (FileNotFoundError, EOFError, pickle.UnpicklingError) as e:
    # Compute if cache doesn't exist or is invalid
    print(f"Cache not found or invalid ({e}). Computing validation baseline log likelihood...")
    val_baseline_log_likelihood = SALICON_centerbias.information_gain(
        SALICON_val_stimuli,
        SALICON_val_fixations,
        verbose=True,
        average='image'
    )
    print(f"Computation finished. Validation LL = {val_baseline_log_likelihood}")
    # Save the result
    try:
        os.makedirs(os.path.dirname(val_ll_cache_file), exist_ok=True) # Ensure directory exists
        with open(val_ll_cache_file, 'wb') as f:
            pickle.dump(val_baseline_log_likelihood, f)
        print(f"Saved validation baseline log likelihood to: {val_ll_cache_file}")
    except Exception as save_e:
        print(f"Error saving cache file {val_ll_cache_file}: {save_e}")


# --- Final Output ---
print("-" * 30)
print(f"Final Train Baseline Log Likelihood: {train_baseline_log_likelihood}")
print(f"Final Validation Baseline Log Likelihood: {val_baseline_log_likelihood}")
print("-" * 30)

Loading SALICON data from pysaliency_datasets...
SALICON data loaded.
Initializing BaselineModel...
BaselineModel initialized.
Loaded cached train baseline log likelihood from: pysaliency_datasets/salicon_baseline_train_ll.pkl
Loaded cached validation baseline log likelihood from: pysaliency_datasets/salicon_baseline_val_ll.pkl
------------------------------
Final Train Baseline Log Likelihood: 0.46408017115279726
Final Validation Baseline Log Likelihood: 0.4291592320821601
------------------------------


In [8]:
model = DeepGazeIII(
    features=FeatureExtractor(RGBDenseNet201(), [
            '1.features.denseblock4.denselayer32.norm1',
            '1.features.denseblock4.denselayer32.conv1',
            '1.features.denseblock4.denselayer31.conv2',
        ]),
    saliency_network=build_saliency_network(2048),
    scanpath_network=None,
    fixation_selection_network=build_fixation_selection_network(scanpath_features=0),
    downsample=1.5,
    readout_factor=4,
    saliency_map_factor=4,
    included_fixations=[],
)

model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[15, 30, 45, 60, 75, 90, 105, 120])

Using cache found in /home/mirko/.cache/torch/hub/pytorch_vision_v0.6.0


In [9]:
train_loader = prepare_spatial_dataset(SALICON_train_stimuli, SALICON_train_fixations, SALICON_centerbias, batch_size=32, path=train_directory / 'lmdb_cache' / 'SALICON_train')
validation_loader = prepare_spatial_dataset(SALICON_val_stimuli, SALICON_val_fixations, SALICON_centerbias, batch_size=32, path=train_directory / 'lmdb_cache' / 'SALICON_val')

Valid LMDB found at train_deepgaze3/lmdb_cache/SALICON_train with 10000 items. Skipping generation.
Populating fixations cache


100%|█████████▉| 68992354/68992355 [00:16<00:00, 4292110.64it/s]


Valid LMDB found at train_deepgaze3/lmdb_cache/SALICON_val with 5000 items. Skipping generation.
Populating fixations cache


100%|█████████▉| 38846997/38846998 [00:08<00:00, 4375400.80it/s]


In [10]:
_train(train_directory / 'pretraining',
    model,
    train_loader, train_baseline_log_likelihood,
    validation_loader, val_baseline_log_likelihood,
    optimizer, lr_scheduler,
    minimum_learning_rate=1e-7,
    device=device,
    validation_metrics=['LL', 'IG', 'NSS'],
) 


Training Already finished


# Preparing the MIT1003 dataset

In [11]:
mit_stimuli_orig, mit_scanpaths_orig = pysaliency.external_datasets.mit.get_mit1003_with_initial_fixation(location=dataset_directory, replace_initial_invalid_fixations=True)

In [12]:
def convert_stimulus(input_image):
    size = input_image.shape[0], input_image.shape[1]
    if size[0] < size[1]:
        new_size = 768, 1024
    else:
        new_size = 1024,768
    
    # pillow uses width, height
    new_size = tuple(list(new_size)[::-1])
    
    new_stimulus = np.array(Image.fromarray(input_image).resize(new_size, Image.BILINEAR))
    return new_stimulus

def convert_fixations(stimuli, fixations):
    new_fixations = fixations.copy()
    for n in tqdm(list(range(len(stimuli)))):
        stimulus = stimuli.stimuli[n]
        size = stimulus.shape[0], stimulus.shape[1]
        if size[0] < size[1]:
            new_size = 768, 1024
        else:
            new_size = 1024,768
        x_factor = new_size[1] / size[1]
        y_factor = new_size[0] / size[0]
        
        inds = new_fixations.n == n
        new_fixations.x[inds] *= x_factor
        new_fixations.y[inds] *= y_factor
        new_fixations.x_hist[inds] *= x_factor
        new_fixations.y_hist[inds] *= y_factor
    
    return new_fixations

def convert_fixation_trains(stimuli, fixations):
    train_xs = fixations.train_xs.copy()
    train_ys = fixations.train_ys.copy()
    
    for i in tqdm(range(len(train_xs))):
        n = fixations.train_ns[i]
        
        size = stimuli.shapes[n][0], stimuli.shapes[n][1]
        
        if size[0] < size[1]:
            new_size = 768, 1024
        else:
            new_size = 1024,768
        
        x_factor = new_size[1] / size[1]
        y_factor = new_size[0] / size[0]
        
        train_xs[i] *= x_factor
        train_ys[i] *= y_factor
        
    new_fixations = pysaliency.FixationTrains(
        train_xs = train_xs,
        train_ys = train_ys,
        train_ts = fixations.train_ts.copy(),
        train_ns = fixations.train_ns.copy(),
        train_subjects = fixations.train_subjects.copy(),
        attributes={key: getattr(fixations, key).copy() for key in fixations.__attributes__ if key not in ['subjects', 'scanpath_index']},
    )
    return new_fixations



def convert_stimuli(stimuli, new_location: Path):
    assert isinstance(stimuli, pysaliency.FileStimuli)
    new_stimuli_location = new_location / 'stimuli'
    new_stimuli_location.mkdir(parents=True, exist_ok=True)
    new_filenames = []
    for filename in tqdm(stimuli.filenames):
        stimulus = imread(filename)
        new_stimulus = convert_stimulus(stimulus)
        
        basename = os.path.basename(filename)
        new_filename = new_stimuli_location / basename
        if new_stimulus.size != stimulus.size:
            imwrite(new_filename, new_stimulus)
        else:
            #print("Keeping")
            shutil.copy(filename, new_filename)
        new_filenames.append(new_filename)
    return pysaliency.FileStimuli(new_filenames)

mit_scanpaths_twosize = convert_fixation_trains(mit_stimuli_orig, mit_scanpaths_orig)
mit_stimuli_twosize = convert_stimuli(mit_stimuli_orig, train_directory / 'MIT1003_twosize')

100%|██████████| 15045/15045 [00:00<00:00, 568810.82it/s]
  new_fixations = pysaliency.FixationTrains(
100%|██████████| 1003/1003 [00:04<00:00, 224.76it/s]


In [13]:
# remove the initial forced fixation from the training data, it's only used for conditioning
mit_fixations_twosize = mit_scanpaths_twosize[mit_scanpaths_twosize.lengths > 0]

In [14]:
# parameters optimized on MIT1003 for maximum leave-one-image-out crossvalidation log-likelihood
MIT1003_centerbias = CrossvalidatedBaselineModel(
    mit_stimuli_twosize,
    mit_fixations_twosize,
    bandwidth=10**-1.6667673342543432,
    eps=10**-14.884189168516073,
    caching=False,
)

In [None]:
for crossval_fold in range(10):
    MIT1003_stimuli_train, MIT1003_fixations_train = pysaliency.dataset_config.train_split(mit_stimuli_twosize, mit_fixations_twosize, crossval_folds=10, fold_no=crossval_fold)
    MIT1003_stimuli_val, MIT1003_fixations_val = pysaliency.dataset_config.validation_split(mit_stimuli_twosize, mit_fixations_twosize, crossval_folds=10, fold_no=crossval_fold)

    train_baseline_log_likelihood = MIT1003_centerbias.information_gain(MIT1003_stimuli_train, MIT1003_fixations_train, verbose=True, average='image')
    val_baseline_log_likelihood = MIT1003_centerbias.information_gain(MIT1003_stimuli_val, MIT1003_fixations_val, verbose=True, average='image')

    # finetune spatial model on MIT1003

    model = DeepGazeIII(
        features=FeatureExtractor(RGBDenseNet201(), [
                '1.features.denseblock4.denselayer32.norm1',
                '1.features.denseblock4.denselayer32.conv1',
                '1.features.denseblock4.denselayer31.conv2',
            ]),
        saliency_network=build_saliency_network(2048),
        scanpath_network=None,
        fixation_selection_network=build_fixation_selection_network(scanpath_features=0),
        downsample=2,
        readout_factor=4,
        saliency_map_factor=4,
        included_fixations=[],
    )

    model = model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9, 12, 15, 18, 21, 24])

    train_loader = prepare_spatial_dataset(MIT1003_stimuli_train, MIT1003_fixations_train, MIT1003_centerbias, batch_size=32, path=train_directory / 'lmdb_cache' / f'MIT1003_train_spatial_{crossval_fold}')
    validation_loader = prepare_spatial_dataset(MIT1003_stimuli_val, MIT1003_fixations_val, MIT1003_centerbias, batch_size=32, path=train_directory / 'lmdb_cache' / f'MIT1003_val_spatial_{crossval_fold}')

    _train(train_directory / 'MIT1003_spatial' / f'crossval-10-{crossval_fold}',
        model,
        train_loader, train_baseline_log_likelihood,
        validation_loader, val_baseline_log_likelihood,
        optimizer, lr_scheduler,
        minimum_learning_rate=1e-7,
        device=device,
        startwith=train_directory / 'pretraining' / 'final.pth',
    )


    # Train scanpath model

    train_loader = prepare_scanpath_dataset(MIT1003_stimuli_train, MIT1003_fixations_train, MIT1003_centerbias, batch_size=32, path=train_directory / 'lmdb_cache' / f'MIT1003_train_scanpath_{crossval_fold}')
    validation_loader = prepare_scanpath_dataset(MIT1003_stimuli_val, MIT1003_fixations_val, MIT1003_centerbias, batch_size=32, path=train_directory / 'lmdb_cache' / f'MIT1003_val_scanpath_{crossval_fold}')

    # first train with partially frozen saliency network


    model = DeepGazeIII(
        features=FeatureExtractor(RGBDenseNet201(), [
                '1.features.denseblock4.denselayer32.norm1',
                '1.features.denseblock4.denselayer32.conv1',
                '1.features.denseblock4.denselayer31.conv2',
            ]),
        saliency_network=build_saliency_network(2048),
        scanpath_network=build_scanpath_network(),
        fixation_selection_network=build_fixation_selection_network(scanpath_features=16),
        downsample=2,
        readout_factor=4,
        saliency_map_factor=4,
        included_fixations=[-1, -2, -3, -4],
    )
    model = model.to(device)

    frozen_scopes = [
        "saliency_network.layernorm0",
        "saliency_network.conv0",
        "saliency_network.bias0",
        "saliency_network.layernorm1",
        "saliency_network.conv1",
        "saliency_network.bias1",
    ]

    for scope in frozen_scopes:
        for parameter_name, parameter in model.named_parameters():
            if parameter_name.startswith(scope):
                print("Fixating parameter", parameter_name)
                parameter.requires_grad = False


    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30, 31, 32, 33, 34, 35])

    _train(train_directory / 'MIT1003_scanpath_partially_frozen_saliency_network' / f'crossval-10-{crossval_fold}',
        model,
        train_loader, train_baseline_log_likelihood,
        validation_loader, val_baseline_log_likelihood,
        optimizer, lr_scheduler,
        minimum_learning_rate=1e-7,
        device=device,
        startwith=train_directory / 'MIT1003_spatial' /  f'crossval-10-{crossval_fold}' / 'final.pth',
        validation_metrics=['LL', 'IG', 'NSS'],
    )

    # Now finetune full scanpath model

    model = DeepGazeIII(
        features=FeatureExtractor(RGBDenseNet201(), [
                '1.features.denseblock4.denselayer32.norm1',
                '1.features.denseblock4.denselayer32.conv1',
                '1.features.denseblock4.denselayer31.conv2',
            ]),
        saliency_network=build_saliency_network(2048),
        scanpath_network=build_scanpath_network(),
        fixation_selection_network=build_fixation_selection_network(scanpath_features=16),
        downsample=2,
        readout_factor=4,
        saliency_map_factor=4,
        included_fixations=[-1, -2, -3, -4],
    )

    model = model.to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr=0.00001)
    lr_scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[3, 6, 9, 12, 15, 18, 21, 24])

    _train(train_directory / 'MIT1003_scanpath' / f'crossval-10-{crossval_fold}',
        model,
        train_loader, train_baseline_log_likelihood,
        validation_loader, val_baseline_log_likelihood,
        optimizer, lr_scheduler,
        minimum_learning_rate=1e-7,
        device=device,
        startwith=train_directory / 'MIT1003_scanpath_partially_frozen_saliency_network' / f'crossval-10-{crossval_fold}' / 'final.pth',
        validation_metrics=['LL', 'IG', 'NSS']
    )


Using random shuffles for crossvalidation
Using random shuffles for crossvalidation


  0%|          | 0/808 [00:00<?, ?it/s]

100%|██████████| 808/808 [00:56<00:00, 14.43it/s]
100%|██████████| 94/94 [00:06<00:00, 15.48it/s]
Using cache found in /home/mirko/.cache/torch/hub/pytorch_vision_v0.6.0


Valid LMDB found at train_deepgaze3/lmdb_cache/MIT1003_train_spatial_0 with 808 items. Skipping generation.
Populating fixations cache


100%|█████████▉| 83717/83718 [00:00<00:00, 4111934.66it/s]


Valid LMDB found at train_deepgaze3/lmdb_cache/MIT1003_val_spatial_0 with 94 items. Skipping generation.
Populating fixations cache


100%|█████████▉| 9925/9926 [00:00<00:00, 4379638.84it/s]
Using cache found in /home/mirko/.cache/torch/hub/pytorch_vision_v0.6.0


Training Already finished
Valid LMDB found at train_deepgaze3/lmdb_cache/MIT1003_train_scanpath_0 with 808 items. Skipping generation.
Valid LMDB found at train_deepgaze3/lmdb_cache/MIT1003_val_scanpath_0 with 94 items. Skipping generation.
Fixating parameter saliency_network.layernorm0.weight
Fixating parameter saliency_network.layernorm0.bias
Fixating parameter saliency_network.conv0.weight
Fixating parameter saliency_network.bias0.bias
Fixating parameter saliency_network.layernorm1.weight
Fixating parameter saliency_network.layernorm1.bias
Fixating parameter saliency_network.conv1.weight
Fixating parameter saliency_network.bias1.bias
Using device cuda
Restoring from train_deepgaze3/MIT1003_spatial/crossval-10-0/final.pth
validation metrics ['LL', 'IG', 'NSS']
Found old checkpoint train_deepgaze3/MIT1003_scanpath_partially_frozen_saliency_network/crossval-10-0/step-0002.pth
Restoring from train_deepgaze3/MIT1003_scanpath_partially_frozen_saliency_network/crossval-10-0/step-0002.pth
S

-2.41191: 100%|██████████| 2617/2617 [12:40<00:00,  3.44it/s]


DEBUG: Evaluating metrics: ['LL', 'IG', 'NSS']


Validating LL 2.44499: 100%|██████████| 312/312 [01:24<00:00,  3.70it/s]


   epoch                   timestamp  learning_rate      loss  validation_LL  \
2      2  2025-04-16 20:18:54.640846          0.001 -2.387792       2.432705   
3      3  2025-04-17 14:37:21.435345          0.001 -2.411915       2.444985   

   validation_IG  validation_NSS  
2        1.52096       10.171130  
3        1.53324        9.647873  
validation_LL     3
validation_IG     3
validation_NSS    2
dtype: int64
removing train_deepgaze3/MIT1003_scanpath_partially_frozen_saliency_network/crossval-10-0/step-0002.pth


-2.42710: 100%|██████████| 2617/2617 [12:34<00:00,  3.47it/s]


DEBUG: Evaluating metrics: ['LL', 'IG', 'NSS']


Validating LL 2.45608: 100%|██████████| 312/312 [01:22<00:00,  3.78it/s]


   epoch                   timestamp  learning_rate      loss  validation_LL  \
3      3  2025-04-17 14:37:21.435345          0.001 -2.411915       2.444985   
4      4  2025-04-17 14:51:18.859947          0.001 -2.427100       2.456082   

   validation_IG  validation_NSS  
3       1.533240        9.647873  
4       1.544337       10.737778  
validation_LL     4
validation_IG     4
validation_NSS    4
dtype: int64
removing train_deepgaze3/MIT1003_scanpath_partially_frozen_saliency_network/crossval-10-0/step-0003.pth


-2.43784: 100%|██████████| 2617/2617 [12:33<00:00,  3.47it/s]


DEBUG: Evaluating metrics: ['LL', 'IG', 'NSS']


Validating LL 2.45803: 100%|██████████| 312/312 [01:22<00:00,  3.79it/s]


   epoch                   timestamp  learning_rate      loss  validation_LL  \
4      4  2025-04-17 14:51:18.859947          0.001 -2.427100       2.456082   
5      5  2025-04-17 15:05:15.302295          0.001 -2.437837       2.458034   

   validation_IG  validation_NSS  
4       1.544337       10.737778  
5       1.546289        9.821991  
validation_LL     5
validation_IG     5
validation_NSS    4
dtype: int64
removing train_deepgaze3/MIT1003_scanpath_partially_frozen_saliency_network/crossval-10-0/step-0004.pth


-2.45815:   6%|▌         | 158/2617 [00:48<11:47,  3.48it/s]