In [13]:
from nlb_tools.nwb_interface import NWBDataset
from nlb_tools.make_tensors import make_train_input_tensors, make_eval_input_tensors, make_eval_target_tensors, save_to_h5
from nlb_tools.evaluation import evaluate
import os
import pandas as pd
import numpy as np
from foundational_ssm.collate import pad_collate


dataset_folder = '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/' 

datasets = [
    {'name':'mc_maze', 'subpath':'./000128/sub-Jenkins/'},
    {'name':'mc_rtt', 'subpath':'./000129/sub-Indy/'},
    {'name':'area2_bump', 'subpath':'./000127/sub-Han/'},
    {'name':'dmfc_rsg', 'subpath':'./000130/sub-Haydn/'},
]

In [2]:
d = datasets[1]
dataset_name = d['name']
dataset_subpath = d['subpath']
train_trial_split = ['train', 'val']

raw_data_path = os.path.join(dataset_folder, 'raw', 'dandi', dataset_subpath) 
processed_data_folder = os.path.join(dataset_folder, 'processed', 'nlb')
processed_data_path = os.path.join(processed_data_folder, dataset_name + '.h5')
trial_info_path = os.path.join(processed_data_folder, dataset_name + '.csv')

if not os.path.exists(processed_data_folder):
    print(f"Creating directory: {processed_data_folder}")
    os.makedirs(processed_data_folder, exist_ok=True)

nwb_dataset = NWBDataset(raw_data_path, split_heldout=False)


# Find when target pos changes
has_change = nwb_dataset.data.target_pos.fillna(-1000).diff(axis=0).any(axis=1) # filling NaNs with arbitrary scalar to treat as one block
# Find if target pos change corresponds to NaN-padded gap between files
change_nan = nwb_dataset.data[has_change].isna().any(axis=1)
# Drop trials containing the gap and immediately before and after, as those trials may be cut short
drop_trial = (change_nan | change_nan.shift(1, fill_value=True) | change_nan.shift(-1, fill_value=True))[:-1]
# Add start and end times to trial info
change_times = nwb_dataset.data.index[has_change]
start_times = change_times[:-1][~drop_trial]
end_times = change_times[1:][~drop_trial]
# Get target position per trial
start_pos = nwb_dataset.data.target_pos.loc[start_times - pd.Timedelta(1, 'ms')].to_numpy().tolist()
target_pos = nwb_dataset.data.target_pos.loc[start_times].to_numpy().tolist()
# Compute reach distance and angle
reach_dist = nwb_dataset.data.target_pos.loc[end_times - pd.Timedelta(1, 'ms')].to_numpy() - nwb_dataset.data.target_pos.loc[start_times - pd.Timedelta(1, 'ms')].to_numpy()
reach_angle = np.arctan2(reach_dist[:, 1], reach_dist[:, 0]) / np.pi * 180
# Create trial info
nwb_dataset.trial_info = pd.DataFrame({
    'trial_id': np.arange(len(start_times)),
    'start_time': start_times,
    'end_time': end_times,
    'duration': (end_times - start_times).total_seconds(),
    'start_pos': start_pos,
    'target_pos': target_pos,
    'reach_dist_x': reach_dist[:, 0],
    'reach_dist_y': reach_dist[:, 1],
    'reach_angle': reach_angle,
})
nwb_dataset.resample(5)

In [3]:
nwb_dataset.trial_info.to_csv('/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized.csv')

In [4]:
import torch
import sys 

DTYPE_FLOAT = np.float32 
def pad_collate(batch, fixed_seq_len=None):
    # Assume batch is a list of dicts with keys: 'neural_input', 'behavior_input', etc.
    # Each 'neural_input' is a tensor of shape (timesteps, units)
    neural_inputs = [item['neural_input'] for item in batch if item is not None]  # (timesteps, units)
    behavioral_inputs = [item['behavior_input'] for item in batch if item is not None]
    
    # Determine the fixed sequence length
    if fixed_seq_len is None:
        max_len = max(x.shape[0] for x in neural_inputs)
    else:
        max_len = fixed_seq_len

    # Pad or truncate each sequence to fixed length
    def pad_or_truncate(tensor, max_len):
        seq_len = tensor.shape[0]
        if seq_len == max_len:
            return tensor
        elif seq_len > max_len:
            return tensor[:max_len]
        else:
            pad_shape = (max_len - seq_len,) + tensor.shape[1:]
            pad_tensor = torch.zeros(pad_shape, dtype=tensor.dtype, device=tensor.device)
            return torch.cat([tensor, pad_tensor], dim=0)

    padded_neural = torch.stack([pad_or_truncate(x, max_len) for x in neural_inputs if x is not None])  # (batch, max_len, units)
    padded_behavior = torch.stack([pad_or_truncate(x, max_len) for x in behavioral_inputs if x is not None])

    # Create mask: 1 for real data, 0 for padding
    lengths = [x.shape[0] for x in neural_inputs]
    mask = torch.zeros((len(batch), max_len), dtype=torch.bool)
    for i, l in enumerate(lengths):
        mask[i, :min(l, max_len)] = 1

    # Stack other fields (e.g., dataset_group_idx)
    dataset_group_idx = torch.stack([item['dataset_group_idx'] for item in batch])
    
    return {
        'neural_input': padded_neural,
        'behavior_input': padded_behavior,
        'mask': mask,
        'dataset_group_idx': dataset_group_idx,
        # add other fields as needed
    }
    
