# Train model

### Imports

In [1]:
import os
if 'notebooks' in os.getcwd(): os.chdir('../..')  # change to main directory
print('Working directory:', os.getcwd() )

Working directory: /scratch/snx3000/bp000429/submission/adrian_sensorium


In [2]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

from nnfabrik.builder import get_data, get_model, get_trainer

from sensorium.utility.training import read_config, print_t, set_seed
from sensorium.utility import prediction

### Load configuration for model

In [3]:
config_file = 'notebooks/submission_m3/config_submission_m3.yaml'
config = read_config( config_file )
print(config)

ordereddict([('data_sets', ['all']), ('dataset_fn', 'sensorium.datasets.static_loaders'), ('dataset_config', ordereddict([('normalize', True), ('include_behavior', True), ('include_eye_position', True), ('batch_size', 128), ('scale', 0.25), ('preload_from_merged_data', True), ('include_trial_id', True), ('include_rank_id', True), ('include_history', True), ('include_behav_state', True), ('adjusted_normalization', True)])), ('model_fn', 'sensorium.models.modulated_stacked_core_full_gauss_readout'), ('model_seed', 3452), ('model_config', ordereddict([('pad_input', False), ('stack', -1), ('layers', 4), ('input_kern', 9), ('gamma_input', 9.8), ('gamma_readout', 0.48), ('hidden_kern', 10), ('hidden_channels', 64), ('depth_separable', True), ('grid_mean_predictor', ordereddict([('type', 'cortex'), ('input_dimensions', 2), ('hidden_layers', 4), ('hidden_features', 20), ('nonlinearity', 'ReLU'), ('final_tanh', True)])), ('init_sigma', 0.14), ('init_mu_range', 0.8), ('gauss_type', 'full'), ('sh

### Prepare dataloader

In [4]:
set_seed( config['model_seed'] )  # seed all random generators

if config['data_sets'][0] == 'all':
    basepath = "notebooks/data/"
    filenames = [os.path.join(basepath, file) for file in os.listdir(basepath) if ".zip" in file ]
    filenames = [file for file in filenames if 'static26872-17-20' not in file]
else:
    filenames = config['data_sets']
    # filenames like ['notebooks/data/static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', ]
    
dataset_fn = config['dataset_fn']  # 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                  **config['dataset_config'],
                 }

dataloaders = get_data(dataset_fn, dataset_config)

### Instantiate model and trainer

In [5]:
# Instantiate model
model_fn = config['model_fn']     # e.g. 'sensorium.models.modulated_stacked_core_full_gauss_readout'
model_config = config['model_config']

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=config['model_seed'],
                 )

# Trainer
trainer_fn = config['trainer_fn']   # "sensorium.training.standard_trainer"
trainer_config = config['trainer_config']

trainer = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config)

### Train model

In [6]:
print_t('Start of model training')
validation_score, trainer_output, state_dict = trainer(model, dataloaders, seed=42)
print_t('Model training finished')

save_file = 'notebooks/submission_m3/results/model_v3.pth'
torch.save(model.state_dict(), save_file )

2022-10-12 15:09:39.090882: Start of model training
correlation 0.019702673
poisson_loss 28384930.0
correlation 0.16340499
poisson_loss 15555482.0
correlation 0.21689159
poisson_loss 14941753.0
correlation 0.24678558
poisson_loss 14581655.0
correlation 0.27456275
poisson_loss 14235428.0
correlation 0.29427832
poisson_loss 13956108.0
correlation 0.31595892
poisson_loss 13684575.0
correlation 0.33865714
poisson_loss 13384968.0
correlation 0.35921317
poisson_loss 13128105.0
correlation 0.37086836
poisson_loss 12980786.0
correlation 0.37931454
poisson_loss 12881538.0
correlation 0.38665688
poisson_loss 12785906.0
correlation 0.39331934
poisson_loss 12709832.0
correlation 0.4005654
poisson_loss 12613548.0
correlation 0.4048122
poisson_loss 12574420.0
correlation 0.40934557
poisson_loss 12512138.0
correlation 0.41198575
poisson_loss 12481812.0
correlation 0.41715196
poisson_loss 12418844.0
correlation 0.42021087
poisson_loss 12376063.0
correlation 0.42490163
poisson_loss 12323685.0
correlati

### Save all predictions as .npy

In [7]:
if config['save_predictions_npy']:
    # calculate predictions per dataloader
    results = prediction.all_predictions_with_trial(model, dataloaders)

    # merge predictions, sort in time and add behavioral variables
    merged = prediction.merge_predictions(results)
    sorted_res = prediction.sort_predictions_by_time(merged)
    prediction.inplace_add_behavior_to_sorted_predictions(sorted_res)
    
    npy_file = 'notebooks/submission_m3/results/prediction_model_v3.npy'
    np.save( npy_file, sorted_res)
    

Iterating datasets: 100%|██████████| 6/6 [01:20<00:00, 13.37s/it]
