In [None]:
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
import re, io, base64
import numpy as np
import librosa
import librosa.display as lbd
import matplotlib.pyplot as plt
from scipy.signal import ellip, filtfilt
import ipywidgets as W
from IPython.display import HTML, display
from pathlib import Path
import soundfile as sf

# ====== 基本配置 ======
ROOT = Path().resolve()           # 根目录（上文的 ROOT）
MODELS = ["TAW-foley", "TAW-foley-small", "T-foley", "Mamba-foley"]
FIXED_WIDTH_PX = 1000             # 图和 <audio> 的固定像素宽

# ====== STFT / RMS 参数 ======
N_FFT, HOP_LENGTH, WIN_LENGTH, WINDOW = 256, 128, 256, "hann"
RMS_FRAME, RMS_HOP = 512, 128
hz_limit, hz_step = 8000, 2000

# ---------- 工具函数 ----------
def zero_phased_filter(x):
    b, a = ellip(4, 0.01, 120, 0.125)
    return filtfilt(b, a, x, method="gust")

def get_rms(y, sr):
    rms = librosa.feature.rms(y=y, frame_length=RMS_FRAME, hop_length=RMS_HOP)[0]
    rms = zero_phased_filter(rms)
    t = librosa.frames_to_time(np.arange(len(rms)), sr=sr, hop_length=RMS_HOP)
    return t, rms

def load_audio(path, sr_target=None):
    y, sr = librosa.load(path, sr=sr_target, mono=True)
    peak = np.max(np.abs(y)) + 1e-12
    return y / peak, sr

def pad_or_trim_to_length(y, length):
    if len(y) == length: return y
    if len(y) > length:  return y[:length]
    z = np.zeros(length, dtype=y.dtype); z[:len(y)] = y; return z

def compute_spec_db(y, sr):
    D = librosa.stft(y, n_fft=N_FFT, hop_length=HOP_LENGTH,
                     win_length=WIN_LENGTH, window=WINDOW, center=True)
    S_db = librosa.amplitude_to_db(np.abs(D), ref=np.max, top_db=80)
    t = librosa.frames_to_time(np.arange(S_db.shape[1]), sr=sr, hop_length=HOP_LENGTH)
    f = librosa.fft_frequencies(sr=sr, n_fft=N_FFT)
    return S_db, t, f

# 固定像素宽图像 + 等宽 <audio>
def _audio_html(y, sr, width_px):
    buf = io.BytesIO(); sf.write(buf, y, sr, format='WAV')
    b64 = base64.b64encode(buf.getvalue()).decode('ascii')
    return HTML(f'<audio controls style="width:{width_px}px; max-width:none; outline:none;">'
                f'<source src="data:audio/wav;base64,{b64}" type="audio/wav"></audio>')

def _show_fig_fixed(fig, width_px):
    dpi = fig.dpi
    h_in = fig.get_size_inches()[1]
    fig.set_size_inches(width_px / dpi, h_in)
    buf = io.BytesIO()
    fig.savefig(buf, format="png", dpi=dpi, bbox_inches="tight")
    b64 = base64.b64encode(buf.getvalue()).decode("ascii")
    plt.close(fig)
    display(HTML(f'<img src="data:image/png;base64,{b64}" '
                 f'style="width:{width_px}px; max-width:none; display:block;">'))

# ---------- 行渲染（图 + 音频） ----------
def row_rms_with_audio(y_ref, sr_ref, t_rms, rms_ref, t_spec):
    # 根据持续时间动态设置时间刻度（整数秒）
    time_ticks = np.arange(0, int(np.ceil(t_spec[-1])) + 1, 1)
    out = W.Output()
    with out:
        fig, ax = plt.subplots(1, 1, figsize=(14, 2.0), constrained_layout=True)
        ax.plot(t_rms, rms_ref, linewidth=1.5, color="black")
        ax.set_title("Reference: RMS Energy", fontsize=12)
        ax.set_ylabel("RMS")
        ax.set_xlabel("Time (s)")
        ax.set_xlim(t_spec[0], t_spec[-1]); ax.set_xticks(time_ticks)
        _show_fig_fixed(fig, FIXED_WIDTH_PX)
        display(_audio_html(y_ref, sr_ref, FIXED_WIDTH_PX))
    return out

def row_spec_with_audio(title, Sdb, y, sr_ref, t_spec, f_spec, vmin, vmax):
    time_ticks = np.arange(0, int(np.ceil(t_spec[-1])) + 1, 1)
    hz_ticks   = np.arange(0, hz_limit + hz_step, hz_step)
    out = W.Output()
    with out:
        fig, ax = plt.subplots(1, 1, figsize=(14, 2.6), constrained_layout=True)
        lbd.specshow(Sdb, x_coords=t_spec, y_coords=f_spec, ax=ax,
                     cmap="magma", vmin=vmin, vmax=vmax)
        ax.set_ylabel("Freq (Hz)"); ax.set_ylim(0, hz_limit); ax.set_yticks(hz_ticks)
        ax.set_xlabel("Time (s)");  ax.set_xlim(t_spec[0], t_spec[-1]); ax.set_xticks(time_ticks)
        ax.set_title(title, fontsize=12)
        _show_fig_fixed(fig, FIXED_WIDTH_PX)
        display(_audio_html(y, sr_ref, FIXED_WIDTH_PX))
    return out

