In [2]:
#!/usr/bin/env python3
"""
bulk_compare_gpu.py

End-to-end bulk comparison script:
1. Generate acoustic world (world.py logic)
2. Unified noise covariance estimation
3. MVDR + SMVB beamforming (GPU accelerated)
4. Evaluation using eval.py metrics (CPU)
5. Log results to CSV

Authoritative experiment loop.
"""

import os
import csv
import subprocess
import numpy as np
import soundfile as sf
import torch
import torchaudio

# ------------------------
# USER CONFIG
# ------------------------
N_SAMPLES = 100
SAVE_DIR = "sample"
CSV_PATH = "bulk_results.csv"

FS = 16000
N_FFT = 256
N_HOP = 128
D = 0.08
C = 343.0
ANGLE_TARGET = 90.0
SIGMA = 1e-3

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# ------------------------
# IMPORT EVAL FUNCTIONS
# ------------------------
from eval import (
    load_and_align_signals,
    calculate_osnr_and_osir,
    calculate_pesq_metric
)

# ============================================================
# STEERING VECTOR (GPU)
# ============================================================
def steering_vector(f_hz):
    theta = torch.deg2rad(torch.tensor(ANGLE_TARGET, device=DEVICE))
    omega = 2 * np.pi * f_hz

    tau1 = (D / 2) * torch.cos(theta) / C
    tau2 = (D / 2) * torch.cos(theta - np.pi) / C

    v = torch.stack([
        torch.exp(-1j * omega * tau1),
        torch.exp(-1j * omega * tau2)
    ], dim=0).reshape(2, 1)

    return v


# ============================================================
# GPU PIPELINE
# ============================================================
def run_beamformers_gpu():
    # ---------- Load audio ----------
    y_mix, _ = sf.read(f"{SAVE_DIR}/mixture.wav", dtype="float32")
    s_t, _ = sf.read(f"{SAVE_DIR}/target.wav", dtype="float32")
    s_i, _ = sf.read(f"{SAVE_DIR}/interference.wav", dtype="float32")
    s_n, _ = sf.read(f"{SAVE_DIR}/noise.wav", dtype="float32")

    # Mono references
    s_t = s_t[:, 0] if s_t.ndim > 1 else s_t
    s_i = s_i[:, 0] if s_i.ndim > 1 else s_i
    s_n = s_n[:, 0] if s_n.ndim > 1 else s_n

    # To torch
    Y = torch.tensor(y_mix.T, device=DEVICE)
    s_t = torch.tensor(s_t, device=DEVICE)
    s_i = torch.tensor(s_i, device=DEVICE)
    s_n = torch.tensor(s_n, device=DEVICE)

    # ---------- STFT ----------
    Y_stft = torch.stft(Y, N_FFT, N_HOP, return_complex=True)
    S_t = torch.stft(s_t, N_FFT, N_HOP, return_complex=True)
    S_i = torch.stft(s_i, N_FFT, N_HOP, return_complex=True)
    S_n = torch.stft(s_n, N_FFT, N_HOP, return_complex=True)

    # ---------- Oracle mask ----------
    mag_t2 = torch.abs(S_t) ** 2
    mag_i2 = torch.abs(S_i) ** 2
    mag_n2 = torch.abs(S_n) ** 2

    mask_t = mag_t2 / (mag_t2 + mag_i2 + mag_n2 + 1e-10)
    mask_in = 1.0 - mask_t

    # ---------- Covariance ----------
    n_freq = Y_stft.shape[1]
    R_in = torch.zeros((n_freq, 2, 2), dtype=torch.complex64, device=DEVICE)

    for f in range(n_freq):
        w = torch.sqrt(mask_in[f])
        Yf = Y_stft[:, f, :]
        Yw = Yf * w
        R_in[f] = (Yw @ Yw.conj().T) / (torch.sum(w**2) + 1e-8)

    # ---------- Beamforming ----------
    S_mvdr = torch.zeros_like(S_t)
    S_smvb = torch.zeros_like(S_t)

    for f in range(n_freq):
        f_hz = f * FS / N_FFT
        if f_hz < 100:
            S_mvdr[f] = Y_stft[0, f]
            S_smvb[f] = Y_stft[0, f]
            continue

        R = R_in[f] + SIGMA * torch.eye(2, device=DEVICE)
        d = steering_vector(f_hz)

        # MVDR
        w_mvdr = torch.linalg.solve(R, d)
        w_mvdr /= (d.conj().T @ w_mvdr + 1e-10)
        S_mvdr[f] = (w_mvdr.conj().T @ Y_stft[:, f]).squeeze()

        # SMVB
        eigvals, eigvecs = torch.linalg.eigh(R_in[f])
        v_int = eigvecs[:, -1].reshape(2, 1)
        v_tgt = d
        Cmat = torch.cat([v_tgt, v_int], dim=1)

        if torch.linalg.cond(Cmat) < 10:
            w_smvb = torch.linalg.solve(Cmat.conj().T,
                                        torch.tensor([[1.0], [0.0]], device=DEVICE))
        else:
            w_smvb = v_tgt / 2.0

        S_smvb[f] = (w_smvb.conj().T @ Y_stft[:, f]).squeeze()

    # ---------- Post-filter + ISTFT ----------
    s_mvdr = torch.istft(S_mvdr * mask_t, N_FFT, N_HOP).cpu().numpy()
    s_smvb = torch.istft(S_smvb * mask_t, N_FFT, N_HOP).cpu().numpy()

    s_mvdr /= np.max(np.abs(s_mvdr)) + 1e-10
    s_smvb /= np.max(np.abs(s_smvb)) + 1e-10

    sf.write(f"{SAVE_DIR}/output_unified_mvdr.wav", s_mvdr.astype(np.float32), FS)
    sf.write(f"{SAVE_DIR}/output_unified_smvb.wav", s_smvb.astype(np.float32), FS)


