In [None]:
# GTCRN vs Spectral Subtraction vs Wiener Filter - Fair Comparison
# This notebook compares GTCRN, GTCRN+SS, GTCRN+WF on the SAME test set

import warnings
from pathlib import Path
import re

import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from IPython.display import display

# Configure paths
repo_root = Path.cwd().parent.parent
results_root = repo_root / "results" / "EXP3" /"GTCRN"
figures_dir = repo_root / "reports" / "figures" / "GTCRN_SS_WF_Comparison"
figures_dir.mkdir(parents=True, exist_ok=True)

snr_levels = [-5, 0, 5, 10, 15]
metrics_of_interest = ["PESQ", "STOI", "SI_SDR", "DNSMOS_mos_ovr"]  

warnings.simplefilter("ignore", category=FutureWarning)
sns.set_theme(style="whitegrid")
plt.rcParams.update({
    "figure.dpi": 300,
    "font.size": 14,
    "axes.labelsize": 14,
    "axes.titlesize": 16,
    "xtick.labelsize": 12,
    "ytick.labelsize": 12,
    "legend.fontsize": 12,
})


In [None]:
catalog = {
    "Noisy": {
        "label": "Noisy Baseline",
        "color": "#1f77b4",  # Blue
        "marker": "o",
        "directory": repo_root / "results" / "BASELINE" / "NOIZEUS_EARS_BASELINE",
        "template": "BASELINE_NOIZEUS_EARS_[{snr}]dB.csv",
    },
    "GTCRN": {
        "label": "GTCRN",
        "color": "#ff7e0ef3",  # Orange
        "marker": "s",
        "directory": results_root / "GTCRN" / "GTCRN_EXP3p2a_ss",  
        "template": "GTCRN_EXP3p2a_merged_[{snr}]dB.csv",  
    },
    "GTCRN_WF": {
        "label": "GTCRN_WF",
        "color": "#2ca02c",  # Green
        "marker": "^",
        "directory": results_root / "GTCRN" / "GTCRNWF_BEST_CONFIG", 
        "template": "GTCRNWF_merged_[{snr}]dB.csv",  
    },
    "GTCRN_SS": {
        "label": "mband_py_log_hybrid_20ms_ov75_fl0p8_nf1_N4",
        "color": "#ff9900",  # Yellow/Orange
        "marker": "D",
        "directory": results_root / "spectral" / "Notebook_analysis",  
        "template": "mband_py_log_hybrid_20ms_ov75_fl0p8_nf1_N4_[{snr}]dB.csv",  
    },
}


In [None]:


# ============================================================================
# Helper Functions
# ============================================================================

def load_experiment(prefix: str, meta: dict) -> pd.DataFrame:
    """Load merged CSV files for a given experiment across all SNR levels."""
    frames = []
    directory = meta["directory"]
    template = meta["template"]
    
    for snr in snr_levels:
        csv_path = directory / template.format(snr=snr)
        
        if not csv_path.exists():
            print(f"⚠️  WARNING: File not found: {csv_path}")
            continue
            
        df = pd.read_csv(csv_path)
        df['SNR'] = snr
        df['experiment'] = prefix
        df['label'] = meta['label']
        frames.append(df)
    
    if not frames:
        raise FileNotFoundError(f"No files found for {prefix}. Check paths in catalog.")
    
    result = pd.concat(frames, ignore_index=True)
    
    # Extract noise type if available
    if 'enhanced_file' in result.columns:
        result['noise_type'] = result['enhanced_file'].str.extract(r'NOIZEUS_NOISE_DATASET_(.*?)_SNR')
    elif 'noisy_file' in result.columns:
        result['noise_type'] = result['noisy_file'].str.extract(r'NOIZEUS_NOISE_DATASET_(.*?)_SNR')
    else:
        result['noise_type'] = np.nan
    
    return result


def build_summary_tables() -> tuple[pd.DataFrame, pd.DataFrame]:
    """Return full concatenated data and SNR-aggregated summary tables."""
    all_frames = []
    summaries = []
    
    for exp_name, meta in catalog.items():
        print(f"Loading {exp_name}...")
        df = load_experiment(exp_name, meta)
        all_frames.append(df)

        # Compute summary statistics per SNR
        summary = df.groupby('SNR')[metrics_of_interest].mean().reset_index()
        summary['experiment'] = exp_name
        summary['label'] = meta['label']
        summaries.append(summary)
    
    full_df = pd.concat(all_frames, ignore_index=True)
    summary_df = pd.concat(summaries, ignore_index=True)
    
    return full_df, summary_df


# ============================================================================
# Load Data
# ============================================================================

print("Loading all experiments...")
full_results, summary_by_snr = build_summary_tables()

