In [1]:
from torch_brain.data import Dataset, collate, chain
from foundational_ssm.data_utils import get_dataset_config
from torch_brain.data.sampler import RandomFixedWindowSampler, SequentialFixedWindowSampler
from torch.utils.data import DataLoader
from foundational_ssm.constants import DATA_ROOT
from foundational_ssm.data_utils.samplers import GroupedRandomFixedWindowSampler
from omegaconf import OmegaConf

In [None]:
config_path = "/cs/student/projects1/ml/2024/mlaimon/foundational_ssm/configs/pretrain.yaml"
cfg = OmegaConf.load(config_path) 
train_dataset = Dataset(
    root=DATA_ROOT,
    config=get_dataset_config(
        cfg.train_dataset.name,
        subjects=cfg.train_dataset.subjects
    ),
    split="train",
)
train_sampling_intervals = train_dataset.get_sampling_intervals()
train_dataset.disable_data_leakage_check()

from torch_brain.models import POYO
from torch_brain.registry import MODALITY_REGISTRY
import torch
import numpy as np
from typing import Any, Dict

device= "cuda" if torch.cuda.is_available() else "cpu"
poyo_model = POYO(
    sequence_length=1.0,
    readout_spec=MODALITY_REGISTRY['cursor_velocity_2d'],
    latent_step=1.0 / 8,
    num_latents_per_step=16,
    dim=64,
    depth=6,
    dim_head=64,
    cross_heads=2,
    self_heads=8,
).to(device)
poyo_model.unit_emb.initialize_vocab(train_dataset.get_unit_ids())
poyo_model.session_emb.initialize_vocab(train_dataset.get_session_ids())


def transform_brainsets_to_fixed_dim_samples(
    data: Any,
    sampling_rate: int = 100,
    sampling_window_ms: int = 1000
) -> Dict[str, torch.Tensor | str]:
    """Convert a *temporaldata* sample to a dictionary of Torch tensors.

    The function takes care of binning & smoothing spikes, cropping/padding neural
    and behavioural features to a globally consistent dimensionality that depends
    on the *(dataset, subject, task)* triple.

    Parameters
    ----------
    data: temporaldata.Data
        Sample returned by **torch-brain**/**temporaldata**.
    sampling_rate: int, default=100
        Target sampling rate *Hz* used for binning.
    sampling_window_ms: int, default=1000   
        Length of the temporal window after binning.
    kern_sd_ms: int, default=20
        Standard deviation of the Gaussian kernel (in ms) for smoothing spikes.

    Returns
    -------
    Dict[str, torch.Tensor]
        Dictionary with keys ``neural_input``, ``behavior_input``, ``session_id``
        and ``subject_id``.
    """
    def _ensure_dim(arr: np.ndarray, target_dim: int, pad_value: float = 0.0, *, axis: int = 1) -> np.ndarray:
        """
        Crop or pad `arr` along `axis` to match `target_dim`, right-aligning the original data.
        Pads with `pad_value` if needed.
        """
        current_dim = arr.shape[axis]
        
        # Pad if too small
        shape = list(arr.shape)
        shape[axis] = target_dim
        result = np.full(shape, pad_value, dtype=arr.dtype)
        
        # Right-align: place arr at the end along the axis
        idx = [slice(None)] * arr.ndim
        idx[axis] = slice(-current_dim, None)
        result[tuple(idx)] = arr
        return result
    
    num_timesteps = int(sampling_rate * sampling_window_ms / 1000)
    
    # ------------------------------------------------------------------
    # 1. Bin + smooth spikes
    # ------------------------------------------------------------------
    smoothed_spikes = data.smoothed_spikes.smoothed_spikes

    # ------------------------------------------------------------------
    # 2. Prepare behaviour signal (cursor velocity)
    # ------------------------------------------------------------------
    behavior_input = data.cursor.vel  # np.ndarray, (timesteps?, features)


    # ------------------------------------------------------------------
    # 3. Align channel dimensions based on (dataset, subject, task)
    # ------------------------------------------------------------------
    smoothed_spikes = _ensure_dim(smoothed_spikes, 353, axis=1)
    behavior_input = _ensure_dim(behavior_input, 2, axis=1)
    smoothed_spikes = _ensure_dim(smoothed_spikes, num_timesteps, axis=0)
    behavior_input = _ensure_dim(behavior_input, num_timesteps, axis=0)
    # smoothed_spikes = smoothed_spikes[:80, :10]
    # behavior_input = behavior_input[:80, :2]

    # ------------------------------------------------------------------
    # 4. Pack into torch tensors
    # ------------------------------------------------------------------
    # dataset, subject, task = parse_session_id(data.session.id)
    # group_tuple = (dataset, subject, task)
    # group_idx = DATASET_GROUP_TO_IDX[group_tuple]

    return {
        "neural_input": torch.as_tensor(smoothed_spikes, dtype=torch.float32),
        "behavior_input": torch.as_tensor(behavior_input, dtype=torch.float32),
        # "dataset_group_idx": torch.as_tensor(group_idx, dtype=torch.int32),
    }
    




