# Train and predict responses for Baseline (with and without behavior)

### Imports

In [1]:
import os
if 'notebooks' in os.getcwd(): os.chdir('../..')  # change to main directory
print('Working directory:', os.getcwd() )

Working directory: /scratch/snx3000/bp000429/adrian_sensorium


In [2]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

from nnfabrik.builder import get_data, get_model, get_trainer

### Instantiate DataLoader for Sensorium+

In [12]:
# loading the SENSORIUM+ dataset
# filenames = ['../data/static27204-5-13-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', ]

# another dataset
# filenames = ['notebooks/data/static21067-10-18-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', ]

# all datasets with behavior data
basepath = "notebooks/data/"
filenames = [os.path.join(basepath, file) for file in os.listdir(basepath) if ".zip" in file ]
filenames = [file for file in filenames if 'static26872-17-20' not in file]

dataset_fn = 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                 'normalize': True,
                 # 'include_behavior': True,
                  'include_behavior': False,
                 'include_eye_position': True,
                 'batch_size': 128,
                 'scale':.25,
                 'preload_from_merged_data':True,
                 'include_trial_id':True,
                 }

dataloaders = get_data(dataset_fn, dataset_config)

## Instantiate State of the Art Model (SOTA)

In [13]:
model_fn = 'sensorium.models.stacked_core_full_gauss_readout'
model_config = {'pad_input': False,
  'stack': -1,
  'layers': 4,
  'input_kern': 9,
  'gamma_input': 6.3831,
  'gamma_readout': 0.0076,
  'hidden_kern': 7,
  'hidden_channels': 64,
  'depth_separable': True,
  'grid_mean_predictor': {'type': 'cortex',
   'input_dimensions': 2,
   'hidden_layers': 1,
   'hidden_features': 30,
   'final_tanh': True},
  'init_sigma': 0.1,
  'init_mu_range': 0.3,
  'gauss_type': 'full',
  'shifter': True,
}

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=42,)

## Configure Trainer

In [6]:
trainer_fn = "sensorium.training.standard_trainer"

trainer_config = {'max_iter': 200,
                 'verbose': True,
                 'lr_decay_steps': 4,
                 'avg_loss': False,
                 'lr_init': 0.009,
                 }

trainer = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config)

## Run model training

In [7]:
validation_score, trainer_output, state_dict = trainer(model, dataloaders, seed=42)

Epoch 1:  11%|█         | 24/216 [00:07<01:02,  3.07it/s]


KeyboardInterrupt: 

In [14]:
trainer_output

{'validation_corr': 0.40177137}

## Save model checkpoints

In [15]:
torch.save(model.state_dict(), 'notebooks/my_models/all_behavior_allRecordings_model_1.pth')

# Predict responses of model

In [5]:
# this model requires include_behavior=True
# model.load_state_dict(torch.load('notebooks/my_models/all_behavior_allRecordings_model_1.pth'));

# this model requires include_behavior=False
model.load_state_dict(torch.load('notebooks/my_models/shifter_only_allRecordings_model_1.pth'));

---

### Get predictions for train, val and test set with trial_id

In [6]:
from sensorium.utility import prediction

In [15]:
# calculate predictions per dataloader
results = prediction.all_predictions_with_trial(model, dataloaders)

# merge predictions, sort in time and add behavioral variables
merged = prediction.merge_predictions(results)
sorted_res = prediction.sort_predictions_by_time(merged)
prediction.inplace_add_behavior_to_sorted_predictions(sorted_res)

Iterating datasets: 100%|██████████| 6/6 [00:43<00:00,  7.17s/it]


In [16]:
out_folder = 'notebooks/my_results'
# file_name = 'res_v1_allBehavior.npy'
file_name = 'res_v1_onlyShifter.npy'

np.save( os.path.join(out_folder, file_name), sorted_res)

---