# How to train the Baseline Models for the SENSORIUM+ track

### This notebook will show how to
- instantiate dataloader for the Sensorium+ track
- instantiate pytorch model
- instantiate a trainer function
- train two baselines for this competition track
- save the model weights (the model weights can already be found in './model_checkpoints/pretrained/')

### Imports

In [16]:
from nnfabrik.builder import get_data, get_model, get_trainer
import torch



import numpy as np
import pandas as pd



import matplotlib.pyplot as plt
import seaborn as sns



import warnings



warnings.filterwarnings('ignore')

### Instantiate DataLoader for Sensorium+

The only difference to the Sensorium track is that here, we include the behavioral variables and the eye position,
by setting include_behavior=True, and include_eye_position=True.
this will append the behavioral variables to the input images, and the eye position will be passed to
the shifter network of the model.


In [17]:
import os
current_path = os.getcwd()
# Identify if path has 'sensorium' as a folder in it
if 'sensorium' in current_path:
    # If so, set the path to the root of the repo
    current_path = current_path.split('sensorium')[0] + 'sensorium'
else:
    raise FileNotFoundError(
        'This needs to be run from within the sensorium folder')
os.chdir(current_path)

In [18]:
# loading the SENSORIUM+ dataset
filenames = [
    # 'notebooks/data/static23964-4-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip',
    'notebooks/data/IM_prezipped/LPE11086/2023_12_16/'
]


dataset_fn = 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                  'normalize': True,
                  'include_behavior': True,
                  'include_eye_position': True,
                  'batch_size': 128,
                  'scale': .25,
                  }

dataloaders = get_data(dataset_fn, dataset_config)

# Instantiate State of the Art Model (SOTA)

Because the behavioral variables are available, we instantiate the Shifter network
by setting Shifter=True in the model configuration.

In [19]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
                'stack': -1,
                'layers': 4,
                'input_kern': 9,
                'gamma_input': 6.3831,
                'gamma_readout': 0.0076,
                'hidden_kern': 7,
                'hidden_channels': 64,
                'depth_separable': True,
                'grid_mean_predictor': {'type': 'cortex',
                                        'input_dimensions': 2,
                                        'hidden_layers': 1,
                                        'hidden_features': 30,
                                        'final_tanh': True},
                'init_sigma': 0.1,
                'init_mu_range': 0.3,
                'gauss_type': 'full',
                'shifter': True,
                }

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

Because the behavioral variables are available, we instantiate the Shifter network
by setting Shifter=True in the model configuration.

In [20]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
                'stack': -1,
                'layers': 4,
                'input_kern': 9,
                'gamma_input': 6.3831,
                'gamma_readout': 0.0076,
                'hidden_kern': 7,
                'hidden_channels': 64,
                'depth_separable': True,
                'grid_mean_predictor': {'type': 'cortex',
                                        'input_dimensions': 2,
                                        'hidden_layers': 1,
                                        'hidden_features': 30,
                                        'final_tanh': True},
                'init_sigma': 0.1,
                'init_mu_range': 0.3,
                'gauss_type': 'full',
                'shifter': True,
                }

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

## Configure Trainer

In [21]:
trainer_fn = "sensorium.training.standard_trainer"

trainer_config = {'max_iter': 200,
                  'verbose': False,
                  'lr_decay_steps': 4,
                  'avg_loss': False,
                  'lr_init': 0.009,
                  }

trainer = get_trainer(trainer_fn=trainer_fn,
                      trainer_config=trainer_config)

# Run model training

In [22]:
validation_score, trainer_output, state_dict = trainer(
    model, dataloaders, seed=42)

Epoch 1: 100%|██████████| 33/33 [01:52<00:00,  3.40s/it]
Epoch 2: 100%|██████████| 33/33 [00:03<00:00,  9.24it/s]
Epoch 3: 100%|██████████| 33/33 [00:04<00:00,  8.10it/s]
Epoch 4: 100%|██████████| 33/33 [00:04<00:00,  7.69it/s]
Epoch 5: 100%|██████████| 33/33 [00:03<00:00,  8.57it/s]
Epoch 6: 100%|██████████| 33/33 [00:03<00:00,  8.32it/s]
Epoch 7: 100%|██████████| 33/33 [00:03<00:00,  8.71it/s]
Epoch 8: 100%|██████████| 33/33 [00:03<00:00,  8.50it/s]
Epoch 9: 100%|██████████| 33/33 [00:04<00:00,  7.92it/s]
Epoch 10: 100%|██████████| 33/33 [00:03<00:00,  8.85it/s]
Epoch 11: 100%|██████████| 33/33 [00:03<00:00,  9.06it/s]
Epoch 12: 100%|██████████| 33/33 [00:03<00:00,  8.52it/s]
Epoch 13: 100%|██████████| 33/33 [00:03<00:00,  8.39it/s]
Epoch 14: 100%|██████████| 33/33 [00:03<00:00,  8.43it/s]
Epoch 15: 100%|██████████| 33/33 [00:03<00:00,  8.78it/s]
Epoch 16: 100%|██████████| 33/33 [00:03<00:00,  8.46it/s]
Epoch 17: 100%|██████████| 33/33 [00:07<00:00,  4.44it/s]
Epoch 18: 100%|████████

