# CRN Baseline — STFT Speech Enhancement (Fixed)

**What changed from the original CRN baseline:**
1. Switched from **Mel spectrogram** (non-invertible) → **STFT** (lossless ISTFT)
2. Real **PESQ / STOI / SI-SDR** on actual waveforms (no more "estimated" metrics)
3. Proper CRN architecture: CNN encoder → LSTM (per-frequency-bin) → CNN decoder → sigmoid mask
4. Checkpoint saving works correctly

**Architecture:**
```
Input: (B, 1, 257, T) ← log1p(STFT magnitude)
  → CNN Encoder (3 layers, stride-2 on freq): (B, 256, 33, T)
  → Reshape → LSTM(256, hidden=256, 2 layers) across time
  → CNN Decoder + Sigmoid → mask (B, 257, T)
Reconstruction: mask × noisy_mag → ISTFT with noisy phase → waveform
```

**Team:** Krishnasinh Jadeja (22BLC1211), Kirtan Sondagar (22BLC1228), Prabhu Kalyan Panda (22BLC1213)
**Guide:** Dr. Praveen Jaraut — VIT Bhopal Capstone

In [1]:
# ============================================================================
# Cell 1: Install deps + Imports + Config
# ============================================================================
!pip install pesq==0.0.4 pystoi -q

import torch, torch.nn as nn, torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt
from tqdm.auto import tqdm
import glob, os, json, time, warnings
warnings.filterwarnings('ignore')
from pesq import pesq as pesq_metric
from pystoi import stoi as stoi_metric

torch.manual_seed(42)
np.random.seed(42)

# STFT config (same as all reviews — consistent comparison)
N_FFT      = 512
HOP_LENGTH = 256
N_FREQ     = N_FFT // 2 + 1   # 257
SR         = 16000
MAX_LEN    = 48000             # 3 s

device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f'Device: {device}')
if device == 'cuda':
    props = torch.cuda.get_device_properties(0)
    vram = getattr(props, 'total_memory', getattr(props, 'total_mem', 0))
    print(f'GPU: {torch.cuda.get_device_name(0)} | VRAM: {vram/1e9:.1f}GB')
