# GTCRN + Wiener Filter (EXP3p2a) Analysis

This notebook reviews the GTCRNWF EXP3p2a experiments that append causal Wiener filtering to the GTCRN neural enhancer. We focus on how frame length (8/20/25 ms) and the Wiener strength parameter `mu` (intelligibility, balanced, quality settings) impact objective metrics across the DNS challenge test set.

## Experiment Catalogue
- **Noisy Baseline:** Raw noisy mixtures (no enhancement).
- **GTCRN (No WF):** Baseline neural enhancer without post-processing.
- **GTCRN+WF 8 ms:** `mu` tuned for intelligibility (0.123), balanced (0.552), and quality (0.980).
- **GTCRN+WF 20 ms:** Larger analysis window; `mu` values 0.374, 0.677, 0.980.
- **GTCRN+WF 25 ms:** Longest frame; `mu` values 0.374, 0.677, 0.980.
- Metrics tracked per SNR (`-5` to `15` dB): PESQ, STOI, SI-SDR, DNSMOS overall.

In [None]:
import warnings
from pathlib import Path
import re
from typing import List, Optional

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from IPython.display import display

# Configure paths
repo_root = Path.cwd().parent.parent
results_root = repo_root / "results" / "EXP3" / "GTCRN"
figures_dir = repo_root / "reports" / "figures" / "GTCRNWF_EXP3p2a"
figures_dir.mkdir(parents=True, exist_ok=True)

snr_levels = [-5, 0, 5, 10, 15]
metrics_of_interest = ["PESQ", "STOI", "SI_SDR", "DNSMOS_p808_mos"]

warnings.simplefilter("ignore", category=FutureWarning)
sns.set_theme(style="whitegrid")
plt.rcParams.update({
    "figure.dpi": 300,
    "font.size": 14,
    "axes.labelsize": 14,
    "axes.titlesize": 16,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "legend.fontsize": 12,
})

