In [105]:
import random
import time
import numpy as np
import polars as pl
from pathlib import Path
from IPython.display import Audio, display
from hume_wsds import WSDataset

In [191]:
datasets = {
    'wyndlabs-1M-en': '/mnt/weka/data-wsds/wyndlabs_1M-en/v4-vad_ws_continuous/',
    'wyndlabs-1M-ja': '/mnt/weka/data-wsds/wyndlabs_1M-ja/v4-vad_ws_continuous/',
    'wyndlabs-1M-de': '/mnt/weka/data-wsds/wyndlabs_1M-de/v4-vad_ws_continuous/',
    'wyndlabs-v2-ko': '/mnt/weka/data-wsds/wyndlabs_v2-ko/v4-vad_ws_continuous/',
    'librilight': '/mnt/weka/data-wsds/librilight/v3-vad_ws',

    "spotify2": "/mnt/weka/data-wsds/spotify2/v6-vad_wsds/",
    "youtube-cc": "/mnt/weka/data-wsds/youtube-cc/v4-vad_ws_continuous", 
}
ds_name = "wyndlabs-1M-en"
ds = WSDataset(datasets[ds_name])
validation_list_path = datasets[ds_name]
lanugage = 'en'
data_pct = 0.01

In [192]:
print("Number of samples (rows):", ds.index.n_samples)

Number of samples (rows): 44229406


In [195]:
import polars as pl
from pathlib import Path

df = ds.sql_select("__key__", "pq", "`language_whisper.txt`")

df = df.filter(pl.col("language_whisper.txt") == "en")

q1 = df.select(pl.col("pq").quantile(0.20)).item()
q2 = df.select(pl.col("pq").quantile(0.70)).item()

df_low = df.filter(pl.col("pq") < q1)
df_mid = df.filter((pl.col("pq") >= q1) & (pl.col("pq") < q2))
df_high = df.filter(pl.col("pq") >= q2)

print(f"Dynamic PQ bins: low < {q1:.2f}, mid {q1:.2f}–{q2:.2f}, high ≥ {q2:.2f}")

df_filtered = pl.concat([
    df_low.with_columns(pl.lit("low").alias("bin")),
    df_mid.with_columns(pl.lit("mid").alias("bin")),
    df_high.with_columns(pl.lit("high").alias("bin")),
])

df_bases = (
    df_filtered
    .with_columns(
        pl.col("__key__").str.replace(r"(_\d+)$", "").alias("base_key")
    )
    .unique(subset=["base_key"])
)

# % of unique base keys per bin
samples = []
counts = {}
for b in ["low", "mid", "high"]:
    subset = df_bases.filter(pl.col("bin") == b)
    n_samples = max(1, int(data_pct * subset.height))
    n_samples = min(n_samples, subset.height)
    counts[b] = n_samples
    if n_samples > 0:
        samples.append(subset.sample(n=n_samples, seed=42))
df_sampled = pl.concat(samples)

val_keys_before = df_sampled["__key__"].to_list()
Path(f"{validation_list_path}/validation_segments.list").write_text("\n".join(val_keys_before))
print(f"Saved {len(val_keys_before)} base sample keys")

base_keys = (
    df_sampled
    .with_columns(pl.col("__key__").str.replace(r"(_\d+)$", "").alias("base_key"))
    ["base_key"]
    .to_list()
)

df_expanded = (
    df.with_columns(pl.col("__key__").str.replace(r"(_\d+)$", "").alias("base_key"))
      .filter(pl.col("base_key").is_in(base_keys))
      .sort("__key__")
)

val_keys_after = df_expanded["__key__"].to_list()
Path(f"{validation_list_path}/validation_segments_all_source.list").write_text("\n".join(val_keys_after))
print(f"Saved {len(val_keys_after)} expanded segment keys")
print(f"PQ bins used: low < {q1:.2f}, mid {q1:.2f}–{q2:.2f}, high ≥ {q2:.2f}")

print("Samples per quality bucket:")
for b in ["low", "mid", "high"]:
    pct = (counts[b] / len(df_bases) * 100) if len(df_bases) > 0 else 0
    print(f"  {b:>4}: {counts[b]} samples ({pct:.2f}% of that bin)")

Dynamic PQ bins: low < 3.89, mid 3.89–5.65, high ≥ 5.65
Saved 4591 base sample keys
Saved 253968 expanded segment keys
PQ bins used: low < 3.89, mid 3.89–5.65, high ≥ 5.65
Samples per quality bucket:
   low: 3029 samples (0.66% of that bin)
   mid: 1267 samples (0.28% of that bin)
  high: 295 samples (0.06% of that bin)
