# How to train the Baseline Models for the SENSORIUM+ track

### This notebook will show how to
- instantiate dataloader for the Sensorium+ track
- instantiate pytorch model
- instantiate a trainer function
- train two baselines for this competition track
- save the model weights (the model weights can already be found in './model_checkpoints/pretrained/')

### Imports

In [29]:
!pwd

/srv/user/ninasophie.nellen/sensorium/notebooks/model_tutorial


In [63]:
import neuralpredictors

In [64]:
import torch

In [65]:
torch.cuda.is_available()

True

In [66]:
!nvidia-smi

Tue Oct 15 10:27:32 2024       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.14              Driver Version: 550.54.14      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA RTX A5000               On  |   00000000:01:00.0 Off |                  Off |
| 30%   34C    P8             22W /  230W |    1454MiB /  24564MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
|   1  NVIDIA RTX A5000               On  |   00

In [67]:
torch.cuda.set_device("cuda:6")

In [68]:
import torch
import numpy as np
import pandas as pd

#import matplotlib.pyplot as plt
#import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

from nnfabrik.builder import get_data, get_model, get_trainer

### Instantiate DataLoader for Sensorium+

The only difference to the Sensorium track is that here, we include the behavioral variables and the eye position,
by setting include_behavior=True, and include_eye_position=True.
this will append the behavioral variables to the input images, and the eye position will be passed to
the shifter network of the model.


In [69]:
!ls /usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/

static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6
static22846-10-16-GrayImageNet-94c6ff995dac583098847cfecd43e7b6
static23343-5-17-GrayImageNet-94c6ff995dac583098847cfecd43e7b6
static23656-14-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6
static23964-4-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6
static26872-17-20-GrayImageNet-94c6ff995dac583098847cfecd43e7b6
static27204-5-13-GrayImageNet-94c6ff995dac583098847cfecd43e7b6


In [70]:
import os

In [71]:
pre = "/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/"
[f"{pre}{i}/" for i in os.listdir(pre)]

['/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/static23964-4-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6/',
 '/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6/',
 '/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/static22846-10-16-GrayImageNet-94c6ff995dac583098847cfecd43e7b6/',
 '/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/static23343-5-17-GrayImageNet-94c6ff995dac583098847cfecd43e7b6/',
 '/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/static26872-17-20-GrayImageNet-94c6ff995dac583098847cfecd43e7b6/',
 '/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/static23656-14-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6/',
 '/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/static27204-5-13-GrayImageNet-94c6ff995dac583098847cfecd43e7b6/']

In [72]:
import skimage

In [73]:
# loading the SENSORIUM+ dataset
pre = "/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/"
print(pre)
filenames = [f"{pre}{i}/" for i in os.listdir(pre)]
print(filenames)

dataset_fn = 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                 'normalize': True,
                 'include_behavior': True,
                 'include_eye_position': True,
                 'batch_size': 128,
                 'scale':.25,
                 }

dataloaders = get_data(dataset_fn, dataset_config)

/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/
['/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/static23964-4-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6/', '/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6/', '/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/static22846-10-16-GrayImageNet-94c6ff995dac583098847cfecd43e7b6/', '/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/static23343-5-17-GrayImageNet-94c6ff995dac583098847cfecd43e7b6/', '/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/static26872-17-20-GrayImageNet-94c6ff995dac583098847cfecd43e7b6/', '/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/static23656-14-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6/', '/usr/users/agecker/datasets/sensorium_2022_pictures/real_dataset/static27204-5-13-GrayImageNet-94c6ff995dac583098847cfecd43e7b6/']


In [75]:
dataloaders