print("\n✓ Data loaded successfully!")
print(f"Total samples: {len(full_results)}")
print(f"Experiments: {summary_by_snr['experiment'].unique()}")
print(f"SNR levels: {sorted(summary_by_snr['SNR'].unique())}")

# Display first few rows
print("\nSummary by SNR (first 10 rows):")
display(summary_by_snr.head(10))


# ============================================================================
# Overall Performance (Averaged Across All SNRs)
# ============================================================================

overall_mean = (
    summary_by_snr.groupby(['experiment', 'label'])[metrics_of_interest]
    .mean()
    .reset_index()
    .sort_values('DNSMOS_p808_mos', ascending=False)
)

print("\n" + "="*80)
print("OVERALL PERFORMANCE (Averaged Across All SNRs)")
print("="*80)
display(overall_mean)


# ============================================================================
# Percentage Gains vs Reference Systems
# ============================================================================

def compute_percentage_gain(reference_key: str, reference_label: str) -> pd.DataFrame:
    """Return mean percentage gain over a reference experiment for all metrics."""
    reference = summary_by_snr[summary_by_snr['experiment'] == reference_key]
    
    rows = []
    for exp_name in catalog:
        if exp_name == reference_key:
            continue
        
        exp_slice = summary_by_snr[summary_by_snr['experiment'] == exp_name]
        merged = exp_slice.merge(reference, on='SNR', suffixes=('', '_ref'))
        
        for metric in metrics_of_interest:
            delta = merged[metric] - merged[f"{metric}_ref"]
            pct = (delta / merged[f"{metric}_ref"].abs()) * 100  # Handle negative values
            rows.append({
                'experiment': exp_name,
                'label': merged['label'].iloc[0],
                'metric': metric,
                'avg_pct_gain': pct.mean(),
            })
    
    pivot = pd.DataFrame(rows).pivot(index=['experiment', 'label'], columns='metric', values='avg_pct_gain')
    pivot = pivot.reset_index()
    return pivot.sort_values('DNSMOS_p808_mos', ascending=False)


print("\n" + "="*80)
print("PERCENTAGE GAINS vs NOISY BASELINE")
print("="*80)
pct_vs_noise = compute_percentage_gain('Noisy', 'Noisy')
display(pct_vs_noise)

print("\n" + "="*80)
print("PERCENTAGE GAINS vs GTCRN BASELINE")
print("="*80)
pct_vs_gtcrn = compute_percentage_gain('GTCRN', 'GTCRN')
display(pct_vs_gtcrn)


# ============================================================================
# Best Configuration per Metric and SNR
# ============================================================================

best_records = []
for snr in snr_levels:
    subset = summary_by_snr[summary_by_snr['SNR'] == snr]
    for metric in metrics_of_interest:
        idx = subset[metric].idxmax()
        row = subset.loc[idx]
        best_records.append({
            'SNR': snr,
            'Metric': metric,
            'Best Config': row['label'],
            'Score': row[metric],
        })

best_table = pd.DataFrame(best_records)

print("\n" + "="*80)
print("BEST CONFIGURATION PER METRIC AND SNR")
print("="*80)
display(best_table)


# ============================================================================
# SNR-Specific Performance Tables
# ============================================================================

print("\n" + "="*80)
print("PERFORMANCE AT -5dB (Critical Low SNR)")
print("="*80)
display(summary_by_snr[summary_by_snr['SNR'] == -5][['label'] + metrics_of_interest])

print("\n" + "="*80)
print("PERFORMANCE AT 0dB")
print("="*80)
display(summary_by_snr[summary_by_snr['SNR'] == 0][['label'] + metrics_of_interest])

print("\n" + "="*80)
print("PERFORMANCE AT 5dB")
print("="*80)
display(summary_by_snr[summary_by_snr['SNR'] == 5][['label'] + metrics_of_interest])


# ============================================================================
# VISUALIZATION: Multi-SNR Performance Plot
# ============================================================================

fig, axes = plt.subplots(2, 2, figsize=(16, 12))
axes = axes.flatten()

for idx, metric in enumerate(metrics_of_interest):
    ax = axes[idx]
    
    for exp_name, meta in catalog.items():
        data = summary_by_snr[summary_by_snr['experiment'] == exp_name]
        ax.plot(
            data['SNR'], 
            data[metric],
            label=meta['label'],
            color=meta['color'],
            marker=meta['marker'],
            linewidth=2.5,
            markersize=8,
        )
    
    ax.set_title(f"{metric} Performance Across SNR Levels", fontsize=16, fontweight='bold')
    ax.set_xlabel('SNR (dB)', fontsize=14)
    ax.set_ylabel(metric, fontsize=14)
    ax.grid(True, alpha=0.3)
    ax.legend(loc='best', frameon=True, shadow=True)

