# How to train the Baseline Models for the SENSORIUM track

### This notebook will show how to
- instantiate dataloader for the Sensorium track
- instantiate pytorch model
- instantiate a trainer function
- train two baselines for this competition track
- save the model weights (the model weights can already be found in './model_checkpoints/pretrained/')

### Imports

In [1]:
from nnfabrik.builder import get_data, get_model, get_trainer
import torch
import numpy as np
import pandas as pd


import matplotlib.pyplot as plt
import seaborn as sns


import warnings


warnings.filterwarnings('ignore')

In [2]:
import os
current_path = os.getcwd()
# Identify if path has 'sensorium' as a folder in it
if 'sensorium' in current_path:
    # If so, set the path to the root of the repo
    current_path = current_path.split('sensorium')[0] + 'sensorium'
else:
    raise FileNotFoundError(
        'This needs to be run from within the sensorium folder')
os.chdir(current_path)

### Instantiate DataLoader

In [3]:
# loading the SENSORIUM dataset

filenames = [
    # 'notebooks/data/static23964-4-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip',
    'notebooks/data/IM_prezipped/LPE11086/2023_12_16/'
]


dataset_fn = 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,


                  'normalize': True,


                  'include_behavior': False,


                  'include_eye_position': False,


                  'batch_size': 128,


                  'scale': 0.25,


                  }


dataloaders_sens = get_data(dataset_fn, dataset_config)

In [4]:
# loading the SENSORIUM dataset

filenames = [
    # 'notebooks/data/static23964-4-22-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip',
    'notebooks/data/IM_prezipped/LPE11086/2023_12_16/'
]


dataset_fn = 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                  'normalize': True,
                  'include_behavior': False,
                  'include_eye_position': False,
                  'batch_size': 128,
                  'scale': 0.25,
                  }


dataloaders = get_data(dataset_fn, dataset_config)

# Instantiate State of the Art Model (SOTA)

In [5]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
                'stack': -1,
                'layers': 4,
                'input_kern': 9,
                'gamma_input': 6.3831,
                'gamma_readout': 0.0076,
                'hidden_kern': 7,
                'hidden_channels': 64,
                'depth_separable': True,
                'grid_mean_predictor': {'type': 'cortex',
                                        'input_dimensions': 2,
                                        'hidden_layers': 1,
                                        'hidden_features': 30,
                                        'final_tanh': True},
                'init_sigma': 0.1,
                'init_mu_range': 0.3,
                'gauss_type': 'full',
                'shifter': False,
                }

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

## Configure Trainer

In [6]:
trainer_fn = "sensorium.training.standard_trainer"

trainer_config = {'max_iter': 200,
                  'verbose': True,
                  'lr_decay_steps': 4,
                  'avg_loss': False,
                  'lr_init': 0.009,
                  }

trainer = get_trainer(trainer_fn=trainer_fn,
                      trainer_config=trainer_config)

# Run model training

In [7]:
validation_score, trainer_output, state_dict = trainer(
    model, dataloaders, seed=42)

Epoch 1:   6%|▌         | 2/33 [00:14<03:40,  7.10s/it]


KeyboardInterrupt: 

In [25]:
validation_score

0.18121223

### Save model checkpoints after training is complete

In [26]:
os.getcwd()

'c:\\Users\\asimo\\Documents\\BCCN\\Lab Rotations\\Petreanu Lab\\sensorium'

In [27]:
# torch.save(model.state_dict(), './model_checkpoints/sensorium_sota_model.pth')
torch.save(model.state_dict(),
           'notebooks/model_tutorial/model_checkpoints/IM_sota_model.pth')

## Load Model Checkpoints

In [28]:
# model.load_state_dict(torch.load("./model_checkpoints/pretrained/sensorium_sota_model.pth"));
model.load_state_dict(torch.load(
    'notebooks/model_tutorial/model_checkpoints/IM_sota_model.pth'))

<All keys matched successfully>

---

# Train a simple LN model

Our LN model has the same architecture as our CNN model (a convolutional core followed by a gaussian readout)
but with all non-linearities removed except the final ELU+1 nonlinearity.
Thus turning the CNN model effectively into a fully linear model followed by a single output non-linearity.


In [29]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
                'stack': -1,
                'layers': 3,
                'input_kern': 9,
                'gamma_input': 6.3831,
                'gamma_readout': 0.0076,
                'hidden_kern': 7,
                'hidden_channels': 64,
                'grid_mean_predictor': {'type': 'cortex',
                                        'input_dimensions': 2,
                                        'hidden_layers': 1,
                                        'hidden_features': 30,
                                        'final_tanh': True},
                'depth_separable': True,
                'init_sigma': 0.1,
                'init_mu_range': 0.3,
                'gauss_type': 'full',
                'linear': True
                }
model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

In [30]:
validation_score, trainer_output, state_dict = trainer(
    model, dataloaders, seed=42)

Epoch 1: 100%|██████████| 33/33 [00:02<00:00, 14.70it/s]
Epoch 2: 100%|██████████| 33/33 [00:02<00:00, 14.78it/s]
Epoch 3: 100%|██████████| 33/33 [00:02<00:00, 15.34it/s]
Epoch 4: 100%|██████████| 33/33 [00:02<00:00, 14.89it/s]
Epoch 5: 100%|██████████| 33/33 [00:02<00:00, 15.81it/s]
Epoch 6: 100%|██████████| 33/33 [00:02<00:00, 15.23it/s]
Epoch 7: 100%|██████████| 33/33 [00:02<00:00, 15.88it/s]
Epoch 8: 100%|██████████| 33/33 [00:02<00:00, 15.53it/s]
Epoch 9: 100%|██████████| 33/33 [00:02<00:00, 14.87it/s]
Epoch 10: 100%|██████████| 33/33 [00:02<00:00, 15.68it/s]
Epoch 11: 100%|██████████| 33/33 [00:02<00:00, 15.45it/s]
Epoch 12: 100%|██████████| 33/33 [00:02<00:00, 14.19it/s]
Epoch 13: 100%|██████████| 33/33 [00:02<00:00, 14.21it/s]
Epoch 14: 100%|██████████| 33/33 [00:02<00:00, 13.30it/s]
Epoch 15: 100%|██████████| 33/33 [00:02<00:00, 15.27it/s]
Epoch 16: 100%|██████████| 33/33 [00:02<00:00, 14.72it/s]
Epoch 17: 100%|██████████| 33/33 [00:02<00:00, 15.17it/s]
Epoch 18: 100%|████████

In [31]:
validation_score

0.120164454

In [32]:
# torch.save(model.state_dict(), './model_checkpoints/sensorium_ln_model.pth')
torch.save(model.state_dict(),
           'notebooks/model_tutorial/model_checkpoints/IM_ln_model.pth')

In [33]:
# model.load_state_dict(torch.load("./model_checkpoints/pretrained/sensorium_ln_model.pth"));
model.load_state_dict(torch.load(
    'notebooks/model_tutorial/model_checkpoints/IM_ln_model.pth'))

<All keys matched successfully>

---