print(f'STFT: n_fft={N_FFT}, hop={HOP_LENGTH}, freq={N_FREQ}')
print('Imports OK')

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for pesq (setup.py) ... [?25l[?25hdone
Device: cuda
GPU: Tesla P100-PCIE-16GB | VRAM: 17.1GB
STFT: n_fft=512, hop=256, freq=257
Imports OK


## Dataset: LibriSpeech-Noise
Download & extract `earth16/libri-speech-noise-dataset` (7000 train + 105 test WAV pairs).

In [2]:
# ============================================================================
# Cell 2: Dataset download & extraction
# ============================================================================
import subprocess, zipfile

data_base = '/kaggle/working/data'
dl_tmp    = '/kaggle/working/dl_tmp'
os.makedirs(data_base, exist_ok=True)
os.makedirs(dl_tmp, exist_ok=True)
done_flag = os.path.join(data_base, '.done')

if os.path.exists(done_flag):
    print('Dataset already extracted, skipping')
else:
    mounted = '/kaggle/input/libri-speech-noise-dataset'
    if os.path.isdir(mounted) and len(os.listdir(mounted)) > 0:
        src = mounted
        print(f'Using mounted dataset at {src}')
    else:
        print('Dataset not mounted, downloading via kaggle API...')
        subprocess.run(['kaggle', 'datasets', 'download',
                        'earth16/libri-speech-noise-dataset', '-p', dl_tmp], check=True)
        zf = os.path.join(dl_tmp, 'libri-speech-noise-dataset.zip')
        if os.path.exists(zf):
            with zipfile.ZipFile(zf, 'r') as z:
                z.extractall(dl_tmp)
            os.remove(zf)
        src = dl_tmp
        print(f'Downloaded to {src}')

    subprocess.run(['apt-get', 'install', '-y', 'p7zip-full'], capture_output=True)
    for arch in ['train.7z', 'y_train.7z', 'test.7z', 'y_test.7z']:
        fp = os.path.join(src, arch)
        if os.path.exists(fp):
            print(f'Extracting {arch}...')
            subprocess.run(['7z', 'x', fp, f'-o{data_base}', '-y'], capture_output=True)
    open(done_flag, 'w').close()

def find_wav_dir(base, name):
    for root, dirs, files in os.walk(base):
        if os.path.basename(root) == name and any(f.endswith('.wav') for f in files):
            return root
    return None

noisy_train = find_wav_dir(data_base, 'train')
clean_train = find_wav_dir(data_base, 'y_train')
noisy_test  = find_wav_dir(data_base, 'test')
clean_test  = find_wav_dir(data_base, 'y_test')

for tag, d in [('noisy_train', noisy_train), ('clean_train', clean_train),
               ('noisy_test', noisy_test), ('clean_test', clean_test)]:
    n = len(glob.glob(os.path.join(d, '*.wav'))) if d else 0
    print(f'  {tag}: {d} ({n} files)')

Dataset not mounted, downloading via kaggle API...
Dataset URL: https://www.kaggle.com/datasets/earth16/libri-speech-noise-dataset
License(s): DbCL-1.0
Downloading libri-speech-noise-dataset.zip to /kaggle/working/dl_tmp


100%|██████████| 6.03G/6.03G [00:26<00:00, 243MB/s]



Downloaded to /kaggle/working/dl_tmp
Extracting train.7z...
Extracting y_train.7z...
Extracting test.7z...
Extracting y_test.7z...
  noisy_train: /kaggle/working/data/train (7000 files)
  clean_train: /kaggle/working/data/y_train (7000 files)
  noisy_test: /kaggle/working/data/test (105 files)
  clean_test: /kaggle/working/data/y_test (105 files)


## STFT Dataset
Same STFT config as R2/R3 for fair comparison. Returns magnitude, phase, waveforms.

In [3]:
# ============================================================================
# Cell 3: STFTSpeechDataset
# ============================================================================
class STFTSpeechDataset(Dataset):
    def __init__(self, noisy_dir, clean_dir, n_fft=N_FFT, hop_length=HOP_LENGTH,
                 sr=SR, max_len=MAX_LEN):
        self.noisy_files = sorted(glob.glob(os.path.join(noisy_dir, '*.wav')))
        self.clean_files = sorted(glob.glob(os.path.join(clean_dir, '*.wav')))
        assert len(self.noisy_files) == len(self.clean_files), \
            f'Mismatch: {len(self.noisy_files)} noisy vs {len(self.clean_files)} clean'
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.sr = sr
        self.max_len = max_len
        self.window = torch.hann_window(n_fft)

    def __len__(self):
        return len(self.noisy_files)

    def _load_fix(self, path):
        wav, sr = torchaudio.load(path)
        if sr != self.sr:
            wav = torchaudio.functional.resample(wav, sr, self.sr)
        wav = wav[0]
        if wav.shape[0] > self.max_len:
            start = torch.randint(0, wav.shape[0] - self.max_len, (1,)).item()
            wav = wav[start:start + self.max_len]
        elif wav.shape[0] < self.max_len:
            wav = F.pad(wav, (0, self.max_len - wav.shape[0]))
        return wav

    def __getitem__(self, idx):
        noisy_wav = self._load_fix(self.noisy_files[idx])
        clean_wav = self._load_fix(self.clean_files[idx])
        noisy_stft = torch.stft(noisy_wav, self.n_fft, self.hop_length,
                                window=self.window, return_complex=True)
        clean_stft = torch.stft(clean_wav, self.n_fft, self.hop_length,
                                window=self.window, return_complex=True)
        return {
            'noisy_mag':   noisy_stft.abs(),
            'clean_mag':   clean_stft.abs(),
            'noisy_phase': torch.angle(noisy_stft),
            'noisy_wav':   noisy_wav,
            'clean_wav':   clean_wav,
        }

print(f'STFTSpeechDataset defined (n_fft={N_FFT}, hop={HOP_LENGTH})')

STFTSpeechDataset defined (n_fft=512, hop=256)


## CRN Model (Fixed)

**Key fixes from original:**
1. **STFT input** (257 freq bins) instead of mel (128) — enables lossless ISTFT reconstruction
2. **Per-frequency LSTM**: reshapes to `(B*F', T, C)` so LSTM models temporal dynamics for each frequency sub-band independently — instead of collapsing frequency with `mean(dim=2)`
3. **Proper CNN freq downsampling/upsampling** with stride-2 convolutions + transposed convolutions

```
CNN Encoder: (B, 1, 257, T)
  -> Conv2d(1->64, stride=(2,1)) -> (B, 64, 129, T)
  -> Conv2d(64->128, stride=(2,1)) -> (B, 128, 65, T)
  -> Conv2d(128->256, stride=(2,1)) -> (B, 256, 33, T)
LSTM: reshape to (B*33, T, 256) -> LSTM(256, 256, 2 layers) -> (B*33, T, 256)
CNN Decoder: (B, 256, 33, T)
  -> ConvT2d(256->128, stride=(2,1)) -> (B, 128, 65, T)
  -> ConvT2d(128->64, stride=(2,1)) -> (B, 64, 129, T)
  -> ConvT2d(64->32, stride=(2,1)) -> (B, 32, 257, T)
  -> Conv2d(32->1, 1x1) + Sigmoid -> mask (B, 1, 257, T)
```

In [4]:
# ============================================================================
# Cell 4: CRN Baseline Model (Fixed -- STFT based)
# ============================================================================
class CRNBaseline(nn.Module):
    """CRN for STFT-based speech enhancement."""
    def __init__(self, n_freq=257):
        super().__init__()
        self.n_freq = n_freq

        # CNN Encoder: downsample frequency with stride-2
        self.enc1 = nn.Sequential(
            nn.Conv2d(1, 64, kernel_size=3, stride=(2, 1), padding=1),
            nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.enc2 = nn.Sequential(
            nn.Conv2d(64, 128, kernel_size=3, stride=(2, 1), padding=1),
            nn.BatchNorm2d(128), nn.ReLU(inplace=True))
        self.enc3 = nn.Sequential(
            nn.Conv2d(128, 256, kernel_size=3, stride=(2, 1), padding=1),
            nn.BatchNorm2d(256), nn.ReLU(inplace=True))

        # LSTM: processes each frequency sub-band across time
        self.lstm = nn.LSTM(
            input_size=256, hidden_size=256, num_layers=2,
            batch_first=True, dropout=0.1)

        # CNN Decoder: upsample frequency back with transposed convolutions
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=3, stride=(2, 1), padding=1, output_padding=(1, 0)),
            nn.BatchNorm2d(128), nn.ReLU(inplace=True))
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(128, 64, kernel_size=3, stride=(2, 1), padding=1, output_padding=(1, 0)),
            nn.BatchNorm2d(64), nn.ReLU(inplace=True))
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(64, 32, kernel_size=3, stride=(2, 1), padding=1, output_padding=(1, 0)),
            nn.BatchNorm2d(32), nn.ReLU(inplace=True))

        # Final 1x1 conv + sigmoid mask
        self.mask_conv = nn.Sequential(
            nn.Conv2d(32, 1, kernel_size=1),
            nn.Sigmoid())

        self._init_weights()

    def _init_weights(self):
        for m in self.modules():
            if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
                nn.init.kaiming_normal_(m.weight, nonlinearity='relu')
                if m.bias is not None:
                    nn.init.zeros_(m.bias)
            elif isinstance(m, nn.BatchNorm2d):
                nn.init.ones_(m.weight)
                nn.init.zeros_(m.bias)
            elif isinstance(m, nn.LSTM):
                for name, param in m.named_parameters():
                    if 'weight' in name:
                        nn.init.xavier_normal_(param)
                    elif 'bias' in name:
                        nn.init.zeros_(param)

    def forward(self, x):
        # x: (B, 1, n_freq, T)
        B, _, F_orig, T_orig = x.shape

        # Encode (downsample freq: 257->129->65->33)
        e1 = self.enc1(x)     # (B, 64, 129, T)
        e2 = self.enc2(e1)    # (B, 128, 65, T)
        e3 = self.enc3(e2)    # (B, 256, 33, T)

        # LSTM: per-frequency-bin processing across time
        B2, C, Fenc, T = e3.shape
        # Reshape: (B, C, Fenc, T) -> (B*Fenc, T, C)
        lstm_in = e3.permute(0, 2, 3, 1).reshape(B2 * Fenc, T, C)
        lstm_out, _ = self.lstm(lstm_in)  # (B*Fenc, T, C)
        # Reshape back: (B*Fenc, T, C) -> (B, C, Fenc, T)
        h = lstm_out.reshape(B2, Fenc, T, C).permute(0, 3, 1, 2)  # (B, 256, 33, T)

        # Decode (upsample freq: 33->65->129->257)
        d3 = self.dec3(h)      # (B, 128, 65, T)
        d2 = self.dec2(d3)     # (B, 64, 129, T)
        d1 = self.dec1(d2)     # (B, 32, 257, T)

        # Crop/pad to match original freq dimension
        if d1.shape[2] != F_orig:
            d1 = F.interpolate(d1, size=(F_orig, T_orig), mode='bilinear', align_corners=False)

        mask = self.mask_conv(d1).squeeze(1)  # (B, n_freq, T)
        return mask