In [None]:
catalog = {
    "BASELINE_NOIZEUS_EARS": {
        "label": "Noisy Baseline",
        "frame_ms": None,
        "mu": None,
        "focus": "No Enhancement",
        "directory": repo_root / "results" / "BASELINE" / "NOIZEUS_EARS_BASELINE",
        "template": "BASELINE_NOIZEUS_EARS_[{snr}]dB.csv",
    },
    "GTCRN_EXP3p2a": {
        "label": "GTCRN (No WF)",
        "frame_ms": None,
        "mu": None,
        "focus": "Neural",
        "directory": results_root / "GTCRN_EXP3p2a",
        "template": "GTCRN_EXP3p2a_merged_[{snr}]dB.csv",
    },
    "GTCRNWF_EXP3p2a_8ms_intelligibility": {
        "label": "GTCRN+WF 8ms Intelligibility (mu=0.123)",
        "frame_ms": 8,
        "mu": 0.123,
        "focus": "Intelligibility",
        "directory": results_root / "GTCRNWF_EXP3p2a_8ms_intelligibility",
        "template": "GTCRNWF_EXP3p2a_8ms_intelligibility_merged_[{snr}]dB.csv",
    },
    "GTCRNWF_EXP3p2a_8ms_balance": {
        "label": "GTCRN+WF 8ms Balanced (mu=0.552)",
        "frame_ms": 8,
        "mu": 0.552,
        "focus": "Balanced",
        "directory": results_root / "GTCRNWF_EXP3p2a_8ms_balance",
        "template": "GTCRNWF_EXP3p2a_8ms_balance_merged_[{snr}]dB.csv",
    },
    "GTCRNWF_EXP3p2a_8ms_quality": {
        "label": "GTCRN+WF 8ms Quality (mu=0.980)",
        "frame_ms": 8,
        "mu": 0.98,
        "focus": "Quality",
        "directory": results_root / "GTCRNWF_EXP3p2a_8ms_quality",
        "template": "GTCRNWF_EXP3p2a_8ms_quality_merged_[{snr}]dB.csv",
    },
    "GTCRNWF_EXP3p2a_20ms_intelligibility": {
        "label": "GTCRN+WF 20ms Intelligibility (mu=0.374)",
        "frame_ms": 20,
        "mu": 0.374,
        "focus": "Intelligibility",
        "directory": results_root / "GTCRNWF_EXP3p2a_20ms_intelligibility",
        "template": "GTCRNWF_EXP3p2a_20ms_intelligibility_merged_[{snr}]dB.csv",
    },
    "GTCRNWF_EXP3p2a_20ms_balance": {
        "label": "GTCRN+WF 20ms Balanced (mu=0.677)",
        "frame_ms": 20,
        "mu": 0.677,
        "focus": "Balanced",
        "directory": results_root / "GTCRNWF_EXP3p2a_20ms_balance",
        "template": "GTCRNWF_EXP3p2a_20ms_balance_merged_[{snr}]dB.csv",
    },
    "GTCRNWF_EXP3p2a_20ms_quality": {
        "label": "GTCRN+WF 20ms Quality (mu=0.980)",
        "frame_ms": 20,
        "mu": 0.98,
        "focus": "Quality",
        "directory": results_root / "GTCRNWF_EXP3p2a_20ms_quality",
        "template": "GTCRNWF_EXP3p2a_20ms_quality_merged_[{snr}]dB.csv",
    },
    "GTCRNWF_EXP3p2a_25ms_intelligibility": {
        "label": "GTCRN+WF 25ms Intelligibility (mu=0.374)",
        "frame_ms": 25,
        "mu": 0.374,
        "focus": "Intelligibility",
        "directory": results_root / "GTCRNWF_EXP3p2a_25ms_intelligibility",
        "template": "GTCRNWF_EXP3p2a_25ms_intelligibility_merged_[{snr}]dB.csv",
    },
    "GTCRNWF_EXP3p2a_25ms_balance": {
        "label": "GTCRN+WF 25ms Balanced (mu=0.677)",
        "frame_ms": 25,
        "mu": 0.677,
        "focus": "Balanced",
        "directory": results_root / "GTCRNWF_EXP3p2a_25ms_balance",
        "template": "GTCRNWF_EXP3p2a_25ms_balance_merged_[{snr}]dB.csv",
    },
    "GTCRNWF_EXP3p2a_25ms_quality": {
        "label": "GTCRN+WF 25ms Quality (mu=0.980)",
        "frame_ms": 25,
        "mu": 0.98,
        "focus": "Quality",
        "directory": results_root / "GTCRNWF_EXP3p2a_25ms_quality",
        "template": "GTCRNWF_EXP3p2a_25ms_quality_merged_[{snr}]dB.csv",
    },
    "WF_EXP1p1d": {
        "label": "WF 25ms Default (mu=0.98)",
        "frame_ms": 25,
        "mu": 0.98,
        "focus": "Default",
        "directory": repo_root / "results" / "EXP1" / "wiener" / "WF_EXP1p1d",
        "template": "WF_EXP1p1d_merged_[{snr}]dB.csv",
    },
}

meta_fields = ["label", "frame_ms", "mu", "focus"]

In [None]:
def load_experiment(prefix: str, meta: dict) -> pd.DataFrame:
    """Load merged CSV files for a given experiment across all SNR levels."""
    frames = []
    directory = meta["directory"]
    template = meta["template"]
    for snr in snr_levels:
        csv_path = directory / template.format(snr=snr)
        df = pd.read_csv(csv_path)
        df['SNR'] = snr
        df['experiment'] = prefix
        frames.append(df)
    result = pd.concat(frames, ignore_index=True)
    if 'enhanced_file' in result.columns:
        result['noise_type'] = result['enhanced_file'].str.extract(r'NOIZEUS_NOISE_DATASET_(.*?)_SNR')
    else:
        result['noise_type'] = np.nan
    return result


def build_summary_tables() -> tuple[pd.DataFrame, pd.DataFrame]:
    """Return full concatenated data and SNR-aggregated summary tables."""
    all_frames = []
    summaries = []
    for exp_name, meta in catalog.items():
        df = load_experiment(exp_name, meta)
        for field in meta_fields:
            df[field] = meta.get(field)
        all_frames.append(df)

        summary = df.groupby('SNR')[metrics_of_interest].mean().reset_index()
        summary['experiment'] = exp_name
        for field in meta_fields:
            summary[field] = meta.get(field)
        summaries.append(summary)
    full_df = pd.concat(all_frames, ignore_index=True)
    summary_df = pd.concat(summaries, ignore_index=True)
    return full_df, summary_df


