# Train modulated model

### Imports

In [1]:
import os
if 'notebooks' in os.getcwd(): os.chdir('../..')  # change to main directory
print('Working directory:', os.getcwd() )

Working directory: /scratch/snx3000/bp000429/adrian_sensorium


In [2]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

from nnfabrik.builder import get_data, get_model, get_trainer

### Prepare data loader

In [3]:
# loading the SENSORIUM+ dataset
# filenames = ['../data/static27204-5-13-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', ]

# another dataset
# filenames = ['notebooks/data/static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', ]

# all datasets with behavior data
basepath = "notebooks/data/"
filenames = [os.path.join(basepath, file) for file in os.listdir(basepath) if ".zip" in file ]
filenames = [file for file in filenames if 'static26872-17-20' not in file]

dataset_fn = 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                 'normalize': True,
                 'include_behavior': True,
                 # 'include_behavior': False,
                 'include_eye_position': True,
                 'batch_size': 128,
                 'scale':.25,
                 'preload_from_merged_data':True,
                 'include_trial_id':True,
                 'include_history':True,
                 'include_gain':True,
                 }

dataloaders = get_data(dataset_fn, dataset_config)

In [4]:
tier = 'train'
dataset_name = '21067-10-18'

for batch in dataloaders[tier][dataset_name]:
    break
    
for i, field in enumerate(batch._fields):
    print(f"{field}, {batch[i].shape}")

images, torch.Size([128, 4, 36, 64])
responses, torch.Size([128, 8372])
behavior, torch.Size([128, 3])
pupil_center, torch.Size([128, 2])
trial_id, torch.Size([128, 1])
history, torch.Size([128, 8372, 5])
gain, torch.Size([128, 1])


### Instantiate modulated model

In [5]:
import importlib
from sensorium.models import models
importlib.reload(models)

<module 'sensorium.models.models' from '/scratch/snx3000/bp000429/adrian_sensorium/sensorium/models/models.py'>

In [6]:
model_fn = 'sensorium.models.modulated_stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
  'stack': -1,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': True,
  'with_history':True,
  'with_gain':True,
}

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

In [7]:
model

ModulatedFiringRateEncoder(
  (core): Stacked2dCore(
    (_input_weights_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv): Conv2d(4, 64, kernel_size=(9, 9), stride=(1, 1), bias=False)
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): AdaptiveELU()
      )
      (layer1): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (spatial_conv): Conv2d(64, 64, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3), groups=64, bias=False)
          (out_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): AdaptiveELU()
      )
      (layer2): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_dept

### Configure trainer

In [8]:
trainer_fn = "sensorium.training.standard_trainer"

trainer_config = {'max_iter': 200,
                 'verbose': True,
                 'lr_decay_steps': 4,
                 'avg_loss': False,
                 'lr_init': 0.009,
                 'track_training':True,
                 }

trainer = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config)

### Train model

In [None]:
validation_score, trainer_output, state_dict = trainer(model, dataloaders, seed=42)

correlation 0.021659302
poisson_loss 28039642.0


Epoch 1: 100%|██████████| 216/216 [01:15<00:00,  2.88it/s]


correlation 0.032460805
poisson_loss 38090884.0


Epoch 2: 100%|██████████| 216/216 [01:14<00:00,  2.88it/s]


correlation 0.05982633
poisson_loss 35142812.0


Epoch 3: 100%|██████████| 216/216 [01:14<00:00,  2.89it/s]


correlation 0.07480242
poisson_loss 32631860.0


Epoch 4: 100%|██████████| 216/216 [01:14<00:00,  2.88it/s]


correlation 0.08353358
poisson_loss 32968652.0


Epoch 5: 100%|██████████| 216/216 [01:14<00:00,  2.89it/s]


correlation 0.09707296
poisson_loss 29369500.0


Epoch 6: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.09585443
poisson_loss 28664730.0


Epoch 7: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.10476508
poisson_loss 27182264.0


Epoch 8: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.10571827
poisson_loss 26748952.0


Epoch 9: 100%|██████████| 216/216 [01:14<00:00,  2.91it/s]


correlation 0.10353032
poisson_loss 27290496.0


Epoch 10: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.109802976
poisson_loss 24772576.0


Epoch 11: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.10908372
poisson_loss 23806090.0


Epoch 12: 100%|██████████| 216/216 [01:14<00:00,  2.91it/s]


correlation 0.11456754
poisson_loss 22930808.0


Epoch 13: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.11299714
poisson_loss 22858308.0


Epoch 14: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.11593269
poisson_loss 22555468.0


Epoch 15: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.11715897
poisson_loss 21609306.0


Epoch 16: 100%|██████████| 216/216 [01:14<00:00,  2.91it/s]


correlation 0.11802587
poisson_loss 21404690.0


Epoch 17: 100%|██████████| 216/216 [01:14<00:00,  2.91it/s]


correlation 0.119625024
poisson_loss 20839072.0


Epoch 18: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.12043563
poisson_loss 19787484.0