# Quick test
model = CRNBaseline(n_freq=N_FREQ).to(device)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'CRNBaseline: {total_params:,} params ({total_params/1e6:.2f}M)')

with torch.no_grad():
    dummy = torch.randn(2, 1, N_FREQ, 188).to(device)
    mask = model(dummy)
    print(f'Input: {dummy.shape} -> Mask: {mask.shape}')
    assert mask.shape == (2, N_FREQ, 188), f'Shape mismatch: {mask.shape}'
    assert mask.min().item() >= 0 and mask.max().item() <= 1
    print(f'Mask range: [{mask.min().item():.4f}, {mask.max().item():.4f}]')
    print('Forward pass OK')

# Architecture breakdown
enc_p = sum(p.numel() for n, p in model.named_parameters() if 'enc' in n)
lstm_p = sum(p.numel() for n, p in model.named_parameters() if 'lstm' in n)
dec_p = sum(p.numel() for n, p in model.named_parameters() if 'dec' in n or 'mask_conv' in n)
print(f'  Encoder:  {enc_p:>10,} ({enc_p/total_params*100:.1f}%)')
print(f'  LSTM:     {lstm_p:>10,} ({lstm_p/total_params*100:.1f}%)')
print(f'  Decoder:  {dec_p:>10,} ({dec_p/total_params*100:.1f}%)')

