# Review 3: STFT-Based CNN-Transformer Speech Enhancement

**Key Fix from Review 2:** Replaced mel spectrogram (non-invertible) with STFT (phase-preserving).  
**Architecture:** CNN encoder (1→64→128→256) + 2-layer Pre-LN Transformer (4 heads, d=256) + CNN decoder → Sigmoid mask  
**Reconstruction:** `mask × complex_STFT → ISTFT` (lossless — no GriffinLim needed)  

**Why R2 failed:** MelSpectrogram is a many-to-one mapping. `InverseMelScale + GriffinLim` introduced  
catastrophic phase artifacts (SI-SDR went from −0.82 dB to −25.58 dB). The model WAS learning  
(val loss improved from 0.1764→0.1485) but reconstruction destroyed the signal.

**Team:** Krishnasinh Jadeja (22BLC1211), Kirtan Sondagar (22BLC1228), Prabhu Kalyan Panda (22BLC1213)  
**Guide:** Dr. Praveen Jaraut — VIT Bhopal Capstone

In [None]:
# ============================================================================
# Cell 1: Install dependencies + Imports + Device check
# ============================================================================
!pip install pesq==0.0.4 pystoi -q

import torch, torch.nn as nn, torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import glob, os, json, time, warnings
warnings.filterwarnings('ignore')
from pesq import pesq as pesq_metric
from pystoi import stoi as stoi_metric

torch.manual_seed(42)
np.random.seed(42)

# STFT Configuration (matches CRN baseline)
N_FFT = 512
HOP_LENGTH = 256
N_FREQ = N_FFT // 2 + 1  # 257 frequency bins
SR = 16000
MAX_LEN = 48000  # 3 seconds at 16kHz

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')
if device == 'cuda':
    props = torch.cuda.get_device_properties(0)
    vram = getattr(props, 'total_memory', getattr(props, 'total_mem', 0))
    print(f'GPU: {torch.cuda.get_device_name(0)} | VRAM: {vram / 1e9:.1f}GB')
print(f'STFT config: n_fft={N_FFT}, hop={HOP_LENGTH}, freq_bins={N_FREQ}')
print('Imports OK')

## Dataset: LibriSpeech-Noise
Download and extract `earth16/libri-speech-noise-dataset` (6.6 GB, 7000 train + 105 test WAV pairs).

In [None]:
# ============================================================================
# Cell 2: Dataset download & extraction (Kaggle)
# ============================================================================
import subprocess, zipfile

data_base = '/kaggle/working/data'
dl_tmp = '/kaggle/working/dl_tmp'
os.makedirs(data_base, exist_ok=True)
os.makedirs(dl_tmp, exist_ok=True)
done_flag = os.path.join(data_base, '.done')

if os.path.exists(done_flag):
    print('Dataset already extracted, skipping')
else:
    mounted = '/kaggle/input/libri-speech-noise-dataset'
    if os.path.isdir(mounted) and len(os.listdir(mounted)) > 0:
        src = mounted
        print(f'Using mounted dataset at {src}')
    else:
        print('Dataset not mounted, downloading via kaggle API...')
        subprocess.run(['kaggle', 'datasets', 'download',
                        'earth16/libri-speech-noise-dataset', '-p', dl_tmp], check=True)
        zf = os.path.join(dl_tmp, 'libri-speech-noise-dataset.zip')
        if os.path.exists(zf):
            with zipfile.ZipFile(zf, 'r') as z:
                z.extractall(dl_tmp)
            os.remove(zf)
        src = dl_tmp
        print(f'Downloaded to {src}')

    # Extract 7z archives
    subprocess.run(['apt-get', 'install', '-y', 'p7zip-full'], capture_output=True)
    for arch in ['train.7z', 'y_train.7z', 'test.7z', 'y_test.7z']:
        fp = os.path.join(src, arch)
        if os.path.exists(fp):
            print(f'Extracting {arch}...')
            subprocess.run(['7z', 'x', fp, f'-o{data_base}', '-y'], capture_output=True)
    open(done_flag, 'w').close()

# Verify extraction
def find_wav_dir(base, name):
    for root, dirs, files in os.walk(base):
        if os.path.basename(root) == name and any(f.endswith('.wav') for f in files):
            return root
    return None