Epoch 19: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.12405147
poisson_loss 19099262.0


Epoch 20: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.12381343
poisson_loss 19198186.0


Epoch 21: 100%|██████████| 216/216 [01:14<00:00,  2.89it/s]


correlation 0.12601659
poisson_loss 18731854.0


Epoch 22: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.12819296
poisson_loss 18484184.0


Epoch 23: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.12976663
poisson_loss 19169126.0


Epoch 24: 100%|██████████| 216/216 [01:14<00:00,  2.91it/s]


correlation 0.12935127
poisson_loss 18698468.0


Epoch 25: 100%|██████████| 216/216 [01:14<00:00,  2.91it/s]


correlation 0.12461507
poisson_loss 19377424.0


Epoch 26: 100%|██████████| 216/216 [01:14<00:00,  2.91it/s]


correlation 0.12571584
poisson_loss 18825680.0


Epoch 27: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.124217965
poisson_loss 18460248.0


Epoch 28: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.12975726
poisson_loss 19171744.0


Epoch 29: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


Epoch    29: reducing learning rate of group 0 to 2.7000e-03.
correlation 0.13085936
poisson_loss 19163046.0


Epoch 30: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.13420855
poisson_loss 18302238.0


Epoch 31: 100%|██████████| 216/216 [01:14<00:00,  2.91it/s]


correlation 0.13504171
poisson_loss 17961950.0


Epoch 32: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.13582371
poisson_loss 18013080.0


Epoch 33: 100%|██████████| 216/216 [01:14<00:00,  2.91it/s]


correlation 0.13660808
poisson_loss 17622748.0


Epoch 34: 100%|██████████| 216/216 [01:14<00:00,  2.91it/s]


correlation 0.13747653
poisson_loss 17358324.0


Epoch 35: 100%|██████████| 216/216 [01:14<00:00,  2.91it/s]


correlation 0.13835442
poisson_loss 17189780.0


Epoch 36: 100%|██████████| 216/216 [01:14<00:00,  2.91it/s]


correlation 0.13932641
poisson_loss 17068604.0


Epoch 37: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.14039895
poisson_loss 16959784.0


Epoch 38: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.14184295
poisson_loss 16827952.0


Epoch 39: 100%|██████████| 216/216 [01:14<00:00,  2.89it/s]


correlation 0.14334084
poisson_loss 16725041.0


Epoch 40: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.14420411
poisson_loss 16644942.0


Epoch 41: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.1449989
poisson_loss 16569525.0


Epoch 42: 100%|██████████| 216/216 [01:14<00:00,  2.89it/s]


correlation 0.1459616
poisson_loss 16504432.0


Epoch 43: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.14687797
poisson_loss 16442501.0


Epoch 44: 100%|██████████| 216/216 [01:14<00:00,  2.89it/s]


correlation 0.1475651
poisson_loss 16380590.0


Epoch 45: 100%|██████████| 216/216 [01:14<00:00,  2.89it/s]


correlation 0.14872928
poisson_loss 16330948.0


Epoch 46: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.14969017
poisson_loss 16280370.0


Epoch 47: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.15029559
poisson_loss 16229265.0


Epoch 48: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.15077324
poisson_loss 16186675.0


Epoch 49: 100%|██████████| 216/216 [01:14<00:00,  2.89it/s]


correlation 0.15229523
poisson_loss 16144566.0


Epoch 50: 100%|██████████| 216/216 [01:14<00:00,  2.89it/s]


correlation 0.15283455
poisson_loss 16102704.0


Epoch 51: 100%|██████████| 216/216 [01:14<00:00,  2.89it/s]


correlation 0.15371917
poisson_loss 16068622.0


Epoch 52: 100%|██████████| 216/216 [01:14<00:00,  2.89it/s]


correlation 0.15499136
poisson_loss 16035478.0


Epoch 53: 100%|██████████| 216/216 [01:14<00:00,  2.89it/s]


correlation 0.15566918
poisson_loss 16001922.0


Epoch 54: 100%|██████████| 216/216 [01:14<00:00,  2.89it/s]


correlation 0.1561937
poisson_loss 15965501.0


Epoch 55: 100%|██████████| 216/216 [01:13<00:00,  2.93it/s]


correlation 0.1573367
poisson_loss 15939796.0


Epoch 56: 100%|██████████| 216/216 [01:13<00:00,  2.93it/s]


correlation 0.15815358
poisson_loss 15915852.0


Epoch 57: 100%|██████████| 216/216 [01:13<00:00,  2.94it/s]


correlation 0.15904883
poisson_loss 15884643.0


Epoch 58: 100%|██████████| 216/216 [01:13<00:00,  2.93it/s]


correlation 0.15952909
poisson_loss 15857888.0


Epoch 59: 100%|██████████| 216/216 [01:13<00:00,  2.94it/s]


correlation 0.16047446
poisson_loss 15833138.0


Epoch 60: 100%|██████████| 216/216 [01:13<00:00,  2.93it/s]