CRNBaseline: 1,811,009 params (1.81M)
Input: torch.Size([2, 1, 257, 188]) -> Mask: torch.Size([2, 257, 188])
Mask range: [0.0175, 0.9935]
Forward pass OK
  Encoder:     370,560 (20.5%)
  LSTM:      1,052,672 (58.1%)
  Decoder:     387,777 (21.4%)


In [5]:
# ============================================================================
# Cell 5: SI-SDR utility
# ============================================================================
def si_sdr(estimate, reference):
    ref = reference - reference.mean()
    est = estimate  - estimate.mean()
    dot = torch.sum(ref * est)
    s_target = dot * ref / (torch.sum(ref**2) + 1e-8)
    e_noise  = est - s_target
    return 10 * torch.log10(torch.sum(s_target**2) / (torch.sum(e_noise**2) + 1e-8) + 1e-8)

print('si_sdr defined')

si_sdr defined


## Training
L1 loss on log-magnitude, Adam lr=1e-3 with ReduceLROnPlateau.
Batch=16, 25 epochs, patience=10.

In [6]:
# ============================================================================
# Cell 6: Training setup
# ============================================================================
MAX_EPOCHS    = 25
LR            = 1e-3
BATCH         = 16
PATIENCE      = 10
CKPT          = 'crn_baseline_best.pth'