In [None]:
import time
train_sampler = RandomFixedWindowSampler(
    sampling_intervals=train_sampling_intervals,
    window_length=1.0,
    # batch_size=256,
)
train_loader = DataLoader(
    dataset=train_dataset,
    sampler=train_sampler,
    collate_fn=collate,
    num_workers=4,
    pin_memory=True,
    batch_size=256,
)
train_dataset.transform = poyo_model.tokenize

prev_start_time = time.time()
for i, batch in enumerate(train_loader):
    start_time = time.time()
    print(f"Time taken: {start_time - prev_start_time} seconds")
    prev_start_time = start_time
    if i > 10:
        break



Time taken: 3.5574705600738525 seconds
Time taken: 0.40118980407714844 seconds
Time taken: 0.0017473697662353516 seconds
Time taken: 0.0007100105285644531 seconds
Time taken: 2.014241933822632 seconds
Time taken: 0.13263964653015137 seconds
Time taken: 0.0009927749633789062 seconds
Time taken: 0.0015914440155029297 seconds
Time taken: 1.6969027519226074 seconds
Time taken: 0.19640517234802246 seconds
Time taken: 0.001524209976196289 seconds
Time taken: 0.0007107257843017578 seconds


In [15]:
train_sampler = GroupedRandomFixedWindowSampler(
    sampling_intervals=train_sampling_intervals,
    window_length=1.0,
    batch_size=256,
)
train_loader = DataLoader(
    dataset=train_dataset,
    batch_sampler=train_sampler,
    collate_fn=collate,
    num_workers=4,
    pin_memory=True,
)
train_dataset.transform = transform_brainsets_to_fixed_dim_samples

prev_start_time = time.time()
for i, batch in enumerate(train_loader):
    start_time = time.time()
    print(f"Time taken: {start_time - prev_start_time} seconds")
    prev_start_time = start_time
    if i > 100:
        break



Time taken: 4.721935272216797 seconds
Time taken: 1.561842441558838 seconds
Time taken: 0.001127481460571289 seconds
Time taken: 5.91278076171875e-05 seconds
Time taken: 2.118638038635254 seconds
Time taken: 0.0010046958923339844 seconds
Time taken: 0.7510652542114258 seconds
Time taken: 0.0008804798126220703 seconds
Time taken: 2.420652151107788 seconds
Time taken: 0.0010461807250976562 seconds
Time taken: 0.6492063999176025 seconds
Time taken: 0.0013842582702636719 seconds
Time taken: 3.3137524127960205 seconds
Time taken: 0.0006413459777832031 seconds
Time taken: 6.723403930664062e-05 seconds
Time taken: 0.6956057548522949 seconds
Time taken: 2.3636980056762695 seconds
Time taken: 0.0008573532104492188 seconds
Time taken: 0.6097381114959717 seconds
Time taken: 1.0648438930511475 seconds
Time taken: 2.3183610439300537 seconds
Time taken: 0.0002963542938232422 seconds
Time taken: 0.4427821636199951 seconds
Time taken: 0.9433753490447998 seconds
Time taken: 1.2553966045379639 seconds
T

KeyboardInterrupt: 