In [1]:
import numpy as np
from typing import Dict, Tuple, Any
from tqdm import tqdm
import re
from foundational_ssm.data_utils.dataset import TorchBrainDataset
from foundational_ssm.constants import DATA_ROOT

def parse_session_id(session_id: str) -> Tuple[str, str, str]:
    patterns = {
        "churchland_shenoy_neural_2012": re.compile(r"([^/]+)/([^_]+)_[0-9]+_(.+)"),
        "flint_slutzky_accurate_2012": re.compile(r"([^/]+)/monkey_([^_]+)_e1_(.+)"),
        "odoherty_sabes_nonhuman_2017": re.compile(r"([^/]+)/([^_]+)_[0-9]{8}_[0-9]+"),
        "pei_pandarinath_nlb_2021": re.compile(r"([^/]+)/([^_]+)_(.+)"),
        "perich_miller_population_2018": re.compile(r"([^/]+)/([^_]+)_[0-9]+_(.+)"),
    }

    dataset = session_id.split('/')[0]
    if dataset not in patterns:
        raise ValueError(f"Unknown dataset: {dataset}")

    match = patterns[dataset].match(session_id)
    if not match:
        raise ValueError(f"Could not parse session_id: {session_id!r}")

    if dataset == "odoherty_sabes_nonhuman_2017":
        # Always assign task as 'random_target_reaching'
        _, subject = match.groups()
        return dataset, subject, "random_target_reaching"
    elif dataset == "flint_slutzky_accurate_2012":
        # task is always 'center_out_reaching'
        _, subject, _ = match.groups()
        return dataset, subject, "center_out_reaching"
    else:
        return match.groups()

In [2]:
pretrain_config_path = "/cs/student/projects1/ml/2024/mlaimon/foundational_ssm/configs/dataset/pretrain_train_and_val.yaml"

pretrain_dataset = TorchBrainDataset(
        root="../"+DATA_ROOT,                # root directory where .h5 files are found
        # recording_id=recording_id,  # you either specify a single recording ID
        config=pretrain_config_path,                 # or a config for multi-session training / more complex configs
        keep_files_open=True,
        lazy=True,
        split="train"
    )

pretrain_sampling_intervals = pretrain_dataset.get_sampling_intervals()        
DATASET_GROUP_INFO: Dict[Tuple[str, str, str], Dict[str, Any]] = {}

for i, (recording_id, train_intervals) in tqdm(enumerate(pretrain_sampling_intervals.items())):
    recording_data = pretrain_dataset.get_recording_data(recording_id)
    num_units = int(np.max(recording_data.spikes.unit_index))
    dataset, subject, task = parse_session_id(recording_id)
    train_duration = np.sum(train_intervals.end - train_intervals.start)
    # print(type(recording_data.cursor.vel))
    if recording_id.startswith("pei_pandarinath_nlb_2021"):
        behavior_sampling_rate = np.min(recording_data.hand.timestamps[1:] - recording_data.hand.timestamps[:-1])
    else:    
        behavior_sampling_rate = np.min(recording_data.cursor.timestamps[1:] - recording_data.cursor.timestamps[:-1])
    if (dataset, subject, task) not in DATASET_GROUP_INFO:
        DATASET_GROUP_INFO[(dataset, subject, task)] = {
            "max_num_units": num_units,
            "behavior_dim": 2,
            "train_duration": float(round(train_duration, 2)),
            "min_behavior_sampling_rate": behavior_sampling_rate,
        }
    else:
        DATASET_GROUP_INFO[(dataset, subject, task)]["max_num_units"] = max(DATASET_GROUP_INFO[(dataset, subject, task)]["max_num_units"], num_units)
        DATASET_GROUP_INFO[(dataset, subject, task)]["train_duration"] = DATASET_GROUP_INFO[(dataset, subject, task)]["train_duration"] + train_duration
        DATASET_GROUP_INFO[(dataset, subject, task)]["min_behavior_sampling_rate"] = min(DATASET_GROUP_INFO[(dataset, subject, task)]["min_behavior_sampling_rate"], behavior_sampling_rate)

148it [25:58, 10.53s/it] 


In [7]:
downstream_t_rt_config_path = "/cs/student/projects1/ml/2024/mlaimon/foundational_ssm/configs/dataset/downstream_t_rt.yaml"
downstream_t_co_config_path = "/cs/student/projects1/ml/2024/mlaimon/foundational_ssm/configs/dataset/downstream_t_co.yaml"

downstream_t_co_dataset = TorchBrainDataset(
        root="../"+DATA_ROOT,                # root directory where .h5 files are found
        # recording_id=recording_id,  # you either specify a single recording ID
        config=downstream_t_co_config_path,                 # or a config for multi-session training / more complex configs
        keep_files_open=True,
        lazy=True,
        split="train"
    )

downstream_t_co_sampling_intervals = downstream_t_co_dataset.get_sampling_intervals()        
DATASET_GROUP_INFO: Dict[Tuple[str, str, str], Dict[str, Any]] = {}

for i, (recording_id, train_intervals) in tqdm(enumerate(downstream_t_co_sampling_intervals.items())):
    recording_data = downstream_t_co_dataset.get_recording_data(recording_id)
    num_units = int(np.max(recording_data.spikes.unit_index))
    dataset, subject, task = parse_session_id(recording_id)
    train_duration = np.sum(train_intervals.end - train_intervals.start)
    # print(type(recording_data.cursor.vel))
    if recording_id.startswith("pei_pandarinath_nlb_2021"):
        behavior_sampling_rate = np.min(recording_data.hand.timestamps[1:] - recording_data.hand.timestamps[:-1])
    else:    
        behavior_sampling_rate = np.min(recording_data.cursor.timestamps[1:] - recording_data.cursor.timestamps[:-1])
    if (dataset, subject, task) not in DATASET_GROUP_INFO:
        DATASET_GROUP_INFO[(dataset, subject, task)] = {
            "max_num_units": num_units,
            "behavior_dim": 2,
            "train_duration": float(round(train_duration, 2)),
            "min_behavior_sampling_rate": behavior_sampling_rate,
        }
    else:
        DATASET_GROUP_INFO[(dataset, subject, task)]["max_num_units"] = max(DATASET_GROUP_INFO[(dataset, subject, task)]["max_num_units"], num_units)
        DATASET_GROUP_INFO[(dataset, subject, task)]["train_duration"] = DATASET_GROUP_INFO[(dataset, subject, task)]["train_duration"] + train_duration
        DATASET_GROUP_INFO[(dataset, subject, task)]["min_behavior_sampling_rate"] = min(DATASET_GROUP_INFO[(dataset, subject, task)]["min_behavior_sampling_rate"], behavior_sampling_rate)

4it [00:01,  3.77it/s]


In [8]:
DATASET_GROUP_INFO

{('perich_miller_population_2018',
  't',
  'center_out_reaching'): {'max_num_units': 58, 'behavior_dim': 2, 'train_duration': np.float64(2193.2977666666666), 'min_behavior_sampling_rate': np.float64(0.009999999999990905)}}