OrderedDict([('train',
              OrderedDict([('23964-4-22',
                            <torch.utils.data.dataloader.DataLoader at 0x7f532499cb00>),
                           ('21067-10-18',
                            <torch.utils.data.dataloader.DataLoader at 0x7f53248c11c0>),
                           ('22846-10-16',
                            <torch.utils.data.dataloader.DataLoader at 0x7f532487e690>),
                           ('23343-5-17',
                            <torch.utils.data.dataloader.DataLoader at 0x7f53248b78c0>),
                           ('26872-17-20',
                            <torch.utils.data.dataloader.DataLoader at 0x7f532499d7c0>),
                           ('23656-14-22',
                            <torch.utils.data.dataloader.DataLoader at 0x7f5324950650>),
                           ('27204-5-13',
                            <torch.utils.data.dataloader.DataLoader at 0x7f532487f3b0>)])),
             ('validation',
              OrderedDict

In [76]:
elem = next(iter(dataloaders["train"]["23343-5-17"]))

In [77]:
elem.images.shape
# images, channels, height, width

torch.Size([128, 4, 36, 64])

In [41]:
elem.images[0, 1, :, :]

tensor([[2.0980, 2.0980, 2.0980,  ..., 2.0980, 2.0980, 2.0980],
        [2.0980, 2.0980, 2.0980,  ..., 2.0980, 2.0980, 2.0980],
        [2.0980, 2.0980, 2.0980,  ..., 2.0980, 2.0980, 2.0980],
        ...,
        [2.0980, 2.0980, 2.0980,  ..., 2.0980, 2.0980, 2.0980],
        [2.0980, 2.0980, 2.0980,  ..., 2.0980, 2.0980, 2.0980],
        [2.0980, 2.0980, 2.0980,  ..., 2.0980, 2.0980, 2.0980]],
       device='cuda:6')

In [78]:
elem.responses.shape
#images, neurons

torch.Size([128, 7334])

In [79]:
#elem

In [80]:
elem.behavior.shape,  elem.pupil_center.shape

(torch.Size([128, 3]), torch.Size([128, 2]))

# Instantiate State of the Art Model (SOTA)

Because the behavioral variables are available, we instantiate the Shifter network
by setting Shifter=True in the model configuration.

In [45]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
  'stack': -1,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': True,
}

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

Because the behavioral variables are available, we instantiate the Shifter network
by setting Shifter=True in the model configuration.

In [81]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
  'stack': -1,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': True,
}

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

In [90]:
device = torch.device("cuda:6")
#'device' : f"cuda:6"
model.to(device)

SyntaxError: illegal target for annotation (2281648631.py, line 2)

## Configure Trainer

In [92]:
trainer_fn = "sensorium.training.standard_trainer"

trainer_config = {'max_iter': 200,
                 'verbose': False,
                 'lr_decay_steps': 4,
                 'avg_loss': False,
                 'lr_init': 0.009,
                 'device' : f"cuda:6"
                 }

trainer = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config)

#trainer.to(device)

In [93]:
model.readout['23343-5-17'].features.device
next(iter(dataloaders['train']['23343-5-17'])).images.device

device(type='cuda', index=6)

# Run model training

In [95]:
validation_score, trainer_output, state_dict = trainer(model, dataloaders, seed=42)

