In [9]:
import os
import numpy as np
from pathlib import Path
import random
from tqdm import tqdm

from utils.utils import (
    set_seed,
    get_args,
)

set_seed(42)

DATA_DIR = Path('../../data/encoding/ibl-mouse-separate')
TRAIN_RATIO = 0.7
VAL_RATIO = 0.1
TEST_RATIO = 0.2

seed set to 42


In [11]:
for session_dir in DATA_DIR.iterdir():
    if not session_dir.is_dir():
        continue
    print(f"Processing session {session_dir.name}")
    # get all .npy files in the session directory
    all_files = list(session_dir.glob('*.npy'))
    num_trials = len(all_files)
    trial_idxs = np.random.choice(np.arange(num_trials), num_trials, replace=False)
    train_idxs = trial_idxs[:int(num_trials*0.7)]
    val_idxs = trial_idxs[int(num_trials*0.7):int(num_trials*0.8)]
    test_idxs = trial_idxs[int(num_trials*0.8):]
    print(f"Train/Val/Test: {len(train_idxs)}/{len(val_idxs)}/{len(test_idxs)}")
    # create train, val, test directories
    (session_dir / 'train').mkdir(exist_ok=True)
    (session_dir / 'val').mkdir(exist_ok=True)
    (session_dir / 'test').mkdir(exist_ok=True)
    for file in tqdm(all_files, desc=f"Processing session {session_dir.name}"):
        trial_idx = int(file.stem)
        if trial_idx in train_idxs:
            os.rename(file, session_dir / 'train' / file.name)
        elif trial_idx in val_idxs:
            os.rename(file, session_dir / 'val' / file.name)
        elif trial_idx in test_idxs:
            os.rename(file, session_dir / 'test' / file.name)
        else:
            raise ValueError(f"Trial idx {trial_idx} not in train/val/test idxs")

Processing session 3638d102-e8b6-4230-8742-e548cd87a949
Train/Val/Test: 486/70/139


Processing session 3638d102-e8b6-4230-8742-e548cd87a949:   9%|▉         | 63/695 [00:00<00:04, 155.23it/s]

Processing session 3638d102-e8b6-4230-8742-e548cd87a949: 100%|██████████| 695/695 [00:04<00:00, 149.17it/s]


Processing session d23a44ef-1402-4ed7-97f5-47e9a7a504d9
Train/Val/Test: 287/41/82


Processing session d23a44ef-1402-4ed7-97f5-47e9a7a504d9: 100%|██████████| 410/410 [00:01<00:00, 232.73it/s]


Processing session db4df448-e449-4a6f-a0e7-288711e7a75a
Train/Val/Test: 281/40/81


Processing session db4df448-e449-4a6f-a0e7-288711e7a75a: 100%|██████████| 402/402 [00:02<00:00, 179.68it/s]


Processing session 03d9a098-07bf-4765-88b7-85f8d8f620cc
Train/Val/Test: 406/58/117


Processing session 03d9a098-07bf-4765-88b7-85f8d8f620cc: 100%|██████████| 581/581 [00:03<00:00, 163.01it/s]


Processing session 4b7fbad4-f6de-43b4-9b15-c7c7ef44db4b
Train/Val/Test: 589/84/169


Processing session 4b7fbad4-f6de-43b4-9b15-c7c7ef44db4b: 100%|██████████| 842/842 [00:05<00:00, 146.40it/s]


Processing session 9b528ad0-4599-4a55-9148-96cc1d93fb24
Train/Val/Test: 406/58/117


Processing session 9b528ad0-4599-4a55-9148-96cc1d93fb24: 100%|██████████| 581/581 [00:03<00:00, 152.74it/s]


Processing session 687017d4-c9fc-458f-a7d5-0979fe1a7470
Train/Val/Test: 305/43/88


Processing session 687017d4-c9fc-458f-a7d5-0979fe1a7470: 100%|██████████| 436/436 [00:02<00:00, 150.38it/s]


Processing session 0841d188-8ef2-4f20-9828-76a94d5343a4
Train/Val/Test: 353/51/101


Processing session 0841d188-8ef2-4f20-9828-76a94d5343a4: 100%|██████████| 505/505 [00:03<00:00, 147.18it/s]
