In [1]:
from utils.global_functions import *
import numpy as np
import matplotlib.pyplot as plt
import os
from models.helper_functions import get_model_and_dataloader, get_model_temp_reach, get_model_and_dataloader_for_nm
from evaluations.single_cell_performance import get_performance_for_single_cell
from meis.visualizer import get_model_activations
from datasets.stas import get_cell_sta, show_sta, get_sta
import pickle

## Configuring which model to load

In [2]:
# index of the dataset
retina_index = 1

data_type = 'marmoset'

# directory from which the models are loaded
directory = f'{home}/models/factorized_4_ev_0.15_cnn/marmoset/retina{retina_index + 1}/cell_None/readout_isotropic/gmp_0/'

# file of the specific model architecture
filename = 'lr_0.0060_l_4_ch_[8, 16, 32, 64]_t_27_bs_16_tr_10_ik_27x(21, 21)x(21, 21)_hk_5x(5, 5)x(5, 5)_g_48.0000_gt_0.0740_l1_0.0230_l2_0.0000_sg_0.25_d_1_dt_1_hd_1-2-3_hdt_1-1-1_p_0_bn_1_s_1norm_0_fn_1_0_h_80_w_90'

# directory to which MEIs are saved
mei_dir = f"{home}/meis/data/{data_type}/meni/retina{retina_index + 1}"

# seeds for the specific model architecture
seeds = [16, 64, 8, 128]
first_seed = seeds[0]

device = 'cuda'
model_fn = 'models.FactorizedEncoder.build_trained'
models = {}

## Loading models for ensemble into a dictionary

In [None]:
# loading models for the given seeds 
# get_model_and_dataloader returns a tuple of (dataloader, model and config) for each seed
# the dataloader is more or less useless for further purposes and could be ommitted but it's useful to see

for seed in seeds:
    models[seed] = get_model_and_dataloader_for_nm(
            directory,
            filename,
            model_fn=model_fn,
            device=device,
            data_dir='/home/vystrcilova/', # if data_dir is None, root of the project is considered
            test=False,
            seed=seed,
            data_type=data_type,
        )

data_dir: /usr/users/vystrcilova/retinal_circuit_modeling
train idx: [7 6 9 1 2 8 0 3]
val idx: [ 4 10  5]
train responses shape:  (69, 25500, 11)
training trials:  8 [7 6 9 1 2 8 0 3]
validation trials:  2 [4 5]
getting loaders


 79%|███████▉  | 11556/14544 [01:00<00:13, 220.07it/s]

## Exploring the dataloader

In [None]:
# get the dataloader for the first model (they should be all the same for all the models)
dataloader = models[first_seed][0]

# the dataloader contrains three tiers train, validation and test
# train and validation are single trial tiers, test is averaged over multiple trials
print(dataloader.keys())

tier = 'train'
inputs, targets = next(iter(dataloader[tier][f'0{retina_index+1}']))

# the printed intput shape is (batch_size, in_channels, frames, height, width)
# the number of frames is the number necessary to make one prediction plus time_chunk which is a parameter in the dataloader specifiyng for how many time steps the models makes predictions at once
# the number of channels is always 1
print(inputs.shape) 

# the targets shape is (batch_size, num_of_neurons, time_chunk)
print(targets.shape)


## Exploring the model

The model consists of two parts:
* the **core** is a convolutional neural network. The core creates a shared non-linear feature space which is shared for all neurons for which we want to predict.  

To avoid full 3d convolutions in every layer, every layer has two steps of 3d convolutions. First comes a spatial convolution and then a temporal. So instead of having *(out_channels x in_channels x kernel_depth x kernel_width x kernel_height)* parameters, we have *(out_channels x in_channels x 1 x kernel_width x kernel_height) + (out_channels x out_channels x kernel_depth x 1 x 1)* 
* the **readout** is then neuron specific. It picks the position where a specific neuron looks and weighs the features of the shared feature space for the given neuron.

![CNN architecture](./figures/architecture.png)

In [None]:
model_to_explore = models[first_seed][1]

In [None]:
model_to_explore

In [None]:
# cell_names are indices of cells when considering all cells and not just those that passed a reliability threshold
# cell_indices are indices of cells when considering only cells that passed reliability threshold

cell_names = get_cell_names(retina_index=1, explained_variance_threshold=0.15, config=models[first_seed][1].config_dict['config'])
cell_indices = list(range(len(cel_names)))

## Checking predictive performance of a model

In [None]:
correlations, all_predictions, all_responses = get_performance_for_single_cell(
        model=models[8][1],
        dataloaders=models[first_seed][0],
        performance='validation',
        device=device,
        retina_index=1,
        rf_size=(models[first_seed][1].config_dict["img_h"], models[first_seed][1].config_dict["img_w"]),
        img_h=models[first_seed][1].config_dict["img_h"],
        img_w=models[first_seed][1].config_dict["img_w"],
    )

## Creating an input array for a model

In [None]:

def get_initial_input(model, init_variance, num_of_predictions=1):
    """ initializes a random 3d array that can be fed into a model to be optimized
    param model: model for which the array is meant for
    param init_variance: variance of the normal distribution that is used for array initialization
    parram num of predictions: the number of predictions that the 3d array is supposed to elicit 
    """
    input_shape=(1, 1, get_model_temp_reach(model.config_dict)+num_of_predictions-1, inputs.shape[-2], inputs.shape[-1])
    dist = torch.distributions.Normal(0, init_variance)
    initial_input = dist.sample(input_shape).double()
    return initial_input

In [None]:
initial_input = get_initial_input(models[first_seed], init_variance=0.1, num_of_predictions=1)
activation = get_model_activations(model, initial_input)

# activation has the shape (num_of_predictions, num_of_neurons)
print(activation.shape)
print(activation)