# ============================================================
# MAIN BULK LOOP
# ============================================================
def main():
    with open(CSV_PATH, "w", newline="") as f:
        writer = csv.writer(f)
        writer.writerow(["sample_id", "mvdr_sir", "smvb_sir",
                          "mvdr_pesq", "smvb_pesq"])

        for i in range(N_SAMPLES):
            print(f"\n=== Sample {i} ===")

            # 1. World generation
            subprocess.run([
                "python", "world.py",
                "--no-reverb",
                "--dataset", "ljspeech",
                "--n", "1"
            ], check=True)

            # 2. GPU processing
            run_beamformers_gpu()

            # 3. Evaluation (MVDR)
            s_est, s_tgt, s_int, _, _ = load_and_align_signals(
                f"{SAVE_DIR}/output_unified_mvdr.wav", SAVE_DIR)
            _, mvdr_sir, _, _, _ = calculate_osnr_and_osir(s_est, s_tgt, s_int)
            mvdr_pesq, _ = calculate_pesq_metric(s_tgt, s_est, FS)

            # 4. Evaluation (SMVB)
            s_est, s_tgt, s_int, _, _ = load_and_align_signals(
                f"{SAVE_DIR}/output_unified_smvb.wav", SAVE_DIR)
            _, smvb_sir, _, _, _ = calculate_osnr_and_osir(s_est, s_tgt, s_int)
            smvb_pesq, _ = calculate_pesq_metric(s_tgt, s_est, FS)

            # 5. Log
            writer.writerow([i, mvdr_sir, smvb_sir, mvdr_pesq, smvb_pesq])
            print(f"Logged sample {i}")

    print("\nBulk comparison complete.")
    print(f"Results saved to {CSV_PATH}")


if __name__ == "__main__":
    main()


ModuleNotFoundError: No module named 'eval'

In [1]:
!pip install soundfile torch torchaudio

