In [5]:
import os
import sys
import warnings
import logging
import torch

# Suppress warnings and logging
warnings.filterwarnings('ignore')
logging.disable(logging.WARNING)

# Core imports
from foundational_ssm.utils import get_train_val_loaders, get_dataset_config
from foundational_ssm.data_preprocessing import bin_spikes, map_binned_features_to_global, smooth_spikes
# from foundational_ssm.models import SSMFoundational
from foundational_ssm.loss import CombinedLoss
from foundational_ssm.metrics import ValidationMetrics

from omegaconf import OmegaConf

from temporaldata import Data
from typing import List, Dict


In [6]:
config_path = "/cs/student/projects1/ml/2024/mlaimon/foundational_ssm/configs/cmt.yaml"
config = OmegaConf.load(config_path) 

# Load dataset
train_dataset, train_loader, val_dataset, val_loader = get_train_val_loaders(
    train_config=get_dataset_config(
        config.dataset.name,
        subjects=config.dataset.subjects
    ),
    batch_size=config.dataset.batch_size
)

In [10]:
import re
import numpy as np

GROUP_DIMS = {
    ("perich_miller_population_2018", "c", "center_out_reaching"): (353, 2),
    ("perich_miller_population_2018", "c", "random_target_reaching"): (88, 2),
    ("perich_miller_population_2018", "m", "center_out_reaching"): (159, 2),
    ("perich_miller_population_2018", "m", "random_target_reaching"): (165, 2),
    ("perich_miller_population_2018", "t", "center_out_reaching"): (65, 2),
    ("perich_miller_population_2018", "t", "random_target_reaching"): (73, 2),
    ("perich_miller_population_2018", "j", "center_out_reaching"): (38, 2),
}

def transform(data: Data, sampling_rate=100, num_timesteps=100, kern_sd_ms=40, bin_width=5) -> Dict:

    unit_ids = data.units.id
    spikes = data.spikes        
    binned_spikes = bin_spikes(
        spikes=spikes,
        num_units=len(unit_ids),
        bin_size=1 / sampling_rate,
        num_bins=num_timesteps  
    ).T  # (timesteps, units)
    
    smoothed_spikes = smooth_spikes(binned_spikes, kern_sd_ms=kern_sd_ms, bin_width=bin_width)

    behavior_input = data.cursor.vel  # (timesteps, features)
    if len(behavior_input) > num_timesteps:
        behavior_input = behavior_input[:num_timesteps]
    if len(behavior_input) < num_timesteps:
        behavior_input = np.pad(behavior_input, ((0, num_timesteps - len(behavior_input)), (0, 0)), mode='constant')
    if behavior_input.shape[1] > 2:
        behavior_input = behavior_input[:, :2]

    # --- Parse session_id for group key ---
    session_id = data.session.id
    # Example: "perich_miller_population_2018/c_20131003_center_out_reaching"
    match = re.match(r'([^/]+)/([^_]+)_[^_]+_(.+)', session_id)
    if match:
        dataset, subject, task = match.groups()
        # Normalize task string to match lookup table (optional: handle underscores/case)
        # task = task.replace("_", " ").title()
        group_key = (dataset, subject, task)
    else:
        raise ValueError(f"Could not parse session_id: {session_id}")

    # --- Get target dims ---
    if group_key not in GROUP_DIMS:
        raise ValueError(f"Group {group_key} not found in GROUP_DIMS")
    neural_dim, output_dim = GROUP_DIMS[group_key]

    # --- Pad/crop binned_spikes ---
    # binned_spikes: (timesteps, units)
    if smoothed_spikes.shape[1] > neural_dim:
        smoothed_spikes = smoothed_spikes[:, :neural_dim]
    elif smoothed_spikes.shape[1] < neural_dim:
        pad_width = neural_dim - smoothed_spikes.shape[1]
        smoothed_spikes = np.pad(smoothed_spikes, ((0, 0), (0, pad_width)), mode='constant')

    # --- Pad/crop behavior_input ---
    if behavior_input.shape[1] > output_dim:
        behavior_input = behavior_input[:, :output_dim]
    elif behavior_input.shape[1] < output_dim:
        pad_width = output_dim - behavior_input.shape[1]
        behavior_input = np.pad(behavior_input, ((0, 0), (0, pad_width)), mode='constant')

    data_dict = {
        "neural_input": torch.tensor(smoothed_spikes, dtype=torch.float32),
        "behavior_input": torch.tensor(behavior_input, dtype=torch.float32),
        "dataset_group_key": group_key
    }
    return data_dict

In [12]:
# Example usage of GroupedRandomFixedWindowSampler
from foundational_ssm.utils.grouped_sampler import GroupedRandomFixedWindowSampler
sampling_intervals = train_dataset.get_sampling_intervals()

# Create the grouped sampler
grouped_sampler = GroupedRandomFixedWindowSampler(
    sampling_intervals=sampling_intervals,
    window_length=1.0,
    batch_size=128,
    generator=torch.Generator().manual_seed(42)
)

