
# Data Analysis for Pyannote Diarization Protocol

This notebook helps you **verify and analyze** your dataset used with `pyannote.audio`:
1. **Verify** all audio files in each split (`train/dev/test`) can be loaded.
2. Show **split sizes** (file counts).
3. Compute **durations** (per split + total) and print in `H M S`.
4. Compute **speakers-per-chunk** statistics (and the **max** across all chunks).
5. Check **sample rate (frequency)** distribution and **silence ratio** per file (with a histogram).
6. Visualize **spectrograms** for 15 random training files.
7. Apply **MUSAN augmentation** (music, noise, babble) on those 15 files with the **same SNR ranges** as training code, plot spectrograms, and **save** augmented WAVs to `./augmented_samples`.


In [None]:

# --- Imports & settings ---
import os, math, random, warnings, json, glob
from pathlib import Path
from typing import Optional, List

import numpy as np
import torch
import torchaudio
import matplotlib.pyplot as plt

from pyannote.database import registry
from pyannote.core import Segment

# Reduce noisy deprecation warnings seen in recent torchaudio/pandas.
warnings.filterwarnings("ignore", category=UserWarning, message=r".*torchaudio\._backend.*deprecated.*")
warnings.filterwarnings("ignore", category=FutureWarning, message=r".*delim_whitespace.*")

# Make plots a bit larger by default
plt.rcParams['figure.figsize'] = (10, 4)

# Reproducibility
random.seed(1337)
np.random.seed(1337)
torch.manual_seed(1337)

print("PyTorch:", torch.__version__, "| torchaudio:", torchaudio.__version__)


In [None]:

# --- Configuration ---
# Environment variable PYANNOTE_DATABASE_CONFIG must point to your data/database.yml
db_cfg = os.environ.get("PYANNOTE_DATABASE_CONFIG")
assert db_cfg and Path(db_cfg).exists(), "Set PYANNOTE_DATABASE_CONFIG to your data/database.yml"

# Your protocol name:
PROTOCOL = "MyDatabase.SpeakerDiarization.MyProtocol"

# Chunk duration (seconds) used to count speakers-per-chunk (match your training, e.g., 2.0)
CHUNK_DURATION = 2.0

# For MUSAN augmentation (same defaults as your training script)
MUSAN_ROOT = Path(os.environ.get("MUSAN_ROOT", "/musan"))
SNR_NOISE = (5.0, 20.0)
SNR_MUSIC = (5.0, 20.0)
SNR_BABBLE = (10.0, 20.0)

# Probabilities (weights) for selecting a background type when augmenting
P_NOISE = 0.4
P_MUSIC = 0.4
P_BABBLE = 0.4

# Safety limits for heavy computations
MAX_FILES_FOR_CHUNK_ANALYSIS = None   # set to an int (e.g., 200) to speed up
MAX_FILES_FOR_SPECTRUM = 100          # PSD averaging cap for speed
N_SPECTROGRAM_SAMPLES = 15            # how many training files to visualize


In [None]:

# --- Helper utilities ---
def hms(seconds: float) -> str:
    seconds = int(round(seconds))
    h = seconds // 3600
    m = (seconds % 3600) // 60
    s = seconds % 60
    return f"{h}h {m}m {s}s"

def audio_path_pattern_from_cfg(cfg_path: str) -> str:
    '''Assumes audio lives next to database.yml under ./audio/{uri}.wav
    Adjust if your layout differs.
    '''
    base = Path(cfg_path).parent
    return str(base / "audio" / "{uri}.wav")

class UriToAudioPath:
    def __init__(self, pattern: str):
        self.pattern = pattern
    def __call__(self, file):
        return self.pattern.format(uri=file["uri"])

def safe_audio_info(path: str):
    '''Return (num_frames, sample_rate, num_channels) without fully loading audio.'''
    try:
        info = torchaudio.info(path)
        return info.num_frames, info.sample_rate, info.num_channels
    except Exception:
        return None

def load_wave(path: str, target_sr: Optional[int] = None):
    wav, sr = torchaudio.load(path)
    wav = wav.mean(dim=0, keepdim=True) if wav.shape[0] > 1 else wav[:1, :]
    if target_sr and sr != target_sr:
        wav = torchaudio.functional.resample(wav, sr, target_sr)
        sr = target_sr
    return wav, sr

