In [1]:

import glob
import os

import sys
import glob
from pathlib import Path, PurePath
path = Path.cwd()
parent_path = path.parents[1]
sys.path.append(str(PurePath(parent_path, 'neuroformer')))
sys.path.append('neuroformer')
sys.path.append('.')
sys.path.append('../')

import numpy as np
import pandas as pd

import torch
from torch.utils.data.dataloader import DataLoader

import math

from neuroformer.model_neuroformer import Neuroformer, NeuroformerConfig
from neuroformer.utils import get_attr
from neuroformer.trainer import Trainer, TrainerConfig
from neuroformer.utils import (set_seed, update_object, running_jupyter, 
                                 all_device, load_config, 
                                 dict_to_object, object_to_dict, recursive_print,
                                 create_modalities_dict)
from neuroformer.visualize import set_plot_params
from neuroformer.data_utils import round_n, Tokenizer, NFDataloader
from neuroformer.datasets import load_visnav, load_V1AL

parent_path = os.path.dirname(os.path.dirname(os.getcwd())) + "/"
import wandb

# set up logging
import logging
logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
)

from neuroformer.default_args import DefaultArgs, parse_args

if running_jupyter(): # or __name__ == "__main__":
    print("Running in Jupyter")
    args = DefaultArgs()
else:
    print("Running in terminal")
    args = parse_args()

# SET SEED - VERY IMPORTANT
set_seed(args.seed)

print(f"CONTRASTIUVEEEEEEE {args.contrastive}")
print(f"VISUAL: {args.visual}")
print(f"PAST_STATE: {args.past_state}")

# Use the function
if args.config is None:
    config_path = "./models/NF.15/Visnav_VR_Expt/lateral/Neuroformer/predict_all/(state_history=6,_state=6,_stimulus=6,_behavior=6,_self_att=6,_modalities=(n_behavior=25))/25/mconf.yaml"

    #### THIS IS THE WEIGHTS YOU WANT TO CONTINUE TRAINING FROM
    args.resume = "./models/NF.15/Visnav_VR_Expt/lateral/Neuroformer/predict_all/(state_history=6,_state=6,_stimulus=6,_behavior=6,_self_att=6,_modalities=(n_behavior=25))/25/model.pt"
else:
    config_path = args.config
config = load_config(config_path)  # replace 'config.yaml' with your file path


  from .autonotebook import tqdm as notebook_tqdm


Running in Jupyter
CONTRASTIUVEEEEEEE False
VISUAL: True
PAST_STATE: True


In [2]:
""" 

-- DATA --
neuroformer/data/OneCombo3_V1AL/
df = response
video_stack = stimulus
DOWNLOAD DATA URL = https://drive.google.com/drive/folders/1jNvA4f-epdpRmeG9s2E-2Sfo-pwYbjeY?usp=sharing


"""

if args.dataset in ["lateral", "medial"]:
    data, intervals, train_intervals, \
    test_intervals, finetune_intervals, \
    callback = load_visnav(args.dataset, config, 
                           selection=config.selection if hasattr(config, "selection") else None)
elif args.dataset == "V1AL":
    data, intervals, train_intervals, \
    test_intervals, finetune_intervals, \
    callback = load_V1AL(config)

spikes = data['spikes']
stimulus = data['stimulus']


In [34]:
window = config.window.curr
window_prev = config.window.prev
dt = config.resolution.dt

# -------- #

spikes_dict = {
    "ID": data['spikes'],
    "Frames": data['stimulus'],
    "Interval": intervals,
    "dt": config.resolution.dt,
    "id_block_size": config.block_size.id,
    "prev_id_block_size": config.block_size.prev_id,
    "frame_block_size": config.block_size.frame,
    "window": config.window.curr,
    "window_prev": config.window.prev,
    "frame_window": config.window.frame,
}

""" 
 - see mconf.yaml "modalities" structure:

modalities:
  behavior:
    n_layers: 4
    window: 0.05
    variables:
      speed:
        data: speed
        dt: 0.05
        predict: true
        objective: regression
      phi:
        data: phi
        dt: 0.05
        predict: true
        objective: regression
      th:
        data: th
        dt: 0.05
        predict: true
        objective: regression


Modalities: any additional modalities other than spikes and frames
    Behavior: the name of the <modality type>
        Variables: the name of the <modality>
            Data: the data of the <modality> in shape (n_samples, n_features)
            dt: the time resolution of the <modality>, used to index n_samples
            Predict: whether to predict this modality or not.
                     If you set predict to false, then it will 
                     not be used as an input in the model,
                     but rather to be predicted as an output. 
            Objective: regression or classification

"""

