In [1]:
import os
import sys
import warnings
import logging
import torch

# Suppress warnings and logging
warnings.filterwarnings('ignore')
logging.disable(logging.WARNING)

# Core imports
from foundational_ssm.utils import get_dataset_config
from foundational_ssm.data_utils import get_brainset_train_val_loaders
from foundational_ssm.models.foundational import SSMFoundational

from omegaconf import OmegaConf

from temporaldata import Data
from typing import List, Dict
import time

%load_ext autoreload
%autoreload 2


In [2]:
config_path = "/cs/student/projects1/ml/2024/mlaimon/foundational_ssm/configs/pretrain.yaml"
cfg = OmegaConf.load(config_path) 

# Load dataset
train_dataset, train_loader, val_dataset, val_loader = get_brainset_train_val_loaders(
    train_config=get_dataset_config(
        cfg.train_dataset.name,
        subjects=cfg.train_dataset.subjects
    ),
    val_config=get_dataset_config(
        cfg.val_dataset.name,
        subjects=cfg.val_dataset.subjects
    ),
    batch_size=cfg.train_dataset.batch_size,
    num_workers=0
)

In [12]:
import jax 
import jax.numpy as jnp
import jax.random as jr
import equinox as eqx
from jax.tree_util import tree_map, tree_flatten_with_path
import optax
from tqdm import tqdm


key = jr.PRNGKey(cfg.rng_seed)
model_key, train_key = jr.split(key, 2)

model = SSMFoundational(
        ssm_io_dim = cfg.model.ssm_io_dim,
        ssm_dim = cfg.model.ssm_dim,
        ssm_init_diag_blocks = cfg.model.ssm_init_diag_blocks,
        ssm_num_layers = cfg.model.ssm_num_layers,
        output_dim = cfg.model.output_dim,
        key = model_key,
    )
state = eqx.nn.State(model)

In [None]:
import numpy as np
import torch
from typing import Any, Dict
from foundational_ssm.data_utils.dataset import TorchBrainDataset
from foundational_ssm.constants import DATA_ROOT,MAX_NEURAL_INPUT_DIM, MAX_BEHAVIOR_INPUT_DIM
import os
from foundational_ssm.data_utils.loaders import bin_spikes, smooth_spikes, parse_session_id, DATASET_GROUP_TO_IDX

In [12]:
import numpy as np
import torch
from typing import Any, Dict
from foundational_ssm.data_utils.dataset import TorchBrainDataset
from foundational_ssm.constants import DATA_ROOT,MAX_NEURAL_INPUT_DIM, MAX_BEHAVIOR_INPUT_DIM
import os
from foundational_ssm.data_utils.loaders import bin_spikes, smooth_spikes, parse_session_id, DATASET_GROUP_TO_IDX