Epoch 1: 100%|██████████| 252/252 [00:40<00:00,  6.23it/s]
Epoch 2: 100%|██████████| 252/252 [00:42<00:00,  5.96it/s]
Epoch 3: 100%|██████████| 252/252 [00:40<00:00,  6.15it/s]
Epoch 4: 100%|██████████| 252/252 [00:40<00:00,  6.17it/s]
Epoch 5: 100%|██████████| 252/252 [00:40<00:00,  6.19it/s]
Epoch 6: 100%|██████████| 252/252 [00:40<00:00,  6.27it/s]
Epoch 7: 100%|██████████| 252/252 [00:40<00:00,  6.26it/s]
Epoch 8: 100%|██████████| 252/252 [00:40<00:00,  6.25it/s]
Epoch 9: 100%|██████████| 252/252 [00:40<00:00,  6.20it/s]
Epoch 10: 100%|██████████| 252/252 [00:39<00:00,  6.30it/s]
Epoch 11: 100%|██████████| 252/252 [00:40<00:00,  6.25it/s]
Epoch 12: 100%|██████████| 252/252 [00:40<00:00,  6.23it/s]
Epoch 13: 100%|██████████| 252/252 [00:40<00:00,  6.27it/s]
Epoch 14: 100%|██████████| 252/252 [00:40<00:00,  6.26it/s]
Epoch 15: 100%|██████████| 252/252 [00:40<00:00,  6.18it/s]
Epoch 16: 100%|██████████| 252/252 [00:40<00:00,  6.27it/s]
Epoch 17: 100%|██████████| 252/252 [00:40<00:00, 

## Save model checkpoints

In [99]:
torch.save(model.state_dict(), './model_checkpoints/sensorium_p_sota_model.pth')

## Load Model Checkpoints

In [None]:
model.load_state_dict(torch.load("./model_checkpoints/sensorium_p_sota_model.pth"));

In [105]:
model.load_state_dict(torch.load("./model_checkpoints/sensorium_p_sota_model.pth"));

RuntimeError: Error(s) in loading state_dict for FiringRateEncoder:
	Unexpected key(s) in state_dict: "core.features.layer3.ds_conv.in_depth_conv.weight", "core.features.layer3.ds_conv.in_depth_conv.bias", "core.features.layer3.ds_conv.spatial_conv.weight", "core.features.layer3.ds_conv.spatial_conv.bias", "core.features.layer3.ds_conv.out_depth_conv.weight", "core.features.layer3.ds_conv.out_depth_conv.bias", "core.features.layer3.norm.weight", "core.features.layer3.norm.bias", "core.features.layer3.norm.running_mean", "core.features.layer3.norm.running_var", "core.features.layer3.norm.num_batches_tracked". 

---

# Train a simple LN model

In [101]:
# this will remove all nonlinearities from the CNN, and creates essentially a ln model: linear core + readout, with a subsequent non-linearity

model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,  #no padding
              'stack': -1,  #no stacking
              'layers': 3,
              'input_kern': 9,  #filter dimension 9x9
              'gamma_input': 6.3831,
              'gamma_readout': 0.0076,
              'hidden_kern': 7,  #convolution in hidden layers
              'hidden_channels': 64, #hidden filters that are applied
              'grid_mean_predictor': {'type': 'cortex', #submodel, uses positions of the brain cells to help to learn their receptive fields (because often cells that are actually close in the brain look in close pixels areas)
              'input_dimensions': 2,
              'hidden_layers': 1,
              'hidden_features': 30,
              'final_tanh': True}, #applying tanh activation function in final layer
              'depth_separable': True, #depthwise separable convolutions
              'init_sigma': 0.1, #nitial standard deviation for the weights
              'init_mu_range': 0.3, #range for initial mean of weights
              'gauss_type': 'full', #Gaussian type, full covariance matrix
              'linear': True, #include linear layers
              'shifter': True, #include shifting?
               }
model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

In [102]:
validation_score, trainer_output, state_dict = trainer(model, dataloaders, seed=42)