frames = {'feats': stimulus, 'callback': callback, 'window': config.window.frame, 'dt': config.resolution.dt}

  
def configure_token_types(config, modalities, data):
    max_window = max(config.window.curr, config.window.prev)
    dt_range = math.ceil(max_window / config.resolution.dt) + 1

    def round_n(x, resolution):
        return round(x, int(-math.log10(resolution)))

    n_dt = [round_n(x, config.resolution.dt) for x in np.arange(0, max_window + config.resolution.dt, config.resolution.dt)]

    token_types = {
        'ID': {
            'tokens': list(np.arange(0, data['spikes'].shape[0] if isinstance(data['spikes'], np.ndarray) else data['spikes'][1].shape[0]))
        },
        'dt': {
            'tokens': n_dt,
            'resolution': config.resolution.dt
        }
    }

    if modalities is not None and config.modalities is not None:
        for modality, details in modalities.items():
            if details.get('predict', False) and details.get('objective', '') == 'classification':
                token_types[modality] = {
                    'tokens': sorted(list(set(eval(modality)))),
                    'resolution': details.get('resolution')
                }

    return token_types

modalities = create_modalities_dict(data, config.modalities) if get_attr(config, 'modalities', None) else None
token_types = configure_token_types(config, modalities, data)
tokenizer = Tokenizer(token_types)

ID vocab size: 2026
dt vocab size: 9


In [7]:
if modalities is not None:
    for modality_type, modality in modalities.items():
        for variable_type, variable in modality.items():
            print(variable_type, variable)



n_layers 4
variables {'phi': {'data': array([2.15574538, 2.15574538, 3.4186857 , ..., 0.20013816, 0.18952116,
       0.18952116]), 'dt': 0.05, 'window': 0.05, 'predict': True, 'objective': 'regression'}, 'speed': {'data': array([ 0.03216422, -0.08701274, -0.36626688, ..., -0.657169  ,
       -0.79575145, -0.97988075], dtype=float32), 'dt': 0.05, 'window': 0.05, 'predict': True, 'objective': 'regression'}, 'th': {'data': array([0.50604788, 0.50604788, 0.39550107, ..., 1.84192663, 1.7470296 ,
       1.7470296 ]), 'dt': 0.05, 'window': 0.05, 'predict': True, 'objective': 'regression'}}


In [8]:
train_dataset = NFDataloader(spikes_dict, tokenizer, config, dataset=args.dataset, 
                             frames=frames, intervals=train_intervals, modalities=modalities)
test_dataset = NFDataloader(spikes_dict, tokenizer, config, dataset=args.dataset, 
                            frames=frames, intervals=test_intervals, modalities=modalities)
finetune_dataset = NFDataloader(spikes_dict, tokenizer, config, dataset=args.dataset, 
                                frames=frames, intervals=finetune_intervals, modalities=modalities)

    
# print(f'train: {len(train_dataset)}, test: {len(test_dataset)}')
iterable = iter(train_dataset)
x, y = next(iterable)
print(x['id'])
print(x['dt'])
recursive_print(x)

# Update the config
config.id_vocab_size = tokenizer.ID_vocab_size
model = Neuroformer(config, tokenizer)

# Create a DataLoader
loader = DataLoader(test_dataset, batch_size=2, shuffle=True, num_workers=0)
iterable = iter(loader)
x, y = next(iterable)
recursive_print(y)
preds, features, loss = model(x, y)

# Set training parameters
MAX_EPOCHS = 250
BATCH_SIZE = 32 * 5
SHUFFLE = True

if config.gru_only:
    model_name = "GRU"
elif config.mlp_only:
    model_name = "MLP"
elif config.gru2_only:
    model_name = "GRU_2.0"
else:
    model_name = "Neuroformer"

CKPT_PATH = f"./models/NF.15/Visnav_VR_Expt/{args.dataset}/{model_name}/{args.title}/{str(config.layers)}/{args.seed}"
CKPT_PATH = CKPT_PATH.replace("namespace", "").replace(" ", "_")

