In [None]:
import multiprocessing as mp

# Foundational SSM core imports
from foundational_ssm.loaders import get_brainset_data_loader
from foundational_ssm.constants import DATA_ROOT
from foundational_ssm.samplers import TrialSampler
import os 
import equinox as eqx

mp.set_start_method("spawn", force=True) # otherwise causes deadlock on jax.

data_root = '../' + DATA_ROOT # change to the folder holding the brainsets
config_dir = '../configs/dataset' # change
dataset_args = {
    'keep_files_open': False,
    'lazy': True,
    'split': 'train' # or 'train' 
    'min_window_length': 3.280
}
dataloader_args = {
    'batch_size': 128, # Adjust per your system capacity
    'num_workers': 10,
    'persistent_workers': False
}
sampler = 'SequentialFixedWindowSampler'
sampler_args = { 
                'window_length': 3.279,
                'drop_short': False 
                }

dataset, data_loader = get_brainset_data_loader(
    dataset_args=dataset_args,
    sampler = sampler,
    sampler_args = sampler_args,
    dataloader_args = dataloader_args,
    sampling_rate = 200,
    dataset_cfg = os.path.join(config_dir, 'reaching.yaml'),
    data_root = data_root
)

sessions = dataset.get_session_ids() # list of sessions in your dataset
sampling_intervals = dataset.get_sampling_intervals() # list of sampling intervals for each session

In [7]:
import jax 
import jax.numpy as jnp 
import numpy as np
from tqdm import tqdm 

metrics = {}  # New: store metrics per group
all_preds = []
all_targets = []
all_dataset_group_idxs = []
for batch_idx, batch in tqdm(enumerate(data_loader)):
    batch = {k: jax.device_put(np.array(v)) for k, v in batch.items()}
    dataset_group_idxs = batch["dataset_group_idx"]
    inputs = batch["neural_input"]
    targets = batch["behavior_input"]

    all_targets.append(targets)
    all_dataset_group_idxs.append(dataset_group_idxs)
all_targets = jnp.concatenate(all_targets, axis=0)
all_dataset_group_idxs = jnp.concatenate(all_dataset_group_idxs, axis=0)

for group_idx in tqdm(jnp.unique(all_dataset_group_idxs)):
    group_targets = all_targets[all_dataset_group_idxs == group_idx]
    variance = jnp.var(group_targets.reshape(-1, 2), axis=0)
    metrics[int(group_idx)] = {
        "variance": variance.mean()
    }
metrics

  batch = {k: jax.device_put(np.array(v)) for k, v in batch.items()}
  slope = (y_hi - y_lo) / (x_hi - x_lo)[:, None]
  slope = (y_hi - y_lo) / (x_hi - x_lo)[:, None]
  slope = (y_hi - y_lo) / (x_hi - x_lo)[:, None]
  slope = (y_hi - y_lo) / (x_hi - x_lo)[:, None]
560it [05:28,  1.71it/s]
100%|██████████| 8/8 [00:00<00:00, 619.77it/s]


{0: {'variance': Array(1., dtype=float32)},
 1: {'variance': Array(1.0000001, dtype=float32)},
 2: {'variance': Array(0.9999999, dtype=float32)},
 3: {'variance': Array(1., dtype=float32)},
 5: {'variance': Array(0.99999994, dtype=float32)},
 6: {'variance': Array(1., dtype=float32)},
 7: {'variance': Array(0.9999999, dtype=float32)},
 8: {'variance': Array(0.9999999, dtype=float32)}}

In [2]:
pretrain_config_path = "/cs/student/projects1/ml/2024/mlaimon/foundational_ssm/configs/dataset/pretrain_train_and_val.yaml"

pretrain_dataset = TorchBrainDataset(
        root="../"+DATA_ROOT,                # root directory where .h5 files are found
        # recording_id=recording_id,  # you either specify a single recording ID
        config=pretrain_config_path,                 # or a config for multi-session training / more complex configs
        keep_files_open=True,
        lazy=True,
        split="train"
    )

pretrain_sampling_intervals = pretrain_dataset.get_sampling_intervals()        
DATASET_GROUP_INFO: Dict[Tuple[str, str, str], Dict[str, Any]] = {}

