In [1]:

import glob
import os

import sys
import glob
from pathlib import Path, PurePath
path = Path.cwd()
parent_path = path.parents[1]
sys.path.append(str(PurePath(parent_path, 'neuroformer')))
sys.path.append('neuroformer')
sys.path.append('.')
sys.path.append('../')

import pandas as pd
import numpy as np

import numpy as np
import torch
import pandas as pd
import matplotlib.pyplot as plt

from torch.utils.data.dataloader import DataLoader

import math

from neuroformer.model_neuroformer_2 import Neuroformer, NeuroformerConfig, load_model_and_tokenizer
from neuroformer.utils import get_attr
from neuroformer.trainer import Trainer, TrainerConfig
from neuroformer.utils_2 import (set_seed, update_object, running_jupyter, 
                                 all_device, load_config, 
                                 dict_to_object, object_to_dict, recursive_print,
                                 create_modalities_dict)
from neuroformer.visualize import set_plot_params
from neuroformer.SpikeVidUtils import make_intervals, round_n, SpikeTimeVidData2
from neuroformer.DataUtils import round_n, Tokenizer
from neuroformer.datasets import load_visnav, load_V1AL

parent_path = os.path.dirname(os.path.dirname(os.getcwd())) + "/"
import wandb

# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

from neuroformer.default_args import DefaultArgs, parse_args

if running_jupyter(): # or __name__ == "__main__":
    print("Running in Jupyter")
    args = DefaultArgs()
    args.ckpt_path = "./models/NF.15/Visnav_VR_Expt/lateral/Neuroformer/pos_emb/Neuroformer/1_new/(state_history=6,_state=6,_stimulus=6,_behavior=6,_self_att=6,_modalities=(n_behavior=25))/25"
    args.dataset = "lateral"
    args.config = "./models/NF.15/Visnav_VR_Expt/lateral/Neuroformer/1_new/(state_history=6,_state=6,_stimulus=6,_behavior=6,_self_att=6,_modalities=(n_behavior=25))/25/mconf_finetune_gaze.yaml"
    args.finetune = True
    args.loss_bprop = ["phi", "th"]
else:
    print("Running in terminal")
    args = parse_args()

# SET SEED - VERY IMPORTANT
set_seed(args.seed)

print(f"CONTRASTIUVEEEEEEE {args.contrastive}")
print(f"VISUAL: {args.visual}")
print(f"PAST_STATE: {args.past_state}")

# Use the function
if args.config is None:
    config_path = "./configs/NF_1.5/VisNav_VR_Expt/gru2_only_cls/mconf.yaml"
else:
    config = load_config(args.config)

base_config, base_tokenizer, model = load_model_and_tokenizer(args.ckpt_path)

Running in Jupyter
CONTRASTIUVEEEEEEE False
VISUAL: True
PAST_STATE: True
256 2203


10/19/2023 22:16:08 - INFO - neuroformer.model_neuroformer_2 -   number of parameters: 2.877057e+07


 ///// <=----- Loading model from ./models/NF.15/Visnav_VR_Expt/lateral/Neuroformer/pos_emb/Neuroformer/1_new/(state_history=6,_state=6,_stimulus=6,_behavior=6,_self_att=6,_modalities=(n_behavior=25))/25 -----=> \\\


In [2]:
""" 

-- DATA --
neuroformer/data/OneCombo3_V1AL/
df = response
video_stack = stimulus
DOWNLOAD DATA URL = https://drive.google.com/drive/folders/1jNvA4f-epdpRmeG9s2E-2Sfo-pwYbjeY?usp=sharing


"""

if args.dataset in ["lateral", "medial"]:
    data, intervals, train_intervals, \
    test_intervals, finetune_intervals, \
    callback = load_visnav(args.dataset, config, 
                           selection=config.selection if hasattr(config, "selection") else None)
elif args.dataset == "V1AL":
    data, intervals, train_intervals, \
    test_intervals, finetune_intervals, \
    callback = load_V1AL(config)

spikes = data['spikes']
stimulus = data['stimulus']


In [6]:
intervals

