In [13]:
from nlb_tools.nwb_interface import NWBDataset
from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors, make_eval_target_tensors, save_to_h5
from nlb_tools.evaluation import evaluate
import os
import pandas as pd
import numpy as np
from foundational_ssm.collate import pad_collate
import torch

dataset_folder = '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/' 

datasets = [
    {'name':'mc_maze', 'subpath':'./000128/sub-Jenkins/'},
    {'name':'mc_rtt', 'subpath':'./000129/sub-Indy/'},
    {'name':'area2_bump', 'subpath':'./000127/sub-Han/'},
    {'name':'dmfc_rsg', 'subpath':'./000130/sub-Haydn/'},
]

import h5py
DTYPE_FLOAT = np.float32 

def dict_to_h5(tensor_dict, output_h5_file):
    with h5py.File(output_h5_file, 'w') as f:
        # Iterate through the items in your 'batch' dictionary
        for key, value in tensor_dict.items():
            # Convert PyTorch tensor to NumPy array before saving
            # Ensure data type consistency for saving
            if torch.is_tensor(value):
                # For boolean tensors, convert to int8 if you want them to take less space in HDF5
                # (np.bool_ usually takes 1 byte, but some systems/HDF5 viewers prefer integer)
                if value.dtype == torch.bool:
                    data_to_save = value.to(torch.int8).numpy()
                else:
                    data_to_save = value.numpy()
            else:
                data_to_save = value # If any value isn't a tensor (e.g., a simple scalar or list)

            f.create_dataset(key, data=data_to_save)
            print(f"  - Saved '{key}' with shape {data_to_save.shape} and dtype {data_to_save.dtype}")

        print(f"Successfully saved all data from 'batch' to {output_h5_file}")

def pad_collate(batch, fixed_seq_len=None):
    # Assume batch is a list of dicts with keys: 'neural_input', 'behavior_input', etc.
    # Each 'neural_input' is a tensor of shape (timesteps, units)
    neural_inputs = [item['neural_input'] for item in batch if item is not None]  # (timesteps, units)
    behavioral_inputs = [item['behavior_input'] for item in batch if item is not None]
    
    # Determine the fixed sequence length
    if fixed_seq_len is None:
        max_len = max(x.shape[0] for x in neural_inputs)
    else:
        max_len = fixed_seq_len

    # Pad or truncate each sequence to fixed length
    def pad_or_truncate(tensor, max_len):
        seq_len = tensor.shape[0]
        if seq_len == max_len:
            return tensor
        elif seq_len > max_len:
            return tensor[:max_len]
        else:
            pad_shape = (max_len - seq_len,) + tensor.shape[1:]
            pad_tensor = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)
            return torch.cat([tensor, pad_tensor], dim=0)

    padded_neural = torch.stack([pad_or_truncate(x, max_len) for x in neural_inputs if x is not None])  # (batch, max_len, units)
    padded_behavior = torch.stack([pad_or_truncate(x, max_len) for x in behavioral_inputs if x is not None])

    # Create mask: 1 for real data, 0 for padding
    lengths = [x.shape[0] for x in neural_inputs]
    mask = torch.zeros((len(batch), max_len), dtype=torch.bool)
    for i, l in enumerate(lengths):
        mask[i, :min(l, max_len)] = 1

    # Stack other fields (e.g., dataset_group_idx)
    dataset_group_idx = torch.stack([item['dataset_group_idx'] for item in batch])
    
    return {
        'neural_input': padded_neural,
        'behavior_input': padded_behavior,
        'mask': mask,
        'dataset_group_idx': dataset_group_idx,
        # add other fields as needed
    }
    
import numpy as np
import torch
from typing import Optional

