In [None]:
from __future__ import annotations

import argparse
import json
import math
from dataclasses import dataclass
from typing import Iterable, Optional
from pathlib import Path

import numpy as np

# Third-party deps used:
# - soundfile (audio I/O)
# - scipy (signal processing)
# - pretty_midi (MIDI writing)
# These must be available in your environment.

### Parameters

In [None]:
from dataclasses import dataclass
from math import ceil


@dataclass(slots=True)
class Params:
    """Global parameters for the Sing→MIDI pipeline.

    All durations and sizes are validated; any derived values are computed on init.
    """

    # Audio
    sr: int = 48_000
    frame: int = 2048
    hop: int = 480

    # Conditioning
    hpf_cutoff: float = 90.0

    # Smoothing
    median_k: int = 7
    ma_k: int = 3

    # Segmentation / hysteresis
    min_note_ms: int = 70
    debounce_frames: int = 3
    cents_tolerance: float = 35.0

    # Key snapping
    use_key_snap: bool = True
    key_snap_cents: float = 40.0

    # Onsets
    onset_prom: float = 1.5

    # Derived
    min_note_frames: int = 0  # computed

    # YIN band (Hz)
    fmin_hz: float = 80.0
    fmax_hz: float = 1000.0

    # Voicing thresholds
    min_confidence: float = 0.2
    min_rms_dbfs: float = -50.0

    def __post_init__(self) -> None:
        if self.sr <= 0:
            raise ValueError("sr must be positive")
        if self.frame <= 0 or self.hop <= 0:
            raise ValueError("frame and hop must be positive")
        if self.median_k < 1 or self.median_k % 2 == 0:
            raise ValueError("median_k must be odd and >= 1")
        if self.ma_k < 1:
            raise ValueError("ma_k must be >= 1")
        if self.debounce_frames < 1:
            raise ValueError("debounce_frames must be >= 1")
        if not (0 < self.fmin_hz < self.fmax_hz):
            raise ValueError("fmin_hz must be < fmax_hz and > 0")
        # Derived
        self.min_note_frames = max(1, ceil((self.min_note_ms / 1000.0) * (self.sr / self.hop)))

## Setup

### Utilities: windows, stats, filters, smoothing, guards, peaks, timescale

In [None]:
def hann_window(frame: int) -> np.ndarray:
    """Return a Hann window of length `frame`.

    Uses periodic=False (symmetric Hann) suitable for STFT-style framing.
    """
    if frame <= 0:
        raise ValueError("frame must be positive")
    return np.hanning(frame)


def frame_audio(x: np.ndarray, frame: int, hop: int) -> np.ndarray:
    """Slice 1D audio into frames with hop, returning shape (n_frames, frame).

    Zero-pads the end so the last frame is complete.
    """
    if x.ndim != 1:
        raise ValueError("x must be 1D mono audio")
    n = len(x)
    if n == 0:
        return np.zeros((0, frame), dtype=x.dtype)
    n_frames = 1 + int(np.ceil((n - frame) / hop)) if n > frame else 1
    total = (n_frames - 1) * hop + frame
    pad = total - n
    if pad > 0:
        x = np.pad(x, (0, pad))
    strides = (x.strides[0] * hop, x.strides[0])
    shape = (n_frames, frame)
    return np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides).copy()


def frame_rms(frames: np.ndarray) -> np.ndarray:
    """Per-frame RMS for an array shaped (n_frames, frame_len)."""
    return np.sqrt(np.mean(frames.astype(np.float64) ** 2, axis=1))


def dbfs(val: float, eps: float = 1e-12) -> float:
    """Convert linear amplitude to dBFS-like scale with epsilon clamp."""
    return 20.0 * np.log10(max(val, eps))


def dc_blocker(x: np.ndarray, R: float = 0.995) -> np.ndarray:
    """Simple first-order DC blocker.

    y[n] = x[n] - x[n-1] + R * y[n-1]
    """
    y = np.zeros_like(x)
    xm1 = 0.0
    ym1 = 0.0
    for i, xi in enumerate(x):
        y[i] = xi - xm1 + R * ym1
        xm1 = xi
        ym1 = y[i]
    return y


def highpass_butter(x: np.ndarray, sr: int, cutoff_hz: float, order: int = 2) -> np.ndarray:
    """Zero-phase Butterworth high-pass filter.

    - `cutoff_hz` ≤ 0 returns a copy of `x`.
    - Uses `filtfilt` for zero-phase response and minimal distortion.
    """
    if cutoff_hz <= 0:
        return x.copy()
    from scipy.signal import butter, filtfilt  # local import to keep top-level light
    nyq = 0.5 * sr
    norm = cutoff_hz / nyq
    b, a = butter(order, norm, btype="highpass")
    return filtfilt(b, a, x, axis=-1)


def median_filter(x: np.ndarray, k: int) -> np.ndarray:
    """Apply a 1D median filter of size `k` (odd) to `x`.

    If `k<=1`, returns a copy of `x`. Even `k` is incremented to the next odd.
    """
    if k <= 1:
        return x.copy()
    if k % 2 == 0:
        k += 1
    from scipy.signal import medfilt  # local import
    return medfilt(x, kernel_size=k)


def moving_average(x: np.ndarray, k: int) -> np.ndarray:
    """Centered moving average of width `k` using same-length convolution."""
    if k <= 1:
        return x.copy()
    k = max(1, int(k))
    w = np.ones(k) / float(k)
    y = np.convolve(x, w, mode="same")
    return y


def band_clamp(x: np.ndarray, lo: float, hi: float) -> np.ndarray:
    """Replace values outside [lo, hi] with NaN, preserving in-band values."""
    out = x.copy()
    mask = (out < lo) | (out > hi)
    out[mask] = np.nan
    return out


def doubling_halving_guard(f0: np.ndarray, window: int = 5, cents_tol: float = 35.0) -> np.ndarray:
    """Correct x2/x0.5 octave jumps based on local median context."""
    out = f0.copy()
    n = len(out)
    if n == 0:
        return out
    for i in range(n):
        if not np.isfinite(out[i]) or out[i] <= 0:
            continue
        lo = max(0, i - window)
        hi = min(n, i + window + 1)
        local = out[lo:hi]
        local = local[np.isfinite(local) & (local > 0)]
        if local.size < 3:
            continue
        med = float(np.median(local))
        if med <= 0:
            continue
        r = out[i] / med
        if 1.8 <= r <= 2.2:
            cand = out[i] / 2.0
        elif 0.45 <= r <= 0.55:
            cand = out[i] * 2.0
        else:
            cand = out[i]
        if cand != out[i]:
            def cents(a: float, b: float) -> float:
                return 1200.0 * math.log2(a / b)

            err_old = abs(cents(out[i], med))
            err_new = abs(cents(cand, med))
            if err_new + 1e-6 < err_old and err_new <= cents_tol:
                out[i] = cand
    return out


def find_peaks(signal: np.ndarray, min_distance: int = 1, threshold: float = 0.0) -> np.ndarray:
    """Simple peak picker: local maxima above threshold with min distance."""
    x = signal
    n = len(x)
    if n < 3:
        return np.array([], dtype=int)
    greater_prev = x[1:-1] > x[:-2]
    greater_next = x[1:-1] >= x[2:]
    cand = np.where(greater_prev & greater_next)[0] + 1
    cand = cand[x[cand] >= threshold]
    if cand.size == 0:
        return cand
    order = np.argsort(x[cand])[::-1]
    selected: list[int] = []
    for idx in cand[order]:
        if all(abs(idx - s) >= min_distance for s in selected):
            selected.append(int(idx))
    selected.sort()
    return np.asarray(selected, dtype=int)