Epoch 1: 100%|██████████| 252/252 [00:37<00:00,  6.65it/s]
Epoch 2: 100%|██████████| 252/252 [00:38<00:00,  6.57it/s]
Epoch 3: 100%|██████████| 252/252 [00:38<00:00,  6.60it/s]
Epoch 4: 100%|██████████| 252/252 [00:38<00:00,  6.62it/s]
Epoch 5: 100%|██████████| 252/252 [00:38<00:00,  6.60it/s]
Epoch 6: 100%|██████████| 252/252 [00:38<00:00,  6.57it/s]
Epoch 7: 100%|██████████| 252/252 [00:37<00:00,  6.66it/s]
Epoch 8: 100%|██████████| 252/252 [00:37<00:00,  6.66it/s]
Epoch 9: 100%|██████████| 252/252 [00:37<00:00,  6.68it/s]
Epoch 10: 100%|██████████| 252/252 [00:37<00:00,  6.66it/s]
Epoch 11: 100%|██████████| 252/252 [00:37<00:00,  6.69it/s]
Epoch 12: 100%|██████████| 252/252 [00:37<00:00,  6.71it/s]
Epoch 13: 100%|██████████| 252/252 [00:37<00:00,  6.73it/s]
Epoch 14: 100%|██████████| 252/252 [00:38<00:00,  6.61it/s]
Epoch 15: 100%|██████████| 252/252 [00:37<00:00,  6.63it/s]
Epoch 16: 100%|██████████| 252/252 [00:38<00:00,  6.63it/s]
Epoch 17: 100%|██████████| 252/252 [00:37<00:00, 

In [106]:
torch.save(model.state_dict(), './model_checkpoints/sensorium_p_ln_model.pth')

In [107]:
model.load_state_dict(torch.load("./model_checkpoints/sensorium_p_ln_model.pth"));

In [109]:
print(model)

FiringRateEncoder(
  (core): Stacked2dCore(
    (_input_weights_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv): Conv2d(4, 64, kernel_size=(9, 9), stride=(1, 1), bias=False)
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
      )
      (layer1): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
          (spatial_conv): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64)
          (out_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
      )
      (layer2): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
          (spatial_conv): Conv2d(64, 64, kernel_si

In [110]:
# this will remove all nonlinearities from the CNN, and creates essentially a ln model: linear core + readout, with a subsequent non-linearity

model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,  #no padding
              'stack': -1,  #no stacking
              'layers': 3,
              'input_kern': 9,  #filter dimension 9x9
              'gamma_input': 6.3831,
              'gamma_readout': 0.0076,
              'hidden_kern': 7,  #convolution in hidden layers
              'hidden_channels': 64, #hidden filters that are applied
              'grid_mean_predictor': {'type': 'cortex', #submodel, uses positions of the brain cells to help to learn their receptive fields (because often cells that are actually close in the brain look in close pixels areas)
              'input_dimensions': 2,
              'hidden_layers': 1,
              'hidden_features': 30,
              'final_tanh': True}, #applying tanh activation function in final layer
              'depth_separable': True, #depthwise separable convolutions
              'init_sigma': 0.1, #nitial standard deviation for the weights
              'init_mu_range': 0.3, #range for initial mean of weights
              'gauss_type': 'full', #Gaussian type, full covariance matrix
              'linear': True, #include linear layers
              'shifter': False, #include shifting?
               }
model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

In [111]:
validation_score, trainer_output, state_dict = trainer(model, dataloaders, seed=42)

Epoch 1: 100%|██████████| 252/252 [00:37<00:00,  6.63it/s]
Epoch 2: 100%|██████████| 252/252 [00:37<00:00,  6.69it/s]
Epoch 3: 100%|██████████| 252/252 [00:37<00:00,  6.66it/s]
Epoch 4: 100%|██████████| 252/252 [00:37<00:00,  6.68it/s]
Epoch 5: 100%|██████████| 252/252 [00:37<00:00,  6.65it/s]
Epoch 6: 100%|██████████| 252/252 [00:38<00:00,  6.62it/s]
Epoch 7: 100%|██████████| 252/252 [00:38<00:00,  6.56it/s]
Epoch 8: 100%|██████████| 252/252 [00:37<00:00,  6.64it/s]
Epoch 9: 100%|██████████| 252/252 [00:38<00:00,  6.54it/s]
Epoch 10: 100%|██████████| 252/252 [00:38<00:00,  6.56it/s]
Epoch 11: 100%|██████████| 252/252 [00:38<00:00,  6.61it/s]
Epoch 12: 100%|██████████| 252/252 [00:38<00:00,  6.58it/s]
Epoch 13: 100%|██████████| 252/252 [00:38<00:00,  6.63it/s]
Epoch 14: 100%|██████████| 252/252 [00:38<00:00,  6.56it/s]
Epoch 15: 100%|██████████| 252/252 [00:38<00:00,  6.62it/s]
Epoch 16: 100%|██████████| 252/252 [00:37<00:00,  6.65it/s]
Epoch 17: 100%|██████████| 252/252 [00:38<00:00, 

In [112]:
print(model)

FiringRateEncoder(
  (core): Stacked2dCore(
    (_input_weights_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv): Conv2d(4, 64, kernel_size=(9, 9), stride=(1, 1), bias=False)
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
      )
      (layer1): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
          (spatial_conv): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64)
          (out_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
        )
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
      )
      (layer2): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1))
          (spatial_conv): Conv2d(64, 64, kernel_si

In [None]:
model_fn = 'sensorium.models.ecker_core_full_gauss_readout'
model_config = {
  'pad_input': False,
  'stack': -1,
  'layers': 4,
  'hidden_channels': 16,
  'num_rotations': 8,
  'input_kern': 9,
  'hidden_kern': 7,
    
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'feature_reg_weight':0.0076,
 
  'depth_separable': True,
                
   'grid_mean_predictor': {
        'type': 'cortex',
        'input_dimensions': 2,
        'hidden_layers': 1,
        'hidden_features': 30,
        'final_tanh': True
   },
                
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': True,
  
}

---