def build_split_batch(nwb_dataset, splits, prepend_duration=299, dataset_group_idx=9, bin_width=5, dtype_neural=np.int8, dtype_behavior=np.float32, train_std: Optional[np.ndarray]=None, behavior_attribute="cursor_vel", start_field='start_time', end_field='end_time'):
    """Build padded batch for a given split.

    If split == 'train' and train_std is None, compute per-axis std across all valid behavior timepoints and normalize behavior by it.
    If split != 'train' and train_std is provided, normalize behavior by the provided train_std.

    Returns a dict with tensors; when train std is computed it is added under key 'behavior_std' (numpy array).
    """
    # Ensure data resampled already if required
    if bin_width is not None:
        nwb_dataset.resample(bin_width)

    # Trial rows for this split
    trials = nwb_dataset.trial_info
    sel = trials['split'].isin(splits if isinstance(splits, list) else [splits])
    trials_sel = trials[sel].reset_index(drop=True)
    if len(trials_sel) == 0:
        return {'neural_input': torch.empty(0), 'behavior_input': torch.empty(0), 'mask': torch.empty(0), 'dataset_group_idx': torch.empty(0)}

    # Convert series to numpy arrays once
    # adjust these attribute names to match what you actually use
    behavior_arr = getattr(nwb_dataset.data, behavior_attribute).to_numpy(dtype=dtype_behavior)    # shape (T, B)
    neural_arr   = nwb_dataset.data.spikes.to_numpy(dtype=dtype_neural)         # shape (T, N)
    time_index = nwb_dataset.data.index.values

    # compute start/end times for each trial (include prepend)
    starts = trials_sel[start_field] - pd.Timedelta(prepend_duration, 'ms')
    ends   = trials_sel[end_field]

    # Map timestamps to integer indices via searchsorted (faster than repeated .loc)
    start_idx = np.searchsorted(time_index, starts.values)
    end_idx = np.searchsorted(time_index, ends.values, side='right')

    # Clip indices to valid range
    T = behavior_arr.shape[0]
    start_idx = np.clip(start_idx, 0, T-1)
    end_idx = np.clip(end_idx, 1, T)  # end exclusive

    lengths = (end_idx - start_idx).astype(int)
    max_len = lengths.max()

    # dims
    beh_dim = behavior_arr.shape[1] if behavior_arr.ndim > 1 else 1
    neu_dim = neural_arr.shape[1] if neural_arr.ndim > 1 else 1
    n_trials = len(trials_sel)

    # Pre-allocate padded numpy arrays
    padded_behavior = np.zeros((n_trials, max_len, beh_dim), dtype=behavior_arr.dtype)
    padded_neural   = np.zeros((n_trials, max_len, neu_dim), dtype=neural_arr.dtype)
    mask = np.zeros((n_trials, max_len), dtype=bool)

    # Fill slices (short Python loop; memory copies are contiguous)
    for i, (s, e) in enumerate(zip(start_idx, end_idx)):
        L = int(e - s)
        if L <= 0:
            continue
        padded_behavior[i, :L, :] = behavior_arr[s:e]
        padded_neural[i, :L, :] = neural_arr[s:e]
        mask[i, :L] = True

    # Compute / apply behavior normalization
    behavior_std = None
    # Flatten only the valid (masked) timepoints across trials
    offset = prepend_duration // bin_width + 1
    mask_flat = mask[:, offset:].reshape(-1)
    if mask_flat.any():
        flat_beh = padded_behavior[:, offset:, ].reshape(-1, beh_dim)
        valid_beh = flat_beh[mask_flat]
        if len(splits) == 1 and splits[0] == 'train' and train_std is None:
            behavior_std = np.nanstd(valid_beh, axis=0, ddof=0)
            padded_behavior = padded_behavior / behavior_std[None, None, :]
        elif train_std is not None:
            # Use provided train std to normalize (expected shape (beh_dim,))
            ts = np.asarray(train_std)
            if ts.size != beh_dim:
                raise ValueError(f"train_std length {ts.size} does not match behavior dim {beh_dim}")
            padded_behavior = padded_behavior / ts[None, None, :]
        else:
            behavior_std = np.std(valid_beh, axis=0, ddof=0)
            padded_behavior = padded_behavior / behavior_std[None, None, :]
    else:
        # No valid samples found; leave data as-is
        behavior_std = None

    # Convert to torch tensors
    t_behavior = torch.from_numpy(padded_behavior)
    t_neural = torch.from_numpy(padded_neural)
    t_mask = torch.from_numpy(mask)
    t_dataset_idx = torch.full((n_trials,), dataset_group_idx, dtype=torch.int8)

    # Optionally include trial_idx
    t_trial_idx = torch.from_numpy(trials_sel['trial_id'].values.astype(np.int32))

    out = {
        'neural_input': t_neural,
        'behavior_input': t_behavior,
        'mask': t_mask,
        'dataset_group_idx': t_dataset_idx,
        'trial_idx': t_trial_idx
    }
    if behavior_std is not None:
        out['behavior_std'] = behavior_std

    return out

# RTT , Target-based trials

In [111]:

d = datasets[1]
dataset_name = d['name']
dataset_subpath = d['subpath']
train_trial_split = ['train', 'val']