def frame_to_time(frame_idx: int, hop: int, sr: int) -> float:
    return (frame_idx * hop) / float(sr)


def time_to_frame(time_s: float, hop: int, sr: int) -> int:
    return int(round(time_s * sr / hop))

### DSP: f0 (YIN), stabilization, onsets, hysteresis

In [None]:
def f0_yin(
    y: np.ndarray,
    sr: int,
    frame: int,
    hop: int,
    fmin_hz: float,
    fmax_hz: float,
) -> tuple[np.ndarray, np.ndarray, np.ndarray]:
    """Estimate f0 using a light YIN implementation. Returns (f0_hz, conf, rms)."""
    frames = frame_audio(y, frame=frame, hop=hop).astype(np.float64)
    win = hann_window(frame)
    frames *= win[None, :]
    n_frames = frames.shape[0]
    f0 = np.full(n_frames, np.nan, dtype=float)
    rms = frame_rms(frames.astype(np.float64))

    tmin = int(np.floor(sr / fmax_hz))
    tmax = int(np.ceil(sr / fmin_hz))
    tmax = min(tmax, frame - 2)
    if tmax <= tmin:
        return f0, np.zeros_like(f0), rms.astype(float)

    for i in range(n_frames):
        x = frames[i]
        x = x - np.mean(x)
        d = _yin_difference(x, tmax)
        cmnd = _yin_cmnd(d)
        tau = _yin_absolute_threshold(cmnd, 0.1, tmin)
        if tau == -1:
            tau = int(np.argmin(cmnd[tmin:tmax]) + tmin)
        if 1 <= tau < len(cmnd) - 1:
            tau = _parabolic_interpolate(cmnd, tau)
        if tau > 0:
            f0[i] = sr / tau

    conf = np.zeros_like(f0)
    if rms.size > 0:
        max_r = float(np.max(rms)) + 1e-12
        conf = np.clip(rms / max_r, 0.0, 1.0)
        low_energy = rms < (1e-3 * max_r)
        f0[low_energy] = np.nan
    return f0.astype(float), conf.astype(float), rms.astype(float)


def _yin_difference(x: np.ndarray, tmax: int) -> np.ndarray:
    N = len(x)
    d = np.zeros(tmax, dtype=np.float64)
    for tau in range(1, tmax):
        diff = x[:-tau] - x[tau:]
        d[tau] = np.dot(diff, diff)
    d[0] = 0.0
    return d


def _yin_cmnd(d: np.ndarray) -> np.ndarray:
    cmnd = np.zeros_like(d)
    cmnd[0] = 1.0
    running_sum = 0.0
    for tau in range(1, len(d)):
        running_sum += d[tau]
        cmnd[tau] = d[tau] * tau / (running_sum + 1e-12)
    return cmnd


def _yin_absolute_threshold(cmnd: np.ndarray, thresh: float, start: int) -> int:
    for tau in range(max(2, start), len(cmnd)):
        if cmnd[tau] < thresh:
            while tau + 1 < len(cmnd) and cmnd[tau + 1] < cmnd[tau]:
                tau += 1
            return tau
    return -1


def _parabolic_interpolate(y: np.ndarray, x: int) -> float:
    xm1, x0, xp1 = y[x - 1], y[x], y[x + 1]
    denom = (xm1 - 2 * x0 + xp1)
    if abs(denom) < 1e-12:
        return float(x)
    delta = 0.5 * (xm1 - xp1) / denom
    return float(x) + delta


def forward_fill(x: np.ndarray) -> np.ndarray:
    out = x.copy()
    last = np.nan
    for i, v in enumerate(out):
        if np.isfinite(v):
            last = v
        else:
            out[i] = last
    if not np.isfinite(out[0]):
        next_val = np.nan
        for i in range(len(out) - 1, -1, -1):
            if np.isfinite(out[i]):
                next_val = out[i]
            else:
                out[i] = next_val
    return out


def stabilize_f0(
    f0_hz: np.ndarray,
    fmin_hz: float,
    fmax_hz: float,
    median_k: int,
    ma_k: int,
    guard_window: int = 5,
    cents_tol: float = 35.0,
) -> np.ndarray:
    x = band_clamp(f0_hz, fmin_hz, fmax_hz)
    x = median_filter(np.nan_to_num(x, nan=0.0), k=median_k)
    x[x <= 0] = np.nan
    if ma_k > 1:
        xf = forward_fill(x)
        xa = moving_average(xf, k=ma_k)
        x = np.where(np.isfinite(x), xa, x)
    x = doubling_halving_guard(x, window=guard_window, cents_tol=cents_tol)
    return x


def spectral_flux_onsets(
    y: np.ndarray,
    sr: int,
    frame: int,
    hop: int,
    prom: float = 1.5,
    min_distance_frames: int = 2,
) -> np.ndarray:
    frames = frame_audio(y, frame=frame, hop=hop).astype(float)
    win = hann_window(frame)
    frames *= win[None, :]
    n_fft = int(2 ** np.ceil(np.log2(frame)))
    mag = np.abs(np.fft.rfft(frames, n=n_fft))
    diff = np.maximum(0.0, mag[1:, :] - mag[:-1, :])
    flux = np.sum(diff, axis=1)
    if flux.size == 0:
        return np.array([], dtype=int)
    thr = np.median(flux) * prom
    peaks = find_peaks(flux, min_distance=min_distance_frames, threshold=thr)
    return peaks + 1  # shift by 1 frame due to diff


def hysteresis_round(
    midi_float: np.ndarray,
    debounce_frames: int,
    cents_tolerance: float,
) -> np.ndarray:
    out = np.full_like(midi_float, np.nan, dtype=float)
    current: Optional[int] = None
    streak_note: Optional[int] = None
    streak_len = 0

    def cents_err(nf: float, ni: int) -> float:
        return 100.0 * abs(nf - float(ni))

    for i, nf in enumerate(midi_float):
        if not np.isfinite(nf):
            out[i] = np.nan
            streak_note = None
            streak_len = 0
            continue
        cand = int(np.round(nf))
        if current is None:
            if cents_err(nf, cand) <= cents_tolerance:
                streak_note = cand if streak_note is None else streak_note
                if streak_note == cand:
                    streak_len += 1
                else:
                    streak_note = cand
                    streak_len = 1
                if streak_len >= debounce_frames:
                    current = cand
            out[i] = np.nan if current is None else float(current)
            continue

        if cand == current or cents_err(nf, current) <= cents_tolerance:
            streak_note = None
            streak_len = 0
            out[i] = float(current)
            continue

        if cents_err(nf, cand) <= cents_tolerance:
            if streak_note == cand:
                streak_len += 1
            else:
                streak_note = cand
                streak_len = 1
            if streak_len >= debounce_frames:
                current = cand
                streak_note = None
                streak_len = 0
        else:
            streak_note = None
            streak_len = 0
        out[i] = float(current)
    return out