def transform_brainsets_to_fixed_dim_samples(
    data: Any,
    sampling_rate: int = 100,
    sampling_window_ms: int = 1000
) -> Dict[str, torch.Tensor | str]:
    """Convert a *temporaldata* sample to a dictionary of Torch tensors.

    The function takes care of binning & smoothing spikes, cropping/padding neural
    and behavioural features to a globally consistent dimensionality that depends
    on the *(dataset, subject, task)* triple.

    Parameters
    ----------
    data: temporaldata.Data
        Sample returned by **torch-brain**/**temporaldata**.
    sampling_rate: int, default=100
        Target sampling rate *Hz* used for binning.
    sampling_window_ms: int, default=1000   
        Length of the temporal window after binning.
    kern_sd_ms: int, default=20
        Standard deviation of the Gaussian kernel (in ms) for smoothing spikes.

    Returns
    -------
    Dict[str, torch.Tensor]
        Dictionary with keys ``neural_input``, ``behavior_input``, ``session_id``
        and ``subject_id``.
    """
    def _ensure_dim(arr: np.ndarray, target_dim: int, pad_value: float = 0.0, *, axis: int = 1) -> np.ndarray:
        """
        Crop or pad `arr` along `axis` to match `target_dim`, right-aligning the original data.
        Pads with `pad_value` if needed.
        """
        current_dim = arr.shape[axis]
        
        # Pad if too small
        shape = list(arr.shape)
        shape[axis] = target_dim
        result = np.full(shape, pad_value, dtype=arr.dtype)
        
        # Right-align: place arr at the end along the axis
        idx = [slice(None)] * arr.ndim
        idx[axis] = slice(-current_dim, None)
        result[tuple(idx)] = arr
        return result
    
    num_timesteps = int(sampling_rate * sampling_window_ms / 1000)
    
    # ------------------------------------------------------------------
    # 1. Bin + smooth spikes
    # ------------------------------------------------------------------
    smoothed_spikes = data.smoothed_spikes.smoothed_spikes

    # ------------------------------------------------------------------
    # 2. Prepare behaviour signal (cursor velocity)
    # ------------------------------------------------------------------
    behavior_input = data.cursor.vel  # np.ndarray, (timesteps?, features)


    # ------------------------------------------------------------------
    # 3. Align channel dimensions based on (dataset, subject, task)
    # ------------------------------------------------------------------
    # smoothed_spikes = _ensure_dim(smoothed_spikes, MAX_NEURAL_INPUT_DIM, axis=1)
    # behavior_input = _ensure_dim(behavior_input, MAX_BEHAVIOR_INPUT_DIM, axis=1)
    # smoothed_spikes = _ensure_dim(smoothed_spikes, num_timesteps, axis=0)
    # behavior_input = _ensure_dim(behavior_input, num_timesteps, axis=0)
    smoothed_spikes = smoothed_spikes[:80, :10]
    behavior_input = behavior_input[:80, :2]

    # ------------------------------------------------------------------
    # 4. Pack into torch tensors
    # ------------------------------------------------------------------
    # dataset, subject, task = parse_session_id(data.session.id)
    # group_tuple = (dataset, subject, task)
    # group_idx = DATASET_GROUP_TO_IDX[group_tuple]

    return {
        "neural_input": torch.as_tensor(smoothed_spikes, dtype=torch.float32),
        "behavior_input": torch.as_tensor(behavior_input, dtype=torch.float32),
        # "dataset_group_idx": torch.as_tensor(group_idx, dtype=torch.int32),
    }
train_dataset.transform = transform_brainsets_to_fixed_dim_samples

In [6]:
%lprun -f _ensure_dim _ensure_dim(arr, target_dim, axis=1)

UsageError: Line magic function `%lprun` not found.


In [13]:
# time how long it takes to draw one batch
start_time = time.time()
for i, batch in enumerate(train_loader):
    break
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")

Time taken: 17.20944857597351 seconds


In [1]:
def transform_brainsets_to_fixed_dim_samples(
    data: Any,
    sampling_rate: int = 100,
    sampling_window_ms: int = 1000
) -> Dict[str, torch.Tensor | str]:
    """Convert a *temporaldata* sample to a dictionary of Torch tensors.

    The function takes care of binning & smoothing spikes, cropping/padding neural
    and behavioural features to a globally consistent dimensionality that depends
    on the *(dataset, subject, task)* triple.

    Parameters
    ----------
    data: temporaldata.Data
        Sample returned by **torch-brain**/**temporaldata**.
    sampling_rate: int, default=100
        Target sampling rate *Hz* used for binning.
    sampling_window_ms: int, default=1000   
        Length of the temporal window after binning.
    kern_sd_ms: int, default=20
        Standard deviation of the Gaussian kernel (in ms) for smoothing spikes.

    Returns
    -------
    Dict[str, torch.Tensor]
        Dictionary with keys ``neural_input``, ``behavior_input``, ``session_id``
        and ``subject_id``.
    """
    def _ensure_dim(arr: np.ndarray, target_dim: int, *, axis: int = 1) -> np.ndarray:
        """Crop or zero-pad *arr* along *axis* to match *target_dim*.

        This is a thin wrapper around :pymod:`numpy` slicing and :func:`numpy.pad` that
        avoids several conditional blocks in the main routine.
        """
        current_dim = arr.shape[axis]
        if current_dim == target_dim:
            return arr  # nothing to do
        if current_dim > target_dim:
            # Crop
            slicer = [slice(None)] * arr.ndim
            slicer[axis] = slice(None, target_dim)
            return arr[tuple(slicer)]
        # Pad (current_dim < target_dim)
        pad_width = [(0, 0)] * arr.ndim
        pad_width[axis] = (0, target_dim - current_dim)
        return np.pad(arr, pad_width, mode="constant")
    
    num_timesteps = int(sampling_rate * sampling_window_ms / 1000)
    
    # ------------------------------------------------------------------
    # 1. Bin + smooth spikes
    # ------------------------------------------------------------------
    smoothed_spikes = data.smoothed_spikes.smoothed_spikes

    # ------------------------------------------------------------------
    # 2. Prepare behaviour signal (cursor velocity)
    # ------------------------------------------------------------------
    behavior_input = data.cursor.vel  # np.ndarray, (timesteps?, features)


    # ------------------------------------------------------------------
    # 3. Align channel dimensions based on (dataset, subject, task)
    # ------------------------------------------------------------------
    # smoothed_spikes = _ensure_dim(smoothed_spikes, MAX_NEURAL_INPUT_DIM, axis=1)
    # behavior_input = _ensure_dim(behavior_input, MAX_BEHAVIOR_INPUT_DIM, axis=1)
    # smoothed_spikes = _ensure_dim(smoothed_spikes, num_timesteps, axis=0)
    # behavior_input = _ensure_dim(behavior_input, num_timesteps, axis=0)

    # ------------------------------------------------------------------
    # 4. Pack into torch tensors
    # ------------------------------------------------------------------
    dataset, subject, task = parse_session_id(data.session.id)
    group_tuple = (dataset, subject, task)
    group_idx = DATASET_GROUP_TO_IDX[group_tuple]

    return {
        "neural_input": torch.as_tensor(smoothed_spikes, dtype=torch.float32),
        "behavior_input": torch.as_tensor(behavior_input, dtype=torch.float32),
        "dataset_group_idx": torch.as_tensor(group_idx, dtype=torch.int32),
    }