raw_data_path = os.path.join(dataset_folder, 'raw', 'dandi', dataset_subpath) 
processed_data_folder = os.path.join(dataset_folder, 'processed', 'nlb')
processed_data_path = os.path.join(processed_data_folder, dataset_name + '.h5')
trial_info_path = os.path.join(processed_data_folder, dataset_name + '.csv')

if not os.path.exists(processed_data_folder):
    print(f"Creating directory: {processed_data_folder}")
    os.makedirs(processed_data_folder, exist_ok=True)

nwb_dataset = NWBDataset(raw_data_path, split_heldout=False)

# Find when target pos changes
has_change = nwb_dataset.data.target_pos.fillna(-1000).diff(axis=0).any(axis=1) # filling NaNs with arbitrary scalar to treat as one block
# Find if target pos change corresponds to NaN-padded gap between files
change_nan = nwb_dataset.data[has_change].isna().any(axis=1)
# Drop trials containing the gap and immediately before and after, as those trials may be cut short
drop_trial = (change_nan | change_nan.shift(1, fill_value=True) | change_nan.shift(-1, fill_value=True))[:-1]
# Add start and end times to trial info
change_times = nwb_dataset.data.index[has_change]
start_times = change_times[:-1][~drop_trial]
end_times = change_times[1:][~drop_trial]
# Get target position per trial
start_pos = nwb_dataset.data.target_pos.loc[start_times - pd.Timedelta(1, 'ms')].to_numpy().tolist()
target_pos = nwb_dataset.data.target_pos.loc[start_times].to_numpy().tolist()
# Compute reach distance and angle
reach_dist = nwb_dataset.data.target_pos.loc[end_times - pd.Timedelta(1, 'ms')].to_numpy() - nwb_dataset.data.target_pos.loc[start_times - pd.Timedelta(1, 'ms')].to_numpy()
reach_angle = np.arctan2(reach_dist[:, 1], reach_dist[:, 0]) / np.pi * 180
# Create trial info
nwb_dataset.trial_info = pd.DataFrame({
    'trial_id': np.arange(len(start_times)),
    'start_time': start_times,
    'end_time': end_times,
    'duration': (end_times - start_times).total_seconds(),
    'start_pos': start_pos,
    'target_pos': target_pos,
    'reach_dist_x': reach_dist[:, 0],
    'reach_dist_y': reach_dist[:, 1],
    'reach_angle': reach_angle,
})
nwb_dataset.resample(5)
nwb_dataset.trial_info['split'] = 'train'
nwb_dataset.trial_info.loc[int(len(nwb_dataset.trial_info)*0.7):, 'split'] = 'val' # 70% train
nwb_dataset.trial_info.to_csv('/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized.csv')
bin_width = 5
nwb_dataset.resample(bin_width)
suffix = '' if (bin_width == 5) else f'_{int(round(bin_width))}'
train_batch = []

# Build train split, compute train std, then build val split using the train_std
train_out = build_split_batch(nwb_dataset, 'train', prepend_duration=299, dataset_group_idx=9, bin_width=5, behavior_attribute="finger_vel")
train_std = train_out.get('behavior_std')
# Remove behavior_std from the batch dict before saving
train_batch = {k: v for k, v in train_out.items() if k != 'behavior_std'}

# Build val batch using train statistics for normalization
val_out = build_split_batch(nwb_dataset, 'val', prepend_duration=299, dataset_group_idx=9, bin_width=5, train_std=train_std, behavior_attribute="finger_vel")
val_batch = {k: v for k, v in val_out.items() if k != 'behavior_std'}

# Persist
dict_to_h5(train_batch, 
           '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_area2bump_prepend_train.h5')
dict_to_h5(val_batch, 
           '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_area2bump_prepend_val.h5')

# Optionally print the computed train std
print('Train behavior per-axis std:', train_std)
    




Dataset already at 5 ms resolution, skipping resampling...
Dataset already at 5 ms resolution, skipping resampling...
Dataset already at 5 ms resolution, skipping resampling...


  - Saved 'neural_input' with shape (378, 768, 130) and dtype int8
  - Saved 'behavior_input' with shape (378, 768, 2) and dtype float32
  - Saved 'mask' with shape (378, 768) and dtype int8
  - Saved 'dataset_group_idx' with shape (378,) and dtype int8
  - Saved 'trial_idx' with shape (378,) and dtype int32