def segments_from_notes(
    note_series: np.ndarray,
    min_note_frames: int,
) -> list[tuple[int, int, int]]:
    n = len(note_series)
    segments: list[tuple[int, int, int]] = []
    i = 0
    while i < n:
        if not np.isfinite(note_series[i]):
            i += 1
            continue
        midi = int(note_series[i])
        j = i + 1
        while j < n and np.isfinite(note_series[j]) and int(note_series[j]) == midi:
            j += 1
        start, end = i, j
        segments.append((midi, start, end))
        i = j

    merged: list[tuple[int, int, int]] = []
    for seg in segments:
        if not merged:
            merged.append(seg)
            continue
        midi, s, e = seg
        prev_m, ps, pe = merged[-1]
        if e - s < min_note_frames:
            if midi == prev_m:
                merged[-1] = (prev_m, ps, e)
                continue
        merged.append(seg)

    final: list[tuple[int, int, int]] = []
    for idx, (m, s, e) in enumerate(merged):
        if e - s >= min_note_frames:
            final.append((m, s, e))
            continue
        attached = False
        if idx > 0 and merged[idx - 1][0] == m:
            pm, ps, pe = final[-1]
            final[-1] = (pm, ps, e)
            attached = True
        elif idx + 1 < len(merged) and merged[idx + 1][0] == m:
            nm, ns, ne = merged[idx + 1]
            final.append((m, s, ne))
            merged[idx + 1] = (nm, ne, ne)
            attached = True
        if not attached:
            pass
    return final

### Music helpers: Hz↔MIDI, basic key detection/snapping

In [None]:
def hz_to_midi(f_hz: np.ndarray | float) -> np.ndarray | float:
    """Convert frequency in Hz to MIDI number (float)."""
    if isinstance(f_hz, np.ndarray):
        out = 69.0 + 12.0 * np.log2(np.asarray(f_hz) / 440.0)
        out[~np.isfinite(out)] = np.nan
        return out
    if not math.isfinite(f_hz) or f_hz <= 0:
        return float("nan")
    return 69.0 + 12.0 * math.log2(f_hz / 440.0)


def midi_to_hz(n: float) -> float:
    return 440.0 * (2.0 ** ((n - 69.0) / 12.0))


@dataclass(slots=True)
class KeyEstimate:
    tonic: int  # 0..11 (C=0)
    mode: str   # 'major' or 'minor'
    score: float


MAJOR_KS = np.array([6.35, 2.23, 3.48, 2.33, 4.38, 4.09, 2.52, 5.19, 2.39, 3.66, 2.29, 2.88])
MINOR_KS = np.array([6.33, 2.68, 3.52, 5.38, 2.60, 3.53, 2.54, 4.75, 3.98, 2.69, 3.34, 3.17])

KEY_NAME_TO_INT = {
    "C": 0,
    "C#": 1,
    "DB": 1,
    "D": 2,
    "D#": 3,
    "EB": 3,
    "E": 4,
    "F": 5,
    "F#": 6,
    "GB": 6,
    "G": 7,
    "G#": 8,
    "AB": 8,
    "A": 9,
    "A#": 10,
    "BB": 10,
    "B": 11,
}


def pitch_class_hist(midi_notes: Iterable[int]) -> np.ndarray:
    pcs = np.zeros(12, dtype=float)
    for n in midi_notes:
        pcs[int(n) % 12] += 1.0
    if pcs.sum() > 0:
        pcs /= pcs.sum()
    return pcs


def estimate_key(midi_notes: Iterable[int]) -> KeyEstimate | None:
    notes = list(midi_notes)
    if not notes:
        return None
    hist = pitch_class_hist(notes)
    best_score = -1.0
    best: KeyEstimate | None = None
    for tonic in range(12):
        maj_t = np.roll(MAJOR_KS, tonic)
        min_t = np.roll(MINOR_KS, tonic)
        smaj = float(np.dot(hist, maj_t))
        smin = float(np.dot(hist, min_t))
        if smaj > best_score:
            best_score = smaj
            best = KeyEstimate(tonic=tonic, mode="major", score=smaj)
        if smin > best_score:
            best_score = smin
            best = KeyEstimate(tonic=tonic, mode="minor", score=smin)
    return best


def snap_to_key(n: int, key: KeyEstimate) -> int:
    if key.mode == "major":
        scale = {0, 2, 4, 5, 7, 9, 11}
    else:
        scale = {0, 2, 3, 5, 7, 8, 10}
    pc = n % 12
    if pc in scale:
        return n
    candidates = []
    for off in range(-2, 3):
        pc2 = (pc + off) % 12
        if pc2 in scale:
            candidates.append((abs(off), n + off))
    if not candidates:
        return n
    candidates.sort()
    return candidates[0][1]


def key_override_to_estimate(value: object) -> KeyEstimate | None:
    if isinstance(value, KeyEstimate):
        return value
    tonic_name: str | None = None
    mode_name: str | None = None
    if isinstance(value, tuple) and len(value) == 2:
        tonic_name, mode_name = value
    elif isinstance(value, str):
        parts = value.replace("-", " ").replace("\t", " ").split()
        if len(parts) >= 2:
            tonic_name, mode_name = parts[0], parts[1]
    if tonic_name is None or mode_name is None:
        return None
    tonic_key = KEY_NAME_TO_INT.get(tonic_name.upper())
    if tonic_key is None:
        return None
    mode_norm = mode_name.strip().lower()
    if mode_norm in {"maj", "major", "ionian"}:
        mode = "major"
    elif mode_norm in {"min", "minor", "aeolian"}:
        mode = "minor"
    else:
        return None
    return KeyEstimate(tonic=tonic_key, mode=mode, score=float("inf"))

### IO: audio load/condition, MIDI writer

In [None]:
from dataclasses import dataclass as _dataclass_for_audio


@_dataclass_for_audio(slots=True)
class Audio:
    y: np.ndarray  # mono float32
    sr: int


def load_mono(path: str, sr: int) -> Audio:
    """Load audio file, resample to `sr`, and return mono float32 array in [-1,1]."""
    import soundfile as sf
    from scipy.signal import resample_poly
    y, file_sr = sf.read(path, always_2d=True)
    y = y.astype(np.float32)
    y = np.mean(y, axis=1)
    if file_sr != sr:
        from math import gcd
        g = gcd(file_sr, sr)
        up = sr // g
        down = file_sr // g
        y = resample_poly(y, up, down)
    return Audio(y=y.astype(np.float32, copy=False), sr=sr)


def save_wav(path: str, audio: Audio) -> None:
    import soundfile as sf
    sf.write(path, audio.y, audio.sr, subtype="PCM_16")


def normalize_rms(y: np.ndarray, target_dbfs: float = -20.0, eps: float = 1e-9) -> np.ndarray:
    rms = float(np.sqrt(np.mean(np.square(y), dtype=np.float64)))
    if rms < eps:
        return y.copy()
    current_dbfs = 20.0 * np.log10(max(rms, eps))
    gain_db = target_dbfs - current_dbfs
    gain = 10.0 ** (gain_db / 20.0)
    out = y * gain
    max_abs = float(np.max(np.abs(out)))
    if max_abs > 1.0:
        out /= max_abs
    return out.astype(np.float32, copy=False)