noisy_train = find_wav_dir(data_base, 'train')
clean_train = find_wav_dir(data_base, 'y_train')
noisy_test  = find_wav_dir(data_base, 'test')
clean_test  = find_wav_dir(data_base, 'y_test')

for tag, d in [('noisy_train', noisy_train), ('clean_train', clean_train),
               ('noisy_test', noisy_test), ('clean_test', clean_test)]:
    n = len(glob.glob(os.path.join(d, '*.wav'))) if d else 0
    print(f'  {tag}: {d} ({n} files)')

## STFT Dataset Class
Loads WAV pairs → computes STFT (n_fft=512, hop=256) → returns magnitude, phase, and waveforms.  
**Key difference from R2:** No mel filterbank. STFT is losslessly invertible via ISTFT.

In [None]:
# ============================================================================
# Cell 3: STFTSpeechDataset
# ============================================================================
class STFTSpeechDataset(Dataset):
    """STFT-based speech enhancement dataset. Returns magnitude, phase, and waveforms."""
    def __init__(self, noisy_dir, clean_dir, n_fft=N_FFT, hop_length=HOP_LENGTH,
                 sr=SR, max_len=MAX_LEN):
        self.noisy_files = sorted(glob.glob(os.path.join(noisy_dir, '*.wav')))
        self.clean_files = sorted(glob.glob(os.path.join(clean_dir, '*.wav')))
        assert len(self.noisy_files) == len(self.clean_files), \
            f'Mismatch: {len(self.noisy_files)} noisy vs {len(self.clean_files)} clean'
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.sr = sr
        self.max_len = max_len
        self.window = torch.hann_window(n_fft)

    def __len__(self):
        return len(self.noisy_files)

    def _load_fix(self, path):
        wav, sr = torchaudio.load(path)
        if sr != self.sr:
            wav = torchaudio.functional.resample(wav, sr, self.sr)
        wav = wav[0]  # mono, shape (samples,)
        if wav.shape[0] > self.max_len:
            start = torch.randint(0, wav.shape[0] - self.max_len, (1,)).item()
            wav = wav[start:start + self.max_len]
        elif wav.shape[0] < self.max_len:
            wav = F.pad(wav, (0, self.max_len - wav.shape[0]))
        return wav

    def __getitem__(self, idx):
        noisy_wav = self._load_fix(self.noisy_files[idx])
        clean_wav = self._load_fix(self.clean_files[idx])

        noisy_stft = torch.stft(noisy_wav, self.n_fft, self.hop_length,
                                window=self.window, return_complex=True)
        clean_stft = torch.stft(clean_wav, self.n_fft, self.hop_length,
                                window=self.window, return_complex=True)

        return {
            'noisy_mag':   noisy_stft.abs(),          # (257, T)
            'clean_mag':   clean_stft.abs(),          # (257, T)
            'noisy_phase': torch.angle(noisy_stft),   # (257, T)
            'noisy_wav':   noisy_wav,                 # (max_len,)
            'clean_wav':   clean_wav,                 # (max_len,)
        }

print(f'STFTSpeechDataset defined — n_fft={N_FFT}, hop={HOP_LENGTH}, freq_bins={N_FREQ}')

## Model: STFTTransformerEnhancer

**Same architecture as Review 2**, only the input dimension changes (257 STFT bins instead of 128 mel bins).

```
Input: (B, 1, 257, T) ← log1p(STFT magnitude)
  → CNN Encoder: Conv2d 1→64→128→256 (3×3, BN, ReLU)
  → mean(freq dim) → (B, T, 256) → Linear(256→256)
  → Sinusoidal PositionalEncoding
  → 2-layer Pre-LN TransformerEncoder (4 heads, d=256, ff=1024)
  → Linear(256→256) → expand to (B, 256, 257, T)
  → CNN Decoder: Conv2d 256→128→64→1 + Sigmoid
Output: mask (B, 257, T) in [0, 1]
Enhanced = mask × noisy_magnitude → ISTFT with noisy phase → waveform
```

In [None]:
# ============================================================================
# Cell 4: Model definition + forward pass test
# ============================================================================
class ConvBlock(nn.Module):
    def __init__(self, in_c, out_c, k=3, s=1, p=1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(in_c, out_c, k, s, p),
            nn.BatchNorm2d(out_c),
            nn.ReLU(inplace=True))
    def forward(self, x):
        return self.net(x)

