In [1]:
seed = 42
import sys

sys.path.append("/srv/user/turishcheva/sensorium_replicate/sensorium_2023/")
from functools import partial

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from nnfabrik.utility.nn_helpers import set_random_seed
from torch.distributions.kl import kl_divergence
from torch.distributions.normal import Normal

set_random_seed(seed)
import os

import wandb
from neuralpredictors.layers.cores.conv2d import Stacked2dCore
from neuralpredictors.layers.encoders.mean_variance_functions import \
    fitted_zig_mean
from neuralpredictors.layers.encoders.zero_inflation_encoders import ZIGEncoder
from neuralpredictors.measures import modules, zero_inflated_losses
from neuralpredictors.training import early_stopping
from nnfabrik.utility.nn_helpers import set_random_seed
from tqdm import tqdm

from eval import eval_model
from moments import load_mean_variance
from sensorium.datasets.mouse_video_loaders import mouse_video_loader
from sensorium.models.make_model import make_video_model
from sensorium.utility import scores
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)
from sensorium.utility.scores import get_correlations
# wandb.login('0066bbfe063ba7e14bd2b068d18ae8ab2d6299ed')

cuda


In [2]:
def kl_divergence_gaussian(mu, sigma):
    """
    Compute the KL divergence D_KL(q || p) where:
    q(z) = N(mu, diag(sigma^2)) and p(z) = N(0, I)

    Parameters:
    - mu (torch.Tensor): Mean vector of the Gaussian distribution q(z), shape (B,time,hidden).
    - sigma (torch.Tensor): Standard deviation vector of q(z), not squared, shape (B,time,).

    Returns:
    - kl_div (torch.Tensor): The KL divergence.
    """

    sigma_squared = sigma**2
    inverse_sigma_squared = 1.0 / sigma_squared

    # Trace term: Sum of inverse variances (since trace of a diagonal matrix is the sum of its diagonal elements)
    dim = mu.shape[0] * mu.shape[1] * mu.shape[2]

    trace_term = (
        inverse_sigma_squared * dim
    )  # sigma is a scalar which is constant acrosss all dims

    # Quadratic term: (mu^T * Sigma_0^{-1} * mu)
    quadratic_term = torch.sum(mu**2) * inverse_sigma_squared

    # Log-determinant term: 2 * sum(log(sigma))
    log_det_term = (
        2 * torch.log(sigma) * dim
    )  # sigma is a scalar which is constant acrosss all dims

    kl_div = 0.5 * (trace_term + quadratic_term - dim + log_det_term)
    return kl_div


# compute exponentail moving average of correlation
def calculate_ema(data, alpha):
    ema = torch.zeros(len(data))
    ema[0] = data[0]

    for t in range(1, len(data)):
        ema[t] = alpha * data[t] + (1 - alpha) * ema[t - 1]

    return ema

In [3]:
# batch.videos.shape, batch.pupil_center.shape, batch.behavior.shape, batch.responses.shape
'''
(torch.Size([8, 3, 80, 36, 64]),
 torch.Size([8, 2, 80]),
 torch.Size([8, 2, 80]),
 torch.Size([8, 7440, 80]))
'''

'\n(torch.Size([8, 3, 80, 36, 64]),\n torch.Size([8, 2, 80]),\n torch.Size([8, 2, 80]),\n torch.Size([8, 7440, 80]))\n'

### Make experanto dataloader

In [3]:
import experanto
from experanto.dataloaders import get_multisession_dataloader
from experanto.configs import DEFAULT_CONFIG as cfg
from tqdm import tqdm

In [4]:
pre_path_tr = '/mnt/vast-react/projects/neural_foundation_model/full_foundation_export_30hz_correct/'
pre_path_test = '/mnt/vast-react/projects/neural_foundation_model/test_set_resampled/interpolate_with_hamming_30_30.0Hz/'

train = [
    'dynamic29156-11-10-Video-021a75e56847d574b9acbcc06c675055_30hz', 
    'dynamic29228-2-10-Video-021a75e56847d574b9acbcc06c675055_30hz', 
    'dynamic29234-6-9-Video-021a75e56847d574b9acbcc06c675055_30hz', 
    'dynamic29513-3-5-Video-021a75e56847d574b9acbcc06c675055_30hz', 
    'dynamic29514-2-9-Video-021a75e56847d574b9acbcc06c675055_30hz'
]

test_folder_scans = [
    'dynamic26872-17-20-Video-021a75e56847d574b9acbcc06c675055_30hz',
    'dynamic27204-5-13-Video-021a75e56847d574b9acbcc06c675055_30hz',
    'dynamic29515-10-12-Video-021a75e56847d574b9acbcc06c675055_30hz',
    'dynamic29623-4-9-Video-021a75e56847d574b9acbcc06c675055_30hz',
    'dynamic29647-19-8-Video-021a75e56847d574b9acbcc06c675055_30hz',
    'dynamic29712-5-9-Video-021a75e56847d574b9acbcc06c675055_30hz',
    'dynamic29755-2-8-Video-021a75e56847d574b9acbcc06c675055_30hz'
]

full_paths = [f'{pre_path_tr}{t}/' for t in train] + [f'{pre_path_test}{t}/' for t in test_folder_scans]

