# Train models

To speed up training, the models were trained in parallel on the CSCS Piz Daint infrastructure by calling :

cd adrian_sensorium/scripts

bash start_jobs.sh jobs_ensemble.txt

This script starts 5 machines to run the adrian_sensorium/scripts/train_model.py script with the 5 configuration files in the folder adrian_sensorium/saved_models/config_m4_ens*.yaml

To reproduce this fitting, one can also execute the following code (not tested):

In [None]:
import os
if 'notebooks' in os.getcwd(): os.chdir('../..')  # change to main directory
print('Working directory:', os.getcwd() )

In [None]:
!python scripts/train_model.py -m config_m4_ens0
!python scripts/train_model.py -m config_m4_ens1
!python scripts/train_model.py -m config_m4_ens2
!python scripts/train_model.py -m config_m4_ens3
!python scripts/train_model.py -m config_m4_ens4

## Train model in notebook
Alternatively, the model can also be trained in a notebook with the following code

### Imports

In [None]:
import os
if 'notebooks' in os.getcwd(): os.chdir('../..')  # change to main directory
print('Working directory:', os.getcwd() )

In [None]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

from nnfabrik.builder import get_data, get_model, get_trainer

from sensorium.utility.training import read_config, print_t, set_seed
from sensorium.utility import prediction

### Load configuration for model

In [None]:
config_file = 'saved_models/02_jobs/only_history.yaml'
config = read_config( config_file )
print(config)

### Prepare dataloader

In [None]:
set_seed( config['model_seed'] )  # seed all random generators

if config['data_sets'][0] == 'all':
    basepath = "notebooks/data/"
    filenames = [os.path.join(basepath, file) for file in os.listdir(basepath) if ".zip" in file ]
    filenames = [file for file in filenames if 'static26872-17-20' not in file]
else:
    filenames = config['data_sets']
    # filenames like ['notebooks/data/static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', ]
    
dataset_fn = config['dataset_fn']  # 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                  **config['dataset_config'],
                 }

dataloaders = get_data(dataset_fn, dataset_config)

### Instantiate model and trainer

In [None]:
# Instantiate model
model_fn = config['model_fn']     # e.g. 'sensorium.models.modulated_stacked_core_full_gauss_readout'
model_config = config['model_config']

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=config['model_seed'],
                 )

# Trainer
trainer_fn = config['trainer_fn']   # "sensorium.training.standard_trainer"
trainer_config = config['trainer_config']

trainer = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config)

In [None]:
model.modulator is None

### Train model

In [None]:
print_t('Start of model training')
validation_score, trainer_output, state_dict = trainer(model,
                                                       dataloaders,
                                                       seed=42)
print_t('Model training finished')

save_file = 'notebooks/model_walkthrough/results/trained_model.pth'
torch.save(model.state_dict(), save_file )

### Save all predictions as .npy

In [None]:
if config['save_predictions_npy']:
    # calculate predictions per dataloader
    results = prediction.all_predictions_with_trial(model, dataloaders)

    # merge predictions, sort in time and add behavioral variables
    merged = prediction.merge_predictions(results)
    sorted_res = prediction.sort_predictions_by_time(merged)
    prediction.inplace_add_behavior_to_sorted_predictions(sorted_res)
    
    npy_file = 'notebooks/model_walkthrough/results/prediction_trained_model.npy'
    np.save( npy_file, sorted_res)