Successfully saved all data from 'batch' to /cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_area2bump_prepend_train.h5
  - Saved 'neural_input' with shape (163, 694, 130) and dtype int8
  - Saved 'behavior_input' with shape (163, 694, 2) and dtype float32
  - Saved 'mask' with shape (163, 694) and dtype int8
  - Saved 'dataset_group_idx' with shape (163,) and dtype int8
  - Saved 'trial_idx' with shape (163,) and dtype int32
Successfully saved all data from 'batch' to /cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_area2bump_prepend_val.h5
Train behavior per-axis std: [68.90626 67.18307]


# RTT, Original Trials Prepended


In [97]:
prepend_duration = 299  # ms to prepend to each trial


d = datasets[1]
dataset_name = d['name']
dataset_subpath = d['subpath']
train_trial_split = ['train', 'val']

raw_data_path = os.path.join(dataset_folder, 'raw', 'dandi', dataset_subpath) 

nwb_dataset = NWBDataset(raw_data_path, split_heldout=False)
bin_width = 5
nwb_dataset.resample(bin_width)
suffix = '' if (bin_width == 5) else f'_{int(round(bin_width))}'
train_batch = []

# Build train split, compute train std, then build val split using the train_std
train_out = build_split_batch(nwb_dataset, 'train', prepend_duration=299, dataset_group_idx=9, bin_width=5, behavior_attribute="finger_vel")
train_std = train_out.get('behavior_std')
# Remove behavior_std from the batch dict before saving
train_batch = {k: v for k, v in train_out.items() if k != 'behavior_std'}

# Build val batch using train statistics for normalization
val_out = build_split_batch(nwb_dataset, 'val', prepend_duration=299, dataset_group_idx=9, bin_width=5, train_std=train_std, behavior_attribute="finger_vel")
val_batch = {k: v for k, v in val_out.items() if k != 'behavior_std'}

# Persist
dict_to_h5(train_batch, 
           '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_area2bump_prepend_train.h5')
dict_to_h5(val_batch, 
           '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_area2bump_prepend_val.h5')

# Optionally print the computed train std
print('Train behavior per-axis std:', train_std)

Dataset already at 5 ms resolution, skipping resampling...
Dataset already at 5 ms resolution, skipping resampling...
Dataset already at 5 ms resolution, skipping resampling...


  - Saved 'neural_input' with shape (810, 180, 130) and dtype int8
  - Saved 'behavior_input' with shape (810, 180, 2) and dtype float32
  - Saved 'mask' with shape (810, 180) and dtype int8
  - Saved 'dataset_group_idx' with shape (810,) and dtype int8
  - Saved 'trial_idx' with shape (810,) and dtype int32
Successfully saved all data from 'batch' to /cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_area2bump_prepend_train.h5
  - Saved 'neural_input' with shape (270, 180, 130) and dtype int8
  - Saved 'behavior_input' with shape (270, 180, 2) and dtype float32
  - Saved 'mask' with shape (270, 180) and dtype int8
  - Saved 'dataset_group_idx' with shape (270,) and dtype int8
  - Saved 'trial_idx' with shape (270,) and dtype int32