def condition(y: np.ndarray, sr: int, hpf_cutoff: float) -> np.ndarray:
    y = dc_blocker(y)
    y = highpass_butter(y, sr=sr, cutoff_hz=hpf_cutoff, order=2)
    return y.astype(np.float32, copy=False)


def velocity_from_rms(rms: float, rms_min: float, rms_max: float) -> int:
    lo, hi = 40, 100
    if rms_max <= rms_min:
        return int((lo + hi) // 2)
    t = (rms - rms_min) / (rms_max - rms_min)
    v = int(np.clip(lo + t * (hi - lo), lo, hi))
    return v


def write_midi(
    path: str,
    segments: list[tuple[int, float, float, float]],
    program: int = 0,
    tempo_bpm: float | None = None,
) -> None:
    import pretty_midi
    initial_tempo = float(tempo_bpm) if tempo_bpm else 120.0
    pm = pretty_midi.PrettyMIDI(initial_tempo=initial_tempo)
    inst = pretty_midi.Instrument(program=program)
    if segments:
        rms_vals = [s[3] for s in segments]
        rmin, rmax = min(rms_vals), max(rms_vals)
    else:
        rmin = rmax = 0.0
    for midi, s, e, rms in segments:
        vel = velocity_from_rms(rms, rmin, rmax)
        note = pretty_midi.Note(start=float(s), end=float(e), pitch=int(midi), velocity=int(vel))
        inst.notes.append(note)
    pm.instruments.append(inst)
    pm.write(path)


def voice_mask_from_confidence(
    f0_hz: np.ndarray,
    voiced_conf: np.ndarray | None,
    frame_rms_vals: np.ndarray,
    min_confidence: float,
    min_rms_dbfs: float,
) -> np.ndarray:
    eps = 1e-12
    ref = max(np.max(frame_rms_vals), eps)
    db = 20.0 * np.log10(np.clip(frame_rms_vals / ref, eps, None))
    rms_ok = db >= min_rms_dbfs
    if voiced_conf is None:
        conf_ok = np.isfinite(f0_hz)
    else:
        conf_ok = voiced_conf >= min_confidence
    return rms_ok & conf_ok & np.isfinite(f0_hz) & (f0_hz > 0)

## Pipeline (full)

### Pipeline: step-by-step functions and a convenience runner

In [None]:
def extract_pipeline_features(y: np.ndarray, params: Params) -> dict:
    """Run all core steps and return intermediates for inspection."""
    sr = params.sr
    # f0 + confidence + frame rms
    f0_hz, conf, rms_vals = f0_yin(
        y=y,
        sr=sr,
        frame=params.frame,
        hop=params.hop,
        fmin_hz=params.fmin_hz,
        fmax_hz=params.fmax_hz,
    )
    # voice mask
    voiced_mask = voice_mask_from_confidence(
        f0_hz=f0_hz,
        voiced_conf=conf,
        frame_rms_vals=rms_vals,
        min_confidence=params.min_confidence,
        min_rms_dbfs=params.min_rms_dbfs,
    )
    f0_hz_voiced = f0_hz.copy()
    f0_hz_voiced[~voiced_mask] = np.nan

    # stabilized f0
    f0_stable = stabilize_f0(
        f0_hz=f0_hz_voiced,
        fmin_hz=params.fmin_hz,
        fmax_hz=params.fmax_hz,
        median_k=params.median_k,
        ma_k=params.ma_k,
        guard_window=5,
        cents_tol=params.cents_tolerance,
    )

    # onsets (optional for debug)
    onsets = spectral_flux_onsets(
        y=y,
        sr=sr,
        frame=params.frame,
        hop=params.hop,
        prom=params.onset_prom,
    )

    # map to MIDI float
    midi_float = hz_to_midi(f0_stable)

    # hysteresis rounding and frame segments
    note_series = hysteresis_round(midi_float, params.debounce_frames, params.cents_tolerance)
    seg_idx = segments_from_notes(note_series, min_note_frames=params.min_note_frames)

    # key estimate
    midi_list = [m for (m, s, e) in seg_idx]
    key = estimate_key(midi_list) if (params.use_key_snap and midi_list) else None

    # frame RMS for velocity
    frames = frame_audio(y, frame=params.frame, hop=params.hop)
    win = hann_window(params.frame)
    frames_win = frames * win[None, :]
    rms_frames = frame_rms(frames_win)

    return {
        "f0_hz": f0_hz,
        "conf": conf,
        "frame_rms": rms_vals,
        "voiced_mask": voiced_mask,
        "f0_stable": f0_stable,
        "onsets": onsets,
        "midi_float": midi_float,
        "note_series": note_series,
        "seg_idx": seg_idx,
        "key": key,
        "rms_frames": rms_frames,
    }


def segments_to_timed(
    seg_idx: list[tuple[int, int, int]],
    rms_frames: np.ndarray,
    params: Params,
) -> list[tuple[int, float, float, float]]:
    segments: list[tuple[int, float, float, float]] = []
    for midi, s_idx, e_idx in seg_idx:
        midi_eff = midi
        start_s = frame_to_time(s_idx, hop=params.hop, sr=params.sr)
        end_s = frame_to_time(e_idx, hop=params.hop, sr=params.sr)
        if e_idx > s_idx and e_idx <= len(rms_frames):
            rms_med = float(np.median(rms_frames[s_idx:e_idx]))
        else:
            rms_med = 0.0
        segments.append((int(midi_eff), float(start_s), float(end_s), float(rms_med)))
    return segments


def apply_key_snap(segments: list[tuple[int, float, float, float]], key: KeyEstimate | None) -> list[tuple[int, float, float, float]]:
    if key is None:
        return segments
    out = []
    for midi, s, e, rms in segments:
        out.append((snap_to_key(midi, key), s, e, rms))
    return out


def wav_to_midi_simple(in_wav: str, out_mid: str, params: Params) -> dict:
    """Convenience: full WAV→MIDI run, returns intermediates plus output path."""
    audio = load_mono(in_wav, sr=params.sr)
    y0 = audio.y
    y = condition(y0, sr=audio.sr, hpf_cutoff=params.hpf_cutoff)
    y = normalize_rms(y)

    feats = extract_pipeline_features(y, params)
    seg_idx = feats["seg_idx"]
    key = feats["key"]
    segments = segments_to_timed(seg_idx, feats["rms_frames"], params)
    if key is not None and params.use_key_snap:
        segments = apply_key_snap(segments, key)
    write_midi(out_mid, segments)
    feats["segments"] = segments
    feats["out_mid"] = out_mid
    feats["audio"] = y
    feats["sr"] = params.sr
    return feats

### Demo Run: CLI summary

In [None]:
def _print_summary(feats: dict, params: Params) -> None:
    f0 = feats["f0_hz"]
    conf = feats["conf"]
    rms_vals = feats["frame_rms"]
    voiced = feats["voiced_mask"]
    f0_stable = feats["f0_stable"]
    onsets = feats["onsets"]
    seg_idx = feats["seg_idx"]
    key = feats["key"]
    segments = feats.get("segments", [])

    n_frames = len(f0)
    n_voiced = int(np.sum(voiced))
    voiced_pct = 100.0 * n_voiced / max(1, n_frames)

    def mednan(a: np.ndarray) -> float:
        a = a[np.isfinite(a)]
        return float(np.median(a)) if a.size else float("nan")

    print("Frames:", n_frames, " Hop:", params.hop, " Frame:", params.frame, " SR:", params.sr)
    print("RMS median:", round(float(np.median(rms_vals)), 6))
    print("Voiced frames:", n_voiced, f"({voiced_pct:.1f}%)")
    print("f0 median (raw):", round(mednan(f0), 2), "Hz")
    print("f0 median (stable):", round(mednan(f0_stable), 2), "Hz")
    print("Onsets detected:", int(len(onsets)))
    print("Frame segments:", int(len(seg_idx)))
    if key is not None:
        tonic_names = ["C","C#","D","D#","E","F","F#","G","G#","A","A#","B"]
        print("Estimated key:", tonic_names[key.tonic], key.mode, f"(score={key.score:.3f})")
    print("Notes exported:", int(len(segments)))
    if segments:
        print("First 5 notes (midi, start, end, rms):")
        for row in segments[:5]:
            print("  ", row)


# if __name__ == "__main__":
#     # Simple CLI for quick runs; you can also run cells interactively.
#     cli = argparse.ArgumentParser(description="Sing→MIDI simple runner (notebook-style)")
#     cli.add_argument("input", type=str, help="Input WAV path")
#     cli.add_argument("--out", type=str, default=None, help="Output MIDI path (default: input.mid)")
#     cli.add_argument("--sr", type=int, default=48000)
#     args = cli.parse_args()
#     params = Params(sr=args.sr)
#     in_wav = args.input
#     out_mid = args.out or str(Path(in_wav).with_suffix(".mid"))
#     feats = wav_to_midi_simple(in_wav, out_mid, params)
#     print(f"Wrote: {Path(out_mid).resolve()}")
#     _print_summary(feats, params)

## Pipeline (states)

### Quick Start (Notebook cell): set INPUT_WAV and run

In [None]:
datasets = {
    "audio": {
        "example": "examples/example.mp3",
    },
    "midi_gt": {
        "example": "examples/example.midi",
    }
}

In [None]:
# Example interactive usage:
# - Set INPUT_WAV to your audio file path.
# - Optionally set OUT_MIDI or leave as None to default to INPUT_WAV with .mid.
# - Tweak Params(), re-run this cell to iterate.
#

selection = "example"

INPUT_WAV = datasets["audio"][selection]

OUT_MIDI: str | None = None
params = Params()

GT_MIDI: str | None = None
GT_MIDI = datasets["midi_gt"].get(selection, None)

# if INPUT_WAV:
#     out_mid = OUT_MIDI or str(Path(INPUT_WAV).with_suffix(".mid"))
#     feats = wav_to_midi_simple(INPUT_WAV, out_mid, params)
#     print(f"Wrote: {Path(out_mid).resolve()}")
#     _print_summary(feats, params)

if INPUT_WAV: print(Path(INPUT_WAV).exists())
if GT_MIDI: print(Path(GT_MIDI).exists())

### Optional Overrides: Key / Tempo

In [None]:
KEY_OVERRIDE = ("C", "major")  # e.g., ("E", "minor") or "C major"
TEMPO_BPM = TEMPO_BPM if 'TEMPO_BPM' in globals() else None  # e.g., 96.0

### Plotting Helpers

In [None]:
# Minimal plotting functions with graceful fallback if matplotlib isn't installed.
try:
    import matplotlib.pyplot as plt  # type: ignore
    _HAVE_MPL = True
except Exception:  # pragma: no cover
    plt = None
    _HAVE_MPL = False


def _head_finite(a: np.ndarray, k: int = 10) -> list:
    if not isinstance(a, np.ndarray):
        return []
    af = a[np.isfinite(a)]
    return [float(x) for x in af[:k]]


def _tonic_name(i: int) -> str:
    names = ["C","C#","D","D#","E","F","F#","G","G#","A","A#","B"]
    return names[int(i) % 12]


def _time_axis(n_frames: int, hop: int, sr: int) -> np.ndarray:
    return (np.arange(n_frames) * hop) / float(sr)


def plot_wave(y: np.ndarray, sr: int, title: str = "wave") -> None:
    if not _HAVE_MPL:
        return
    t = np.arange(len(y)) / float(sr)
    plt.figure(figsize=(10, 2.5))
    plt.plot(t, y, lw=0.8)
    plt.title(title)
    plt.xlabel("Time (s)")
    plt.ylabel("Amp")
    plt.tight_layout()
    plt.show()


def plot_track(t: np.ndarray, y: np.ndarray, title: str, ylabel: str) -> None:
    if not _HAVE_MPL:
        return
    plt.figure(figsize=(10, 3))
    plt.plot(t, y, lw=0.8)
    plt.title(title)
    plt.xlabel("Time (s)")
    plt.ylabel(ylabel)
    plt.tight_layout()
    plt.show()


def plot_two_tracks(t: np.ndarray, a: np.ndarray, b: np.ndarray, la: str, lb: str, title: str, ylabel: str) -> None:
    if not _HAVE_MPL:
        return
    plt.figure(figsize=(10, 3))
    plt.plot(t, a, lw=0.8, label=la)
    plt.plot(t, b, lw=0.8, label=lb)
    plt.legend()
    plt.title(title)
    plt.xlabel("Time (s)")
    plt.ylabel(ylabel)
    plt.tight_layout()
    plt.show()


def plot_onsets(t: np.ndarray, onsets_idx: np.ndarray, title: str = "onsets") -> None:
    if not _HAVE_MPL:
        return
    plt.figure(figsize=(10, 2))
    for i in onsets_idx[:200]:  # cap drawn lines
        ts = float(t[int(i)]) if int(i) < len(t) else None
        if ts is not None:
            plt.axvline(ts, color="tomato", alpha=0.6, lw=0.8)
    plt.title(title)
    plt.xlabel("Time (s)")
    plt.yticks([])
    plt.tight_layout()
    plt.show()


def plot_segments(segments: list[tuple[int, float, float, float]], title: str = "segments") -> None:
    if not _HAVE_MPL:
        return
    plt.figure(figsize=(10, 2.5))
    for midi, s, e, _ in segments:
        plt.hlines(y=midi, xmin=s, xmax=e, colors="royalblue", lw=3)
    plt.title(title)
    plt.xlabel("Time (s)")
    plt.ylabel("MIDI")
    plt.tight_layout()
    plt.show()


def plot_segments_compare(pred, gt, title: str = "Pred vs GT (notes)") -> None:
    """Plot predicted note segments vs ground truth for visual comparison.

    pred: list of (midi, start, end[, rms]) or PrettyMIDI-like Note objects
    gt:   list of _EvalNote or (start, end, midi)
    """
    if not _HAVE_MPL:
        return
    def _to_triples(seq):
        out = []
        for x in seq:
            if isinstance(x, tuple) or isinstance(x, list):
                if len(x) >= 3:
                    # allow (midi, start, end[, rms]) OR (start, end, midi)
                    a, b, c = x[0], x[1], x[2]
                    # Heuristic: if first looks like time (< 2000) but second < 2000 too, assume (midi,start,end)
                    if isinstance(a, (int, float)) and isinstance(b, (int, float)) and isinstance(c, (int, float)):
                        if a > 20 and a < 140 and b < c:  # likely (midi,start,end)
                            out.append((int(a), float(b), float(c)))
                        else:  # assume (start,end,midi)
                            out.append((int(c), float(a), float(b)))
                else:
                    continue
            else:
                # object with attributes start,end,midi
                try:
                    out.append((int(x.midi), float(x.start), float(x.end)))
                except Exception:
                    pass
        return out

    pred_tri = _to_triples(pred)
    gt_tri = _to_triples(gt)
    fig, axes = plt.subplots(3, 1, figsize=(10, 4.5), sharex=True)
    axes[0].set_title(title)
    
    # Prediction
    for m, s, e in pred_tri:
        axes[1].hlines(y=m, xmin=s, xmax=e, colors="royalblue", lw=3, label=None)
        axes[2].hlines(y=m, xmin=s, xmax=e, colors="royalblue", lw=3, label=None)
    axes[1].set_ylabel("Pred MIDI")
    axes[1].set_xlabel("Time (s)")
    
    # Ground truth
    for m, s, e in gt_tri:
        axes[0].hlines(y=m, xmin=s, xmax=e, colors="seagreen", lw=3, label=None)
        axes[2].hlines(y=m, xmin=s, xmax=e, colors="seagreen", lw=3, label=None)
    axes[0].set_ylabel("GT MIDI")

    
    plt.tight_layout()
    plt.show()

### Notebook Variables (shared across stages)

### Stage 0: Load → Condition → Normalize

In [None]:
if INPUT_WAV:
    audio = load_mono(INPUT_WAV, sr=params.sr)
    y0 = audio.y
    yc = condition(y0, sr=audio.sr, hpf_cutoff=params.hpf_cutoff)
    yn = normalize_rms(yc)
    dur_s = len(y0) / params.sr if len(y0) else 0.0
    rms0 = float(np.sqrt(np.mean(y0.astype(np.float64)**2))) if y0.size else 0.0
    rms1 = float(np.sqrt(np.mean(yn.astype(np.float64)**2))) if yn.size else 0.0
    print(f"Loaded mono @ {params.sr} Hz, duration {dur_s:.2f} s, RMS {rms0:.6f} → post {rms1:.6f}")
    plot_wave(y0, params.sr, title="Wave (raw)")
    plot_wave(yn, params.sr, title="Wave (conditioned + normalized)")

### Stage 1: f0 (YIN) + Confidence + Frame RMS

In [None]:
if 'yn' in globals():
    f0_hz, conf, rms_frames = f0_yin(
        y=yn, sr=params.sr, frame=params.frame, hop=params.hop,
        fmin_hz=params.fmin_hz, fmax_hz=params.fmax_hz,
    )
    print("Frames:", len(f0_hz), " f0 head(Hz):", [round(x,2) for x in _head_finite(f0_hz)])
    print("conf head:", [round(x,3) for x in _head_finite(conf)])
    t = _time_axis(len(f0_hz), params.hop, params.sr)
    plot_track(t, f0_hz, title="f0 (raw)", ylabel="Hz")

### Stage 2: Voicing Mask

In [None]:
if 'f0_hz' in globals() and 'conf' in globals() and 'rms_frames' in globals():
    voiced_mask = voice_mask_from_confidence(
        f0_hz=f0_hz, voiced_conf=conf, frame_rms_vals=rms_frames,
        min_confidence=params.min_confidence, min_rms_dbfs=params.min_rms_dbfs,
    )
    f0_voiced = f0_hz.copy(); f0_voiced[~voiced_mask] = np.nan
    n_frames = len(voiced_mask)
    n_voiced = int(np.sum(voiced_mask))
    print(f"Voiced frames: {n_voiced}/{n_frames} ({(100*n_voiced/max(1,n_frames)):.1f}%)")
    t = _time_axis(len(f0_hz), params.hop, params.sr)
    plot_track(t, f0_voiced, title="f0 (Masked)", ylabel="Hz")

### Stage 3: Stabilize f0

In [None]:
if 'f0_voiced' in globals():
    f0_stable = stabilize_f0(
        f0_hz=f0_voiced, fmin_hz=params.fmin_hz, fmax_hz=params.fmax_hz,
        median_k=params.median_k, ma_k=params.ma_k, guard_window=5, cents_tol=params.cents_tolerance,
    )
    med_f0 = float(np.nanmedian(f0_stable)) if np.isfinite(f0_stable).any() else float('nan')
    print("f0_stable median:", round(med_f0, 2), "Hz", " head:", [round(x,2) for x in _head_finite(f0_stable)])
    t = _time_axis(len(f0_stable), params.hop, params.sr)
    plot_two_tracks(t, f0_hz, f0_stable, la="raw", lb="stable", title="f0: raw vs stable", ylabel="Hz")

### Stage 4: Onsets (Spectral Flux)

In [None]:
if 'yn' in globals():
    onsets = spectral_flux_onsets(y=yn, sr=params.sr, frame=params.frame, hop=params.hop, prom=params.onset_prom)
    onsets_s = [round(frame_to_time(int(i), params.hop, params.sr), 3) for i in onsets[:10]]
    print("Onsets count:", len(onsets), " first(s):", onsets_s)
    t = _time_axis(len(f0_hz) if 'f0_hz' in globals() else len(onsets), params.hop, params.sr)
    plot_onsets(t, onsets, title="Onsets (vertical lines)")

### Stage 5: Map to MIDI Float

In [None]:
if 'f0_stable' in globals():
    midi_float = hz_to_midi(f0_stable)
    print("midi_float head:", [round(x,2) for x in _head_finite(midi_float)])
    t = _time_axis(len(midi_float), params.hop, params.sr)
    plot_track(t, midi_float, title="MIDI (float)", ylabel="MIDI")

### Stage 6: Hysteresis Rounding + Segments

In [None]:
if 'midi_float' in globals():
    note_series = hysteresis_round(midi_float, params.debounce_frames, params.cents_tolerance)
    seg_idx = segments_from_notes(note_series, min_note_frames=params.min_note_frames)
    print("Segments:", len(seg_idx), " first:", seg_idx[:5])
    t = _time_axis(len(note_series), params.hop, params.sr)
    plot_track(t, note_series, title="MIDI (integer with NaNs)", ylabel="MIDI int")

### Stage 7: Key Detection (Optional)

In [None]:
if 'seg_idx' in globals():
    midi_list = [m for (m, s, e) in seg_idx]
    key = estimate_key(midi_list) if (params.use_key_snap and midi_list) else None
    override_used = False
    if KEY_OVERRIDE:
        key_override_est = key_override_to_estimate(KEY_OVERRIDE)
        if key_override_est is not None:
            key = key_override_est
            override_used = True
            print(f"Key override applied: {_tonic_name(key.tonic)} {key.mode}")
        else:
            print("Key override invalid, falling back to detected key.")
    if key is None:
        print("Key: None / skipped")
    else:
        if override_used:
            print("Key (override):", _tonic_name(key.tonic), key.mode)
        else:
            print(f"Key: {_tonic_name(key.tonic)} {key.mode} (score={key.score:.3f})")

### Stage 8: Timed Segments + Write MIDI

In [None]:
if 'seg_idx' in globals() and 'rms_frames' in globals():
    segments = segments_to_timed(seg_idx, rms_frames, params)
    if 'key' in globals() and key is not None and params.use_key_snap:
        segments = apply_key_snap(segments, key)
    
    p = Path(INPUT_WAV)
    midi_path = p.with_name(p.stem + "_predicted").with_suffix(".mid")
    out_mid = OUT_MIDI if OUT_MIDI else str(midi_path)
    write_midi(out_mid, segments, tempo_bpm=TEMPO_BPM)
    print(f"Wrote: {Path(out_mid).resolve()}")
    
    if TEMPO_BPM:
        print(f"Tempo (override): {float(TEMPO_BPM):.2f} BPM")
    print("First 5 notes:")
    for row in segments[:5]:
        print("  ", row)
    plot_segments(segments, title="Exported notes")

### Listen: Audio (optional)

In [None]:
try:
    from IPython.display import Audio as IPyAudio, display as ipy_display
    _HAVE_IPY = True
except Exception:  # pragma: no cover
    _HAVE_IPY = False

if _HAVE_IPY:
    if 'y0' in globals():
        print("Raw audio (y0): the original loaded signal")
        ipy_display(IPyAudio(data=y0, rate=params.sr))  # type: ignore
    if 'yc' in globals():
        print("Conditioned (yc): after DC blocker + high‑pass")
        ipy_display(IPyAudio(data=yc, rate=params.sr))  # type: ignore
    if 'yn' in globals():
        print("Normalized (yn): after RMS normalization")
        ipy_display(IPyAudio(data=yn, rate=params.sr))  # type: ignore
    # Synthesized MIDI preview from detected segments (if available)
    if 'segments' in globals():
        try:
            import pretty_midi
            pm = pretty_midi.PrettyMIDI()
            inst = pretty_midi.Instrument(program=0)
            for midi, s, e, _rms in segments:
                inst.notes.append(pretty_midi.Note(start=float(s), end=float(e), pitch=int(midi), velocity=80))
            pm.instruments.append(inst)
            try:
                y_midi = pm.synthesize(fs=params.sr)
            except Exception as e:
                y_midi = None
                print(e)
            if y_midi is not None:
                print("Synthesized MIDI (simple synth): renders the detected notes to audio")
                ipy_display(IPyAudio(data=y_midi.astype(np.float32, copy=False), rate=params.sr)) # type: ignore
        except Exception as e:
            print(e)

### Evaluation Helpers (pred vs. ground-truth MIDI)

In [None]:
@dataclass
class _EvalNote:
    start: float
    end: float
    midi: int


def _load_midi_notes(path: str) -> list[_EvalNote]:
    import pretty_midi
    pm = pretty_midi.PrettyMIDI(path)
    notes: list[_EvalNote] = []
    for inst in pm.instruments:
        for n in inst.notes:
            if n.end > n.start:
                notes.append(_EvalNote(start=float(n.start), end=float(n.end), midi=int(n.pitch)))
    notes.sort(key=lambda n: n.start)
    return notes


def _overlap_ratio(a: _EvalNote, b: _EvalNote) -> float:
    inter = max(0.0, min(a.end, b.end) - max(a.start, b.start))
    if inter <= 0:
        return 0.0
    denom = max(a.end - a.start, b.end - b.start)
    return inter / denom if denom > 0 else 0.0


def _cents_diff(midi_a: int, midi_b: int) -> float:
    return 100.0 * abs(midi_a - midi_b)


def eval_match_notes(
    gt_mid: str,
    pred_mid: str,
    onset_ms: float = 50.0,
    offset_ms: float = 50.0,
    overlap_thr: float = 0.5,
    pitch_cents: float = 50.0,
) -> dict:
    """Evaluate predicted vs. ground-truth MIDI with extended metrics.

    Adds onset-only F1 and duration error (ms) to the existing note F1 and
    pitch error (cents). Uses greedy one-to-one matching for both note-level
    and onset-only metrics.
    """
    gt = _load_midi_notes(gt_mid)
    pr = _load_midi_notes(pred_mid)

    # ---------- Note-level matching (pitch + timing/overlap) ----------
    used_pred: set[int] = set()
    pairs: list[tuple[int, int]] = []
    edges: list[tuple[float, int, int]] = []
    for i, g in enumerate(gt):
        for j, p in enumerate(pr):
            # Pitch constraint in cents
            if _cents_diff(g.midi, p.midi) > pitch_cents:
                continue
            # Timing/overlap rule
            onset_ok = abs(p.start - g.start) <= (onset_ms / 1000.0)
            offset_ok = abs(p.end - g.end) <= (offset_ms / 1000.0)
            ov_ok = _overlap_ratio(g, p) >= overlap_thr
            if not ((onset_ok and offset_ok) or ov_ok):
                continue
            # Cost: prefer better temporal alignment
            cost = abs(p.start - g.start) + abs(p.end - g.end)
            edges.append((cost, i, j))
    edges.sort(key=lambda t: t[0])
    used_gt: set[int] = set()
    for _, i, j in edges:
        if i in used_gt or j in used_pred:
            continue
        used_gt.add(i)
        used_pred.add(j)
        pairs.append((i, j))

    tp = len(pairs)
    fp = max(0, len(pr) - tp)
    fn = max(0, len(gt) - tp)
    P = tp / (tp + fp) if tp + fp > 0 else 0.0
    R = tp / (tp + fn) if tp + fn > 0 else 0.0
    F1 = 2 * P * R / (P + R) if P + R > 0 else 0.0

    # Pitch error (cents) on matched pairs
    cents_errs = [_cents_diff(gt[i].midi, pr[j].midi) for (i, j) in pairs]
    mae_cents = float(np.mean(cents_errs)) if cents_errs else 0.0
    med_cents = float(np.median(cents_errs)) if cents_errs else 0.0

    # ---------- Onset-only matching (independent from note pairs) ----------
    tol = onset_ms / 1000.0
    onset_edges: list[tuple[float, int, int]] = []
    for i, g in enumerate(gt):
        for j, p in enumerate(pr):
            d = abs(p.start - g.start)
            if d <= tol:
                onset_edges.append((d, i, j))
    onset_edges.sort(key=lambda t: t[0])
    used_gt_o: set[int] = set()
    used_pred_o: set[int] = set()
    onset_pairs: list[tuple[int, int]] = []
    for _, i, j in onset_edges:
        if i in used_gt_o or j in used_pred_o:
            continue
        used_gt_o.add(i)
        used_pred_o.add(j)
        onset_pairs.append((i, j))
    tp_o = len(onset_pairs)
    fp_o = max(0, len(pr) - tp_o)
    fn_o = max(0, len(gt) - tp_o)
    P_o = tp_o / (tp_o + fp_o) if tp_o + fp_o > 0 else 0.0
    R_o = tp_o / (tp_o + fn_o) if tp_o + fn_o > 0 else 0.0
    F1_o = 2 * P_o * R_o / (P_o + R_o) if P_o + R_o > 0 else 0.0

    # ---------- Duration error (ms) on matched note pairs ----------
    dur_errs_ms = [
        abs((pr[j].end - pr[j].start) - (gt[i].end - gt[i].start)) * 1000.0
        for (i, j) in pairs
    ]
    mae_dur_ms = float(np.mean(dur_errs_ms)) if dur_errs_ms else 0.0
    med_dur_ms = float(np.median(dur_errs_ms)) if dur_errs_ms else 0.0

    # Optional rounding for readability
    def r4(x: float) -> float:
        return float(np.round(x, 4))

    return {
        "counts": {"gt": len(gt), "pred": len(pr), "tp": tp, "fp": fp, "fn": fn},
        "note_f1": {"precision": r4(P), "recall": r4(R), "f1": r4(F1)},
        "onset_f1": {"precision": r4(P_o), "recall": r4(R_o), "f1": r4(F1_o)},
        "pitch_error_cents": {"mae": r4(mae_cents), "median": r4(med_cents)},
        "duration_error_ms": {"mae": r4(mae_dur_ms), "median": r4(med_dur_ms)},
        "params": {
            "onset_ms": onset_ms,
            "offset_ms": offset_ms,
            "overlap_thr": overlap_thr,
            "pitch_cents": pitch_cents,
        },
    }

### Noise Robustness Evaluation (multi-SNR)

In [None]:
def evaluate_with_noise_levels(
    clean_wav: str,
    gt_midi: str,
    params: Params,
    noise_levels_db: list[float] = [0, 10, 20],
    render_audio: bool = True,
) -> dict:
    """Evaluate the pipeline under multiple noise levels (in dB SNR).

    Adds Gaussian white noise to a conditioned, clean input signal, renders each
    noisy version (if available), transcribes with the existing pipeline, and
    evaluates via `eval_match_notes()`.

    Returns a dict mapping f"{snr}dB" → metrics dict.
    """
    # Load and condition the clean audio
    audio = load_mono(clean_wav, sr=params.sr)
    y_clean = condition(audio.y, sr=audio.sr, hpf_cutoff=params.hpf_cutoff)

    # Helper: add Gaussian noise at target SNR (dB)
    def add_noise(y: np.ndarray, snr_db: float) -> np.ndarray:
        rms_signal = float(np.sqrt(np.mean(np.square(y), dtype=np.float64)))
        # Avoid division by zero for silent inputs
        if rms_signal <= 1e-12:
            return y.copy()
        rms_noise = rms_signal / (10.0 ** (snr_db / 20.0))
        noise = np.random.normal(0.0, rms_noise, size=y.shape)
        out = y + noise.astype(y.dtype, copy=False)
        return np.clip(out, -1.0, 1.0).astype(np.float32, copy=False)

    # Workspace for temporary files
    import tempfile
    from pathlib import Path as _Path
    tmpdir = _Path(tempfile.mkdtemp(prefix="s2m_noise_eval_"))

    results: dict[str, dict] = {}

    # Try to enable inline audio rendering if in IPython
    _can_render = False
    if render_audio:
        try:
            from IPython.display import Audio as _IPyAudio, display as _ipy_display  # type: ignore
            _can_render = True
        except Exception:
            _can_render = False

    for snr in noise_levels_db:
        # Generate noisy signal and normalize its RMS
        y_noisy = add_noise(y_clean, float(snr))
        y_noisy = normalize_rms(y_noisy)

        # Save noisy audio to a temporary WAV
        tag = (f"{snr:g}").replace(".", "p")
        wav_path = tmpdir / f"temp_snr{tag}.wav"
        midi_path = tmpdir / f"temp_snr{tag}.mid"
        save_wav(str(wav_path), Audio(y=y_noisy, sr=params.sr))

        # Optional inline audio rendering
        if _can_render:
            print(f"SNR {snr:g} dB")
            try:
                _ipy_display(_IPyAudio(data=y_noisy, rate=params.sr))  # type: ignore
            except Exception:
                pass

        # Run pipeline and evaluate
        _ = wav_to_midi_simple(str(wav_path), str(midi_path), params)
        scores = eval_match_notes(gt_midi, str(midi_path))
        results[f"{snr:g}dB"] = scores

    # Print compact summary table
    print("\n--- Evaluation Summary ---")
    print("SNR | Note_F1 | Onset_F1 | Duration_MAE(ms)")
    print("--------------------------------------------")
    def _get(d: dict, *keys: str, default: float = 0.0) -> float:
        x = d
        for k in keys:
            if not isinstance(x, dict) or k not in x:
                return default
            x = x[k]
        try:
            return float(x)
        except Exception:
            return default
    for snr in noise_levels_db:
        key = f"{snr:g}dB"
        s = results.get(key, {})
        note_f1 = _get(s, "note_f1", "f1")
        onset_f1 = _get(s, "onset_f1", "f1")
        dur_mae = _get(s, "duration_error_ms", "mae")
        print(f"{key:>4} | {note_f1:6.2f} | {onset_f1:8.2f} | {dur_mae:16.1f}")

    return results

### Multi-SNR Evaluation (example cell)

In [None]:
# Example usage in notebook or CLI cell
CUSTOM_SNR_LEVELS = [-10, -5, 0, 5, 10, 20, 99]  # change here if needed
if 'INPUT_WAV' in globals() and INPUT_WAV and 'GT_MIDI' in globals() and GT_MIDI:
    try:
        _noise_results = evaluate_with_noise_levels(INPUT_WAV, GT_MIDI, params, noise_levels_db=CUSTOM_SNR_LEVELS)
    except Exception as _e:
        print(_e)

### Evaluation: compare predicted vs. ground-truth MIDI

In [None]:
# Configure these before running this cell (optional overrides):
EVAL_ONSET_MS = EVAL_ONSET_MS if 'EVAL_ONSET_MS' in globals() else 50.0
EVAL_OFFSET_MS = EVAL_OFFSET_MS if 'EVAL_OFFSET_MS' in globals() else 50.0
EVAL_OVERLAP = EVAL_OVERLAP if 'EVAL_OVERLAP' in globals() else 0.5
EVAL_PITCH_CENTS = EVAL_PITCH_CENTS if 'EVAL_PITCH_CENTS' in globals() else 50.0

# Resolve predicted MIDI path from earlier cells (Stage 8 or Quick Start)
pred_mid = None
if 'out_mid' in globals():
    pred_mid = out_mid
elif 'OUT_MIDI' in globals() and OUT_MIDI:
    pred_mid = OUT_MIDI
elif 'INPUT_WAV' in globals() and INPUT_WAV:
    pred_mid = str(Path(INPUT_WAV).with_suffix('.mid'))

if GT_MIDI and pred_mid:
    scores = eval_match_notes(
        gt_mid=GT_MIDI,
        pred_mid=pred_mid,
        onset_ms=EVAL_ONSET_MS,
        offset_ms=EVAL_OFFSET_MS,
        overlap_thr=EVAL_OVERLAP,
        pitch_cents=EVAL_PITCH_CENTS,
    )
    print(json.dumps(scores, indent=2, sort_keys=True))
    # Visual comparison (if matplotlib is available)
    if _HAVE_MPL:
        gt_notes = _load_midi_notes(GT_MIDI)
        if 'segments' in globals() and segments:
            pred_segs = [(m, s, e) for (m, s, e, _r) in segments]
        else:
            pr_notes = _load_midi_notes(pred_mid)
            pred_segs = [(n.midi, n.start, n.end) for n in pr_notes]
        plot_segments_compare(pred_segs, gt_notes, title="Pred vs GT (notes)")
    else:
        print("Matplotlib not available. Skipping visual comparison.")
else:
    if not GT_MIDI:
        print("Set GT_MIDI to your ground-truth MIDI path.")
    if not pred_mid:
        print("Set OUT_MIDI to your predicted MIDI path.")