class PositionalEncoding(nn.Module):
    def __init__(self, d_model, dropout=0.1, max_len=2000):
        super().__init__()
        self.dropout = nn.Dropout(dropout)
        pe = torch.zeros(max_len, d_model)
        pos = torch.arange(max_len).unsqueeze(1).float()
        div = torch.exp(torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(pos * div)
        pe[:, 1::2] = torch.cos(pos * div)
        self.register_buffer('pe', pe.unsqueeze(0))
    def forward(self, x):
        return self.dropout(x + self.pe[:, :x.size(1)])

class STFTTransformerEnhancer(nn.Module):
    def __init__(self, n_freq=257, d_model=256, nhead=4, num_layers=2,
                 dim_ff=1024, dropout=0.1):
        super().__init__()
        self.n_freq = n_freq
        # CNN Encoder
        self.encoder = nn.Sequential(
            ConvBlock(1, 64), ConvBlock(64, 128), ConvBlock(128, 256))
        # Transformer
        self.pre_proj  = nn.Linear(256, d_model)
        self.pos_enc   = PositionalEncoding(d_model, dropout)
        enc_layer = nn.TransformerEncoderLayer(
            d_model, nhead, dim_ff, dropout, batch_first=True, norm_first=True)
        self.transformer = nn.TransformerEncoder(enc_layer, num_layers)
        self.post_proj = nn.Linear(d_model, 256)
        # CNN Decoder
        self.decoder = nn.Sequential(
            ConvBlock(256, 128), ConvBlock(128, 64),
            nn.Conv2d(64, 1, 3, 1, 1), nn.Sigmoid())
        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)

    def forward(self, x):
        # x: (B, 1, n_freq, T)
        enc = self.encoder(x)                                    # (B, 256, n_freq, T)
        feat = enc.mean(dim=2).permute(0, 2, 1)                 # (B, T, 256)
        feat = self.pos_enc(self.pre_proj(feat))                 # (B, T, d_model)
        feat = self.post_proj(self.transformer(feat))            # (B, T, 256)
        feat = feat.permute(0, 2, 1)                             # (B, 256, T)
        feat = feat.unsqueeze(2).expand(-1, -1, self.n_freq, -1) # (B, 256, n_freq, T)
        mask = self.decoder(feat).squeeze(1)                     # (B, n_freq, T)
        return mask

# ---- Test ----
model = STFTTransformerEnhancer(n_freq=N_FREQ).to(device)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Params: {total_params:,} ({total_params/1e6:.2f}M)')

with torch.no_grad():
    dummy = torch.randn(2, 1, N_FREQ, 188).to(device)
    out = model(dummy)
    print(f'Input: {dummy.shape} → Mask: {out.shape}')
    assert out.shape == (2, N_FREQ, 188), f'Shape mismatch: {out.shape}'
    assert 0 <= out.min() and out.max() <= 1, 'Mask not in [0,1]'
    print('Forward pass OK')

In [None]:
# ============================================================================
# Cell 5: Attention weight extraction + SI-SDR utility
# ============================================================================
def get_attention_weights(mdl, x):
    """Extract self-attention weights by manual Pre-LN transformer forward."""
    mdl.eval()
    weights = []
    with torch.no_grad():
        enc = mdl.encoder(x)
        feat = enc.mean(dim=2).permute(0, 2, 1)
        feat = mdl.pos_enc(mdl.pre_proj(feat))
        for layer in mdl.transformer.layers:
            normed = layer.norm1(feat)
            attn_out, w = layer.self_attn(
                normed, normed, normed,
                need_weights=True, average_attn_weights=False)
            weights.append(w.cpu())  # (B, nhead, T, T)
            feat = feat + layer.dropout1(attn_out)
            # Feedforward: replicate Pre-LN FF block
            normed2 = layer.norm2(feat)
            ff_out = layer.linear2(
                F.dropout(layer.activation(layer.linear1(normed2)),
                          p=0.0, training=False))
            feat = feat + ff_out
    return weights

def si_sdr(estimate, reference):
    """Scale-Invariant Signal-to-Distortion Ratio (dB)."""
    ref = reference - reference.mean()
    est = estimate - estimate.mean()
    dot = torch.sum(ref * est)
    s_target = dot * ref / (torch.sum(ref ** 2) + 1e-8)
    e_noise = est - s_target
    return 10 * torch.log10(torch.sum(s_target**2) / (torch.sum(e_noise**2) + 1e-8) + 1e-8)