array([0.0000e+00, 5.0000e-02, 1.0000e-01, ..., 7.5275e+02, 7.5280e+02,
       7.5285e+02])

In [5]:
for key in data.keys():
    print(key, data[key].shape)

spikes (2023, 150578)
speed (30117,)
stimulus (30117, 30, 100)
phi (30117,)
th (30117,)


In [21]:
window = config.window.curr
window_prev = config.window.prev
dt = config.resolution.dt

# -------- #

spikes_dict = {
    "ID": data['spikes'],
    "Frames": data['stimulus'],
    "Interval": intervals,
    "dt": config.resolution.dt,
    "id_block_size": config.block_size.id,
    "prev_id_block_size": config.block_size.prev_id,
    "frame_block_size": config.block_size.frame,
    "window": config.window.curr,
    "window_prev": config.window.prev,
    "frame_window": config.window.frame,
}

""" 
 - see mconf.yaml "modalities" structure:

modalities:
  behavior:
    n_layers: 4
    window: 0.05
    variables:
      speed:
        data: speed
        dt: 0.05
        predict: true
        objective: regression
      phi:
        data: phi
        dt: 0.05
        predict: true
        objective: regression
      th:
        data: th
        dt: 0.05
        predict: true
        objective: regression


Modalities: any additional modalities other than spikes and frames
    Behavior: the name of the <modality type>
        Variables: the name of the <modality>
            Data: the data of the <modality> in shape (n_samples, n_features)
            dt: the time resolution of the <modality>, used to index n_samples
            Predict: whether to predict this modality or not.
                     If you set predict to false, then it will 
                     not be used as an input in the model,
                     but rather to be predicted as an output. 
            Objective: regression or classification

"""

frames = {'feats': stimulus, 'callback': callback, 'window': config.window.frame, 'dt': config.resolution.dt}


"""
callback: this function is used to get the frame at a given time point
given your specific stimulus/video data structure and resolution.
See neuroformer.visnav_callback / combo3_V1AL_callback
"""

modalities = create_modalities_dict(data, config.modalities) if get_attr(config, 'modalities', None) else None

max_window = max(config.window.curr, config.window.prev)
dt_range = math.ceil(max_window / dt) + 1
n_dt = [round_n(x, dt) for x in np.arange(0, max_window + dt, dt)]

token_types = {
    'ID': {'tokens': list(np.arange(0, data['spikes'].shape[0] if isinstance(data['spikes'], np.ndarray) \
                                  else data['spikes'][1].shape[0]))},
    'dt': {'tokens': n_dt, 'resolution': dt},
    **({
        modality: {
            'tokens': sorted(list(set(eval(modality)))),
            'resolution': details.get('resolution')
        }
        # if we have to classify the modality, 
        # then we need to tokenize it
        for modality, details in modalities.items()
        if details.get('predict', False) and details.get('objective', '') == 'classification'
    } if modalities is not None else {})
}
tokenizer = Tokenizer(token_types, max_window, dt)

ID vocab size: 2026
dt vocab size: 14


In [None]:
if modalities is not None:
    for modality_type, modality in modalities.items():
        for variable_type, variable in modality.items():
            print(variable_type, variable)

In [None]:
def compare_configs(config1, config2, path=''):
    if isinstance(config1, dict) and isinstance(config2, dict):
        for key in set(config1.keys()).union(config2.keys()):
            if key in config1 and key in config2:
                compare_configs(config1[key], config2[key], path + str(key) + '.')
            elif key in config1:
                print(f'Key {path + str(key)}: present in first config, missing in second config.')
            else:
                print(f'Key {path + str(key)}: missing in first config, present in second config.')
    elif isinstance(config1, list) and isinstance(config2, list):
        if config1 != config2:
            print(f'Key {path}: {config1} != {config2}')
    else:
        if config1 != config2:
            print(f'Key {path}: {config1} != {config2}')

compare_configs(config, base_config)