plt.tight_layout()
plot_path = figures_dir / "gtcrn_ss_wf_comparison.png"
fig.savefig(plot_path, bbox_inches='tight', dpi=300)
print(f"\n✓ Plot saved: {plot_path}")
plt.show()


# ============================================================================
# HEATMAP: Performance by Noise Type and SNR (if available)
# ============================================================================

if 'noise_type' in full_results.columns and full_results['noise_type'].notna().any():
    print("\n" + "="*80)
    print("GENERATING HEATMAPS BY NOISE TYPE")
    print("="*80)
    
    noise_types = full_results['noise_type'].dropna().unique()
    
    for metric in metrics_of_interest:
        fig, axes = plt.subplots(1, len(catalog), figsize=(6*len(catalog), 5), sharey=True)
        
        if len(catalog) == 1:
            axes = [axes]
        
        for idx, (exp_name, meta) in enumerate(catalog.items()):
            exp_data = full_results[full_results['experiment'] == exp_name]
            
            if exp_data.empty:
                continue
            
            pivot = exp_data.pivot_table(
                values=metric,
                index='noise_type',
                columns='SNR',
                aggfunc='mean'
            )
            
            sns.heatmap(
                pivot,
                annot=True,
                fmt='.3f',
                cmap='YlGnBu',
                ax=axes[idx],
                cbar_kws={'label': metric}
            )
            axes[idx].set_title(f"{meta['label']}", fontsize=14, fontweight='bold')
            axes[idx].set_xlabel('SNR (dB)')
            axes[idx].set_ylabel('Noise Type' if idx == 0 else '')
        
        plt.suptitle(f"{metric} by Noise Type and SNR", fontsize=16, fontweight='bold')
        plt.tight_layout()
        heatmap_path = figures_dir / f"{metric}_heatmap_by_noise.png"
        fig.savefig(heatmap_path, bbox_inches='tight', dpi=300)
        print(f"✓ Heatmap saved: {heatmap_path}")
        plt.show()


# ============================================================================
# DIRECT COMPARISON: SS vs WF at Low SNRs
# ============================================================================

print("\n" + "="*80)
print("DIRECT COMPARISON: SPECTRAL SUBTRACTION vs WIENER FILTER")
print("="*80)

low_snr_data = summary_by_snr[summary_by_snr['SNR'].isin([-5, 0])]
ss_data = low_snr_data[low_snr_data['experiment'] == 'GTCRN_SS']
wf_data = low_snr_data[low_snr_data['experiment'] == 'GTCRN_WF']
gtcrn_data = low_snr_data[low_snr_data['experiment'] == 'GTCRN']

comparison = pd.merge(
    ss_data[['SNR'] + metrics_of_interest],
    wf_data[['SNR'] + metrics_of_interest],
    on='SNR',
    suffixes=('_SS', '_WF')
)

comparison = pd.merge(
    comparison,
    gtcrn_data[['SNR'] + metrics_of_interest],
    on='SNR'
)

for metric in metrics_of_interest:
    comparison[f'{metric}_SS_advantage'] = comparison[f'{metric}_SS'] - comparison[f'{metric}_WF']
    comparison[f'{metric}_SS_vs_GTCRN'] = comparison[f'{metric}_SS'] - comparison[metric]
    comparison[f'{metric}_WF_vs_GTCRN'] = comparison[f'{metric}_WF'] - comparison[metric]

print("\nAt -5dB and 0dB:")
display(comparison)

print("\n" + "="*80)
print("SUMMARY: SS Advantage Over WF (Positive = SS Wins)")
print("="*80)
for metric in metrics_of_interest:
    print(f"\n{metric}:")
    print(f"  -5dB: {comparison[comparison['SNR']==-5][f'{metric}_SS_advantage'].values[0]:.4f}")
    print(f"   0dB: {comparison[comparison['SNR']==0][f'{metric}_SS_advantage'].values[0]:.4f}")


# ============================================================================
# EXPORT RESULTS
# ============================================================================

output_dir = figures_dir / "tables"
output_dir.mkdir(exist_ok=True)

summary_by_snr.to_csv(output_dir / "summary_by_snr.csv", index=False)
overall_mean.to_csv(output_dir / "overall_mean.csv", index=False)
pct_vs_noise.to_csv(output_dir / "percentage_gains_vs_noise.csv", index=False)
pct_vs_gtcrn.to_csv(output_dir / "percentage_gains_vs_gtcrn.csv", index=False)
best_table.to_csv(output_dir / "best_config_per_metric_snr.csv", index=False)

print(f"\n✓ All tables exported to: {output_dir}")
print("\n" + "="*80)
print("ANALYSIS COMPLETE!")
print("="*80)