# Test attention extraction
with torch.no_grad():
    dummy = torch.randn(1, 1, N_FREQ, 188).to(device)
    attn = get_attention_weights(model, dummy)
    print(f'Attention: {len(attn)} layers, shape: {attn[0].shape}')
print('Attention extraction OK')

## Training

**Config:** L1 loss on log-magnitudes, Adam lr=1e-3, ReduceLROnPlateau, 25 epochs, early stopping patience=10.  
**Same hyperparameters as Review 2** — only the spectrogram pipeline changed.

In [None]:
# ============================================================================
# Cell 6: Training setup — dataloaders, optimizer, scheduler
# ============================================================================
MAX_EPOCHS = 25
LR = 1e-3
BATCH = 16
PATIENCE = 10
CKPT = 'stft_transformer_best.pth'

# Re-init model fresh for training
model = STFTTransformerEnhancer(n_freq=N_FREQ).to(device)
print(f'Params: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')
print('Fresh training, Kaiming init')

# Datasets
full_train = STFTSpeechDataset(noisy_train, clean_train)
n_val = int(0.1 * len(full_train))
n_train = len(full_train) - n_val
train_ds, val_ds = torch.utils.data.random_split(
    full_train, [n_train, n_val], generator=torch.Generator().manual_seed(42))
test_ds = STFTSpeechDataset(noisy_test, clean_test)

train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True,
                          num_workers=0, drop_last=True)
val_loader   = DataLoader(val_ds, batch_size=BATCH, shuffle=False, num_workers=0)

# Optimizer & Scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'min', factor=0.5, patience=5)

print(f'train: {n_train} samples')
print(f'val: {n_val} samples')
print(f'Train:{n_train} Val:{n_val} Test:{len(test_ds)} | BS={BATCH} LR={LR}')

In [None]:
# ============================================================================
# Cell 7: Training loop
# ============================================================================
history = {'train_loss': [], 'val_loss': []}
best_val = float('inf')
patience_ctr = 0
t0 = time.time()

for epoch in range(1, MAX_EPOCHS + 1):
    # ---- Train ----
    model.train()
    train_losses = []
    for batch in tqdm(train_loader, desc=f'Ep{epoch}/{MAX_EPOCHS}', leave=False):
        noisy_mag = batch['noisy_mag'].to(device)   # (B, 257, T)
        clean_mag = batch['clean_mag'].to(device)   # (B, 257, T)

        inp = torch.log1p(noisy_mag).unsqueeze(1)   # (B, 1, 257, T)
        mask = model(inp)                            # (B, 257, T)
        enhanced_mag = mask * noisy_mag

        loss = F.l1_loss(torch.log1p(enhanced_mag), torch.log1p(clean_mag))

        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        train_losses.append(loss.item())

    # ---- Validate ----
    model.eval()
    val_losses = []
    with torch.no_grad():
        for batch in val_loader:
            noisy_mag = batch['noisy_mag'].to(device)
            clean_mag = batch['clean_mag'].to(device)
            inp = torch.log1p(noisy_mag).unsqueeze(1)
            mask = model(inp)
            enhanced_mag = mask * noisy_mag
            loss = F.l1_loss(torch.log1p(enhanced_mag), torch.log1p(clean_mag))
            val_losses.append(loss.item())

    tr_loss = np.mean(train_losses)
    va_loss = np.mean(val_losses)
    history['train_loss'].append(tr_loss)
    history['val_loss'].append(va_loss)
    scheduler.step(va_loss)

    elapsed = time.time() - t0
    lr_now = optimizer.param_groups[0]['lr']
    line = f'Ep{epoch:02d} tr={tr_loss:.4f} va={va_loss:.4f} lr={lr_now:.1e} [{elapsed:.0f}s]'

    if va_loss < best_val:
        best_val = va_loss
        patience_ctr = 0
        torch.save({'epoch': epoch, 'model': model.state_dict(), 'val_loss': float(va_loss)}, CKPT)
        print(f'{line}  SAVED best val={va_loss:.4f}')
    else:
        patience_ctr += 1
        print(f'{line}  no improve ({patience_ctr}/{PATIENCE})')

    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'ckpt_ep{epoch}.pth')

    if patience_ctr >= PATIENCE:
        print(f'Early stopping at epoch {epoch}')
        break