def compute_silence_ratio(samples: torch.Tensor, sr: int, frame_ms: int = 20, hop_ms: int = 10, threshold_db: float = -40.0) -> float:
    '''Fraction of frames whose RMS (dBFS) < threshold_db. samples: (1, T).'''
    x = samples.squeeze(0).cpu().numpy()
    frame = int(sr * frame_ms / 1000.0)
    hop = int(sr * hop_ms / 1000.0)
    if frame <= 0: frame = 1
    if hop <= 0: hop = 1

    silent = 0
    total = 0
    # Ensure at least one step even for short files
    end_idx = max(0, len(x) - frame + 1)
    if end_idx == 0 and len(x) >= 1:
        end_idx = 1
    for start in range(0, end_idx, hop):
        seg = x[start:start+frame]
        if len(seg) == 0:
            continue
        rms = np.sqrt(np.mean(seg**2) + 1e-12)
        db = 20.0 * np.log10(rms + 1e-12)
        if db < threshold_db:
            silent += 1
        total += 1
    return (silent / total) if total else 0.0

def plot_spectrogram(samples: torch.Tensor, sr: int, title: str = "Spectrogram"):
    # Compute magnitude spectrogram via STFT
    n_fft = 1024
    hop_length = 256
    spec = torch.stft(samples.squeeze(0), n_fft=n_fft, hop_length=hop_length, return_complex=True)
    mag = spec.abs().numpy()
    mag_db = 20 * np.log10(mag + 1e-10)
    import matplotlib.pyplot as plt
    plt.figure()
    plt.imshow(mag_db, origin="lower", aspect="auto")
    plt.colorbar(label="dB")
    plt.title(title)
    plt.xlabel("Frames")
    plt.ylabel("Frequency bins")
    plt.tight_layout()
    plt.show()

def pick_random_segment(wav: torch.Tensor, target_len: int) -> torch.Tensor:
    '''wav: (1, T) on CPU; returns (1, target_len), looping if too short.'''
    T = wav.shape[-1]
    if T >= target_len:
        start = random.randint(0, T - target_len)
        return wav[:, start:start+target_len]
    reps = max(1, int(math.ceil(target_len / max(1, T))))
    out = wav.repeat(1, reps)[:, :target_len]
    return out

def mix_at_snr(clean: torch.Tensor, noise: torch.Tensor, snr_db: float) -> torch.Tensor:
    '''Both shape (1, T), returns mixed at target SNR (approx) and clipped to [-1,1].'''
    Px = max(np.sqrt(float((clean**2).mean())), 1e-12)
    Pn = max(np.sqrt(float((noise**2).mean())), 1e-12)
    snr_lin = 10.0 ** (snr_db / 10.0)
    a = Px / (Pn * math.sqrt(snr_lin))
    out = clean + a * noise
    return torch.clamp(out, -1.0, 1.0)


In [None]:

# --- Load protocol with preprocessors ---
audio_pattern = audio_path_pattern_from_cfg(db_cfg)
pre = {"audio": UriToAudioPath(audio_pattern)}
registry.load_database(db_cfg)
proto = registry.get_protocol(PROTOCOL, preprocessors=pre)

def collect_split(split_iter):
    items = []
    for f in split_iter:
        items.append(f)
    return items

train_items = collect_split(proto.train())
dev_items   = collect_split(proto.development())
test_items  = collect_split(proto.test())

len(train_items), len(dev_items), len(test_items)


In [None]:

# --- 1) Check all data loads successfully & 2/3) Splits and durations ---
def verify_and_sum(items, name: str):
    ok = 0
    missing = 0
    total_sec = 0.0
    sr_hist = {}
    bad_files = []
    for it in items:
        path = it["audio"]
        info = safe_audio_info(path)
        if info is None:
            missing += 1
            bad_files.append(path)
            continue
        num_frames, sr, ch = info
        dur = float(num_frames) / float(sr) if sr > 0 else 0.0
        total_sec += dur
        sr_hist[sr] = sr_hist.get(sr, 0) + 1
        ok += 1
    print(f"{name}: {ok}/{len(items)} files OK | audio={hms(total_sec)} | missing={missing}")
    if bad_files:
        print("  Missing/unreadable example:", bad_files[:5])
    return total_sec, sr_hist