correlation 0.1609392
poisson_loss 15813623.0


Epoch 61: 100%|██████████| 216/216 [01:13<00:00,  2.93it/s]


correlation 0.16191012
poisson_loss 15789859.0


Epoch 62: 100%|██████████| 216/216 [01:13<00:00,  2.93it/s]


correlation 0.16273592
poisson_loss 15769422.0


Epoch 63: 100%|██████████| 216/216 [01:13<00:00,  2.93it/s]


correlation 0.16351815
poisson_loss 15752296.0


Epoch 64: 100%|██████████| 216/216 [01:13<00:00,  2.94it/s]


correlation 0.16428375
poisson_loss 15726659.0


Epoch 65: 100%|██████████| 216/216 [01:13<00:00,  2.93it/s]


correlation 0.16518028
poisson_loss 15707864.0


Epoch 66: 100%|██████████| 216/216 [01:13<00:00,  2.92it/s]


correlation 0.16566674
poisson_loss 15692328.0


Epoch 67: 100%|██████████| 216/216 [01:14<00:00,  2.92it/s]


correlation 0.16644049
poisson_loss 15668376.0


Epoch 68: 100%|██████████| 216/216 [01:13<00:00,  2.92it/s]


correlation 0.16695447
poisson_loss 15653784.0


Epoch 69: 100%|██████████| 216/216 [01:13<00:00,  2.92it/s]


correlation 0.16786629
poisson_loss 15636368.0


Epoch 70: 100%|██████████| 216/216 [01:13<00:00,  2.93it/s]


correlation 0.1686203
poisson_loss 15615638.0


Epoch 71: 100%|██████████| 216/216 [01:13<00:00,  2.93it/s]


correlation 0.169562
poisson_loss 15603844.0


Epoch 72: 100%|██████████| 216/216 [01:13<00:00,  2.93it/s]


correlation 0.17014147
poisson_loss 15587348.0


Epoch 73: 100%|██████████| 216/216 [01:13<00:00,  2.92it/s]


correlation 0.17087375
poisson_loss 15569363.0


Epoch 74: 100%|██████████| 216/216 [01:13<00:00,  2.93it/s]


correlation 0.17181973
poisson_loss 15559016.0


Epoch 75: 100%|██████████| 216/216 [01:13<00:00,  2.93it/s]


correlation 0.17249021
poisson_loss 15539118.0


Epoch 76: 100%|██████████| 216/216 [01:13<00:00,  2.92it/s]


correlation 0.17285357
poisson_loss 15522672.0


Epoch 77: 100%|██████████| 216/216 [01:13<00:00,  2.93it/s]


correlation 0.17383343
poisson_loss 15512695.0


Epoch 78: 100%|██████████| 216/216 [01:13<00:00,  2.93it/s]


correlation 0.17498161
poisson_loss 15497676.0


Epoch 79: 100%|██████████| 216/216 [01:14<00:00,  2.90it/s]


correlation 0.17547983
poisson_loss 15482318.0


Epoch 80: 100%|██████████| 216/216 [01:14<00:00,  2.88it/s]


correlation 0.17587332
poisson_loss 15467213.0


Epoch 81: 100%|██████████| 216/216 [01:14<00:00,  2.88it/s]


correlation 0.17733844
poisson_loss 15454879.0


Epoch 82: 100%|██████████| 216/216 [01:14<00:00,  2.88it/s]


correlation 0.1779207
poisson_loss 15439754.0


Epoch 83: 100%|██████████| 216/216 [01:14<00:00,  2.89it/s]


correlation 0.17893349
poisson_loss 15427276.0


Epoch 84: 100%|██████████| 216/216 [01:14<00:00,  2.89it/s]


correlation 0.17963554
poisson_loss 15412110.0


Epoch 85:  82%|████████▏ | 177/216 [01:01<00:13,  2.92it/s]

### Save model

In [None]:
torch.save(model.state_dict(), 'notebooks/my_models/withHistory_withState_allRecs_model_1.pth')

In [None]:
# this model requires include_behavior=True
# model.load_state_dict(torch.load('notebooks/my_models/withHistory_21067-10-18_model_1.pth'));

### Predict responses

In [None]:
from sensorium.utility import prediction

In [None]:
# calculate predictions per dataloader
results = prediction.all_predictions_with_trial(model, dataloaders)

# merge predictions, sort in time and add behavioral variables
merged = prediction.merge_predictions(results)
sorted_res = prediction.sort_predictions_by_time(merged)
prediction.inplace_add_behavior_to_sorted_predictions(sorted_res)

In [None]:
out_folder = 'notebooks/my_results'
# file_name = 'res_v1_allBehavior.npy'
file_name = 'res_v2_withModulatorAndState_all.npy'

np.save( os.path.join(out_folder, file_name), sorted_res)

In [None]:
from ipylab import JupyterFrontEnd

app = JupyterFrontEnd()
app.commands.execute('docmanager:save')