# Re-init model fresh
model = CRNBaseline(n_freq=N_FREQ).to(device)
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Params: {total_params:,}')

# Datasets
full_train = STFTSpeechDataset(noisy_train, clean_train)
n_val   = int(0.1 * len(full_train))
n_train = len(full_train) - n_val
train_ds, val_ds = torch.utils.data.random_split(
    full_train, [n_train, n_val], generator=torch.Generator().manual_seed(42))
test_ds = STFTSpeechDataset(noisy_test, clean_test)

train_loader = DataLoader(train_ds, batch_size=BATCH, shuffle=True,
                          num_workers=0, drop_last=True)
val_loader   = DataLoader(val_ds,   batch_size=BATCH, shuffle=False, num_workers=0)

optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-5)
scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
    optimizer, 'min', factor=0.5, patience=5)

print(f'Train:{n_train} Val:{n_val} Test:{len(test_ds)} | BS={BATCH} LR={LR}')

Params: 1,811,009
Train:6300 Val:700 Test:105 | BS=16 LR=0.001


In [7]:
# ============================================================================
# Cell 7: Training loop
# ============================================================================
history = {'train_loss': [], 'val_loss': []}
best_val = float('inf')
patience_ctr = 0
t0 = time.time()

for epoch in range(1, MAX_EPOCHS + 1):
    # --- Train ---
    model.train()
    train_losses = []
    for batch in tqdm(train_loader, desc=f'Ep{epoch}/{MAX_EPOCHS}', leave=False):
        noisy_mag = batch['noisy_mag'].to(device)
        clean_mag = batch['clean_mag'].to(device)
        inp  = torch.log1p(noisy_mag).unsqueeze(1)   # (B, 1, 257, T)
        mask = model(inp)                              # (B, 257, T)
        enhanced_mag = mask * noisy_mag
        loss = F.l1_loss(torch.log1p(enhanced_mag), torch.log1p(clean_mag))
        optimizer.zero_grad()
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 5.0)
        optimizer.step()
        train_losses.append(loss.item())

    # --- Validate ---
    model.eval()
    val_losses = []
    with torch.no_grad():
        for batch in val_loader:
            noisy_mag = batch['noisy_mag'].to(device)
            clean_mag = batch['clean_mag'].to(device)
            inp  = torch.log1p(noisy_mag).unsqueeze(1)
            mask = model(inp)
            enhanced_mag = mask * noisy_mag
            loss = F.l1_loss(torch.log1p(enhanced_mag), torch.log1p(clean_mag))
            val_losses.append(loss.item())

    tr_loss = np.mean(train_losses)
    va_loss = np.mean(val_losses)
    history['train_loss'].append(tr_loss)
    history['val_loss'].append(va_loss)
    scheduler.step(va_loss)

    elapsed = time.time() - t0
    lr_now  = optimizer.param_groups[0]['lr']
    line = f'Ep{epoch:02d} tr={tr_loss:.4f} va={va_loss:.4f} lr={lr_now:.1e} [{elapsed:.0f}s]'

    if va_loss < best_val:
        best_val = va_loss
        patience_ctr = 0
        torch.save({'epoch': epoch, 'model': model.state_dict(),
                     'val_loss': float(va_loss)}, CKPT)
        print(f'{line}  SAVED best={va_loss:.4f}')
    else:
        patience_ctr += 1
        print(f'{line}  no improve ({patience_ctr}/{PATIENCE})')

    if epoch % 5 == 0:
        torch.save(model.state_dict(), f'ckpt_ep{epoch}.pth')

    if patience_ctr >= PATIENCE:
        print(f'Early stopping at epoch {epoch}')
        break

best_ep = history['val_loss'].index(min(history['val_loss'])) + 1
print(f'\nDONE best_ep={best_ep} best_val={best_val:.4f} time={time.time()-t0:.0f}s')

Ep1/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep01 tr=0.1042 va=0.1035 lr=1.0e-03 [274s]  SAVED best=0.1035


Ep2/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep02 tr=0.1031 va=0.1037 lr=1.0e-03 [545s]  no improve (1/10)


