In [1]:
from pathlib import Path
import numpy as np
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from pnpl.datasets import LibriBrainPhoneme, LibriBrainCompetitionHoldout

from libribrain_experiments.grouped_dataset import MyGroupedDatasetV3

In [2]:
_train_dataset = LibriBrainPhoneme(
    data_path="./data/",
    tmin=0.0,
    tmax=0.5,
    standardize=True,
    partition="train",
)

train_dataset = MyGroupedDatasetV3(
    _train_dataset,
    grouped_samples=100,
    drop_remaining=False,
    average_grouped_samples=True,
    state_cache_path=Path("./data_preprocessed/groupedv3/train_grouped_100.pt"),
    shuffle=True,
    # augment=True,  # Set to True if you want to apply data augmentation
)

_val_dataset = LibriBrainPhoneme(
    data_path="./data/",
    tmin=0.0,
    tmax=0.5,
    standardize=True,
    partition="validation",
)

val_dataset = MyGroupedDatasetV3(
    _val_dataset,
    grouped_samples=100,
    drop_remaining=False,
    average_grouped_samples=True,
    state_cache_path=Path("./data_preprocessed/groupedv3/val_grouped_100.pt"),
    shuffle=True,
    # augment=True,  # Set to True if you want to apply data augmentation
)
holdout_dataset = LibriBrainCompetitionHoldout(
    data_path="./data/",
    task="phoneme",
    tmin=0.0,
    tmax=0.5,
)



In [3]:
def print_shape(dataset, holdout=False):
    # dataset[0].shape == (306, 125)
    loader = DataLoader(
        dataset,
        batch_size=256,
        num_workers=4,
    )

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    all_sum = torch.zeros((1, 306)).to(device)
    all_sum_square = torch.zeros((1, 306)).to(device)
    count = torch.tensor(len(dataset), dtype=torch.float64)

    for batch in tqdm(loader):
        batch: np.ndarray = batch.to(
            device) if holdout else batch[0].to(device)
        all_sum += batch.sum(axis=(0, 2))
        all_sum_square += ((batch / torch.sqrt(count)) ** 2).sum(axis=(0, 2))
        count += batch.shape[0]

    mean = all_sum / count
    std = torch.sqrt(all_sum_square - mean ** 2)
    
    print("  Mean of Mean:", mean[~mean.isnan()].mean().item())
    print("  Std of Mean:", mean[~mean.isnan()].std().item())
    print("  Mean of Std:", std[~std.isnan()].mean().item())
    print("  Std of Std:", std[~std.isnan()].std().item())

In [4]:
print("Holdout:")
print_shape(holdout_dataset, holdout=True)

Holdout:


100%|██████████| 10/10 [00:00<00:00, 13.80it/s]


  Mean of Mean: -0.17622576653957367
  Std of Mean: 2.6009180545806885
  Mean of Std: 5.760135173797607
  Std of Std: 2.3829505443573


In [5]:
print("Train:")
print_shape(_train_dataset)

print("Train (Grouped):")
print_shape(train_dataset)

Train:


100%|██████████| 5222/5222 [02:44<00:00, 31.74it/s]


  Mean of Mean: -0.0296162161976099
  Std of Mean: 0.28274792432785034
  Mean of Std: 2.2516109943389893
  Std of Std: 0.6436986327171326
Train (Grouped):


100%|██████████| 53/53 [02:03<00:00,  2.33s/it]

  Mean of Mean: -0.029526973143219948
  Std of Mean: 0.2823629379272461
  Mean of Std: 0.1705937534570694
  Std of Std: 0.0689060166478157





In [6]:
print("Val:")
print_shape(_val_dataset)

print("Val (Grouped):")
print_shape(val_dataset)

Val:


100%|██████████| 1109/1109 [00:32<00:00, 34.44it/s]


  Mean of Mean: 0.07536603510379791
  Std of Mean: 1.0090750455856323
  Mean of Std: 4.097787857055664
  Std of Std: 1.6319663524627686
Val (Grouped):


100%|██████████| 12/12 [00:45<00:00,  3.76s/it]

  Mean of Mean: 0.07475172728300095
  Std of Mean: 1.0109822750091553
  Mean of Std: 0.39236754179000854
  Std of Std: 0.17127688229084015