In [23]:
validation_score

0.24911124

## Save model checkpoints

In [24]:
# torch.save(model.state_dict(), './model_checkpoints/sensorium_p_sota_model.pth')
torch.save(model.state_dict(),
           'notebooks/model_tutorial/model_checkpoints/IM_p_sota_model.pth')

## Load Model Checkpoints

In [25]:
# model.load_state_dict(torch.load("./model_checkpoints/pretrained/sensorium_p_sota_model.pth"));
model.load_state_dict(torch.load(
    "notebooks/model_tutorial/model_checkpoints/IM_p_sota_model.pth"))

<All keys matched successfully>

---

# Train a simple LN model

In [26]:
# this will remove all nonlinearities from the CNN, and creates essentially a ln model: linear core + readout, with a subsequent non-linearity

model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
                'stack': -1,
                'layers': 3,
                'input_kern': 9,
                'gamma_input': 6.3831,
                'gamma_readout': 0.0076,
                'hidden_kern': 7,
                'hidden_channels': 64,
                'grid_mean_predictor': {'type': 'cortex',
                                        'input_dimensions': 2,
                                        'hidden_layers': 1,
                                        'hidden_features': 30,
                                        'final_tanh': True},
                'depth_separable': True,
                'init_sigma': 0.1,
                'init_mu_range': 0.3,
                'gauss_type': 'full',
                'linear': True,
                'shifter': True,
                }
model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

In [27]:
validation_score, trainer_output, state_dict = trainer(
    model, dataloaders, seed=42)

Epoch 1: 100%|██████████| 33/33 [00:07<00:00,  4.57it/s]
Epoch 2: 100%|██████████| 33/33 [00:07<00:00,  4.71it/s]
Epoch 3: 100%|██████████| 33/33 [00:06<00:00,  5.22it/s]
Epoch 4: 100%|██████████| 33/33 [00:05<00:00,  5.67it/s]
Epoch 5: 100%|██████████| 33/33 [00:06<00:00,  5.17it/s]
Epoch 6: 100%|██████████| 33/33 [00:06<00:00,  5.25it/s]
Epoch 7: 100%|██████████| 33/33 [00:06<00:00,  5.19it/s]
Epoch 8: 100%|██████████| 33/33 [00:06<00:00,  5.23it/s]
Epoch 9: 100%|██████████| 33/33 [00:06<00:00,  5.42it/s]
Epoch 10: 100%|██████████| 33/33 [00:06<00:00,  5.49it/s]
Epoch 11: 100%|██████████| 33/33 [00:06<00:00,  5.19it/s]
Epoch 12: 100%|██████████| 33/33 [00:06<00:00,  5.17it/s]
Epoch 13: 100%|██████████| 33/33 [00:02<00:00, 13.02it/s]
Epoch 14: 100%|██████████| 33/33 [00:02<00:00, 11.77it/s]
Epoch 15: 100%|██████████| 33/33 [00:03<00:00, 10.15it/s]
Epoch 16: 100%|██████████| 33/33 [00:02<00:00, 11.21it/s]
Epoch 17: 100%|██████████| 33/33 [00:03<00:00, 10.74it/s]
Epoch 18: 100%|████████

In [28]:
validation_score

0.21406637

In [29]:
# torch.save(model.state_dict(), './model_checkpoints/sensorium_p_ln_model.pth')
torch.save(model.state_dict(),
           'notebooks/model_tutorial/model_checkpoints/IM_p_ln_model.pth')

In [30]:
# model.load_state_dict(torch.load("./model_checkpoints/pretrained/sensorium_p_ln_model.pth"));
model.load_state_dict(torch.load(
    "notebooks/model_tutorial/model_checkpoints/IM_p_ln_model.pth"))

<All keys matched successfully>

---