Ep3/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep03 tr=0.1029 va=0.1030 lr=1.0e-03 [815s]  SAVED best=0.1030


Ep4/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep04 tr=0.1028 va=0.1023 lr=1.0e-03 [1086s]  SAVED best=0.1023


Ep5/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep05 tr=0.1027 va=0.1026 lr=1.0e-03 [1357s]  no improve (1/10)


Ep6/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep06 tr=0.1022 va=0.1031 lr=1.0e-03 [1628s]  no improve (2/10)


Ep7/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep07 tr=0.1028 va=0.1027 lr=1.0e-03 [1900s]  no improve (3/10)


Ep8/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep08 tr=0.1021 va=0.1026 lr=1.0e-03 [2172s]  no improve (4/10)


Ep9/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep09 tr=0.1023 va=0.1029 lr=1.0e-03 [2444s]  no improve (5/10)


Ep10/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep10 tr=0.1021 va=0.1023 lr=5.0e-04 [2716s]  no improve (6/10)


Ep11/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep11 tr=0.1019 va=0.1031 lr=5.0e-04 [2988s]  no improve (7/10)


Ep12/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep12 tr=0.1022 va=0.1026 lr=5.0e-04 [3261s]  no improve (8/10)


Ep13/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep13 tr=0.1020 va=0.1029 lr=5.0e-04 [3536s]  no improve (9/10)


Ep14/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep14 tr=0.1018 va=0.1020 lr=5.0e-04 [3809s]  SAVED best=0.1020


Ep15/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep15 tr=0.1018 va=0.1021 lr=5.0e-04 [4083s]  no improve (1/10)


Ep16/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep16 tr=0.1017 va=0.1026 lr=5.0e-04 [4357s]  no improve (2/10)


Ep17/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep17 tr=0.1019 va=0.1021 lr=5.0e-04 [4627s]  no improve (3/10)


Ep18/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep18 tr=0.1018 va=0.1017 lr=5.0e-04 [4895s]  SAVED best=0.1017


Ep19/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep19 tr=0.1019 va=0.1024 lr=5.0e-04 [5162s]  no improve (1/10)


Ep20/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep20 tr=0.1016 va=0.1016 lr=5.0e-04 [5430s]  SAVED best=0.1016


Ep21/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep21 tr=0.1017 va=0.1015 lr=5.0e-04 [5698s]  SAVED best=0.1015


Ep22/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep22 tr=0.1018 va=0.1009 lr=5.0e-04 [5964s]  SAVED best=0.1009


Ep23/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep23 tr=0.1016 va=0.1032 lr=5.0e-04 [6230s]  no improve (1/10)


Ep24/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep24 tr=0.1017 va=0.1026 lr=5.0e-04 [6498s]  no improve (2/10)


Ep25/25:   0%|          | 0/393 [00:00<?, ?it/s]

Ep25 tr=0.1015 va=0.1011 lr=5.0e-04 [6765s]  no improve (3/10)

DONE best_ep=22 best_val=0.1009 time=6765s


## Results

In [8]:
# ============================================================================
# Cell 8: Training curves
# ============================================================================
fig, ax = plt.subplots(1, 1, figsize=(10, 5))
eps = range(1, len(history['train_loss']) + 1)
ax.plot(eps, history['train_loss'], 'b-o', label='Train', ms=3)
ax.plot(eps, history['val_loss'], 'r-s', label='Val', ms=3)
ax.axvline(best_ep, color='g', ls='--', alpha=0.7, label=f'Best (ep{best_ep})')
ax.set_xlabel('Epoch')
ax.set_ylabel('L1 Loss (log-magnitude)')
ax.set_title('CRN Baseline (Fixed STFT) — Training Curves')
ax.legend()
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig('training_curves.png', dpi=150)
plt.show()
print('Saved training_curves.png')

Saved training_curves.png


## Evaluation
**REAL** PESQ / STOI / SI-SDR on 105 test samples via ISTFT waveform reconstruction.
No more estimated metrics!

