# Train model

### Imports

In [1]:
import os
if 'notebooks' in os.getcwd(): os.chdir('../..')  # change to main directory
print('Working directory:', os.getcwd() )

Working directory: /scratch/snx3000/bp000429/submission/adrian_sensorium


import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

from nnfabrik.builder import get_data, get_model, get_trainer

from sensorium.utility.training import read_config, print_t, set_seed
from sensorium.utility import prediction

### Load configuration for model

In [3]:
config_file = 'notebooks/submission/config_submission_m2.yaml'
config = read_config( config_file )
print(config)

ordereddict([('data_sets', ['all']), ('dataset_fn', 'sensorium.datasets.static_loaders'), ('dataset_config', ordereddict([('normalize', True), ('include_behavior', True), ('include_eye_position', True), ('batch_size', 128), ('scale', 0.25), ('preload_from_merged_data', True), ('include_trial_id', True), ('include_rank_id', True), ('include_history', True), ('include_behav_state', True), ('adjusted_normalization', True)])), ('model_fn', 'sensorium.models.modulated_stacked_core_full_gauss_readout'), ('model_seed', 3452), ('model_config', ordereddict([('pad_input', False), ('stack', -1), ('layers', 4), ('input_kern', 9), ('gamma_input', 6.3831), ('gamma_readout', 0.0076), ('hidden_kern', 7), ('hidden_channels', 64), ('depth_separable', True), ('grid_mean_predictor', ordereddict([('type', 'cortex'), ('input_dimensions', 2), ('hidden_layers', 4), ('hidden_features', 20), ('nonlinearity', 'ReLU'), ('final_tanh', True)])), ('init_sigma', 0.1), ('init_mu_range', 0.3), ('gauss_type', 'full'), (

### Prepare dataloader

In [4]:
set_seed( config['model_seed'] )  # seed all random generators

if config['data_sets'][0] == 'all':
    basepath = "notebooks/data/"
    filenames = [os.path.join(basepath, file) for file in os.listdir(basepath) if ".zip" in file ]
    filenames = [file for file in filenames if 'static26872-17-20' not in file]
else:
    filenames = config['data_sets']
    # filenames like ['notebooks/data/static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', ]
    
dataset_fn = config['dataset_fn']  # 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                  **config['dataset_config'],
                 }

dataloaders = get_data(dataset_fn, dataset_config)

### Instantiate model and trainer

In [5]:
# Instantiate model
model_fn = config['model_fn']     # e.g. 'sensorium.models.modulated_stacked_core_full_gauss_readout'
model_config = config['model_config']

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=config['model_seed'],
                 )

# Trainer
trainer_fn = config['trainer_fn']   # "sensorium.training.standard_trainer"
trainer_config = config['trainer_config']

trainer = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config)

### Train model

In [6]:
print_t('Start of model training')
validation_score, trainer_output, state_dict = trainer(model, dataloaders, seed=42)
print_t('Model training finished')

save_file = 'notebooks/submission/results/model_v2.pth'
torch.save(model.state_dict(), save_file )

2022-10-09 11:00:15.425493: Start of model training
correlation 0.0040906374
poisson_loss 28096714.0


Epoch 1: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.16588514
poisson_loss 15558070.0


Epoch 2: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.2173214
poisson_loss 14931198.0


Epoch 3: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.24135531
poisson_loss 14642372.0


Epoch 4: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.2705906
poisson_loss 14280804.0


Epoch 5: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.29458615
poisson_loss 13961966.0


Epoch 6: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.31677744
poisson_loss 13684361.0


Epoch 7: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.34000996
poisson_loss 13388418.0


Epoch 8: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.35798052
poisson_loss 13155551.0


Epoch 9: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.36800304
poisson_loss 13038218.0


Epoch 10: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.3761991
poisson_loss 12932908.0


Epoch 11: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.38532028
poisson_loss 12813110.0


Epoch 12: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.38949
poisson_loss 12774718.0


Epoch 13: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.39846846
poisson_loss 12650885.0


Epoch 14: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4010617
poisson_loss 12635144.0


Epoch 15: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.40566584
poisson_loss 12570400.0


Epoch 16: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.40991342
poisson_loss 12516355.0


Epoch 17: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.41232136
poisson_loss 12489952.0


Epoch 18: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4158495
poisson_loss 12440696.0


Epoch 19: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.41854668
poisson_loss 12408627.0


Epoch 20: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.41979498
poisson_loss 12402345.0


Epoch 21: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.42282867
poisson_loss 12355625.0


Epoch 22: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.42342263
poisson_loss 12348248.0


Epoch 23: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.42843363
poisson_loss 12299989.0


Epoch 24: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.43020043
poisson_loss 12267109.0


Epoch 25: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.43139347
poisson_loss 12259650.0


Epoch 26: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.43033522
poisson_loss 12270830.0


Epoch 27: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4344676
poisson_loss 12216340.0


Epoch 28: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.43510768
poisson_loss 12211379.0


Epoch 29: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.43776295
poisson_loss 12177393.0


Epoch 30: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4359335
poisson_loss 12210258.0


Epoch 31: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4398992
poisson_loss 12160616.0


Epoch 32: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4408663
poisson_loss 12146759.0


Epoch 33: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4412584
poisson_loss 12134203.0


Epoch 34: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.442475
poisson_loss 12120536.0


Epoch 35: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.44392315
poisson_loss 12098658.0


Epoch 36: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4442567
poisson_loss 12113207.0


Epoch 37: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.44360852
poisson_loss 12113034.0


Epoch 38: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.4461465
poisson_loss 12082354.0


Epoch 39: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4464592
poisson_loss 12073703.0


Epoch 40: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.44564837
poisson_loss 12088389.0


Epoch 41: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.44775456
poisson_loss 12054725.0


Epoch 42: 100%|██████████| 216/216 [01:17<00:00,  2.78it/s]


correlation 0.449173
poisson_loss 12046092.0


Epoch 43: 100%|██████████| 216/216 [01:17<00:00,  2.78it/s]


correlation 0.44902515
poisson_loss 12042200.0


Epoch 44: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.44971457
poisson_loss 12039057.0


Epoch 45: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.44983587
poisson_loss 12040374.0


Epoch 46: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.44999433
poisson_loss 12035894.0


Epoch 47: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4509823
poisson_loss 12024478.0


Epoch 48: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.451626
poisson_loss 12011753.0


Epoch 49: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.45209286
poisson_loss 12016200.0


Epoch 50: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4513029
poisson_loss 12025725.0


Epoch 51: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4517763
poisson_loss 12015206.0


Epoch 52: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4521626
poisson_loss 12012432.0


Epoch 53: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45298755
poisson_loss 11995997.0


Epoch 54: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.4532369
poisson_loss 11995358.0


Epoch 55: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.4530249
poisson_loss 12003359.0


Epoch 56: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45428854
poisson_loss 11989930.0


Epoch 57: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45329812
poisson_loss 11998141.0


Epoch 58: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.4531028
poisson_loss 12003013.0


Epoch 59: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.45483717
poisson_loss 11984916.0


Epoch 60: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45394552
poisson_loss 11992094.0


Epoch 61: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45466045
poisson_loss 11978017.0


Epoch 62: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45431036
poisson_loss 11988973.0


Epoch 63: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45531338
poisson_loss 11979690.0


Epoch 64: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45504293
poisson_loss 11976848.0


Epoch 65: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.454963
poisson_loss 11977460.0


Epoch 66: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45583528
poisson_loss 11970704.0


Epoch 67: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45575288
poisson_loss 11976288.0


Epoch 68: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45590794
poisson_loss 11974727.0


Epoch 69: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45574382
poisson_loss 11969775.0


Epoch 70: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45667765
poisson_loss 11971926.0


Epoch 71: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45726252
poisson_loss 11954885.0


Epoch 72: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45645604
poisson_loss 11969152.0


Epoch 73: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.455246
poisson_loss 11983576.0


Epoch 74: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4561357
poisson_loss 11973064.0


Epoch 75: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45695218
poisson_loss 11959046.0


Epoch 76: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.45714808
poisson_loss 11955849.0


Epoch 77: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


Epoch    77: reducing learning rate of group 0 to 2.7000e-03.
correlation 0.4566706
poisson_loss 11957911.0


Epoch 78: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.4656053
poisson_loss 11844687.0


Epoch 79: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46572816
poisson_loss 11839302.0


Epoch 80: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46587235
poisson_loss 11839943.0


Epoch 81: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46549252
poisson_loss 11844730.0


Epoch 82: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46582446
poisson_loss 11844995.0


Epoch 83: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46584278
poisson_loss 11840330.0


Epoch 84: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4657834
poisson_loss 11841664.0


Epoch 85: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.46608514
poisson_loss 11842506.0


Epoch 86: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46572146
poisson_loss 11846463.0


Epoch 87: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.4656091
poisson_loss 11843405.0


Epoch 88: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46629128
poisson_loss 11839828.0


Epoch 89: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46594033
poisson_loss 11840885.0


Epoch 90: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.46554038
poisson_loss 11843007.0


Epoch 91: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.46589765
poisson_loss 11844050.0


Epoch 92: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.4656975
poisson_loss 11845023.0


Epoch 93: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46631652
poisson_loss 11839626.0


Epoch 94: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


Epoch    94: reducing learning rate of group 0 to 8.1000e-04.
correlation 0.46575025
poisson_loss 11843066.0


Epoch 95: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.46790034
poisson_loss 11816017.0


Epoch 96: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.46825567
poisson_loss 11814150.0


Epoch 97: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46832007
poisson_loss 11814662.0


Epoch 98: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4682537
poisson_loss 11814727.0


Epoch 99: 100%|██████████| 216/216 [01:17<00:00,  2.79it/s]


correlation 0.4683431
poisson_loss 11813508.0


Epoch 100: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46837798
poisson_loss 11813936.0


Epoch 101: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46830225
poisson_loss 11813935.0


Epoch 102: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.4684136
poisson_loss 11811687.0


Epoch 103: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46863744
poisson_loss 11810918.0


Epoch 104: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.4683212
poisson_loss 11813889.0


Epoch 105: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46842644
poisson_loss 11813726.0


Epoch 106: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46838796
poisson_loss 11813971.0


Epoch 107: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46839276
poisson_loss 11811720.0


Epoch 108: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.4683116
poisson_loss 11816313.0


Epoch 109: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.4682966
poisson_loss 11813849.0


Epoch 110: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


Epoch   110: reducing learning rate of group 0 to 2.4300e-04.
correlation 0.46832773
poisson_loss 11813424.0


Epoch 111: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46877605
poisson_loss 11808900.0


Epoch 112: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46885255
poisson_loss 11808024.0


Epoch 113: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


correlation 0.46894723
poisson_loss 11807556.0


Epoch 114: 100%|██████████| 216/216 [01:17<00:00,  2.80it/s]


2022-10-09 14:03:24.853475: Model training finished


### Save all predictions as .npy

In [7]:
if config['save_predictions_npy']:
    # calculate predictions per dataloader
    results = prediction.all_predictions_with_trial(model, dataloaders)

    # merge predictions, sort in time and add behavioral variables
    merged = prediction.merge_predictions(results)
    sorted_res = prediction.sort_predictions_by_time(merged)
    prediction.inplace_add_behavior_to_sorted_predictions(sorted_res)
    
    npy_file = 'notebooks/submission/results/prediction_model_v2.npy'
    np.save( npy_file, sorted_res)
    

Iterating datasets: 100%|██████████| 6/6 [01:17<00:00, 12.85s/it]