cfg['dataset']['modality_config']['responses']['sampling_rate'] = 30
cfg['dataset']['modality_config']['responses']['chunk_size'] = 80

cfg['dataset']['modality_config']['eye_tracker']['sampling_rate'] = 30
cfg['dataset']['modality_config']['eye_tracker']['chunk_size'] = 80

cfg['dataset']['modality_config']['treadmill']['sampling_rate'] = 30
cfg['dataset']['modality_config']['treadmill']['chunk_size'] = 80

cfg['dataset']['modality_config']['screen']['sampling_rate'] = 30
cfg['dataset']['modality_config']['screen']['chunk_size'] =  80 

for k in cfg.dataset.modality_config.keys():
    print(k, cfg.dataset.modality_config[k].sampling_rate, cfg.dataset.modality_config[k].chunk_size)

cfg['dataloader']['prefetch_factor'] = 2
cfg['dataloader']['num_workers'] = 4
cfg['dataloader']['shuffle'] = False
cfg['dataloader']['pin_memory'] = False
# cfg['dataset']['add_behavior_as_channels'] = True
cfg['dataset']['modality_config']['screen']['transforms']['Resize']['size'] = [36, 64]


train_dl = get_multisession_dataloader(full_paths, cfg)

screen 30 80
responses 30 80
eye_tracker 30 80
treadmill 30 80




In [6]:
# batch.videos.shape, batch.pupil_center.shape, batch.behavior.shape, batch.responses.shape
'''
(torch.Size([8, 3, 80, 36, 64]),
 torch.Size([8, 2, 80]),
 torch.Size([8, 2, 80]),
 torch.Size([8, 7440, 80]))
'''
# dilation, dilation derivative, center x , center y

'\n(torch.Size([8, 3, 80, 36, 64]),\n torch.Size([8, 2, 80]),\n torch.Size([8, 2, 80]),\n torch.Size([8, 7440, 80]))\n'

In [37]:
# beh = torch.cat([batch['eye_tracker'][:, :, :2].transpose(2, 1), batch['treadmill'][:, :, :2].transpose(2, 1)], axis=1)
# video = batch['screen']

# b_expanded = beh.unsqueeze(-1).unsqueeze(-1)  # or b[:, :, :, None, None]

# # Now broadcast b to match the spatial dimensions [16, 3, 60, 144, 256]
# beh_tiled = b_expanded.expand(-1, -1, -1, video.shape[3], video.shape[4])

# # Concatenate along dim=1 to get [16, 4, 60, 144, 256]
# video = torch.cat([video, beh_tiled], dim=1)

In [None]:
# beh = torch.cat([batch['eye_tracker'][:, :, :2].transpose(2, 1), batch['treadmill'][:, :, :2].transpose(2, 1)], axis=1)
# video = batch['screen']
# b_expanded = beh.unsqueeze(-1).unsqueeze(-1)  # or b[:, :, :, None, None]
# # Now broadcast b to match the spatial dimensions [16, 3, 60, 144, 256]
# beh_tiled = b_expanded.expand(-1, -1, -1, video.shape[3], video.shape[4])
# # Concatenate along dim=1 to get [16, 4, 60, 144, 256]
# video = torch.cat([video, beh_tiled], dim=1)

# resp = batch['responses'].transpose(2, 1)
# [video, resp,  ]
# batch_kwargs = {
#     'videos': video,
#     'pupil_center': batch['eye_tracker'][:, :, 2:].transpose(2, 1),
#     'responses': resp
# }

In [14]:
# print(cfg)