best_ep = history['val_loss'].index(min(history['val_loss'])) + 1
print(f'\nDONE best_ep={best_ep} best_val={best_val:.4f}')

## Results

In [None]:
# ============================================================================
# Cell 8: Training curves
# ============================================================================
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
epochs_range = range(1, len(history['train_loss']) + 1)
ax.plot(epochs_range, history['train_loss'], 'b-o', label='Train Loss', markersize=4)
ax.plot(epochs_range, history['val_loss'], 'r-s', label='Val Loss', markersize=4)
ax.axvline(x=best_ep, color='green', linestyle='--', alpha=0.7, label=f'Best epoch ({best_ep})')
ax.set_xlabel('Epoch')
ax.set_ylabel('L1 Loss (log-magnitude)')
ax.set_title('Review 3: STFT-Transformer Training Curves')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('training_curves.png', dpi=150)
plt.show()
print('Saved training_curves.png')

In [None]:
# ============================================================================
# Cell 9: Evaluation — PESQ / STOI / SI-SDR with ISTFT waveform reconstruction
# ============================================================================
# Load best model
ckpt = torch.load(CKPT, map_location=device, weights_only=False)
model.load_state_dict(ckpt['model'])
model.eval()
print(f'Loaded ep {ckpt["epoch"]}, val={ckpt["val_loss"]:.4f}')

window = torch.hann_window(N_FFT).to(device)

pesq_noisy_list, pesq_enh_list = [], []
stoi_noisy_list, stoi_enh_list = [], []
sisdr_noisy_list, sisdr_enh_list = [], []

for i in tqdm(range(len(test_ds)), desc='Eval'):
    sample = test_ds[i]
    noisy_mag   = sample['noisy_mag'].unsqueeze(0).to(device)    # (1, 257, T)
    noisy_phase = sample['noisy_phase'].unsqueeze(0).to(device)  # (1, 257, T)
    clean_wav_np = sample['clean_wav'].numpy()
    noisy_wav_np = sample['noisy_wav'].numpy()

    with torch.no_grad():
        inp  = torch.log1p(noisy_mag).unsqueeze(1)   # (1, 1, 257, T)
        mask = model(inp)                             # (1, 257, T)
        enhanced_mag = (mask * noisy_mag).squeeze(0)  # (257, T)

    # Reconstruct waveform: magnitude × exp(j·phase) → ISTFT
    enhanced_stft = enhanced_mag * torch.exp(1j * noisy_phase.squeeze(0))
    enhanced_wav = torch.istft(enhanced_stft, N_FFT, HOP_LENGTH,
                               window=window, length=MAX_LEN)
    enh_np = enhanced_wav.cpu().numpy()

    # PESQ
    try:
        pesq_noisy_list.append(pesq_metric(SR, clean_wav_np, noisy_wav_np, 'wb'))
        pesq_enh_list.append(pesq_metric(SR, clean_wav_np, enh_np, 'wb'))
    except Exception:
        pass

    # STOI
    stoi_noisy_list.append(stoi_metric(clean_wav_np, noisy_wav_np, SR, extended=False))
    stoi_enh_list.append(stoi_metric(clean_wav_np, enh_np, SR, extended=False))

    # SI-SDR
    c_t = torch.from_numpy(clean_wav_np).float()
    n_t = torch.from_numpy(noisy_wav_np).float()
    e_t = torch.from_numpy(enh_np).float()
    sisdr_noisy_list.append(si_sdr(n_t, c_t).item())
    sisdr_enh_list.append(si_sdr(e_t, c_t).item())

# Aggregate
avg_pesq_n = np.mean(pesq_noisy_list) if pesq_noisy_list else 0.0
avg_pesq_e = np.mean(pesq_enh_list) if pesq_enh_list else 0.0
avg_stoi_n = np.mean(stoi_noisy_list)
avg_stoi_e = np.mean(stoi_enh_list)
avg_sisdr_n = np.mean(sisdr_noisy_list)
avg_sisdr_e = np.mean(sisdr_enh_list)