full_results, summary_by_snr = build_summary_tables()
summary_by_snr.head()

## Check Loaded Data
The next cell concatenates all metrics from the experiment CSVs and confirms the expected schema.

## Overall SNR-Averaged Performance
We compute the mean of each metric across all SNRs for each configuration to establish a high-level ranking.

In [None]:
meta_lookup = pd.DataFrame({
    key: {field: catalog[key].get(field) for field in meta_fields}
    for key in catalog
}).T

overall_mean = (
    summary_by_snr.groupby('experiment')[metrics_of_interest].mean()
    .join(meta_lookup)
    .reset_index()
    .rename(columns={'DNSMOS_p808_mos': 'DNSMOS'})
)
overall_mean.sort_values('DNSMOS', ascending=False)

## Percentage Gains vs Reference Systems
Tables below summarise the mean percentage change in each metric relative to (a) the noisy baseline mixtures and (b) the GTCRN output. Positive values mean the row configuration improves upon the reference.

In [None]:
def compute_percentage_gain(reference_key: str) -> pd.DataFrame:
    """Return mean percentage gain over a reference experiment for all metrics."""
    reference = summary_by_snr[summary_by_snr['experiment'] == reference_key]
    rows = []
    for exp_name in catalog:
        if exp_name == reference_key:
            continue
        exp_slice = summary_by_snr[summary_by_snr['experiment'] == exp_name]
        merged = exp_slice.merge(reference, on='SNR', suffixes=('', '_ref'))
        for metric in metrics_of_interest:
            delta = merged[metric] - merged[f"{metric}_ref"]
            pct = (delta / merged[f"{metric}_ref"]) * 100
            rows.append({
                'experiment': exp_name,
                'metric': metric,
                'avg_pct': pct.mean(),
            })
    pivot = pd.DataFrame(rows).pivot(index='experiment', columns='metric', values='avg_pct')
    pivot = pivot.rename(columns={'DNSMOS_p808_mos': 'DNSMOS'})
    pivot = pivot.join(meta_lookup[['label']], how='left')
    pivot = pivot[['label', 'PESQ', 'STOI', 'SI_SDR', 'DNSMOS']]
    return pivot.sort_values('DNSMOS', ascending=False)

pct_vs_noise = compute_percentage_gain('BASELINE_NOIZEUS_EARS')
pct_vs_gtcrn = compute_percentage_gain('GTCRN_EXP3p2a')

display(pct_vs_noise)
display(pct_vs_gtcrn)

## Best Configuration per Metric and SNR
Identifies which setup delivers the highest score for each metric at each SNR level.

In [None]:
best_records = []
for snr in snr_levels:
    subset = summary_by_snr[summary_by_snr['SNR'] == snr]
    for metric in metrics_of_interest:
        idx = subset[metric].idxmax()
        row = subset.loc[idx]
        best_records.append({
            'SNR': snr,
            'Metric': metric,
            'Best Config': row['label'],
            'Score': row[metric],
        })
best_table = pd.DataFrame(best_records)
best_table

In [None]:
metric_display = {
    "PESQ": "PESQ",
    "STOI": "STOI",
    "SI_SDR": "SI-SDR",
    "DNSMOS_p808_mos": "DNSMOS",
}

label_lookup = {key: meta.get("label", key) for key, meta in catalog.items()}

def best_by_snr_with_gain(reference_key: str) -> pd.DataFrame:
    """Return best-performing configuration per SNR/metric with percent gain vs reference."""
    reference_slice = (
        summary_by_snr[summary_by_snr["experiment"] == reference_key]
        .set_index("SNR")
    )
    rows = []
    for snr in snr_levels:
        snr_slice = summary_by_snr[summary_by_snr["SNR"] == snr]
        if snr_slice.empty:
            continue
        for metric in metrics_of_interest:
            best_idx = snr_slice[metric].idxmax()
            best_row = snr_slice.loc[best_idx]
            ref_value = reference_slice.loc[snr, metric] if snr in reference_slice.index else np.nan
            pct_gain = np.nan
            if pd.notnull(ref_value) and ref_value != 0:
                pct_gain = ((best_row[metric] - ref_value) / ref_value) * 100
            rows.append({
                "SNR (dB)": snr,
                "Metric": metric_display.get(metric, metric),
                "Best Config": label_lookup.get(best_row["experiment"], best_row["experiment"]),
                "Score": best_row[metric],
                "Reference": label_lookup.get(reference_key, reference_key),
                "Ref Score": ref_value,
                "Pct Gain (%)": pct_gain,
            })
    result = pd.DataFrame(rows)
    return result.sort_values(["SNR (dB)", "Metric"])