Successfully saved all data from 'batch' to /cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_area2bump_prepend_val.h5
Train behavior per-axis std: [73.42055  67.324356]
Train behavior per-axis std: [73.42

# Area2 Bump, Prepended

In [14]:
prepend_duration = 299  # ms to prepend to each trial


d = datasets[2]
dataset_name = d['name']
dataset_subpath = d['subpath']
train_trial_split = ['train', 'val']

raw_data_path = os.path.join(dataset_folder, 'raw', 'dandi', dataset_subpath) 

# nwb_dataset = NWBDataset(raw_data_path, split_heldout=False)
# bin_width = 5
# nwb_dataset.resample(bin_width)
# suffix = '' if (bin_width == 5) else f'_{int(round(bin_width))}'
# train_batch = []

# Build train split, compute train std, then build val split using the train_std
train_out = build_split_batch(nwb_dataset, 'train', prepend_duration=299, dataset_group_idx=9, bin_width=5, behavior_attribute="hand_vel", start_field='move_onset_time')
train_std = train_out.get('behavior_std')
# Remove behavior_std from the batch dict before saving
train_batch = {k: v for k, v in train_out.items() if k != 'behavior_std'}

# Build val batch using train statistics for normalization
val_out = build_split_batch(nwb_dataset, 'val', prepend_duration=299, dataset_group_idx=9, bin_width=5, train_std=train_std, behavior_attribute="hand_vel", start_field='move_onset_time')
val_batch = {k: v for k, v in val_out.items() if k != 'behavior_std'}

# Build val batch using train statistics for normalization
all_out = build_split_batch(nwb_dataset, ['train','val'], prepend_duration=299, dataset_group_idx=9, bin_width=5, train_std=train_std, behavior_attribute="hand_vel", start_field='move_onset_time')
all_batch = {k: v for k, v in all_out.items() if k != 'behavior_std'}

# Persist
dict_to_h5(train_batch, 
           '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_area2bump_prepend_train.h5')
dict_to_h5(val_batch, 
           '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_area2bump_prepend_val.h5')
dict_to_h5(all_batch, 
           '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_area2bump_prepend_all.h5')


# Optionally print the computed train std
print('Train behavior per-axis std:', train_std)

Dataset already at 5 ms resolution, skipping resampling...
Dataset already at 5 ms resolution, skipping resampling...
Dataset already at 5 ms resolution, skipping resampling...


  - Saved 'neural_input' with shape (272, 872, 65) and dtype int8
  - Saved 'behavior_input' with shape (272, 872, 2) and dtype float32
  - Saved 'mask' with shape (272, 872) and dtype int8
  - Saved 'dataset_group_idx' with shape (272,) and dtype int8
  - Saved 'trial_idx' with shape (272,) and dtype int32
Successfully saved all data from 'batch' to /cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_area2bump_prepend_train.h5
  - Saved 'neural_input' with shape (92, 644, 65) and dtype int8
  - Saved 'behavior_input' with shape (92, 644, 2) and dtype float32
  - Saved 'mask' with shape (92, 644) and dtype int8
  - Saved 'dataset_group_idx' with shape (92,) and dtype int8
  - Saved 'trial_idx' with shape (92,) and dtype int32
Successfully saved all data from 'batch' to /cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_area2bump_prepend_val.h5
  - Saved 'neural_input' with shape (364, 872, 65) and dtype int8
  - Saved 'behavior_input' 

In [None]:
# check data variance
beh = train_batch['behavior_input'][:, prepend_duration // bin_width + 1:, :]
mask = train_batch['mask'][:, prepend_duration // bin_width + 1:].reshape(-1)
beh = np.array(beh.reshape(-1, beh.shape[-1])[mask])
np.std(beh, axis=0)

In [12]:
nwb_dataset.trial_info

Unnamed: 0,trial_id,start_time,end_time,move_onset_time,split,result,ctr_hold,ctr_hold_bump,bump_dir,target_on_time,target_dir,go_cue_time,bump_time,cond_dir
0,0,0 days 00:00:00,0 days 00:00:00.600000,0 days 00:00:00.100000,test,,,,,NaT,,NaT,NaT,
1,1,0 days 00:00:00.700000,0 days 00:00:01.300000,0 days 00:00:00.800000,test,,,,,NaT,,NaT,NaT,
2,2,0 days 00:00:01.400000,0 days 00:00:02,0 days 00:00:01.500000,test,,,,,NaT,,NaT,NaT,
3,3,0 days 00:00:02.100000,0 days 00:00:02.700000,0 days 00:00:02.200000,test,,,,,NaT,,NaT,NaT,
4,4,0 days 00:00:02.800000,0 days 00:00:03.400000,0 days 00:00:02.900000,test,,,,,NaT,,NaT,NaT,
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
919,919,0 days 00:37:55.997000,0 days 00:37:58.982000,0 days 00:37:57.457000,train,R,0.951575,True,180.0,0 days 00:37:58.380000,135.0,0 days 00:37:58.381000,0 days 00:37:57.462000,180.0
920,920,0 days 00:37:59.485000,0 days 00:38:02.696000,0 days 00:38:02.314000,train,R,1.031749,False,,0 days 00:38:02.002000,45.0,0 days 00:38:02.003000,NaT,45.0
921,921,0 days 00:38:03.199000,0 days 00:38:05.391000,0 days 00:38:04.291000,train,R,0.585132,True,315.0,0 days 00:38:04.748000,45.0,0 days 00:38:04.749000,0 days 00:38:04.276000,315.0
922,922,0 days 00:38:05.893000,0 days 00:38:09.100000,0 days 00:38:07.371000,none,I,1.283537,True,180.0,0 days 00:38:08.095000,225.0,0 days 00:38:08.096000,0 days 00:38:07.371000,180.0


In [10]:
dataset_dict = make_train_input_tensors(nwb_dataset, dataset_name=dataset_name, trial_split=['train'], save_file=False, save_path='processed_data_path', include_forward_pred=True, include_behavior=True)


KeyError: 'heldout_spikes'