# Create dataloader with the grouped sampler
grouped_dataloader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_sampler=grouped_sampler, 
    num_workers=4, 
    pin_memory=True
)
train_dataset.transform = transform

# Print the groups that were created
print("Session groups:")
for group_key, sessions in grouped_sampler.session_groups.items():
    print(f"  {group_key}: {len(sessions)} sessions")
    for session in sessions[:3]:  # Show first 3 sessions
        print(f"    - {session}")
    if len(sessions) > 3:
        print(f"    ... and {len(sessions) - 3} more")
    print()
    
# Test the dataloader
print("Testing grouped dataloader:")
for i, batch in enumerate(grouped_dataloader):
    print(batch["dataset_group_key"], batch["neural_input"].shape)
    # if i > 10:
    #     break

Session groups:
  perich_miller_population_2018/c_center_out_reaching: 53 sessions
    - perich_miller_population_2018/c_20131003_center_out_reaching
    - perich_miller_population_2018/c_20131022_center_out_reaching
    - perich_miller_population_2018/c_20131023_center_out_reaching
    ... and 50 more

  perich_miller_population_2018/c_random_target_reaching: 15 sessions
    - perich_miller_population_2018/c_20131009_random_target_reaching
    - perich_miller_population_2018/c_20131010_random_target_reaching
    - perich_miller_population_2018/c_20131011_random_target_reaching
    ... and 12 more

  perich_miller_population_2018/m_random_target_reaching: 6 sessions
    - perich_miller_population_2018/m_20140114_random_target_reaching
    - perich_miller_population_2018/m_20140115_random_target_reaching
    - perich_miller_population_2018/m_20140116_random_target_reaching
    ... and 3 more

  perich_miller_population_2018/m_center_out_reaching: 22 sessions
    - perich_miller_populati

KeyboardInterrupt: 

In [4]:
raw = "../data/foundational_ssm/raw/perich_miller_population_2018/000688/sub-C/sub-C_ses-CO-20160921_behavior+ecephys.nwb"


perich_miller_population_2018/c_20131003_center_out_reaching 38
perich_miller_population_2018/c_20131009_random_target_reaching 30
perich_miller_population_2018/c_20131010_random_target_reaching 36
perich_miller_population_2018/c_20131011_random_target_reaching 32
perich_miller_population_2018/c_20131022_center_out_reaching 33
perich_miller_population_2018/c_20131023_center_out_reaching 40
perich_miller_population_2018/c_20131028_random_target_reaching 32
perich_miller_population_2018/c_20131029_random_target_reaching 34
perich_miller_population_2018/c_20131031_center_out_reaching 50
perich_miller_population_2018/c_20131101_center_out_reaching 57
perich_miller_population_2018/c_20131203_center_out_reaching 39
perich_miller_population_2018/c_20131204_center_out_reaching 34
perich_miller_population_2018/c_20131209_random_target_reaching 39
perich_miller_population_2018/c_20131210_random_target_reaching 39
perich_miller_population_2018/c_20131212_random_target_reaching 38
perich_miller_po

In [20]:
from torch_brain.data.dataset import DatasetIndex
from collections import defaultdict

def get_session_indices(sampling_intervals, window_length, step=None):
    step = step or window_length
    indices_per_session = defaultdict(list)
    for session_name, intervals in sampling_intervals.items():
        for start, end in zip(intervals.start, intervals.end):
            for t in torch.arange(start, end, step, dtype=torch.float64):
                if t + window_length <= end:
                    indices_per_session[session_name].append(DatasetIndex(session_name, t.item(), (t + window_length).item()))
    return indices_per_session

indices_per_session = get_session_indices(sampling_intervals, 1, 1)
batch_sampler = SessionBatchSampler(indices_per_session, batch_size=16)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler, num_workers=4, pin_memory=True)

In [21]:
from foundational_ssm.data_preprocessing import bin_spikes, map_binned_features_to_global

sampling_rate = 100 
num_timesteps = int(1.0 * sampling_rate)
num_neural_features = 64

def transform_batch(data):
    unit_ids = data.units.id
    spikes = data.spikes        
    binned_spikes = bin_spikes(
        spikes=spikes,
        num_units=len(unit_ids),
        bin_size= 1 / sampling_rate,
        num_bins=num_timesteps  
    ).T
    neural_input = map_binned_features_to_global(
        session_binned_features=binned_spikes,
        session_unit_id_strings=unit_ids,
        max_global_units=num_neural_features
    ) # (N_timesteps, N_global_units)
    
    behavior_input = data.cursor.vel # (N_timesteps, N_behavior_features)        

    data_dict = {
        "neural_input": torch.tensor(neural_input, dtype=torch.float32),
        "behavior_input": torch.tensor(behavior_input, dtype=torch.float32),
        "session_id": data.session.id,
        "subject_id": data.subject.id
    }
    return data_dict


In [22]:
for batch in dataloader:
    print(batch)
    break

ValueError: too many dimensions 'str'