In [None]:
# ================================
# Speech Enhancement Artifact Analysis with MSC
# (Multiple Enhanced Files)
# ================================

import numpy as np
import matplotlib.pyplot as plt
from scipy.signal import stft, get_window, find_peaks
from scipy.io import wavfile
import pandas as pd

# ------------------ Utility Functions ------------------

def compute_stft(x, fs, win_len=0.032, hop_ratio=0.5, nfft=512):
    """Compute STFT of a signal."""
    win_len_samples = int(win_len * fs)
    hop = int(win_len_samples * hop_ratio)
    window = get_window('hann', win_len_samples)
    f, t, Z = stft(x, fs=fs, window=window, nperseg=win_len_samples,
                   noverlap=win_len_samples - hop, nfft=nfft, boundary=None)
    return f, t, Z, hop

def compute_msc(X, Y, smooth_frames=7):
    """Compute mean square coherence spectrogram between clean and enhanced STFTs."""
    X2 = np.abs(X)**2
    Y2 = np.abs(Y)**2
    Pxx = np.zeros_like(X2)
    Pyy = np.zeros_like(Y2)
    Pxy = np.zeros_like(X, dtype=complex)
    pad = smooth_frames // 2

    for tt in range(X.shape[1]):
        i0 = max(0, tt - pad)
        i1 = min(X.shape[1], tt + pad + 1)
        Pxx[:, tt] = X2[:, i0:i1].mean(axis=1)
        Pyy[:, tt] = Y2[:, i0:i1].mean(axis=1)
        Pxy[:, tt] = (X[:, i0:i1] * np.conj(Y[:, i0:i1])).mean(axis=1)

    MSC = (np.abs(Pxy) ** 2) / (Pxx * Pyy + 1e-12)
    return MSC, X2, Y2

def artifact_energy_fraction(MSC, X2, Y2, thresh=0.25):
    """Compute artifact energy fraction (AEF)."""
    residual = np.clip(Y2 - X2, a_min=0, a_max=None)
    artifact_mask = (MSC < thresh) & (residual > 1e-8)
    AEF = residual[artifact_mask].sum() / (Y2.sum() + 1e-12)
    return AEF, residual, artifact_mask

def compute_musical_noise_index(residual, fs, hop, dur_thresh=0.1):
    """Rough measure of musical noise index (MNI)."""
    peaks_per_frame = []
    for t in range(residual.shape[1]):
        spec = residual[:, t]
        if np.max(spec) == 0:
            peaks_per_frame.append(0)
            continue
        peaks, _ = find_peaks(spec, height=np.max(spec) * 0.2)
        peaks_per_frame.append(len(peaks))

    # Peaks per second
    frame_rate = fs / hop
    avg_peaks = np.mean(peaks_per_frame)
    MNI = avg_peaks * frame_rate
    return MNI

def plot_spectrogram(data, f, t, title, vmin=None, vmax=None):
    plt.figure(figsize=(10, 4))
    plt.pcolormesh(t, f, 10*np.log10(data+1e-12), shading='gouraud', vmin=vmin, vmax=vmax)
    plt.title(title)
    plt.ylabel('Frequency [Hz]')
    plt.xlabel('Time [s]')
    plt.colorbar(label='dB')
    plt.tight_layout()
    plt.show()

# ------------------ Load Clean File ------------------

clean_path = "clean.wav"   # <-- set your clean file
fs, clean = wavfile.read(clean_path)
if clean.ndim > 1: clean = clean[:,0]
clean = clean.astype(np.float32)
clean /= np.max(np.abs(clean) + 1e-12)

# Compute clean STFT
f, tt, X, hop = compute_stft(clean, fs)

# ------------------ Enhanced Files List ------------------

enhanced_files = [
    "enhanced_method1.wav",
    "enhanced_method2.wav",
    # add more here
]

results = []

# ------------------ Loop over Enhanced Files ------------------
for enh_path in enhanced_files:
    fs_enh, enhanced = wavfile.read(enh_path)
    assert fs_enh == fs, "Sample rate mismatch!"
    if enhanced.ndim > 1: enhanced = enhanced[:,0]
    enhanced = enhanced.astype(np.float32)
    enhanced /= np.max(np.abs(enhanced) + 1e-12)

    _, _, Y, _ = compute_stft(enhanced, fs)
    MSC, X2, Y2 = compute_msc(X, Y)
    AEF, residual, artifact_mask = artifact_energy_fraction(MSC, X2, Y2, thresh=0.25)
    MNI = compute_musical_noise_index(residual, fs, hop)

    # Save results
    results.append({"File": enh_path, "AEF": AEF, "MNI": MNI})

    # ---- Visualization per file ----
    print(f"\n==== {enh_path} ====")
    plot_spectrogram(Y2, f, tt, f"Enhanced Spectrogram Power - {enh_path}")
    plot_spectrogram(residual, f, tt, f"Residual Energy - {enh_path}", vmin=-60, vmax=20)
    plot_spectrogram(MSC, f, tt, f"MSC (Clean vs {enh_path})", vmin=0, vmax=1)
    print("Artifact Energy Fraction (AEF):", AEF)
    print("Musical Noise Index (MNI):", MNI)

# ------------------ Summary Table ------------------
df = pd.DataFrame(results)
print("\n===== Summary of Results =====")
print(df)