train_dataset.transform = transform_brainsets_to_fixed_dim_samples

NameError: name 'Any' is not defined

In [22]:
# time how long it takes to draw one batch
import time
start_time = time.time()
for i, batch in enumerate(train_loader):
    break
end_time = time.time()
print(f"Time taken: {end_time - start_time} seconds")

Time taken: 24.671035289764404 seconds


In [10]:
import re
import numpy as np

GROUP_DIMS = {
    ("perich_miller_population_2018", "c", "center_out_reaching"): (353, 2),
    ("perich_miller_population_2018", "c", "random_target_reaching"): (88, 2),
    ("perich_miller_population_2018", "m", "center_out_reaching"): (159, 2),
    ("perich_miller_population_2018", "m", "random_target_reaching"): (165, 2),
    ("perich_miller_population_2018", "t", "center_out_reaching"): (65, 2),
    ("perich_miller_population_2018", "t", "random_target_reaching"): (73, 2),
    ("perich_miller_population_2018", "j", "center_out_reaching"): (38, 2),
}


In [12]:
# Example usage of GroupedRandomFixedWindowSampler
from foundational_ssm.utils.grouped_sampler import GroupedRandomFixedWindowSampler
sampling_intervals = train_dataset.get_sampling_intervals()

# Create the grouped sampler
grouped_sampler = GroupedRandomFixedWindowSampler(
    sampling_intervals=sampling_intervals,
    window_length=1.0,
    batch_size=128,
    generator=torch.Generator().manual_seed(42)
)

# Create dataloader with the grouped sampler
grouped_dataloader = torch.utils.data.DataLoader(
    train_dataset, 
    batch_sampler=grouped_sampler, 
    num_workers=4, 
    pin_memory=True
)
train_dataset.transform = transform

# Print the groups that were created
print("Session groups:")
for group_key, sessions in grouped_sampler.session_groups.items():
    print(f"  {group_key}: {len(sessions)} sessions")
    for session in sessions[:3]:  # Show first 3 sessions
        print(f"    - {session}")
    if len(sessions) > 3:
        print(f"    ... and {len(sessions) - 3} more")
    print()
    
# Test the dataloader
print("Testing grouped dataloader:")
for i, batch in enumerate(grouped_dataloader):
    print(batch["dataset_group_key"], batch["neural_input"].shape)
    if i > 10:
        break