{'dataset': {'global_sampling_rate': None, 'global_chunk_size': None, 'add_behavior_as_channels': False, 'replace_nans_with_means': False, 'cache_data': False, 'out_keys': ['screen', 'responses', 'eye_tracker', 'treadmill', 'timestamps'], 'normalize_timestamps': True, 'modality_config': {'screen': {'keep_nans': False, 'sampling_rate': 30, 'chunk_size': 60, 'valid_condition': {'tier': 'train'}, 'offset': 0, 'sample_stride': 1, 'include_blanks': True, 'transforms': {'normalization': 'normalize', 'Resize': {'_target_': 'torchvision.transforms.v2.Resize', 'size': [144, 256]}}, 'interpolation': {'rescale': True, 'rescale_size': [144, 256]}}, 'responses': {'keep_nans': False, 'sampling_rate': 30, 'chunk_size': 60, 'offset': 0.0, 'transforms': {'normalization': 'standardize'}, 'interpolation': {'interpolation_mode': 'nearest_neighbor'}, 'filters': {'nan_filter': {'__target__': 'experanto.filters.common_filters.nan_filter', '__partial__': True, 'vicinity': 0.05}}}, 'eye_tracker': {'keep_nans':

In [16]:
mean_activity_dict = {}
n_neurons_dict = {}
data_keys = list(train_dl.loaders.keys())
for k in data_keys:
    batch = next(iter(train_dl.loaders[k]))
    n_neurons_dict[k] = batch['responses'].shape[-1]
    mean_activity_dict[k] = batch['responses'].reshape(-1, n_neurons_dict[k]).mean(axis=0)

batch_size = batch['responses'].shape

In [18]:
factorised_3D_core_dict = dict(
    input_channels=4,
    hidden_channels=[32, 64, 128],
    spatial_input_kernel=(11, 11),
    temporal_input_kernel=11,
    spatial_hidden_kernel=(5, 5),
    temporal_hidden_kernel=5,
    stride=1,
    layers=3,
    gamma_input_spatial=10,
    gamma_input_temporal=0.01,
    bias=True,
    hidden_nonlinearities="elu",
    x_shift=0,
    y_shift=0,
    batch_norm=True,
    laplace_padding=None,
    input_regularizer="LaplaceL2norm",
    padding=False,
    final_nonlin=True,
    momentum=0.7,
)
shifter_dict = dict(
    gamma_shifter=0,
    shift_layers=3,
    input_channels_shifter=2,
    hidden_channels_shifter=5,
)


readout_dict = dict(
    bias=False,
    init_mu_range=0.2,
    init_sigma=1.0,
    gamma_readout=0.0,
    gauss_type="full",
    # grid_mean_predictor={
    #     "type": "cortex",
    #     "input_dimensions": 2,
    #     "hidden_layers": 1,
    #     "hidden_features": 30,
    #     "final_tanh": True,
    # },
    grid_mean_predictor = None,
    share_features=False,
    share_grid=False,
    shared_match_ids=None,
    gamma_grid_dispersion=0.0,
    zig=True,
    out_channels=2,
    kernel_size=(11, 5),
    batch_size=cfg['dataloader']['batch_size'],
    # conv_out = conv_out
)

In [19]:
factorised_3d_model = make_video_model(
    None,
    seed,
    core_dict=factorised_3D_core_dict,
    core_type="3D_factorised",
    readout_dict=readout_dict.copy(),
    readout_type="gaussian",
    use_gru=False,
    gru_dict=None,
    use_shifter=True,  # set to True if behavior is included
    shifter_dict=shifter_dict,
    shifter_type="MLP",
    deeplake_ds=False,
    n_neurons_dict=n_neurons_dict,
    mean_activity_dict=mean_activity_dict,
    experanto=True,
    readout_dim=factorised_3D_core_dict['hidden_channels'][-1]
)



In [20]:
factorised_3d_model

VideoFiringRateEncoder(
  (core): Factorized3dCore(
    (_input_weight_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (temporal_regularizer): DepthLaplaceL21d(
      (laplace): Laplace1d()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv_spatial): Conv3d(4, 32, kernel_size=(1, 11, 11), stride=(1, 1, 1))
        (conv_temporal): Conv3d(32, 32, kernel_size=(11, 1, 1), stride=(1, 1, 1))
        (norm): BatchNorm3d(32, eps=1e-05, momentum=0.7, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0)
      )
      (layer1): Sequential(
        (conv_spatial_1): Conv3d(32, 64, kernel_size=(1, 5, 5), stride=(1, 1, 1))
        (conv_temporal_1): Conv3d(64, 64, kernel_size=(5, 1, 1), stride=(1, 1, 1))
        (norm): BatchNorm3d(64, eps=1e-05, momentum=0.7, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0)
      )
      (layer2): Sequential(
        (conv_spatial_2): Conv3d(64, 128, kernel_size=(1, 5, 5), stride=

### Precompute statistics

In [10]:
# from scipy.optimize import root
# from scipy.stats import gamma
# import gc
# import torch

In [11]:
# cfg.dataset.modality_config['screen']['sample_stride'] = cfg['dataset']['modality_config']['screen']['chunk_size']
# cfg.dataset.modality_config['screen']['sample_stride']

In [12]:
'''
for m in full_paths:
    train_dl = get_multisession_dataloader([m], cfg)
    mouse_name = m.split('dynamic')[-1].split('-Video')[0]
    num_neurons = train_dl.loaders[mouse_name].dataset.__getitem__(0)['responses'].cpu().numpy().shape[-1]
    step = 16 * 60
    concatenated_array = np.zeros((num_neurons, len(train_dl) * step), dtype='float32')
    for idx, batch in enumerate(tqdm(train_dl)):
        batch = batch[1]
        resp_b = batch['responses'].cpu().numpy()
        resp_b = resp_b.reshape(-1, *resp_b.shape[2:]).T
        # # print(resp_b.shape)
        # resp_b = np.where(resp_b < 0.005, np.nan, resp_b)
        # # print(resp_b.shape)
        # resp_b = np.minimum(resp_b, 80)
        # # print(resp_b.shape)
        # # resp_b = resp_b[1, ~np.isnan(resp_b[0, :])]
        # # print(resp_b.shape)
        concatenated_array[:, step*idx : step*(idx+1)] = resp_b
        
    print(concatenated_array.shape)

    # Filter out values smaller than 0.005, this is roughly the threshold of the ZIG distribution. 
    # Only values bigger than this threshold are considered in the gamma distribution
    filtered_array = np.where(
        concatenated_array <= 0.005, np.nan, concatenated_array
    )
    filtered_array = np.minimum(filtered_array, 80)  # cut outliers off
    first_row_valid_values = filtered_array[
        1, ~np.isnan(filtered_array[0, :])
    ]
    k_params = []

    for i in tqdm(range(filtered_array.shape[0])):
        neuron_valid_values = filtered_array[
            i, ~np.isnan(filtered_array[i, :])
        ]
        # Fix the loc parameter to 0.05
        params = gamma.fit(neuron_valid_values, floc=0.005)

        # Get the shape and scale parameters
        k, loc, theta = params
        k_params.append(k)

    # Get the shape, location, and scale parameters
    k, loc, theta = params
    x = np.linspace(
        min(first_row_valid_values), max(first_row_valid_values), 1000
    )
    fitted_gamma = gamma.pdf(x, k, loc, theta)

    # Calculate mean and variance along the first dimension (axis 1) ignoring nan values
    mean_array = np.nanmean(filtered_array, axis=1)
    variance_array = np.nanvar(filtered_array, axis=1)
    
    print('before saving')
    os.makedirs(f'/scratch-grete/projects/nim00012/experanto_stats_finn_model/{m.split("/")[-2]}', exist_ok=True) 
    np.save(f'/scratch-grete/projects/nim00012/experanto_stats_finn_model/{m.split("/")[-2]}/new_mean.npy', mean_array)
    np.save(f'/scratch-grete/projects/nim00012/experanto_stats_finn_model/{m.split("/")[-2]}/new_variance.npy', variance_array)
    np.save(f'/scratch-grete/projects/nim00012/experanto_stats_finn_model/{m.split("/")[-2]}/k_fitted.npy', np.array(k_params))
    del concatenated_array
    torch.cuda.empty_cache()  # If using GPU tensors
    gc.collect()              # Force Python garbage collection

'''

'\nfor m in full_paths:\n    train_dl = get_multisession_dataloader([m], cfg)\n    mouse_name = m.split(\'dynamic\')[-1].split(\'-Video\')[0]\n    num_neurons = train_dl.loaders[mouse_name].dataset.__getitem__(0)[\'responses\'].cpu().numpy().shape[-1]\n    step = 16 * 60\n    concatenated_array = np.zeros((num_neurons, len(train_dl) * step), dtype=\'float32\')\n    for idx, batch in enumerate(tqdm(train_dl)):\n        batch = batch[1]\n        resp_b = batch[\'responses\'].cpu().numpy()\n        resp_b = resp_b.reshape(-1, *resp_b.shape[2:]).T\n        # # print(resp_b.shape)\n        # resp_b = np.where(resp_b < 0.005, np.nan, resp_b)\n        # # print(resp_b.shape)\n        # resp_b = np.minimum(resp_b, 80)\n        # # print(resp_b.shape)\n        # # resp_b = resp_b[1, ~np.isnan(resp_b[0, :])]\n        # # print(resp_b.shape)\n        concatenated_array[:, step*idx : step*(idx+1)] = resp_b\n        \n    print(concatenated_array.shape)\n\n    # Filter out values smaller than 0.005

### End of statistics precompute

In [9]:
base_dir = "/scratch-grete/projects/nim00012/experanto_stats_finn_model/"
mean_variance_dict = load_mean_variance(base_dir, device)

In [10]:
# mean_variance_dict.keys()

In [11]:
# 'dynamic29712-5-9-Video-021a75e56847d574b9acbcc06c675055_30hz_mean'.replace('dynamic', '').replace('-Video-021a75e56847d574b9acbcc06c675055_30hz', '')

In [12]:
mean_variance_dict_new = {}
for k, v in mean_variance_dict.items():
    mean_variance_dict_new[k.replace('dynamic', '').replace('-Video-021a75e56847d574b9acbcc06c675055_30hz', '')] = v

In [13]:
mean_variance_dict_new

{'29712-5-9_mean': tensor([0.7078, 0.8658, 0.9356,  ..., 0.8265, 0.9993, 0.8470], device='cuda:0'),
 '29712-5-9_variance': tensor([3.8925, 2.0632, 2.1666,  ..., 1.5529, 1.5811, 1.6078], device='cuda:0'),
 '29712-5-9fitted_k': tensor([0.4327, 0.5413, 0.5173,  ..., 0.5679, 0.5819, 0.5654], device='cuda:0'),
 '29755-2-8_mean': tensor([0.8733, 0.8881, 0.9002,  ..., 0.7974, 0.8424, 0.7204], device='cuda:0'),
 '29755-2-8_variance': tensor([1.6552, 2.1725, 1.9426,  ..., 1.8062, 2.9628, 2.5200], device='cuda:0'),
 '29755-2-8fitted_k': tensor([0.5569, 0.5238, 0.5345,  ..., 0.5534, 0.4401, 0.4859], device='cuda:0'),
 '29228-2-10_mean': tensor([0.5274, 0.4098, 0.5304,  ..., 0.9011, 0.7784, 0.8361], device='cuda:0'),
 '29228-2-10_variance': tensor([1.4311, 1.5626, 1.6986,  ..., 2.0438, 1.6766, 1.3538], device='cuda:0'),
 '29228-2-10fitted_k': tensor([0.4827, 0.4904, 0.4743,  ..., 0.4863, 0.5019, 0.5391], device='cuda:0'),
 '29647-19-8_mean': tensor([0.8073, 0.8573, 0.9913,  ..., 0.8002, 0.8457, 1.

In [14]:
mean_variance_dict_new['29712-5-9_mean'].shape, mean_variance_dict_new['29712-5-9_variance'].shape, mean_variance_dict_new['29712-5-9fitted_k'].shape

(torch.Size([7939]), torch.Size([7939]), torch.Size([7939]))

In [17]:
n_neurons_dict

{'29156-11-10': 7440,
 '29228-2-10': 7928,
 '29234-6-9': 8285,
 '29513-3-5': 7671,
 '29514-2-9': 7495,
 '26872-17-20': 7776,
 '27204-5-13': 7538,
 '29515-10-12': 7863,
 '29623-4-9': 7908,
 '29647-19-8': 8202,
 '29712-5-9': 7939,
 '29755-2-8': 8122}

In [21]:
# determine maximal number of neurons
max_neurons = max(list(n_neurons_dict.values()))

output_dim = 12
dropout = "across_time"
dropout_prob = 0.5
encoder_dict = {}
encoder_dict["input_dim"] = max_neurons
encoder_dict["hidden_dim"] = 42  # 42
encoder_dict["hidden_gru"] = 20  # 20
encoder_dict["output_dim"] = output_dim
encoder_dict["hidden_layers"] = 1
encoder_dict["n_samples"] = 70
encoder_dict["mice_dim"] = 0
encoder_dict["use_cnn"] = False
encoder_dict["residual"] = False
encoder_dict["kernel_size"] = [11, 5, 5]
encoder_dict["channel_size"] = [32, 32, 20]
encoder_dict["use_resnet"] = False
encoder_dict["pretrained"] = True

decoder_dict = {}
decoder_dict["hidden_dim"] = output_dim
decoder_dict["hidden_layers"] = 1
decoder_dict["use_cnn"] = False
decoder_dict["kernel_size"] = [5, 11]
decoder_dict["channel_size"] = [12, 12]

# TODO latent should be TRUE after ZIG pretraining
latent = False
zig_model = ZIGEncoder(
    core=factorised_3d_model.core,
    readout=factorised_3d_model.readout,
    shifter = factorised_3d_model.shifter,
    # shifter=None,
    k_image_dependent=False,
    loc_image_dependent=False,
    mle_fitting=mean_variance_dict_new,
    latent=latent,
    encoder=encoder_dict,
    # decoder = decoder_dict,
    norm_layer="layer_flex",
    non_linearity=True,
    dropout=dropout,
    dropout_prob=dropout_prob,
    future_prediction=False,
    flow=False,
    position_features = None,
    behavior_in_encoder = None
)
if not latent:
    zig_model.flow = False

### Now lets adjust the train loop for experanto

In [22]:
def standard_trainer(
    model,
    dataloaders,
    seed,
    avg_loss=False,
    scale_loss=True,
    loss_function="PoissonLoss",
    stop_function="get_correlations",
    loss_accum_batch_n=None,
    device="cuda",
    verbose=True,
    interval=1,
    patience=5,
    epoch=0,
    lr_init=0.005,
    max_iter=200,
    maximize=True,
    tolerance=1e-6,
    restore_best=True,
    lr_decay_steps=3,
    lr_decay_factor=0.3,
    min_lr=0.0001,
    cb=None,
    detach_core=False,
    use_wandb=True,
    wandb_project='finn_mode_with_experanto',
    wandb_entity="ecker-lab",
    wandb_name=None,
    wandb_model_config=None,
    wandb_dataset_config=None,
    save_checkpoints=True,
    checkpoint_save_path="local/",
    chpt_save_step=15,
    k_reg=False,
    ema_span=0.3,  # ema for validation correlation
    scheduler_patience=6,  # patience for decaying learning rate
    latent=False,
    **kwargs,
):
    """

    Args:
        model: model to be trained
        dataloaders: dataloaders containing the data to train the model with
        seed: random seed
        avg_loss: whether to average (or sum) the loss over a batch
        scale_loss: whether to scale the loss according to the size of the dataset
        loss_function: loss function to use
        stop_function: the function (metric) that is used to determine the end of the training in early stopping
        loss_accum_batch_n: number of batches to accumulate the loss over
        device: device to run the training on
        verbose: whether to print out a message for each optimizer step
        interval: interval at which objective is evaluated to consider early stopping
        patience: number of times the objective is allowed to not become better before the iterator terminates
        epoch: starting epoch
        lr_init: initial learning rate
        max_iter: maximum number of training iterations
        maximize: whether to maximize or minimize the objective function
        tolerance: tolerance for early stopping
        restore_best: whether to restore the model to the best state after early stopping
        lr_decay_steps: how many times to decay the learning rate after no improvement
        lr_decay_factor: factor to decay the learning rate with
        min_lr: minimum learning rate
        warmup_steps: number of batch steps for linear lr warump
        T_max: epoch periodicity of cosine annealing schedluer
        cb: whether to execute callback function
        zig : True if ZIG encoder is used as model
        k_reg: is a dictonary containg the fitted k_values for each mice, applies regularization to size of shape parameter k of gamma distribution if k_reg is not None but a dictionary,
        ema-Span: alpha factor of exponential moving avaerage of validation correlation
        **kwargs:

    Returns:

    """
    print(loss_function)

    def full_objective(model, dataloader, data_key, *args, k_regu=k_reg, **kwargs):
        loss_scale = (
            np.sqrt(len(dataloader[data_key].dataset) / args[0].shape[0])
            if scale_loss
            else 1.0
        )
        if not isinstance(model.core.regularizer(), tuple):
            regularizers = int(
                not detach_core
            ) * model.core.regularizer() + model.readout.regularizer(data_key)
        else:
            regularizers = int(not detach_core) * sum(
                model.core.regularizer()
            ) + model.readout.regularizer(data_key)
        if loss_function == "ZIGLoss":
            # one entry in a tuple corresponds to one paramter of ZIG
            # the output is (theta,k,loc,q)
            positions = None
            # args[0][0:1] removes behavior from the video input data.
            model_output = model(
                args[0].to(device),
                data_key=data_key,
                out_predicts=False,
                train=True,
                positions=positions,
                **kwargs,
            )
            theta = model_output[0]
            k = model_output[1]
            loc = model_output[2]
            q = model_output[3]
            time_left = k.shape[1]

            original_data = args[1].transpose(2, 1)[:, -time_left:, :].to(device)

            # create zero, non zero masks
            comparison_result = original_data >= loc
            nonzero_mask = comparison_result.int()

            comparison_result = original_data <= loc
            zero_mask = comparison_result.int()

            if k_regu:
                k_fitted = k_regu[data_key + "fitted_k"]
                # k is constant over time and has shape (batch_size,time,num_neurons)
                k_output = k[0, 0, :]
                # punish values of k that are far away from the fitted k value, scale the regularization to the size of zig loss
                k_regularized = ((k_fitted - k_output) ** 2).mean() * 7 * 10**7

            else:
                k_regularized = 0

            if (
                len(model_output) > 4
            ):  # that is the case only for the latent space model
                k = k.unsqueeze(-1)
                loc = loc.unsqueeze(-1)
                zero_mask = zero_mask.unsqueeze(-1)
                nonzero_mask = nonzero_mask.unsqueeze(-1)
                original_data = original_data.unsqueeze(-1)
                means = model_output[4]
                sigma_squared = model_output[5]
                n_samples = model_output[7]

                sigma = torch.sqrt(sigma_squared)
                # Mask neurons, which were given in Encoder
                neuron_mask = model_output[8].to(means.device)
                neuron_mask = neuron_mask.unsqueeze(-1).repeat(1, 1, 1, n_samples)
                zig_loss = (
                    -1
                    * loss_scale
                    * (
                        criterion(
                            theta,
                            k,
                            loc=loc,
                            q=q,
                            target=original_data,
                            zero_mask=zero_mask,
                            nonzero_mask=nonzero_mask,
                        )[0]
                    )
                )

                zig_loss.masked_fill_(~neuron_mask, 0)
                zig_loss = zig_loss.sum() + regularizers

                zig_loss = zig_loss / (
                    n_samples
                )  # zigloss is in that case an MC approximate for the mean of log p(y|x,z)
                # calculate KL divergence between Gaussian prior and approximate posterior

                kl_divergence = kl_divergence_gaussian(
                    means, sigma
                )  # * q.shape[2] #kl_divergence is constant across neuron dimension since latent is the same for all neurons

                differences = means[:, 1:] - means[:, :-1]
                neighbor_loss = torch.norm(differences, p=2, dim=2).sum()

                # average loss over batch_size and time
                zig_loss = (
                    1
                    / (means.shape[0] * means.shape[1])
                    * (zig_loss + 5 * kl_divergence)
                )  # loss is ElBO = p(y|x,z) + KL(q(z|x),p(z)), log_det=0 if no flow is applied

            else:
                zig_loss = (
                    -1
                    * loss_scale
                    * criterion(
                        theta,
                        k,
                        loc=loc,
                        q=q,
                        target=original_data,
                        zero_mask=zero_mask,
                        nonzero_mask=nonzero_mask,
                    )[0].sum()
                    + regularizers
                    + k_regularized
                )
            # only zig loss
            if len(model_output) > 4:
                return zig_loss, kl_divergence
            else:
                return zig_loss
            """
            return (
                -1*loss_scale
                * criterion(theta, k,
                            loc=loc, 
                            q=q, 
                            target=original_data, 
                            zero_mask=zero_mask, 
                            nonzero_mask=nonzero_mask)[0].sum()
                + regularizers
            )
            """
        else:
            model_output = model(args[0].to(device), data_key=data_key, **kwargs)
            time_left = model_output.shape[1]

            original_data = args[1].transpose(2, 1)[:, -time_left:, :].to(device)

            return (
                loss_scale
                * criterion(
                    model_output,
                    original_data,
                )
                + regularizers
            )

    ##### Model training ####################################################################################################
    model.to(device)
    set_random_seed(seed)
    model.train()
    if loss_function == "ZIGLoss":
        zig_loss_instance = zero_inflated_losses.ZIGLoss()
        criterion = zig_loss_instance.get_slab_logl
    else:
        criterion = getattr(modules, loss_function)(avg=avg_loss)
    stop_closure = partial(
        getattr(scores, stop_function),
        dataloaders=dataloaders["oracle"],
        device=device,
        per_neuron=False,
        avg=True,
        flow=model.flow,
        cell_coordinates=None,
    )

    n_iterations = len(dataloaders["train"])

    optimizer = torch.optim.AdamW(model.parameters(), lr=lr_init)
    # Define the optimizer to only include parameters that require gradients
    # optimizer = torch.optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr_init)

    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer,
        mode="max" if maximize else "min",
        factor=lr_decay_factor,
        patience=scheduler_patience,
        threshold=tolerance,
        min_lr=min_lr,
        verbose=verbose,
        threshold_mode="abs",
    )
    # set the number of iterations over which you would like to accummulate gradients
    optim_step_count = (
        len(dataloaders["train"].loaders.keys())
        if loss_accum_batch_n is None
        else loss_accum_batch_n
    )
    print(f"optim_step_count = {optim_step_count}")

    if use_wandb:
        wandb.init(
            project=wandb_project,
            entity=wandb_entity,
            # We pass a run name (otherwise it’ll be randomly assigned, like sunshine-lollypop-10)
            name=wandb_name,
            # Track hyperparameters and run metadata
            config={
                "learning_rate": lr_init,
                "architecture": wandb_model_config,
                "dataset": wandb_dataset_config,
                "cur_epochs": max_iter,
                "starting epoch": epoch,
                "lr_decay_steps": lr_decay_steps,
                "lr_decay_factor": lr_decay_factor,
                "min_lr": min_lr,
            },
        )

        wandb.define_metric(name="Epoch", hidden=True)
        wandb.define_metric(name="Batch", hidden=True)
        
    print('wandb initialized')

    batch_no_tot = 0
    ema_values = []
    best_validation_correlation = 0
    # train over epochs
    for epoch, val_obj in early_stopping(
        model,
        stop_closure,
        interval=interval,
        patience=patience,
        start=epoch,
        max_iter=max_iter,
        maximize=maximize,
        tolerance=tolerance,
        restore_best=restore_best,
        scheduler=scheduler,
        lr_decay_steps=lr_decay_steps,
    ):

        # executes callback function if passed in keyword args
        if cb is not None:
            cb()

        # train over batches
        optimizer.zero_grad(set_to_none=True)
        epoch_loss = 0
        epoch_val_loss = 0
        for batch_no, (data_key, batch) in tqdm(
            enumerate(dataloaders["train"]),
            total=n_iterations,
            desc="Epoch {}".format(epoch),
        ):
            batch_no_tot += 1
            # TODO - polly, these two lines are basically the ones you want to change!
            beh = torch.cat([batch['eye_tracker'][:, :, :2].transpose(2, 1), batch['treadmill'][:, :, :2].transpose(2, 1)], axis=1)
            video = batch['screen']
            b_expanded = beh.unsqueeze(-1).unsqueeze(-1)  # or b[:, :, :, None, None]
            # Now broadcast b to match the spatial dimensions [16, 3, 60, 144, 256]
            beh_tiled = b_expanded.expand(-1, -1, -1, video.shape[3], video.shape[4])
            # Concatenate along dim=1 to get [16, 4, 60, 144, 256]
            video = torch.cat([video, beh_tiled], dim=1).to('cuda:0')

            resp = batch['responses'].transpose(2, 1).to('cuda:0')
            
            batch_kwargs = {
                'videos': video,
                'pupil_center_core': batch['eye_tracker'][:, :, 2:].transpose(2, 1).to('cuda:0'),
                'responses': resp
            }
            batch_args = [video, resp  ]
            # batch_args = list(data)
            # batch_kwargs = data._asdict() if not isinstance(data, dict) else data
            # -----
            loss = full_objective(
                model,
                dataloaders["train"],
                data_key,
                *batch_args,
                **batch_kwargs,
                detach_core=detach_core,
            )[0]
            loss.backward()
            epoch_loss += loss.detach()
            if (batch_no + 1) % optim_step_count == 0:
                optimizer.step()
                optimizer.zero_grad(set_to_none=True)
        model.eval()
        ###
        print(epoch_loss / 225, "loss", epoch, "epoch")
        lr = optimizer.param_groups[0]["lr"]
        ###
        ## after - epoch-analysis

        validation_correlation = get_correlations(
            model,
            dataloaders["oracle"],
            device=device,
            as_dict=False,
            per_neuron=False,
            deeplake_ds=False,
            flow=model.flow,
            cell_coordinates=None,
        )

        if save_checkpoints:
            if validation_correlation > best_validation_correlation:
                torch.save(model.state_dict(), f"{checkpoint_save_path}best.pth")
                best_validation_correlation = validation_correlation

        if loss_function == "PoissonLoss" or (not model.latent):
            val_loss = full_objective(
                model,
                dataloaders["oracle"],
                data_key,
                *batch_args,
                **batch_kwargs,
                detach_core=detach_core,
            )
        else:
            val_loss, kl_div = full_objective(
                model,
                dataloaders["oracle"],
                data_key,
                *batch_args,
                **batch_kwargs,
                detach_core=detach_core,
            )

        # torch.save(
        # model.state_dict(), f"toymodels2/temp_save.pth"
        # )

        print(
            f"Epoch {epoch}, Batch {batch_no}, Train loss {loss}, Validation loss {val_loss}"
        )
        print(f"EPOCH={epoch}  validation_correlation={validation_correlation}")

        ema_values.append(validation_correlation)
        ema = calculate_ema(torch.tensor(ema_values), ema_span)[-1]
        if use_wandb:
            wandb_dict = {
                "Epoch Train loss": epoch_loss,
                "Batch": batch_no_tot,
                "Epoch": epoch,
                "validation_correlation": validation_correlation,
                # "log_det": log_det,
                "Epoch validation loss": val_loss,
                "EMA validation loss": ema,
                # "Poisson Loss": pos_loss,
                "ZIG Loss": val_loss,
                "Epoch": epoch,
                "Learning rate": lr,
            }
            wandb.log(wandb_dict)

        model.train()

    ##### Model evaluation ####################################################################################################
    model.eval()
    # if save_checkpoints:
    # torch.save(model.state_dict(), f"{checkpoint_save_path}final.pth")

    # Compute avg validation and test correlation
    validation_correlation = get_correlations(
        model,
        dataloaders["oracle"],
        device=device,
        as_dict=False,
        per_neuron=False,
        deeplake_ds=False,
        flow=model.flow,
        cell_coordinates=None,
    )
    print(f"\n\n FINAL validation_correlation {validation_correlation} \n\n")

    output = {}
    output["validation_corr"] = validation_correlation

    score = np.mean(validation_correlation)
    if use_wandb:
        wandb.finish()

    # removing the checkpoints except the last one
    # to_clean = os.listdir(checkpoint_save_path)
    to_clean = os.listdir("toymodels")
    for f2c in to_clean:
        if "epoch" in f2c:
            os.remove(os.path.join("toymodels", f2c))

    return score

In [17]:
cfg['dataset']['modality_config']['responses']['sampling_rate'] = 30
cfg['dataset']['modality_config']['responses']['chunk_size'] = 60

cfg['dataset']['modality_config']['eye_tracker']['sampling_rate'] = 30
cfg['dataset']['modality_config']['eye_tracker']['chunk_size'] = 60

cfg['dataset']['modality_config']['treadmill']['sampling_rate'] = 30
cfg['dataset']['modality_config']['treadmill']['chunk_size'] = 60

cfg['dataset']['modality_config']['screen']['sampling_rate'] = 30
cfg['dataset']['modality_config']['screen']['chunk_size'] =  60 

cfg.dataset.modality_config.screen.valid_condition = {"tier": "validation"}

val_dl = get_multisession_dataloader(full_paths, cfg)

KeyboardInterrupt: 

In [18]:
print('mew')

mew


In [23]:
dataloaders = {}
dataloaders['train'] = train_dl
# todo - undo it after the validation set labels are updated
# dataloaders["oracle"] = val_dl
dataloaders["oracle"] = {}
for m in full_paths:
    dataloaders["oracle"][m.split('dynamic')[-1].split('-Video')[0]] = get_multisession_dataloader([m], cfg)



In [24]:
lr_inint = 5e-3
min_lr = 1e-5

In [25]:
zig_model.to('cuda:0')

ZIGEncoder(
  (core): Factorized3dCore(
    (_input_weight_regularizer): LaplaceL2norm(
      (laplace): Laplace()
    )
    (temporal_regularizer): DepthLaplaceL21d(
      (laplace): Laplace1d()
    )
    (features): Sequential(
      (layer0): Sequential(
        (conv_spatial): Conv3d(4, 32, kernel_size=(1, 11, 11), stride=(1, 1, 1))
        (conv_temporal): Conv3d(32, 32, kernel_size=(11, 1, 1), stride=(1, 1, 1))
        (norm): BatchNorm3d(32, eps=1e-05, momentum=0.7, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0)
      )
      (layer1): Sequential(
        (conv_spatial_1): Conv3d(32, 64, kernel_size=(1, 5, 5), stride=(1, 1, 1))
        (conv_temporal_1): Conv3d(64, 64, kernel_size=(5, 1, 1), stride=(1, 1, 1))
        (norm): BatchNorm3d(64, eps=1e-05, momentum=0.7, affine=True, track_running_stats=True)
        (nonlin): ELU(alpha=1.0)
      )
      (layer2): Sequential(
        (conv_spatial_2): Conv3d(64, 128, kernel_size=(1, 5, 5), stride=(1, 1, 1))
 

In [26]:
batch['responses'].device

device(type='cpu')

In [None]:
validation_score = standard_trainer(
    zig_model,
    dataloaders,
    111,
    use_wandb=True,
    wandb_name="16_core_channels_latent_mice6_10",
    loss_function="ZIGLoss",
    # loss_function= "PoissonLoss",
    verbose=True,
    lr_decay_steps=4,
    lr_init=lr_inint,
    min_lr=min_lr,
    device=device,
    patience=12,  # 12#8,
    scheduler_patience=10,  # 10#6,
    checkpoint_save_path="./test_training",
)

ZIGLoss
optim_step_count = 12


Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mpollytur[0m ([33mecker-lab[0m). Use [1m`wandb login --relogin`[0m to force relogin
cat: /sys/module/amdgpu/initstate: No such file or directory
ERROR:root:Driver not initialized (amdgpu not found in modules)


wandb initialized


In [None]:
%debug

In [None]:
print('mew')