In [9]:
# ============================================================================
# Cell 9: Evaluation — REAL PESQ / STOI / SI-SDR
# ============================================================================
ckpt = torch.load(CKPT, map_location=device, weights_only=False)
model.load_state_dict(ckpt['model'])
model.eval()
print(f'Loaded: epoch={ckpt["epoch"]}, val_loss={ckpt["val_loss"]:.4f}')

window_eval = torch.hann_window(N_FFT).to(device)

pesq_noisy_list, pesq_enh_list   = [], []
stoi_noisy_list, stoi_enh_list   = [], []
sisdr_noisy_list, sisdr_enh_list = [], []

for i in tqdm(range(len(test_ds)), desc='Eval'):
    s           = test_ds[i]
    noisy_mag   = s['noisy_mag'].unsqueeze(0).to(device)
    noisy_phase = s['noisy_phase'].unsqueeze(0).to(device)
    clean_np    = s['clean_wav'].numpy()
    noisy_np    = s['noisy_wav'].numpy()

    with torch.no_grad():
        inp     = torch.log1p(noisy_mag).unsqueeze(1)
        mask    = model(inp)
        enh_mag = (mask * noisy_mag).squeeze(0)

    enh_stft = enh_mag * torch.exp(1j * noisy_phase.squeeze(0))
    enh_wav  = torch.istft(enh_stft, N_FFT, HOP_LENGTH,
                           window=window_eval, length=MAX_LEN)
    enh_np   = enh_wav.cpu().numpy()

    try:
        pesq_noisy_list.append(pesq_metric(SR, clean_np, noisy_np, 'wb'))
        pesq_enh_list.append(  pesq_metric(SR, clean_np, enh_np,   'wb'))
    except Exception as e:
        print(f'  PESQ err {i}: {e}')

    stoi_noisy_list.append(stoi_metric(clean_np, noisy_np, SR, extended=False))
    stoi_enh_list.append(  stoi_metric(clean_np, enh_np,   SR, extended=False))

    c_t = torch.from_numpy(clean_np).float()
    n_t = torch.from_numpy(noisy_np).float()
    e_t = torch.from_numpy(enh_np).float()
    sisdr_noisy_list.append(si_sdr(n_t, c_t).item())
    sisdr_enh_list.append(  si_sdr(e_t, c_t).item())

avg = lambda lst: float(np.mean(lst)) if lst else 0.0
avg_pesq_n,  avg_pesq_e  = avg(pesq_noisy_list),  avg(pesq_enh_list)
avg_stoi_n,  avg_stoi_e  = avg(stoi_noisy_list),  avg(stoi_enh_list)
avg_sisdr_n, avg_sisdr_e = avg(sisdr_noisy_list), avg(sisdr_enh_list)

print(f'\nResults on {len(test_ds)} test files:')
print(f'  PESQ  : noisy={avg_pesq_n:.3f}  enhanced={avg_pesq_e:.3f}  d={avg_pesq_e-avg_pesq_n:+.3f}')
print(f'  STOI  : noisy={avg_stoi_n:.3f}  enhanced={avg_stoi_e:.3f}  d={avg_stoi_e-avg_stoi_n:+.4f}')
print(f'  SI-SDR: noisy={avg_sisdr_n:.2f}dB  enh={avg_sisdr_e:.2f}dB  d={avg_sisdr_e-avg_sisdr_n:+.2f}dB')

Loaded: epoch=22, val_loss=0.1009


Eval:   0%|          | 0/105 [00:00<?, ?it/s]


Results on 105 test files:
  PESQ  : noisy=1.126  enhanced=1.144  d=+0.017
  STOI  : noisy=0.215  enhanced=0.336  d=+0.1209
  SI-SDR: noisy=-44.04dB  enh=-41.03dB  d=+3.01dB


## Visualization
Spectrogram comparison: noisy vs enhanced vs clean, plus predicted mask.

In [10]:
# ============================================================================
# Cell 10: Spectrogram comparison
# ============================================================================
sample = test_ds[0]
noisy_mag_s = sample['noisy_mag'].unsqueeze(0).to(device)