batch = []
max_len = 0
for i, row in nwb_dataset.trial_info.iterrows():
    start = row['start_time']
    end = row['end_time']
    # Load data slices. This is where memory efficiency of nwb_dataset is key.
    behavior_data = nwb_dataset.data.finger_vel[start:end].to_numpy()
    neural_data = nwb_dataset.data.spikes[start:end].to_numpy(dtype=np.int8)
    cursor_pos = nwb_dataset.data.cursor_pos[start:end].to_numpy()

    # Convert to torch tensors with desired dtype immediately.
    batch.append({
        'behavior_input': torch.from_numpy(behavior_data),
        'neural_input': torch.from_numpy(neural_data),
        'cursor_pos': torch.from_numpy(cursor_pos),
        'dataset_group_idx': torch.tensor(9, dtype=torch.int8) # Ensure it's a torch.Tensor
    })
    if max_len < behavior_data.shape[0]:
        max_len = behavior_data.shape[0]
    
    # break
batch = pad_collate(batch)

In [None]:
import h5py
import torch 
def dict_to_h5(tensor_dict, output_h5_file):
    with h5py.File(output_h5_file, 'w') as f:
        # Iterate through the items in your 'batch' dictionary
        for key, value in tensor_dict.items():
            # Convert PyTorch tensor to NumPy array before saving
            # Ensure data type consistency for saving
            if torch.is_tensor(value):
                # For boolean tensors, convert to int8 if you want them to take less space in HDF5
                # (np.bool_ usually takes 1 byte, but some systems/HDF5 viewers prefer integer)
                if value.dtype == torch.bool:
                    data_to_save = value.to(torch.int8).numpy()
                else:
                    data_to_save = value.numpy()
            else:
                data_to_save = value # If any value isn't a tensor (e.g., a simple scalar or list)

            f.create_dataset(key, data=data_to_save)
            print(f"  - Saved '{key}' with shape {data_to_save.shape} and dtype {data_to_save.dtype}")

        print(f"Successfully saved all data from 'batch' to {output_h5_file}")

output_h5_file = '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_trialized.h5'
dict_to_h5(batch, output_h5_file)


# Original Trials

In [None]:
d = datasets[1]
dataset_name = d['name']
dataset_subpath = d['subpath']
train_trial_split = ['train', 'val']

raw_data_path = os.path.join(dataset_folder, 'raw', 'dandi', dataset_subpath) 

nwb_dataset = NWBDataset(raw_data_path)
bin_width = 5
nwb_dataset.resample(bin_width)
suffix = '' if (bin_width == 5) else f'_{int(round(bin_width))}'

train_dataset_dict = make_train_input_tensors(nwb_dataset, dataset_name=dataset_name, trial_split='train', save_file=True, save_path=processed_data_path, include_forward_pred=True, include_behavior=True)
val_dataset_dict = make_train_input_tensors(nwb_dataset, dataset_name=dataset_name, trial_split='val', save_file=True, save_path=processed_data_path, include_forward_pred=True, include_behavior=True)

train_dataset_dict['neural_input'] = np.concatenate([train_dataset_dict['train_spikes_heldin'], train_dataset_dict['train_spikes_heldout']], axis=2)
train_dataset_dict['behavior_input'] = train_dataset_dict['train_behavior']
train_dataset_dict['mask'] = np.ones(train_dataset_dict['neural_input'].shape[:2])

val_dataset_dict['neural_input'] = np.concatenate([val_dataset_dict['train_spikes_heldin'], val_dataset_dict['train_spikes_heldout']], axis=2)
val_dataset_dict['behavior_input'] = val_dataset_dict['train_behavior']
val_dataset_dict['mask'] = np.ones(val_dataset_dict['neural_input'].shape[:2])

dict_to_h5(train_dataset_dict, 
           '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_not_trialized_train.h5')
dict_to_h5(val_dataset_dict, 
           '/cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_not_trialized_val.h5')

  - Saved 'train_spikes_heldin' with shape (810, 120, 98) and dtype float16
  - Saved 'train_spikes_heldout' with shape (810, 120, 32) and dtype float16
  - Saved 'train_behavior' with shape (810, 120, 2) and dtype float64
  - Saved 'train_spikes_heldin_forward' with shape (810, 40, 98) and dtype float16
  - Saved 'train_spikes_heldout_forward' with shape (810, 40, 32) and dtype float16
  - Saved 'neural_input' with shape (810, 120, 130) and dtype float16
  - Saved 'behavior_input' with shape (810, 120, 2) and dtype float64
  - Saved 'mask' with shape (810, 120) and dtype float64
Successfully saved all data from 'batch' to /cs/student/projects1/ml/2024/mlaimon/data/foundational_ssm/processed/nlb/mc_rtt_not_trialized_train.h5
  - Saved 'train_spikes_heldin' with shape (270, 120, 98) and dtype float16
  - Saved 'train_spikes_heldout' with shape (270, 120, 32) and dtype float16
  - Saved 'train_behavior' with shape (270, 120, 2) and dtype float64
  - Saved 'train_spikes_heldin_forward' wi