if os.path.exists(CKPT_PATH):
    counter = 1
    print(f"CKPT_PATH {CKPT_PATH} exists!")
    while os.path.exists(CKPT_PATH + f"_{counter}"):
        counter += 1

if args.resume is not None:
    model.load_state_dict(torch.load(args.resume))

Min Interval: 0.1
Intervals:  24092
Window:  0.05
Window Prev:  0.05
Population Size:  202204
ID Population Size:  202204
DT Population Size:  9
Using explicitly passed intervals
Min Interval: 0.1
Intervals:  24092
Window:  0.05
Window Prev:  0.05
Population Size:  202204
ID Population Size:  202204
DT Population Size:  9
Using explicitly passed intervals
Min Interval: 0.1
Intervals:  240
Window:  0.05
Window Prev:  0.05
Population Size:  202204
ID Population Size:  202204
DT Population Size:  9
Using explicitly passed intervals
tensor([202201, 122300, 136800, 132100, 134300,    100,  10800,  62500, 132100,
        132100, 132100,  73200, 132100, 133800,  12100,  19900, 132100, 185300,
        202202, 202203, 202203, 202203, 202203, 202203, 202203, 202203, 202203,
        202203, 202203, 202203, 202203, 202203, 202203, 202203, 202203, 202203,
        202203, 202203, 202203, 202203, 202203, 202203, 202203, 202203, 202203,
        202203, 202203, 202203, 202203, 202203, 202203, 202203, 2

04/23/2024 12:24:32 - INFO - neuroformer.model_neuroformer -   number of parameters: 1.306942e+08


id torch.Size([2, 100]) torch.int64
dt torch.Size([2, 100]) torch.int64
modalities_behavior_phi_value torch.Size([2, 1]) torch.float32
modalities_behavior_phi_dt torch.Size([2]) torch.float32
modalities_behavior_speed_value torch.Size([2, 1]) torch.float32
modalities_behavior_speed_dt torch.Size([2]) torch.float32
modalities_behavior_th_value torch.Size([2, 1]) torch.float32
modalities_behavior_th_dt torch.Size([2]) torch.float32


RuntimeError: Error(s) in loading state_dict for Neuroformer:
	size mismatch for tok_emb.weight: copying a param with shape torch.Size([2026, 256]) from checkpoint, the shape in current model is torch.Size([202204, 256]).
	size mismatch for head_id.weight: copying a param with shape torch.Size([2026, 256]) from checkpoint, the shape in current model is torch.Size([202204, 256]).

In [None]:
"""

Here's how to load all weights except
for the token embeddings!

"""

from neuroformer.utils import load_pretrained_weights

if args.resume is not None:
    load_pretrained_weights(model, args.resume,
                            omit_modules='tok_emb')

Pretrained weights loaded from ./models/predict_all_behavior/(state_history=6,_state=6,_stimulus=6,_behavior=6,_self_att=6,_modalities=(n_behavior=25))/25/model.pt, omitting modules: tok_emb


In [None]:
if args.sweep_id is not None:
    # this is for hyperparameter sweeps
    from neuroformer.hparam_sweep import train_sweep
    print(f"-- SWEEP_ID -- {args.sweep_id}")
    wandb.agent(args.sweep_id, function=train_sweep)
else:
    # Create a TrainerConfig and Trainer
    tconf = TrainerConfig(max_epochs=MAX_EPOCHS, batch_size=BATCH_SIZE, learning_rate=1e-4, 
                          num_workers=16, lr_decay=True, patience=3, warmup_tokens=8e7, 
                          decay_weights=True, weight_decay=1.0, shuffle=SHUFFLE,
                          final_tokens=len(train_dataset)*(config.block_size.id) * (MAX_EPOCHS),
                          clip_norm=1.0, grad_norm_clip=1.0,
                          show_grads=False,
                          ckpt_path=CKPT_PATH, no_pbar=False, 
                          dist=args.dist, save_every=0, eval_every=5, min_eval_epoch=50,
                          use_wandb=True, wandb_project="neuroformer", 
                          wandb_group=f"1.5.1_visnav_{args.dataset}", wandb_name=args.title)

    trainer = Trainer(model, train_dataset, test_dataset, tconf, config)
    trainer.train()