best_vs_noisy = best_by_snr_with_gain("BASELINE_NOIZEUS_EARS")
best_vs_gtcrn = best_by_snr_with_gain("GTCRN_EXP3p2a")

display(best_vs_noisy.round({"Score": 3, "Ref Score": 3, "Pct Gain (%)": 1}))
display(best_vs_gtcrn.round({"Score": 3, "Ref Score": 3, "Pct Gain (%)": 1}))

## Metric Trends vs SNR
Visualize how GTCRN compares against the Wiener variants across SNR for each objective metric.

In [None]:
# Build an explicit color mapping so each label gets a consistent, unique color
label_order = [catalog[key]['label'] for key in catalog]
label_map = {key: meta['label'] for key, meta in catalog.items()}
palette_colors = sns.color_palette("tab20", n_colors=len(label_order))
palette = {label: color for label, color in zip(label_order, palette_colors)}
marker_map = {label: ('X' if label == 'GTCRN (No WF)' else 'o') for label in label_order}
plot_df = summary_by_snr.copy()
plot_df['label'] = plot_df['experiment'].map(label_map)
# Optionally ensure labels present in the data preserve the requested order
plot_df['label'] = pd.Categorical(plot_df['label'], categories=label_order, ordered=True)
fig, axes = plt.subplots(2, 2, figsize=(14, 10), sharex=True)
axes = axes.flatten()
for idx, metric in enumerate(metrics_of_interest):
    ax = axes[idx]
    sns.lineplot(
        data=plot_df,
        x='SNR',
        y=metric,
        hue='label',
        style='label',
        markers=marker_map,
        dashes=False,
        palette=palette,
        hue_order=label_order,
        style_order=label_order,
        ax=ax,
        linewidth=2.0,
    )
    ax.set_title(metric)
    ax.set_xlabel('SNR (dB)')
    ax.legend().remove()
    ax.grid(True, alpha=0.3)
handles, labels = axes[0].get_legend_handles_labels()
fig.tight_layout(rect=[0, 0, 1, 0.88])
fig.legend(
    handles,
    labels,
    loc='upper center',
    ncol=4,
    frameon=False,
    bbox_to_anchor=(0.5, 0.99),
)
plot_path = figures_dir / "gtcrnwf_metric_trends.png"
fig.savefig(plot_path, bbox_inches='tight')
plot_path

## DNSMOS per SNR for Key Systems
We compare the overall DNSMOS across SNR levels for the standalone GTCRN, the standalone 25 ms Wiener filter default configuration, and the combined GTCRN+WF (25 ms, $\mu=0.98$).

In [None]:
selected_keys = [
    "GTCRN_EXP3p2a",
    "WF_EXP1p1d",
    "GTCRNWF_EXP3p2a_25ms_quality",
]
bar_metric = "DNSMOS_p808_mos"
bar_df = summary_by_snr[summary_by_snr["experiment"].isin(selected_keys)].copy()
bar_df["label"] = bar_df["experiment"].map(label_map)
fig, ax = plt.subplots(figsize=(10, 6))
sns.barplot(
    data=bar_df,
    x="SNR",
    y=bar_metric,
    hue="label",
    order=snr_levels,
    ax=ax,
    palette="Set2",
)
ax.set_xlabel("SNR (dB)")
ax.set_ylabel("DNSMOS (Overall)")
ax.set_title("DNSMOS vs SNR for GTCRN and Wiener Variants")
ax.legend(title="System", loc="best")
ax.grid(True, axis="y", alpha=0.3)
fig.tight_layout()
bars_path = figures_dir / "gtcrnwf_25ms_dnsmos_bars.png"
fig.savefig(bars_path, bbox_inches="tight")
bars_path