Session groups:
  perich_miller_population_2018/c_center_out_reaching: 53 sessions
    - perich_miller_population_2018/c_20131003_center_out_reaching
    - perich_miller_population_2018/c_20131022_center_out_reaching
    - perich_miller_population_2018/c_20131023_center_out_reaching
    ... and 50 more

  perich_miller_population_2018/c_random_target_reaching: 15 sessions
    - perich_miller_population_2018/c_20131009_random_target_reaching
    - perich_miller_population_2018/c_20131010_random_target_reaching
    - perich_miller_population_2018/c_20131011_random_target_reaching
    ... and 12 more

  perich_miller_population_2018/m_random_target_reaching: 6 sessions
    - perich_miller_population_2018/m_20140114_random_target_reaching
    - perich_miller_population_2018/m_20140115_random_target_reaching
    - perich_miller_population_2018/m_20140116_random_target_reaching
    ... and 3 more

  perich_miller_population_2018/m_center_out_reaching: 22 sessions
    - perich_miller_populati

KeyboardInterrupt: 

In [4]:
raw = "../data/foundational_ssm/raw/perich_miller_population_2018/000688/sub-C/sub-C_ses-CO-20160921_behavior+ecephys.nwb"


perich_miller_population_2018/c_20131003_center_out_reaching 38
perich_miller_population_2018/c_20131009_random_target_reaching 30
perich_miller_population_2018/c_20131010_random_target_reaching 36
perich_miller_population_2018/c_20131011_random_target_reaching 32
perich_miller_population_2018/c_20131022_center_out_reaching 33
perich_miller_population_2018/c_20131023_center_out_reaching 40
perich_miller_population_2018/c_20131028_random_target_reaching 32
perich_miller_population_2018/c_20131029_random_target_reaching 34
perich_miller_population_2018/c_20131031_center_out_reaching 50
perich_miller_population_2018/c_20131101_center_out_reaching 57
perich_miller_population_2018/c_20131203_center_out_reaching 39
perich_miller_population_2018/c_20131204_center_out_reaching 34
perich_miller_population_2018/c_20131209_random_target_reaching 39
perich_miller_population_2018/c_20131210_random_target_reaching 39
perich_miller_population_2018/c_20131212_random_target_reaching 38
perich_miller_po

In [20]:
from torch_brain.data.dataset import DatasetIndex
from collections import defaultdict

def get_session_indices(sampling_intervals, window_length, step=None):
    step = step or window_length
    indices_per_session = defaultdict(list)
    for session_name, intervals in sampling_intervals.items():
        for start, end in zip(intervals.start, intervals.end):
            for t in torch.arange(start, end, step, dtype=torch.float64):
                if t + window_length <= end:
                    indices_per_session[session_name].append(DatasetIndex(session_name, t.item(), (t + window_length).item()))
    return indices_per_session

indices_per_session = get_session_indices(sampling_intervals, 1, 1)
batch_sampler = SessionBatchSampler(indices_per_session, batch_size=16)
dataloader = torch.utils.data.DataLoader(train_dataset, batch_sampler=batch_sampler, num_workers=4, pin_memory=True)

In [21]:
from foundational_ssm.data_preprocessing import bin_spikes, map_binned_features_to_global

sampling_rate = 100 
num_timesteps = int(1.0 * sampling_rate)
num_neural_features = 64

def transform_batch(data):
    unit_ids = data.units.id
    spikes = data.spikes        
    binned_spikes = bin_spikes(
        spikes=spikes,
        num_units=len(unit_ids),
        bin_size= 1 / sampling_rate,
        num_bins=num_timesteps  
    ).T
    neural_input = map_binned_features_to_global(
        session_binned_features=binned_spikes,
        session_unit_id_strings=unit_ids,
        max_global_units=num_neural_features
    ) # (N_timesteps, N_global_units)
    
    behavior_input = data.cursor.vel # (N_timesteps, N_behavior_features)        

    data_dict = {
        "neural_input": torch.tensor(neural_input, dtype=torch.float32),
        "behavior_input": torch.tensor(behavior_input, dtype=torch.float32),
        "session_id": data.session.id,
        "subject_id": data.subject.id
    }
    return data_dict


In [22]:
for batch in dataloader:
    print(batch)
    break

ValueError: too many dimensions 'str'

In [2]:
import torch
from torch.utils.data.dataloader import default_collate
from jax.tree_util import tree_map


def jax_collate_fn(batch):
    """
    Collate function that converts all torch.Tensors in a batch (dict or list of dicts)
    to numpy arrays, recursively.
    """
    collated = default_collate(batch)
    return tree_map(
        lambda x: x.numpy() if isinstance(x, torch.Tensor) else x,
        collated
    )

In [3]:
batch

NameError: name 'batch' is not defined