for i, (recording_id, train_intervals) in tqdm(enumerate(pretrain_sampling_intervals.items())):
    recording_data = pretrain_dataset.get_recording_data(recording_id)
    num_units = int(np.max(recording_data.spikes.unit_index))
    dataset, subject, task = parse_session_id(recording_id)
    train_duration = np.sum(train_intervals.end - train_intervals.start)
    # print(type(recording_data.cursor.vel))
    if recording_id.startswith("pei_pandarinath_nlb_2021"):
        behavior_sampling_rate = np.min(recording_data.hand.timestamps[1:] - recording_data.hand.timestamps[:-1])
    else:    
        behavior_sampling_rate = np.min(recording_data.cursor.timestamps[1:] - recording_data.cursor.timestamps[:-1])
    if (dataset, subject, task) not in DATASET_GROUP_INFO:
        DATASET_GROUP_INFO[(dataset, subject, task)] = {
            "max_num_units": num_units,
            "behavior_dim": 2,
            "train_duration": float(round(train_duration, 2)),
            "min_behavior_sampling_rate": behavior_sampling_rate,
        }
    else:
        DATASET_GROUP_INFO[(dataset, subject, task)]["max_num_units"] = max(DATASET_GROUP_INFO[(dataset, subject, task)]["max_num_units"], num_units)
        DATASET_GROUP_INFO[(dataset, subject, task)]["train_duration"] = DATASET_GROUP_INFO[(dataset, subject, task)]["train_duration"] + train_duration
        DATASET_GROUP_INFO[(dataset, subject, task)]["min_behavior_sampling_rate"] = min(DATASET_GROUP_INFO[(dataset, subject, task)]["min_behavior_sampling_rate"], behavior_sampling_rate)

148it [25:58, 10.53s/it] 


In [7]:
downstream_t_rt_config_path = "/cs/student/projects1/ml/2024/mlaimon/foundational_ssm/configs/dataset/downstream_t_rt.yaml"
downstream_t_co_config_path = "/cs/student/projects1/ml/2024/mlaimon/foundational_ssm/configs/dataset/downstream_t_co.yaml"

downstream_t_co_dataset = TorchBrainDataset(
        root="../"+DATA_ROOT,                # root directory where .h5 files are found
        # recording_id=recording_id,  # you either specify a single recording ID
        config=downstream_t_co_config_path,                 # or a config for multi-session training / more complex configs
        keep_files_open=True,
        lazy=True,
        split="train"
    )

downstream_t_co_sampling_intervals = downstream_t_co_dataset.get_sampling_intervals()        
DATASET_GROUP_INFO: Dict[Tuple[str, str, str], Dict[str, Any]] = {}

for i, (recording_id, train_intervals) in tqdm(enumerate(downstream_t_co_sampling_intervals.items())):
    recording_data = downstream_t_co_dataset.get_recording_data(recording_id)
    num_units = int(np.max(recording_data.spikes.unit_index))
    dataset, subject, task = parse_session_id(recording_id)
    train_duration = np.sum(train_intervals.end - train_intervals.start)
    # print(type(recording_data.cursor.vel))
    if recording_id.startswith("pei_pandarinath_nlb_2021"):
        behavior_sampling_rate = np.min(recording_data.hand.timestamps[1:] - recording_data.hand.timestamps[:-1])
    else:    
        behavior_sampling_rate = np.min(recording_data.cursor.timestamps[1:] - recording_data.cursor.timestamps[:-1])
    if (dataset, subject, task) not in DATASET_GROUP_INFO:
        DATASET_GROUP_INFO[(dataset, subject, task)] = {
            "max_num_units": num_units,
            "behavior_dim": 2,
            "train_duration": float(round(train_duration, 2)),
            "min_behavior_sampling_rate": behavior_sampling_rate,
        }
    else:
        DATASET_GROUP_INFO[(dataset, subject, task)]["max_num_units"] = max(DATASET_GROUP_INFO[(dataset, subject, task)]["max_num_units"], num_units)
        DATASET_GROUP_INFO[(dataset, subject, task)]["train_duration"] = DATASET_GROUP_INFO[(dataset, subject, task)]["train_duration"] + train_duration
        DATASET_GROUP_INFO[(dataset, subject, task)]["min_behavior_sampling_rate"] = min(DATASET_GROUP_INFO[(dataset, subject, task)]["min_behavior_sampling_rate"], behavior_sampling_rate)

4it [00:01,  3.77it/s]


In [8]:
DATASET_GROUP_INFO

{('perich_miller_population_2018',
  't',
  'center_out_reaching'): {'max_num_units': 58, 'behavior_dim': 2, 'train_duration': np.float64(2193.2977666666666), 'min_behavior_sampling_rate': np.float64(0.009999999999990905)}}