## Noise-Type Breakdown
Replicates the per-noise analysis from earlier experiments by averaging each metric per noise class and SNR for every configuration.

In [None]:
noise_breakdown = (
    full_results.dropna(subset=["noise_type"])
    .groupby(["noise_type", "SNR", "experiment"], as_index=False)[metrics_of_interest]
    .mean()
)
noise_breakdown["label"] = noise_breakdown["experiment"].map(label_map)
noise_breakdown = noise_breakdown.sort_values(["noise_type", "SNR", "label"])
noise_label_order = [label for label in label_order if label in noise_breakdown["label"].unique()]
display(noise_breakdown.head())

In [None]:
def plot_noise_type_breakdown(
    data: pd.DataFrame,
    metrics: List[str],
    snr_levels: List[int],
    label_order: List[str],
    save_dir: Optional[Path] = None,
 ) -> None:
    """Visualize metric trends per noise type and SNR for each configuration."""
    for metric in metrics:
        for snr in snr_levels:
            subset = data[data["SNR"] == snr]
            if subset.empty:
                continue
            fig, ax = plt.subplots(figsize=(14, 7))
            sns.pointplot(
                data=subset,
                x="noise_type",
                y=metric,
                hue="label",
                hue_order=label_order,
                palette="tab10",
                dodge=0.3,
                markers="o",
                linestyles="-",
                errorbar=None,
                ax=ax,
            )
            ax.set_title(f"{metric} by Noise Type at {snr} dB", fontweight="bold", pad=14)
            ax.set_xlabel("Noise Type", fontweight="bold")
            ax.set_ylabel(metric, fontweight="bold")
            ax.tick_params(axis="x", labelrotation=90)
            plt.setp(ax.get_xticklabels(), ha="center")
            ax.grid(True, axis="y", alpha=0.3, linestyle="--")
            ax.legend(title="System", bbox_to_anchor=(1.02, 1), loc="upper left", frameon=True)
            plt.tight_layout()
            if save_dir is not None:
                save_dir.mkdir(parents=True, exist_ok=True)
                fig_path = save_dir / f"{metric}_noise_breakdown_{snr}dB.png"
                fig.savefig(fig_path, bbox_inches="tight")
            plt.show()
    print(f"Generated {len(metrics)} Ã— {len(snr_levels)} noise-type plots.")

In [None]:
def best_by_noise_type_with_gain(reference_key: str) -> pd.DataFrame:
    """Return best-performing configuration per noise type, SNR, and metric with % gain."""
    reference_slice = (
        noise_breakdown[noise_breakdown["experiment"] == reference_key]
        .set_index(["noise_type", "SNR"])
)
    rows = []
    for noise, snr in noise_breakdown[["noise_type", "SNR"]].drop_duplicates().itertuples(index=False):
        noise_slice = noise_breakdown[(noise_breakdown["noise_type"] == noise) & (noise_breakdown["SNR"] == snr)]
        if noise_slice.empty:
            continue
        for metric in metrics_of_interest:
            best_idx = noise_slice[metric].idxmax()
            best_row = noise_slice.loc[best_idx]
            ref_value = np.nan
            if (noise, snr) in reference_slice.index:
                ref_value = reference_slice.loc[(noise, snr), metric]
            pct_gain = np.nan
            if pd.notnull(ref_value) and ref_value != 0:
                pct_gain = ((best_row[metric] - ref_value) / ref_value) * 100
            rows.append({
                "Noise Type": noise,
                "SNR (dB)": snr,
                "Metric": metric_display.get(metric, metric),
                "Best Config": label_lookup.get(best_row["experiment"], best_row["experiment"]),
                "Score": best_row[metric],
                "Reference": label_lookup.get(reference_key, reference_key),
                "Ref Score": ref_value,
                "Pct Gain (%)": pct_gain,
            })
    result = pd.DataFrame(rows)
    return result.sort_values(["Noise Type", "SNR (dB)", "Metric"])

best_noise_vs_noisy = best_by_noise_type_with_gain("BASELINE_NOIZEUS_EARS")
best_noise_vs_gtcrn = best_by_noise_type_with_gain("GTCRN_EXP3p2a")

display(best_noise_vs_noisy.round({"Score": 3, "Ref Score": 3, "Pct Gain (%)": 1}))
display(best_noise_vs_gtcrn.round({"Score": 3, "Ref Score": 3, "Pct Gain (%)": 1}))