train_sec, train_sr_hist = verify_and_sum(train_items, "train")
dev_sec,   dev_sr_hist   = verify_and_sum(dev_items,   "dev")
test_sec,  test_sr_hist  = verify_and_sum(test_items,  "test")

print("\nSplit sizes -> train/dev/test:", len(train_items), len(dev_items), len(test_items))
print("Total duration:", hms(train_sec + dev_sec + test_sec))

# Merge SR histograms for a global view
from collections import Counter
sr_counter = Counter()
sr_counter.update(train_sr_hist)
sr_counter.update(dev_sr_hist)
sr_counter.update(test_sr_hist)

print("\nSample-rate distribution (all splits):")
for sr, cnt in sorted(sr_counter.items()):
    print(f"  {sr} Hz: {cnt} files")


In [None]:

# --- 4) Speakers per CHUNK (e.g., 2 seconds) on TRAIN ---
def speakers_per_chunk_for_file(item, chunk_duration: float):
    annotation = item["annotation"]
    info = safe_audio_info(item["audio"])
    if info is None:
        return []
    num_frames, sr, _ = info
    file_dur = float(num_frames) / float(sr) if sr > 0 else 0.0

    counts = []
    start = 0.0
    while start < file_dur - 1e-6:
        end = min(file_dur, start + chunk_duration)
        seg = Segment(start, end)
        cropped = annotation.crop(seg, mode="intersection", return_timeline=False)
        labels = set()
        for (s, t, label) in cropped.itertracks(yield_label=True):
            labels.add(label)
        counts.append(len(labels))
        start += chunk_duration
    return counts

subset = train_items if (MAX_FILES_FOR_CHUNK_ANALYSIS is None) else train_items[:MAX_FILES_FOR_CHUNK_ANALYSIS]
all_counts = []
for it in subset:
    all_counts.extend(speakers_per_chunk_for_file(it, CHUNK_DURATION))

if all_counts:
    print(f"Max speakers in any {CHUNK_DURATION:.1f}s chunk (train):", max(all_counts))
    # Quick histogram printout
    hist = {}
    for c in all_counts:
        hist[c] = hist.get(c, 0) + 1
    print("Speakers-per-chunk histogram (count of chunks):", hist)
else:
    print("No counts computed (maybe items unreadable).")


In [None]:

# --- 5) Plot sample-rate distribution & silence ratio on TRAIN ---
# Plot SR histogram:
all_sr = []
for sr, cnt in train_sr_hist.items():
    all_sr.extend([sr]*cnt)

if all_sr:
    plt.figure()
    plt.hist(all_sr, bins=len(set(all_sr)))
    plt.title("Training sample-rate distribution (Hz)")
    plt.xlabel("Sample rate (Hz)")
    plt.ylabel("Count")
    plt.tight_layout()
    plt.show()
else:
    print("No sample-rate data to plot.")

# Silence ratio per file (train)
silence_rows = []
for it in train_items:
    info = safe_audio_info(it["audio"])
    if info is None:
        continue
    wav, sr = load_wave(it["audio"], target_sr=None)
    ratio = compute_silence_ratio(wav, sr, frame_ms=20, hop_ms=10, threshold_db=-40.0)
    silence_rows.append((it["uri"], ratio))

# Report
if silence_rows:
    avg_ratio = float(np.mean([r for _, r in silence_rows]))
    print(f"Average silence ratio (per file) in TRAIN: {avg_ratio*100:.2f}%")
    # Plot histogram of ratios
    plt.figure()
    plt.hist([r for _, r in silence_rows], bins=30)
    plt.title("Silence ratio per training file")
    plt.xlabel("Fraction of silent frames")
    plt.ylabel("Files")
    plt.tight_layout()
    plt.show()

    # Top 10 most-silent files
    top10 = sorted(silence_rows, key=lambda x: x[1], reverse=True)[:10]
    print("Top 10 most-silent training files:")
    for uri, r in top10:
        print(f"  {uri}: {r*100:.2f}% silent")
else:
    print("No silence data computed (no train files loaded).")


In [None]:

# --- 6) Average frequency content (Welch PSD) on a subset of TRAIN ---
try:
    from scipy.signal import welch
    subset_for_psd = train_items[:min(len(train_items), MAX_FILES_FOR_SPECTRUM)]
    psd_sum = None
    freqs_ref = None
    count = 0

    for it in subset_for_psd:
        info = safe_audio_info(it["audio"])
        if info is None:
            continue
        wav, sr = load_wave(it["audio"], target_sr=None)
        x = wav.squeeze(0).numpy()
        if len(x) < 2048:
            continue
        f, Pxx = welch(x, fs=sr, nperseg=2048)
        if psd_sum is None:
            psd_sum = Pxx
            freqs_ref = f
        else:
            m = min(len(psd_sum), len(Pxx))
            psd_sum[:m] += Pxx[:m]
        count += 1

    if count > 0:
        psd_avg = psd_sum / count
        plt.figure()
        plt.semilogy(freqs_ref[:len(psd_avg)], psd_avg)
        plt.title("Average Welch PSD (training subset)")
        plt.xlabel("Frequency (Hz)")
        plt.ylabel("Power spectral density")
        plt.tight_layout()
        plt.show()
    else:
        print("No PSD plotted (no valid training audio).")
except Exception as e:
    print("Skipping PSD plot (scipy not available or other error):", e)


In [None]:

# --- 7) Spectrograms for random 15 training files ---
N = min(N_SPECTROGRAM_SAMPLES, len(train_items))
if N == 0:
    print("No training items to visualize.")
else:
    chosen = random.sample(train_items, k=N)
    for it in chosen:
        path = it["audio"]
        info = safe_audio_info(path)
        if info is None:
            print("Unreadable:", path)
            continue
        wav, sr = load_wave(path, target_sr=None)
        title = f"{it['uri']} (sr={sr})"
        plot_spectrogram(wav, sr, title=title)


In [None]:

# --- 8) MUSAN augmentation (music, noise, babble) on the same files ---
def list_audio_files(root: Path):
    if not root.exists():
        return []
    exts = (".wav", ".flac", ".mp3", ".ogg")
    return [Path(p) for p in glob.glob(str(root / "**" / "*"), recursive=True) if Path(p).suffix.lower() in exts]

def load_random_bg(files: List[Path], target_len: int, sr: int):
    if not files:
        return None, None
    path = random.choice(files)
    wav, file_sr = torchaudio.load(str(path))
    wav = wav.mean(dim=0, keepdim=True) if wav.shape[0] > 1 else wav[:1, :]
    if file_sr != sr:
        wav = torchaudio.functional.resample(wav, file_sr, sr)
        file_sr = sr
    noise = pick_random_segment(wav, target_len)
    return noise, path

musan_noise  = list_audio_files(MUSAN_ROOT / "noise")
musan_music  = list_audio_files(MUSAN_ROOT / "music")
musan_speech = list_audio_files(MUSAN_ROOT / "speech")

print(f"MUSAN pools -> noise={len(musan_noise)}, music={len(musan_music)}, babble(speech)={len(musan_speech)}")

out_dir = Path("./augmented_samples")
out_dir.mkdir(parents=True, exist_ok=True)

if 'chosen' not in globals() or not chosen:
    print("No previously chosen files from training set; skipping augmentation cell.")
else:
    for it in chosen:
        path = it["audio"]
        info = safe_audio_info(path)
        if info is None:
            print("Unreadable:", path)
            continue
        wav, sr = load_wave(path, target_sr=None)
        T = wav.shape[-1]

        jobs = [
            ("music",  musan_music,  SNR_MUSIC),
            ("noise",  musan_noise,  SNR_NOISE),
            ("babble", musan_speech, SNR_BABBLE),
        ]

        for kind, pool, (lo, hi) in jobs:
            if not pool:
                print(f"Skipping {kind}: no files found under MUSAN.")
                continue
            bg, bg_path = load_random_bg(pool, T, sr)
            if bg is None:
                print(f"Skipping {kind}: failed to load bg.")
                continue
            snr = random.uniform(lo, hi)
            mixed = mix_at_snr(wav, bg, snr_db=snr)

            # Save
            out_path = out_dir / f"{it['uri']}_{kind}.wav"
            torchaudio.save(str(out_path), mixed, sr)
            print(f"Saved: {out_path} (SNR≈{snr:.1f} dB)")

            # Plot
            plot_spectrogram(mixed, sr, title=f"{it['uri']} + {kind} (SNR≈{snr:.1f} dB)")
