# Retrain only on competition dataset

### Imports

In [1]:
import os
if 'notebooks' in os.getcwd(): os.chdir('../..')  # change to main directory
print('Working directory:', os.getcwd() )

Working directory: /scratch/snx3000/bp000429/submission/adrian_sensorium


In [2]:
import torch
import numpy as np
import pandas as pd

import matplotlib.pyplot as plt
import seaborn as sns

import warnings
warnings.filterwarnings('ignore')

from nnfabrik.builder import get_data, get_model, get_trainer

from sensorium.utility import submission
from sensorium.utility.training import read_config, print_t, set_seed
from sensorium.utility import prediction

In [3]:
config_file = 'notebooks/submission_m3/config_oneModel_m3.yaml'
config = read_config( config_file )
print(config)

ordereddict([('data_sets', ['notebooks/data/static27204-5-13-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip']), ('dataset_fn', 'sensorium.datasets.static_loaders'), ('dataset_config', ordereddict([('normalize', True), ('include_behavior', True), ('include_eye_position', True), ('batch_size', 128), ('scale', 0.25), ('preload_from_merged_data', True), ('include_trial_id', True), ('include_rank_id', True), ('include_history', True), ('include_behav_state', True), ('adjusted_normalization', True)])), ('model_fn', 'sensorium.models.modulated_stacked_core_full_gauss_readout'), ('model_seed', 3452), ('model_config', ordereddict([('pad_input', False), ('stack', -1), ('layers', 4), ('input_kern', 9), ('gamma_input', 9.8), ('gamma_readout', 0.48), ('hidden_kern', 10), ('hidden_channels', 64), ('depth_separable', True), ('grid_mean_predictor', ordereddict([('type', 'cortex'), ('input_dimensions', 2), ('hidden_layers', 4), ('hidden_features', 20), ('nonlinearity', 'ReLU'), ('final_tanh', True)]

### Prepare dataloader

In [4]:
set_seed( config['model_seed'] )  # seed all random generators

filenames = config['data_sets']
# filenames like ['notebooks/data/static27204-5-13-GrayImageNet-94c6ff995dac583098847cfecd43e7b6.zip', ]
    
dataset_fn = config['dataset_fn']  # 'sensorium.datasets.static_loaders'
dataset_config = {'paths': filenames,
                  **config['dataset_config'],
                 }

dataloaders = get_data(dataset_fn, dataset_config)

### Instantiate model

In [5]:
# Instantiate model
model_fn = config['model_fn']     # e.g. 'sensorium.models.modulated_stacked_core_full_gauss_readout'
model_config = config['model_config']

model = get_model(model_fn=model_fn,
                  model_config=model_config,
                  dataloaders=dataloaders,
                  seed=config['model_seed'],
                 )
model

ModulatedFiringRateEncoder(
  (core): Stacked2dCore(
    (_input_weights_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv): Conv2d(4, 64, kernel_size=(9, 9), stride=(1, 1), bias=False)
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): AdaptiveELU()
      )
      (layer1): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
          (spatial_conv): Conv2d(64, 64, kernel_size=(10, 10), stride=(1, 1), padding=(5, 5), groups=64, bias=False)
          (out_depth_conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
        )
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.9, affine=True, track_running_stats=True)
        (nonlin): AdaptiveELU()
      )
      (layer2): Sequential(
        (ds_conv): DepthSeparableConv2d(
          (in_de

In [6]:
save_file = 'notebooks/submission_m3/results/model_v3.pth'
pretrained_dict = torch.load(save_file)

# filter out only values of core
core_only = {k:v for k, v in pretrained_dict.items() if 'core.' in k}

# set pretrained core values
ret = model.load_state_dict(core_only, strict=False)
ret

_IncompatibleKeys(missing_keys=['readout.27204-5-13.sigma', 'readout.27204-5-13._features', 'readout.27204-5-13.bias', 'readout.27204-5-13.source_grid', 'readout.27204-5-13.mu_transform.0.weight', 'readout.27204-5-13.mu_transform.0.bias', 'readout.27204-5-13.mu_transform.2.weight', 'readout.27204-5-13.mu_transform.2.bias', 'readout.27204-5-13.mu_transform.4.weight', 'readout.27204-5-13.mu_transform.4.bias', 'readout.27204-5-13.mu_transform.6.weight', 'readout.27204-5-13.mu_transform.6.bias', 'readout.27204-5-13.mu_transform.8.weight', 'readout.27204-5-13.mu_transform.8.bias', 'shifter.27204-5-13.mlp.0.weight', 'shifter.27204-5-13.mlp.0.bias', 'shifter.27204-5-13.mlp.2.weight', 'shifter.27204-5-13.mlp.2.bias', 'shifter.27204-5-13.mlp.4.weight', 'shifter.27204-5-13.mlp.4.bias', 'modulator.27204-5-13.own_gain', 'modulator.27204-5-13.gain_coupling', 'modulator.27204-5-13.coupling_offset', 'modulator.27204-5-13.history_weights', 'modulator.27204-5-13.history_bias', 'modulator.27204-5-13.sta

In [7]:
# Trainer
trainer_fn = config['trainer_fn']   # "sensorium.training.standard_trainer"
trainer_config = config['trainer_config']

trainer = get_trainer(trainer_fn=trainer_fn, 
                     trainer_config=trainer_config)

### Train model

In [8]:
print_t('Start of model training')
validation_score, trainer_output, state_dict = trainer(model, dataloaders, seed=42)
print_t('Model training finished')

save_file = 'notebooks/submission_m3/results/model_v3_retrained.pth'
torch.save(model.state_dict(), save_file )

2022-10-12 20:19:30.926701: Start of model training
correlation 0.00808165
poisson_loss 4613276.0
correlation 0.23523669
poisson_loss 2491149.0
correlation 0.341758
poisson_loss 2275904.8
correlation 0.38537893
poisson_loss 2190505.5
correlation 0.40428048
poisson_loss 2154535.0
correlation 0.41647708
poisson_loss 2129727.5
correlation 0.42261142
poisson_loss 2119666.8
correlation 0.42597806
poisson_loss 2113006.2
correlation 0.4302361
poisson_loss 2106623.0
correlation 0.4327055
poisson_loss 2100928.0
correlation 0.43428135
poisson_loss 2101731.0
correlation 0.43525565
poisson_loss 2099643.8
correlation 0.4355723
poisson_loss 2099393.0
correlation 0.43692672
poisson_loss 2097413.0
correlation 0.43711153
poisson_loss 2098214.5
correlation 0.439137
poisson_loss 2094615.0
correlation 0.43968484
poisson_loss 2092834.1
correlation 0.43655047
poisson_loss 2101646.5
correlation 0.4379925
poisson_loss 2098620.0
correlation 0.4407688
poisson_loss 2092525.8
correlation 0.43964487
poisson_loss 2

### Save predictions as .npy

In [9]:
if config['save_predictions_npy']:
    # calculate predictions per dataloader
    results = prediction.all_predictions_with_trial(model, dataloaders)

    # merge predictions, sort in time and add behavioral variables
    merged = prediction.merge_predictions(results)
    sorted_res = prediction.sort_predictions_by_time(merged)
    prediction.inplace_add_behavior_to_sorted_predictions(sorted_res)
    
    npy_file = 'notebooks/submission_m3/results/prediction_model_v3_retrained.npy'
    np.save( npy_file, sorted_res)
    

Iterating datasets: 100%|██████████| 1/1 [00:14<00:00, 14.67s/it]