In [None]:
noise_breakdown_dir = figures_dir / "noise_type_breakdown"
plot_noise_type_breakdown(
    data=noise_breakdown,
    metrics=metrics_of_interest,
    snr_levels=snr_levels,
    label_order=noise_label_order,
    save_dir=noise_breakdown_dir,
 )
noise_breakdown_dir

## Focused Comparison: GTCRN vs 25 ms Wiener Variants
Focused look at WF 25 ms default, GTCRN alone, and GTCRN+WF 25 ms ($\mu=0.98$) to report best-performer tables.

In [None]:
focused_experiments = [
    "WF_EXP1p1d",
    "GTCRN_EXP3p2a",
    "GTCRNWF_EXP3p2a_25ms_quality",
]
focused_summary = summary_by_snr[summary_by_snr["experiment"].isin(focused_experiments)].copy()
focused_noise = noise_breakdown[noise_breakdown["experiment"].isin(focused_experiments)].copy()

def best_within_subset_by_snr(data: pd.DataFrame, reference_key: str) -> pd.DataFrame:
    reference_slice = data[data["experiment"] == reference_key].set_index("SNR")
    rows = []
    for snr in snr_levels:
        snr_slice = data[data["SNR"] == snr]
        if snr_slice.empty:
            continue
        for metric in metrics_of_interest:
            best_idx = snr_slice[metric].idxmax()
            best_row = snr_slice.loc[best_idx]
            ref_value = reference_slice.loc[snr, metric] if snr in reference_slice.index else np.nan
            pct_gain = np.nan
            if pd.notnull(ref_value) and ref_value != 0:
                pct_gain = ((best_row[metric] - ref_value) / ref_value) * 100
            rows.append({
                "SNR (dB)": snr,
                "Metric": metric_display.get(metric, metric),
                "Best Config": label_lookup.get(best_row["experiment"], best_row["experiment"]),
                "Score": best_row[metric],
                "Reference": label_lookup.get(reference_key, reference_key),
                "Ref Score": ref_value,
                "Pct Gain (%)": pct_gain,
            })
    result = pd.DataFrame(rows)
    return result.sort_values(["SNR (dB)", "Metric"])

def best_within_subset_by_noise(data: pd.DataFrame, reference_key: str) -> pd.DataFrame:
    reference_slice = data[data["experiment"] == reference_key].set_index(["noise_type", "SNR"])
    rows = []
    for noise, snr in data[["noise_type", "SNR"]].drop_duplicates().itertuples(index=False):
        noise_slice = data[(data["noise_type"] == noise) & (data["SNR"] == snr)]
        if noise_slice.empty:
            continue
        for metric in metrics_of_interest:
            best_idx = noise_slice[metric].idxmax()
            best_row = noise_slice.loc[best_idx]
            ref_value = np.nan
            if (noise, snr) in reference_slice.index:
                ref_value = reference_slice.loc[(noise, snr), metric]
            pct_gain = np.nan
            if pd.notnull(ref_value) and ref_value != 0:
                pct_gain = ((best_row[metric] - ref_value) / ref_value) * 100
            rows.append({
                "Noise Type": noise,
                "SNR (dB)": snr,
                "Metric": metric_display.get(metric, metric),
                "Best Config": label_lookup.get(best_row["experiment"], best_row["experiment"]),
                "Score": best_row[metric],
                "Reference": label_lookup.get(reference_key, reference_key),
                "Ref Score": ref_value,
                "Pct Gain (%)": pct_gain,
            })
    result = pd.DataFrame(rows)
    return result.sort_values(["Noise Type", "SNR (dB)", "Metric"])

focused_best_vs_gtcrn = best_within_subset_by_snr(focused_summary, "GTCRN_EXP3p2a")
focused_noise_best_vs_gtcrn = best_within_subset_by_noise(focused_noise, "GTCRN_EXP3p2a")

display(focused_best_vs_gtcrn.round({"Score": 3, "Ref Score": 3, "Pct Gain (%)": 1}))
display(focused_noise_best_vs_gtcrn.round({"Score": 3, "Ref Score": 3, "Pct Gain (%)": 1}))