print(f'\nPESQ: noisy={avg_pesq_n:.3f}  enhanced={avg_pesq_e:.3f}')
print(f'STOI: {avg_stoi_n:.3f} -> {avg_stoi_e:.3f}')
print(f'SI-SDR: {avg_sisdr_n:.2f}dB -> {avg_sisdr_e:.2f}dB')
print(f'\n--- Comparison with previous reviews ---')
print(f'R1 CRN (estimated, no waveform recon): PESQ~3.10')
print(f'R2 Transformer+Mel (real, GriffinLim):  PESQ={1.141:.3f}')
print(f'R3 Transformer+STFT (real, ISTFT):      PESQ={avg_pesq_e:.3f}  delta={avg_pesq_e - 1.141:+.3f}')
print(f'\nParams: {sum(p.numel() for p in model.parameters() if p.requires_grad):,}')

## Analysis: Attention Visualization & Summary

In [None]:
# ============================================================================
# Cell 10: Attention visualization
# ============================================================================
# Use first test sample
sample = test_ds[0]
inp = torch.log1p(sample['noisy_mag'].unsqueeze(0).unsqueeze(0)).to(device)  # (1,1,257,T)
attn_weights = get_attention_weights(model, inp)

fig, axes = plt.subplots(2, 4, figsize=(20, 8))
for layer_idx, aw in enumerate(attn_weights):
    for head_idx in range(4):
        ax = axes[layer_idx, head_idx]
        w = aw[0, head_idx].numpy()  # (T, T)
        ax.imshow(w[:64, :64], aspect='auto', cmap='viridis')
        ax.set_title(f'L{layer_idx+1} H{head_idx+1}', fontsize=11)
        ax.set_xlabel('Key')
        ax.set_ylabel('Query')
plt.suptitle('Self-Attention Weights (first 64 frames)', fontsize=14)
plt.tight_layout()
plt.savefig('attention_weights.png', dpi=150)
plt.show()
print('Saved attention_weights.png')

In [None]:
# ============================================================================
# Cell 11: Save review3_summary.json + final comparison
# ============================================================================
summary = {
    'review': 3,
    'model': 'STFTTransformerEnhancer',
    'pipeline': 'STFT (n_fft=512, hop=256) → mask → ISTFT',
    'params': sum(p.numel() for p in model.parameters() if p.requires_grad),
    'best_epoch': int(ckpt['epoch']),
    'best_val_loss': float(ckpt['val_loss']),
    'epochs_trained': len(history['train_loss']),
    'metrics': {
        'pesq_noisy': round(avg_pesq_n, 3),
        'pesq_enhanced': round(avg_pesq_e, 3),
        'stoi_noisy': round(avg_stoi_n, 4),
        'stoi_enhanced': round(avg_stoi_e, 4),
        'sisdr_noisy_dB': round(avg_sisdr_n, 2),
        'sisdr_enhanced_dB': round(avg_sisdr_e, 2),
    },
    'comparison': {
        'R1_CRN_pesq_estimated': 3.10,
        'R2_Transformer_Mel_pesq_real': 1.141,
        'R3_Transformer_STFT_pesq_real': round(avg_pesq_e, 3),
    },
    'history': history,
}

with open('review3_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)
print('Saved: review3_summary.json')

# ---- Pretty comparison table ----
print('\n' + '='*70)
print(f'{"Review":>10} {"Pipeline":>25} {"PESQ":>8} {"STOI":>8} {"SI-SDR":>10} {"Params":>10}')
print('='*70)
print(f'{"R1 CRN":>10} {"Mel (estimated)":>25} {"~3.10":>8} {"—":>8} {"—":>10} {"~2.5M":>10}')
print(f'{"R2 Trans":>10} {"Mel+GriffinLim":>25} {1.141:>8.3f} {0.695:>8.3f} {-25.58:>9.2f}dB {"2.45M":>10}')
print(f'{"R3 Trans":>10} {"STFT+ISTFT":>25} {avg_pesq_e:>8.3f} {avg_stoi_e:>8.3f} {avg_sisdr_e:>9.2f}dB {total_params/1e6:>8.2f}M')
print('='*70)
print(f'\nNoisy baseline: PESQ={avg_pesq_n:.3f}  STOI={avg_stoi_n:.3f}  SI-SDR={avg_sisdr_n:.2f}dB')
print(f'R3 improvement: dPESQ={avg_pesq_e - avg_pesq_n:+.3f}  dSTOI={avg_stoi_e - avg_stoi_n:+.4f}  dSI-SDR={avg_sisdr_e - avg_sisdr_n:+.2f}dB')