In [1]:
import os
from pathlib import Path
from pnpl.datasets.libribrain2025 import constants_utils

constants_utils.set_remote_constants_url(
    f"{(Path(os.getcwd()) / 'constants.json').as_uri()}")
constants_utils.refresh_constants();

## Before Averaging

In [None]:
from pnpl.datasets import LibriBrainPhoneme

train_dataset = LibriBrainPhoneme(
    data_path="./data/",
    tmin=0.0,
    tmax=0.5,
    standardize=True,
    partition="train",
)
validation_dataset = LibriBrainPhoneme(
    data_path="./data/",
    tmin=0.0,
    tmax=0.5,
    standardize=True,
    partition="validation",
)

In [None]:
import numpy as np


mean_diff = np.abs(train_dataset.channel_means - validation_dataset.channel_means)
std_diff = np.abs(train_dataset.channel_stds - validation_dataset.channel_stds)

In [None]:
mean_diff.mean().item(), mean_diff.std().item()

In [None]:
std_diff.mean().item(), std_diff.std().item()

## After Averaging

In [15]:
from pathlib import Path
from libribrain_experiments.grouped_dataset import MyGroupedDatasetV3
from typing import Literal
from pnpl.datasets import LibriBrainCompetitionHoldout, LibriBrainPhoneme

split: Literal["train", "val"] = "val"

raw_train_dataset = LibriBrainPhoneme(
    data_path="./data/",
    tmin=0.0,
    tmax=0.5,
    standardize=True,
    partition="train",
)
raw_validation_dataset = LibriBrainPhoneme(
    data_path="./data/",
    tmin=0.0,
    tmax=0.5,
    standardize=True,
    partition="validation",
    channel_means=raw_train_dataset.channel_means,
    channel_stds=raw_train_dataset.channel_stds,
)
source_dataset = MyGroupedDatasetV3(
    raw_train_dataset if split == "train" else raw_validation_dataset,
    grouped_samples=100,
    drop_remaining=False,
    average_grouped_samples=True,
    state_cache_path=Path(f"./data_preprocessed/groupedv3/{split}_grouped_100.pt"),
    balance=True,
    shuffle=True,
)

holdout_dataset = LibriBrainCompetitionHoldout(
    # channel_means=raw_train_dataset.channel_means,
    # channel_stds=raw_train_dataset.channel_stds,
    data_path="./data/",
    task="phoneme",
    tmin=0.0,
    tmax=0.5,
    standardize=False, # already standardized
)

In [19]:
print(f"Raw train means: {raw_train_dataset.channel_means.mean().item():.6e} ± {raw_train_dataset.channel_means.std().item():.6e}")
print(f"Raw train stds: {raw_train_dataset.channel_stds.mean().item():.6e} ± {raw_train_dataset.channel_stds.std().item():.6e}")

Raw train means: 1.072113e-15 ± 9.067206e-14
Raw train stds: 7.629011e-11 ± 8.890892e-11


In [None]:
from tqdm import tqdm
import numpy as np
from torch.utils.data import DataLoader

all_means = []
all_stds = []

source_loader = DataLoader(
    source_dataset, batch_size=256, shuffle=False, num_workers=4)

for sample, *_ in tqdm(source_loader):
    arr = sample.numpy()
    mean = arr.mean(axis=2)  # mean over time
    std = arr.std(axis=2)    # std over time
    all_means.extend(mean)
    all_stds.extend(std)

source_channel_means = np.mean(np.stack(all_means), axis=0)
source_channel_stds = np.mean(np.stack(all_stds), axis=0)

all_means = []
all_stds = []

holdout_loader = DataLoader(
    holdout_dataset, batch_size=256, shuffle=False, num_workers=4)

for sample in tqdm(holdout_loader):
    arr = sample.numpy()
    mean = arr.mean(axis=2)  # mean over time
    std = arr.std(axis=2)    # std over time
    all_means.extend(mean)
    all_stds.extend(std)

holdout_channel_means = np.mean(np.stack(all_means), axis=0)
holdout_channel_stds = np.mean(np.stack(all_stds), axis=0)

100%|██████████| 43/43 [01:21<00:00,  1.89s/it]
100%|██████████| 10/10 [00:00<00:00, 19.83it/s]


In [7]:
import numpy as np


mean_diff = np.abs(source_channel_means - holdout_channel_means)
std_diff = np.abs(source_channel_stds - holdout_channel_stds)

In [13]:
print(f"Source means: {source_channel_means.mean().item():.6f} ± {source_channel_means.std().item():.6f}")
print(f"Source stds: {source_channel_stds.mean().item():.6f} ± {source_channel_stds.std().item():.6f}")
print(f"Holdout means: {holdout_channel_means.mean().item():.6f} ± {holdout_channel_means.std().item():.6f}")
print(f"Holdout stds: {holdout_channel_stds.mean().item():.6f} ± {holdout_channel_stds.std().item():.6f}")

Source means: -0.000416 ± 0.004037
Source stds: 0.010629 ± 0.007592
Holdout means: -0.002820 ± 0.041547
Holdout stds: 0.499323 ± 0.205752


In [11]:
# mean_diff.mean().item(), mean_diff.std().item()
print(f"Mean diff: {mean_diff.mean().item():.6f} ± {mean_diff.std().item():.6f}")
print(f"Std diff: {std_diff.mean().item():.6f} ± {std_diff.std().item():.6f}")

Mean diff: 0.037174 ± 0.015907
Std diff: 0.488695 ± 0.204510