with torch.no_grad():
    inp_s  = torch.log1p(noisy_mag_s).unsqueeze(1)
    mask_s = model(inp_s)
    enh_mag_s = (mask_s * noisy_mag_s).squeeze(0).cpu()

noisy_spec = sample['noisy_mag'].numpy()
clean_spec = sample['clean_mag'].numpy()
enh_spec   = enh_mag_s.numpy()
mask_np    = mask_s.squeeze(0).cpu().numpy()

fig, axes = plt.subplots(2, 2, figsize=(14, 10))

for ax, spec, title in [
    (axes[0,0], np.log1p(noisy_spec), 'Noisy Input'),
    (axes[0,1], np.log1p(clean_spec), 'Clean Target'),
    (axes[1,0], np.log1p(enh_spec),   'Enhanced (CRN)'),
    (axes[1,1], mask_np,              'Predicted Mask'),
]:
    im = ax.imshow(spec, aspect='auto', origin='lower', cmap='viridis')
    ax.set_title(title, fontsize=13)
    ax.set_xlabel('Time frame')
    ax.set_ylabel('Frequency bin')
    plt.colorbar(im, ax=ax, fraction=0.046)

plt.suptitle('CRN Baseline: Spectrogram Comparison (Test Sample 0)', fontsize=14)
plt.tight_layout()
plt.savefig('spectrogram_comparison.png', dpi=150)
plt.show()
print('Saved spectrogram_comparison.png')

Saved spectrogram_comparison.png


In [11]:
# ============================================================================
# Cell 11: Summary JSON + comparison table
# ============================================================================
summary = {
    'review': 'R1_CRN_Fixed',
    'model': 'CRNBaseline',
    'approach': 'Conv-Recurrent Network (STFT-based)',
    'pipeline': 'STFT (n_fft=512, hop=256) -> CRN mask -> ISTFT',
    'params': total_params,
    'checkpoint': {'epoch': int(ckpt['epoch']), 'val_loss': float(ckpt['val_loss'])},
    'test_samples': len(test_ds),
    'metrics': {
        'pesq_noisy':        round(avg_pesq_n, 3),
        'pesq_enhanced':     round(avg_pesq_e, 3),
        'stoi_noisy':        round(avg_stoi_n, 4),
        'stoi_enhanced':     round(avg_stoi_e, 4),
        'sisdr_noisy_dB':    round(avg_sisdr_n, 2),
        'sisdr_enhanced_dB': round(avg_sisdr_e, 2),
    },
    'history': history,
}
with open('crn_baseline_summary.json', 'w') as f:
    json.dump(summary, f, indent=2)
print('Saved crn_baseline_summary.json')

W = 80
print('\n' + '='*W)
print(f'{"Model":>15} {"Pipeline":>20} {"PESQ":>8} {"STOI":>7} {"SI-SDR":>10} {"Params":>10}')
print('='*W)
print(f'{"Noisy":>15} {"---":>20} {avg_pesq_n:>8.3f} {avg_stoi_n:>7.3f} {avg_sisdr_n:>9.2f}dB {"---":>10}')
print(f'{"CRN (fixed)":>15} {"LSTM+STFT":>20} {avg_pesq_e:>8.3f} {avg_stoi_e:>7.3f} {avg_sisdr_e:>9.2f}dB {total_params/1e6:>8.2f}M')
print('='*W)
dp = avg_pesq_e - avg_pesq_n
ds = avg_stoi_e - avg_stoi_n
dd = avg_sisdr_e - avg_sisdr_n
print(f'CRN vs noisy: dPESQ={dp:+.3f}  dSTOI={ds:+.4f}  dSI-SDR={dd:+.2f}dB')

Saved crn_baseline_summary.json

          Model             Pipeline     PESQ    STOI     SI-SDR     Params
          Noisy                  ---    1.126   0.215    -44.04dB        ---
    CRN (fixed)            LSTM+STFT    1.144   0.336    -41.03dB     1.81M
CRN vs noisy: dPESQ=+0.017  dSTOI=+0.1209  dSI-SDR=+3.01dB