Key : namespace(dropout=namespace(attn=0.2, embd=0.2, pos=0.2, resid=0.2, temp=0.2, b=0.45, id=0.35, im=0.35), block_size=namespace(behavior=15, frame=446, id=100, prev_id=700), layers=namespace(state_history=6, state=6, stimulus=6, behavior=6, self_att=6, modalities=namespace(n_behavior=25)), sparse=namespace(p=None, mask=False, topk=None, topk_frame=None, topk_id=None, topk_prev_id=None), window=namespace(frame=None, curr=0.05, prev=0.05, speed=0.05), modalities=namespace(behavior=namespace(n_layers=4, window=0.05, variables=namespace(speed=namespace(data='speed', dt=0.05, predict=True, objective='regression'), phi=namespace(data='phi', dt=0.05, predict=True, objective='regression'), th=namespace(data='th', dt=0.05, predict=True, objective='regression')))), predict=None, frame_encoder=namespace(conv_layer=True, kernel_size=[4, 5, 5], n_embd=256, n_embd_frames=64, resnet_backbone=False), contrastive=namespace(contrastive=False, vars=['id', 'frames', 'speed'], clip_embd=1024, clip_temp

In [19]:
from neuroformer.DataUtils import NFDataloader

train_dataset = NFDataloader(spikes_dict, tokenizer, config, dataset=args.dataset, 
                             frames=frames, intervals=train_intervals, modalities=modalities)
test_dataset = NFDataloader(spikes_dict, tokenizer, config, dataset=args.dataset, 
                            frames=frames, intervals=test_intervals, modalities=modalities)
finetune_dataset = NFDataloader(spikes_dict, tokenizer, config, dataset=args.dataset, 
                                frames=frames, intervals=finetune_intervals, modalities=modalities)

    
# print(f'train: {len(train_dataset)}, test: {len(test_dataset)}')
iterable = iter(train_dataset)
x, y = next(iterable)
recursive_print(x)

# Update the config
model = Neuroformer(config, tokenizer)

# Create a DataLoader
loader = DataLoader(test_dataset, batch_size=2, shuffle=True, num_workers=0)
iterable = iter(loader)
x, y = next(iterable)
recursive_print(y)
preds, features, loss = model(x, y)


Min Interval: 0.1
Intervals:  12046
Window:  0.05
Window Prev:  0.05
Population Size:  2026
ID Population Size:  2026
DT Population Size:  14
Using explicitly passed intervals
Min Interval: 0.1
Intervals:  12046
Window:  0.05
Window Prev:  0.05
Population Size:  2026
ID Population Size:  2026
DT Population Size:  14
Using explicitly passed intervals
Min Interval: 0.1
Intervals:  120
Window:  0.05
Window Prev:  0.05
Population Size:  2026
ID Population Size:  2026
DT Population Size:  14
Using explicitly passed intervals
id_prev torch.Size([700]) torch.int64
dt_prev torch.Size([700]) torch.float32
pad_prev torch.Size([]) torch.int64
id torch.Size([100]) torch.int64
dt torch.Size([100]) torch.float32
pad torch.Size([]) torch.int64
interval torch.Size([]) torch.float32
trial torch.Size([]) torch.int64
cid torch.Size([2]) torch.float32
pid torch.Size([2]) torch.float32
256 14


10/19/2023 15:03:56 - INFO - neuroformer.model_neuroformer_2 -   number of parameters: 2.800590e+07


id torch.Size([2, 100]) torch.int64
dt torch.Size([2, 100]) torch.int64
modalities_behavior_speed_value torch.Size([2, 1]) torch.float32
modalities_behavior_speed_dt torch.Size([2]) torch.float32
modalities_behavior_phi_value torch.Size([2, 1]) torch.float32
modalities_behavior_phi_dt torch.Size([2]) torch.float32
modalities_behavior_th_value torch.Size([2, 1]) torch.float32
modalities_behavior_th_dt torch.Size([2]) torch.float32


In [18]:
# Set training parameters
MAX_EPOCHS = args.epochs
BATCH_SIZE = args.batch_size
SHUFFLE = True

if config.gru_only:
    model_name = "GRU"
elif config.mlp_only:
    model_name = "MLP"
elif config.gru2_only:
    model_name = "GRU_2.0"
else:
    model_name = "Neuroformer"

MODE = "finetune" if args.finetune else "train"
CKPT_PATH = f"./models/NF.15/Visnav_VR_Expt/{args.dataset}/{model_name}/{args.title}/{MODE}/{str(config.layers)}/{args.seed}".replace("namespace", "").replace(" ", "_")

if args.sweep_id is not None:
    from neuroformer.hparam_sweep import train_sweep
    print(f"-- SWEEP_ID -- {args.sweep_id}")
    wandb.agent(args.sweep_id, function=train_sweep)
else:
    # Create a TrainerConfig and Trainer
    tconf = TrainerConfig(max_epochs=MAX_EPOCHS, batch_size=BATCH_SIZE, learning_rate=1e-4, 
                          num_workers=16, lr_decay=True, patience=3, warmup_tokens=8e7, 
                          decay_weights=True, weight_decay=1.0, shuffle=SHUFFLE,
                          final_tokens=len(train_dataset)*(config.block_size.id) * (MAX_EPOCHS),
                          clip_norm=1.0, grad_norm_clip=1.0,
                          show_grads=False,
                          ckpt_path=CKPT_PATH, no_pbar=False, 
                          dist=args.dist, save_every=0, eval_every=5, min_eval_epoch=50,
                          use_wandb=False, wandb_project="neuroformer", 
                          wandb_group=f"1.5.1_visnav_{args.dataset}", wandb_name=args.title,
                          loss_bprop=args.loss_bprop)

    # finetuning to a new task use the finetuning holdout dataset
    if args.finetune:
        print(f"// FINETUNING... //")
        # update model with weights
        # model.load_state_dict(torch.load(os.path.join(args.ckpt_path, "model.pt"),
        #                                   map_location="cpu"), strict=False)
        trainer = Trainer(model, finetune_dataset, None, tconf, config)
    else:
        trainer = Trainer(model, train_dataset, test_dataset, tconf, config)
    trainer.train()

// FINETUNING... //
-- USE WANDB: False --
not decaying: temp_emb.temp_emb.0.weight
not decaying: temp_emb.temp_emb.0.bias
not decaying: temp_emb.temp_emb.2.weight
not decaying: temp_emb.temp_emb.2.bias
not decaying: temp_emb_prev.temp_emb.0.weight
not decaying: temp_emb_prev.temp_emb.0.bias
not decaying: temp_emb_prev.temp_emb.2.weight
not decaying: temp_emb_prev.temp_emb.2.bias
weight_decay: 1.0


epoch 1  speed_train: 0.83709  phi_train: 0.44549  th_train: 0.60861  id_train: 4.61476  time_train: 0.39752  total_loss: 1.05410 lr 1.200000e-08 precision: 0.00219: 100%|██████████| 4/4 [00:02<00:00,  1.93it/s]
10/19/2023 15:03:46 - INFO - neuroformer.trainer -   saving ./models/NF.15/Visnav_VR_Expt/lateral/Neuroformer/None/finetune/(state_history=6,_state=6,_stimulus=6,_behavior=6,_self_att=6,_modalities=(n_behavior=25))/69


namespace(dropout=namespace(attn=0.2, embd=0.2, pos=0.2, resid=0.2, temp=0.2, b=0.45, id=0.35, im=0.35), block_size=namespace(behavior=15, frame=446, id=100, prev_id=700), layers=namespace(state_history=6, state=6, stimulus=6, behavior=6, self_att=6, modalities=namespace(n_behavior=25)), sparse=namespace(p=None, mask=False, topk=None, topk_frame=None, topk_id=None, topk_prev_id=None), window=namespace(frame=None, curr=0.05, prev=0.05, speed=0.05), modalities=namespace(behavior=namespace(n_layers=4, window=0.05, variables=namespace(speed=namespace(data='speed', dt=0.05, predict=True, objective='regression'), phi=namespace(data='phi', dt=0.05, predict=True, objective='regression'), th=namespace(data='th', dt=0.05, predict=True, objective='regression')))), predict=None, frame_encoder=namespace(conv_layer=True, kernel_size=[4, 5, 5], n_embd=256, n_embd_frames=64, resnet_backbone=False), contrastive=namespace(contrastive=False, vars=['id', 'frames', 'speed'], clip_embd=1024, clip_temp=0.5),