# ---------- 建立索引（类别 -> 样本 -> 各路径） ----------
def build_index(root: Path):
    """
    返回: dict[category] -> { sample_key -> {"ref":Path, model_name:Path or None} }
    """
    idx = {}

    def collect_category(cat_dir: Path, cat_name: str):
        rec = {}
        ref_dir = cat_dir / "ref"
        if not ref_dir.exists(): return
        for rf in sorted(ref_dir.glob("*.wav")):
            # 样本键：优先匹配 sample_xxx_labelY_ref.wav，否则去掉 _ref 后缀
            m = re.search(r"(sample_\d+_label\d+)_ref\.wav$", rf.name)
            key = m.group(1) if m else rf.stem.replace("_ref", "")
            rec[key] = {"ref": rf}
            gen_name = rf.name.replace("_ref.wav", "_gen.wav")
            for model in MODELS:
                p = cat_dir / model / gen_name
                rec[key][model] = p if p.exists() else None
        if rec: idx[cat_name] = rec

    # 多类别在子目录
    for sub in sorted([d for d in root.iterdir() if d.is_dir()]):
        if (sub / "ref").exists():
            collect_category(sub, sub.name)

    return idx

INDEX = build_index(ROOT)
if not INDEX:
    raise RuntimeError(f"在 {ROOT} 下没有发现有效的 ref/ 目录结构。请检查目录组织是否正确。")

# ---------- 交互控件 ----------
dd_cat   = W.Dropdown(options=sorted(INDEX.keys()), description="类别:", layout=W.Layout(width="300px"))
sel_samp = W.SelectMultiple(options=[], description="样本:", rows=8, layout=W.Layout(width="300px"))
btn      = W.Button(description="Render", button_style="primary")
out_area = W.Output()

def on_cat_change(change):
    cat = change["new"]
    if cat in INDEX:
        opts = sorted(INDEX[cat].keys())
        sel_samp.options = opts
        sel_samp.value = tuple(opts[:3])  # 默认选前3个，避免一次选择太多
dd_cat.observe(on_cat_change, names="value")
on_cat_change({"new": dd_cat.value})  # 初始化

def render_samples(_):
    out_area.clear_output(wait=True)
    cat = dd_cat.value
    samples = list(sel_samp.value)
    if not samples:
        with out_area: display(HTML("<b>请选择至少一个样本。</b>")); return

    accord_items = []
    for key in samples:
        paths = INDEX[cat][key]
        if not paths.get("ref"):  # 没有参考音频跳过
            continue

        # ---- 加载 & 统一长度 ----
        y_ref, sr_ref = load_audio(paths["ref"], sr_target=None)
        ys = []
        for m in MODELS:
            p = paths.get(m)
            if p and p.exists():
                y, _ = load_audio(p, sr_target=sr_ref)
                ys.append(y)
            else:
                ys.append(None)
        L = len(y_ref)
        ys = [pad_or_trim_to_length(y, L) if y is not None else None for y in ys]

        # ---- 特征与谱 ----
        t_rms, rms_ref = get_rms(y_ref, sr_ref)
        Sdb_ref, t_spec, f_spec = compute_spec_db(y_ref, sr_ref)
        Sdbs = [compute_spec_db(y, sr_ref)[0] if y is not None else None for y in ys]

        # 共享色阶（仅存在的谱参与）
        vmax = np.max([np.max(Sdb_ref)] + [np.max(S) for S in Sdbs if S is not None])
        vmin = vmax - 80.0

        # ---- 输出：RMS + 参考 + 4 模型 ----
        rows = []
        rows.append(row_rms_with_audio(y_ref, sr_ref, t_rms, rms_ref, t_spec))
        rows.append(row_spec_with_audio("Reference: STFT Spectrogram", Sdb_ref, y_ref, sr_ref, t_spec, f_spec, vmin, vmax))

        titles = ["TAW-Foley", "TAW-Foley-small", "T-Foley", "Mamba-Foley"]
        for title, Sdb, y in zip(titles, Sdbs, ys):
            if Sdb is None or y is None:
                miss = W.Output()
                with miss:
                    display(HTML(f"<i>{title}: 文件缺失，已跳过。</i>"))
                rows.append(miss)
            else:
                rows.append(row_spec_with_audio(f"{title}: STFT Spectrogram", Sdb, y, sr_ref, t_spec, f_spec, vmin, vmax))

        accord_items.append(W.VBox(rows, layout=W.Layout(border="1px solid #eee", padding="6px")))

    acc = W.Accordion(children=accord_items)
    for i, key in enumerate(samples):
        acc.set_title(i, f"{cat} · {key}")

    with out_area:
        display(acc)

btn.on_click(render_samples)

controls = W.HBox([W.VBox([dd_cat, sel_samp, btn])])
display(controls, out_area)


HBox(children=(VBox(children=(Dropdown(description='类别:', layout=Layout(width='300px'), options=('Label0', 'La…

Output()