In [1]:
# ============================================================================
# CELL 1: Installation & Imports
# ============================================================================

# Run this first in Colab:
!pip install -q einops soundfile torchaudio pystoi pesq

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from torch.optim import AdamW
from torch.cuda.amp import autocast, GradScaler

import torchaudio
import soundfile as sf
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.notebook import tqdm
import warnings
import os
import time
import tarfile
from datetime import datetime
from IPython.display import Audio, display
import json
import gc

warnings.filterwarnings('ignore')

# Memory management
os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:128,expandable_segments:True'

def clear_gpu_memory():
    """Clear GPU memory cache"""
    gc.collect()
    torch.cuda.empty_cache()
    if torch.cuda.is_available():
        torch.cuda.synchronize()

print("‚úÖ Imports completed!")

  Preparing metadata (setup.py) ... [?25l[?25hdone
  Building wheel for pesq (setup.py) ... [?25l[?25hdone
‚úÖ Imports completed!


In [2]:
# ============================================================================
# CELL 2: Setup and Configuration
# ============================================================================

from google.colab import drive

# ‚úÖ Safe Drive mounting
def mount_drive_safe():
    """Mount Drive safely, skip if already mounted"""
    if not os.path.exists('/content/drive/MyDrive'):
        print("üîí Mounting Google Drive...")
        drive.mount('/content/drive')
        print("‚úÖ Drive mounted!")
    else:
        print("‚úÖ Drive already mounted, skipping...")

mount_drive_safe()

# Project setup
PROJECT_ROOT = '/content/drive/MyDrive/PPSI-Net-Split-v3'
os.makedirs(f'{PROJECT_ROOT}/data', exist_ok=True)
os.makedirs(f'{PROJECT_ROOT}/data/LibriSpeech', exist_ok=True)
os.makedirs(f'{PROJECT_ROOT}/checkpoints', exist_ok=True)
os.makedirs(f'{PROJECT_ROOT}/checkpoints/phase', exist_ok=True)
os.makedirs(f'{PROJECT_ROOT}/logs', exist_ok=True)

# ‚úÖ Define variables BEFORE use
tar_path = f'{PROJECT_ROOT}/data/train-clean-100.tar.gz'
extract_to = f'{PROJECT_ROOT}/data'

# ‚úÖ Auto-extract LibriSpeech
if os.path.exists(tar_path):
    target_dir = f'{extract_to}/LibriSpeech/train-clean-100'
    if not os.path.exists(target_dir):
        print("üì¶ Extracting LibriSpeech dataset...")
        print("   This may take 5-10 minutes...")
        try:
            with tarfile.open(tar_path, 'r:gz') as tar:
                tar.extractall(path=extract_to)
            print("‚úÖ Extraction complete!")
        except Exception as e:
            print(f"‚ùå Extraction failed: {e}")
    else:
        print("‚úÖ LibriSpeech already extracted!")
else:
    print(f"‚ö†Ô∏è  train-clean-100.tar.gz not found at: {tar_path}")
    print("   Please download from: https://www.openslr.org/12/")

# Verify dataset
librispeech_path = f'{extract_to}/LibriSpeech/train-clean-100'
if os.path.exists(librispeech_path):
    flac_files = list(Path(librispeech_path).rglob('*.flac'))
    print(f"‚úÖ LibriSpeech verified: {len(flac_files)} .flac files")
else:
    print(f"‚ùå LibriSpeech not found at: {librispeech_path}")

device = 'cuda' if torch.cuda.is_available() else 'cpu'
torch.manual_seed(42)

# ‚úÖ Configuration
class Config:
    """Configuration for Phase CNN (Paper Architecture)"""

    # Paths
    librispeech_root = f'{PROJECT_ROOT}/data/LibriSpeech/train-clean-100'
    save_dir = f'{PROJECT_ROOT}/checkpoints'

    # Audio parameters
    n_fft = 1024
    hop_length = 256
    win_length = 1024
    sample_rate = 16000
    duration = 2.0  # üî• CRITICAL FIX: Reduced from 5.0 to 2.0

    # Model parameters
    freq_bins = 513
    time_frames = 126

    # Phase CNN (Paper Figure 2)
    phase_embed_dim = 64  # üî• ENHANCED: Increased from 50 to 64
    phase_body_depth = 7  # üî• ENHANCED: Increased from 5 to 7

    # üî• VAD parameters
    use_vad = True
    vad_threshold_db = -5      # üî• OPTIMIZED: Stricter from -35 to -30
    vad_mode = "vad"

    # üî• Gradient clipping
    use_grad_clip = True
    max_grad_norm = 0.5        # üî• OPTIMIZED: Stricter from 1.0 to 0.5

    # üî• NEW: Minimum Phase Initialization (CRITICAL!)
    use_min_phase_init = True   # Enable minimum phase initialization

    # üî• NEW: Frequency Weighting
    use_freq_weighting = True
    low_freq_weight = 1.5
    high_freq_weight = 0.8

    # üî• NEW: Learning Rate Warmup
    use_lr_warmup = True
    warmup_epochs = 10

    # Inference options
    use_strided_inference = True
    use_efficient_solver = True

    # Training parameters (üî• OPTIMIZED)
    batch_size = 32            # üî• Reduced from 32 to 16
    num_epochs = 150           # üî• Increased from 100 to 150
    learning_rate = 5e-4       # üî• CRITICAL: Reduced from 2e-3 to 5e-4
    weight_decay = 1e-4        # üî• Increased from 1e-5

    # Optimization
    use_mixed_precision = True

    device = device

print("\n" + "="*70)
print("‚úÖ Configuration ready!")
print("="*70)
print(f"üìä Dataset: LibriSpeech train-clean-100")
print(f"üéµ Sample rate: {Config.sample_rate} Hz")
print(f"‚è±Ô∏è  Duration: {Config.duration}s (üî• OPTIMIZED)")
print(f"\nüîß Model Architecture:")
print(f"  ‚Ä¢ Phase CNN embed dim: {Config.phase_embed_dim} (üî• Enhanced)")
print(f"  ‚Ä¢ Phase CNN body depth: {Config.phase_body_depth} (üî• Enhanced)")
print(f"  ‚Ä¢ Expected parameters: ~12k")
print(f"\nüî• Critical Fixes Applied:")
print(f"  ‚Ä¢ Minimum Phase Init: {Config.use_min_phase_init}")
print(f"  ‚Ä¢ Frequency Weighting: {Config.use_freq_weighting}")
print(f"  ‚Ä¢ LR Warmup: {Config.use_lr_warmup} ({Config.warmup_epochs} epochs)")
print(f"\nüî• Silence Handling:")
print(f"  ‚Ä¢ VAD enabled: {Config.use_vad}")
print(f"  ‚Ä¢ VAD threshold: {Config.vad_threshold_db} dB (üî• Stricter)")
print(f"  ‚Ä¢ Loss mode: {Config.vad_mode}")
print(f"\n‚ö° Training Optimization:")
print(f"  ‚Ä¢ Gradient clipping: {Config.use_grad_clip}")
print(f"  ‚Ä¢ Max grad norm: {Config.max_grad_norm} (üî• Stricter)")
print(f"\nüìà Training Settings:")
print(f"  ‚Ä¢ Batch size: {Config.batch_size} (üî• Optimized)")
print(f"  ‚Ä¢ Epochs: {Config.num_epochs}")
print(f"  ‚Ä¢ Learning rate: {Config.learning_rate} (üî• Conservative)")
print(f"  ‚Ä¢ Device: {device}")
print("="*70)


üîí Mounting Google Drive...
Mounted at /content/drive
‚úÖ Drive mounted!
‚úÖ LibriSpeech already extracted!
‚úÖ LibriSpeech verified: 28539 .flac files

‚úÖ Configuration ready!
üìä Dataset: LibriSpeech train-clean-100
üéµ Sample rate: 16000 Hz
‚è±Ô∏è  Duration: 2.0s (üî• OPTIMIZED)

üîß Model Architecture:
  ‚Ä¢ Phase CNN embed dim: 64 (üî• Enhanced)
  ‚Ä¢ Phase CNN body depth: 7 (üî• Enhanced)
  ‚Ä¢ Expected parameters: ~12k

üî• Critical Fixes Applied:
  ‚Ä¢ Minimum Phase Init: True
  ‚Ä¢ Frequency Weighting: True
  ‚Ä¢ LR Warmup: True (10 epochs)

üî• Silence Handling:
  ‚Ä¢ VAD enabled: True
  ‚Ä¢ VAD threshold: -5 dB (üî• Stricter)
  ‚Ä¢ Loss mode: vad

‚ö° Training Optimization:
  ‚Ä¢ Gradient clipping: True
  ‚Ä¢ Max grad norm: 0.5 (üî• Stricter)

üìà Training Settings:
  ‚Ä¢ Batch size: 32 (üî• Optimized)
  ‚Ä¢ Epochs: 150
  ‚Ä¢ Learning rate: 0.0005 (üî• Conservative)
  ‚Ä¢ Device: cuda


In [17]:
# ============================================================================
# CELL 3: Dataset (LibriSpeech)
# ============================================================================

class LibriSpeechSpectrogramDataset(Dataset):
    """LibriSpeech Dataset for Phase Derivative Training"""

    def __init__(self, root_dir, subset='train', n_fft=1024, hop_length=256,
                 win_length=1024, sample_rate=16000, duration=2.0, train_split=0.9):

        self.root_dir = Path(root_dir)
        self.subset = subset
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.sample_rate = sample_rate
        self.duration = duration
        self.segment_length = int(duration * sample_rate)

        if not self.root_dir.exists():
            raise FileNotFoundError(f"Directory not found: {self.root_dir}")

        self.audio_files = sorted(list(self.root_dir.rglob('*.flac')))

        if len(self.audio_files) == 0:
            raise FileNotFoundError(f"No .flac files found in {self.root_dir}")

        # Train/val split
        n_train = int(len(self.audio_files) * train_split)
        if subset == 'train':
            self.audio_files = self.audio_files[:n_train]
        else:
            self.audio_files = self.audio_files[n_train:]

        print(f"  {subset} set: {len(self.audio_files)} files")

        self.window = torch.hann_window(self.win_length)

    def __len__(self):
        return len(self.audio_files)

    def _load_audio(self, filepath):
        waveform, sr = torchaudio.load(filepath)
        if waveform.shape[0] > 1:
            waveform = torch.mean(waveform, dim=0, keepdim=True)
        if sr != self.sample_rate:
            resampler = torchaudio.transforms.Resample(sr, self.sample_rate)
            waveform = resampler(waveform)
        return waveform.squeeze(0)

    def _compute_stft(self, waveform):
        stft = torch.stft(
            waveform,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=self.window,
            return_complex=True,
            center=True,
            normalized=False
        )
        return stft

    def _extract_segment(self, waveform):
        if len(waveform) >= self.segment_length:
            max_start = len(waveform) - self.segment_length
            start = np.random.randint(0, max_start + 1)
            segment = waveform[start:start + self.segment_length]
        else:
            segment = F.pad(waveform, (0, self.segment_length - len(waveform)))
        return segment

    @staticmethod
    def _wrap_phase(phase):
        """Wrap phase to [-œÄ, œÄ)"""
        return torch.atan2(torch.sin(phase), torch.cos(phase))

    def __getitem__(self, idx):
        waveform = self._load_audio(self.audio_files[idx])
        segment = self._extract_segment(waveform)
        stft = self._compute_stft(segment).T

        amplitude = torch.abs(stft)
        phase = torch.angle(stft)
        log_mag = torch.log1p(amplitude)

        # FPD: Frequency Phase Difference
        fpd = torch.zeros_like(phase)
        fpd[:, 1:] = self._wrap_phase(phase[:, 1:] - phase[:, :-1])

        # TPD: Time Phase Difference
        tpd = torch.zeros_like(phase)
        tpd[1:, :] = self._wrap_phase(phase[1:, :] - phase[:-1, :])

        # BPD: Baseband Phase Delay
        freq_indices = torch.arange(phase.shape[1], dtype=torch.float32)
        linear_phase = 2 * np.pi * freq_indices * self.hop_length / self.n_fft
        bpd = self._wrap_phase(tpd - linear_phase.unsqueeze(0))

        return {
            'log_mag': log_mag.unsqueeze(0),
            'amplitude_abs': amplitude.unsqueeze(0),
            'phase': phase.unsqueeze(0),
            'fpd': fpd.unsqueeze(0),
            'tpd': tpd.unsqueeze(0),
            'bpd': bpd.unsqueeze(0),
            'waveform': segment
        }

print("‚úÖ Dataset class ready!")

‚úÖ Dataset class ready!


In [18]:
# ============================================================================
# CELL 4: Phase CNN
# ============================================================================

class ScaledSoftsign(nn.Module):
    def __init__(self, scale=np.pi):
        super().__init__()
        self.scale = scale
    def forward(self, x):
        softsign = x / (1 + torch.abs(x))
        tanh = torch.tanh(x)
        return self.scale * (0.75 * softsign + 0.25 * tanh)

class FreqGatedConv(nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size=3):
        super().__init__()
        pad = kernel_size // 2
        self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=(1, 2*pad+1), padding=(0, pad))
        self.gate_conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
    def forward(self, x):
        return self.conv(x) * torch.sigmoid(self.gate_conv(x))

class PhaseCNN(nn.Module):
    """Input: log magnitude [B,1,T,F]  ‚Üí  Output: fpd,bpd [B,1,T,F]"""
    def __init__(self, embed_dim=50, body_depth=5, strided=False):
        super().__init__()
        self.strided = strided
        temporal_stride = 2 if strided else 1

        self.stem = nn.Sequential(
            nn.BatchNorm2d(1),
            nn.Conv2d(1, 16, kernel_size=(7,3), stride=(temporal_stride,1), padding=(0,1)),
            nn.LeakyReLU(0.1),
        )
        self.stem_gated = FreqGatedConv(16, 10, kernel_size=3)

        self.body_blocks = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(10,10,kernel_size=1),
                nn.BatchNorm2d(10),
                nn.LeakyReLU(0.1)
            ) for _ in range(body_depth)
        ])

        self.bottleneck = nn.Sequential(
            nn.Conv2d(10,20,kernel_size=1),
            nn.BatchNorm2d(20),
            nn.LeakyReLU(0.1),
            FreqGatedConv(20, embed_dim, kernel_size=3)
        )

        out_channels = 2 if strided else 1
        self.fpd_head = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim//2, kernel_size=1),
            nn.Conv2d(embed_dim//2, out_channels, kernel_size=1)
        )
        self.bpd_head = nn.Sequential(
            nn.Conv2d(embed_dim, embed_dim//2, kernel_size=1),
            nn.Conv2d(embed_dim//2, out_channels, kernel_size=1)
        )

        self.fpd_activation = ScaledSoftsign(scale=np.pi)
        self.bpd_activation = ScaledSoftsign(scale=np.pi)

    def forward(self, log_mag):
        x = F.pad(log_mag, pad=(0, 0, 6, 0))
        T_in = log_mag.shape[2]

        x = self.stem(x)
        x = self.stem_gated(x)
        for blk in self.body_blocks:
            x = x + blk(x)
        x = self.bottleneck(x)

        fpd = self.fpd_activation(self.fpd_head(x))
        bpd = self.bpd_activation(self.bpd_head(x))

        if self.strided:
            B, C2, T_half, Freq2 = fpd.shape
            fpd = fpd.permute(0,2,1,3).reshape(B,1,T_half*2,Freq2)
            bpd = bpd.permute(0,2,1,3).reshape(B,1,T_half*2,Freq2)

        T_out = fpd.shape[2]
        if T_out > T_in:
            fpd = fpd[:, :, :T_in, :]
            bpd = bpd[:, :, :T_in, :]
        elif T_out < T_in:
            pad_t = T_in - T_out
            fpd = torch.nn.functional.pad(fpd, (0,0,0,pad_t))
            bpd = torch.nn.functional.pad(bpd, (0,0,0,pad_t))

        return fpd, bpd

    def count_parameters(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

print("‚úÖ Phase CNN ready.")


‚úÖ Phase CNN ready.


In [19]:
# ============================================================================
# CELL 5: Voice Activity Detection (VAD)
# ============================================================================

class SimpleVAD(nn.Module):
    """üî• Simple energy-based Voice Activity Detection"""

    def __init__(self,
                 energy_threshold_db=-40,
                 min_speech_duration_ms=100,
                 sample_rate=16000,
                 hop_length=256):
        super().__init__()
        self.energy_threshold_db = energy_threshold_db
        self.sample_rate = sample_rate
        self.hop_length = hop_length
        self.min_speech_frames = int(min_speech_duration_ms * sample_rate / (1000 * hop_length))

    def forward(self, magnitude):
        """
        Args:
            magnitude: [B, 1, T, F] magnitude spectrogram
        Returns:
            vad_mask: [B, 1, T, 1] VAD mask (1=speech, 0=silence)
        """
        B, C, T, F = magnitude.shape

        # Frame energy
        frame_energy = magnitude.sum(dim=-1, keepdim=True)  # [B, 1, T, 1]

        # Convert to dB
        eps = 1e-8
        frame_energy_db = 20 * torch.log10(frame_energy + eps)

        # Adaptive threshold
        energy_sorted = torch.sort(frame_energy_db.view(B, -1), dim=1)[0]
        noise_floor = energy_sorted[:, int(0.1 * energy_sorted.shape[1])].view(B, 1, 1, 1)
        threshold = noise_floor + self.energy_threshold_db

        # Generate mask
        vad_mask = (frame_energy_db > threshold).float()

        # Morphological filtering
        vad_mask = self._morphological_filter(vad_mask)

        return vad_mask

    def _morphological_filter(self, mask):
        """Remove short speech segments and fill short gaps"""
        B, C, T, _ = mask.shape

        for b in range(B):
            m = mask[b, 0, :, 0]

            diff = torch.cat([torch.tensor([0.0], device=m.device), torch.diff(m)])
            starts = torch.where(diff > 0)[0]
            ends = torch.where(diff < 0)[0]

            # Handle boundaries
            if len(starts) > 0 and (len(ends) == 0 or starts[0] < ends[0]):
                ends = torch.cat([ends, torch.tensor([T], device=m.device)])
            if len(ends) > 0 and (len(starts) == 0 or ends[0] < starts[0]):
                starts = torch.cat([torch.tensor([0], device=m.device), starts])

            # Remove short segments
            for i in range(len(starts)):
                if i < len(ends):
                    duration = ends[i] - starts[i]
                    if duration < self.min_speech_frames:
                        m[starts[i]:ends[i]] = 0

            mask[b, 0, :, 0] = m

        return mask

print("‚úÖ VAD module ready!")

‚úÖ VAD module ready!


In [20]:
# ============================================================================
# CELL 6: Tridiagonal Solver
# ============================================================================

class EfficientTridiagonalSolver(nn.Module):
    def __init__(self):
        super().__init__()

    @torch.no_grad()
    def forward(self, lower, main, upper, b):
        B, Fm1 = lower.shape
        _, F = b.shape

        c_p = torch.zeros(B, F-1, device=b.device, dtype=b.dtype)
        d_p = torch.zeros(B, F, device=b.device, dtype=b.dtype)

        denom = main[:, 0]
        c_p[:, 0] = upper[:, 0] / denom
        d_p[:, 0] = b[:, 0] / denom

        for i in range(1, F-1):
            denom = main[:, i] - lower[:, i-1] * c_p[:, i-1]
            c_p[:, i] = upper[:, i] / denom
            d_p[:, i] = (b[:, i] - lower[:, i-1] * d_p[:, i-1]) / denom

        denom_last = main[:, F-1] - lower[:, F-2] * c_p[:, F-2]
        d_p[:, F-1] = (b[:, F-1] - lower[:, F-2] * d_p[:, F-2]) / denom_last

        x = torch.zeros(B, F, device=b.device, dtype=b.dtype)
        x[:, F-1] = d_p[:, F-1]

        for i in range(F-2, -1, -1):
            x[:, i] = d_p[:, i] - c_p[:, i] * x[:, i+1]

        return x

print("‚úÖ Tridiagonal solver ready!")


‚úÖ Tridiagonal solver ready!


In [21]:
# ============================================================================
# CELL 7: üî• Enhanced Phase Solver with VAD
# ============================================================================

class EnhancedPhaseSolver(nn.Module):
    """üî• Phase Solver with VAD-aware weighting"""

    def __init__(self,
                 use_efficient_solver=True,
                 use_vad=True,
                 vad_threshold_db=-35,
                 hop_length=256,
                 n_fft=1024):
        super().__init__()
        self.use_efficient_solver = use_efficient_solver
        self.use_vad = use_vad
        self.hop_length = hop_length
        self.n_fft = n_fft

        if use_efficient_solver:
            self.solver = EfficientTridiagonalSolver()
            print("    ‚úÖ Using O(L) tridiagonal solver")

        if use_vad:
            self.vad = SimpleVAD(energy_threshold_db=vad_threshold_db)
            print(f"    üî• Solver VAD enabled (threshold: {vad_threshold_db} dB)")

    def forward(self, magnitude, fpd, bpd, Y_prev=None):
        if self.use_efficient_solver and Y_prev is not None:
            return self._solve_efficient_vad(magnitude, fpd, bpd, Y_prev)
        else:
            return self._solve_integration(fpd, bpd)

    def _solve_efficient_vad(self, magnitude, fpd, bpd, Y_prev):
        """üî• Efficient solver with VAD-enhanced weighting"""
        B, _, T, F = magnitude.shape
        dev = magnitude.device

        amp = magnitude.squeeze(1)
        u = torch.exp(1j * fpd.squeeze(1))

        freq = torch.arange(F, device=dev, dtype=amp.dtype).view(1, 1, F)
        linear = 2 * np.pi * (self.hop_length / self.n_fft) * freq
        v = torch.exp(1j * (bpd.squeeze(1) + linear))

        yprev = Y_prev.squeeze(1).to(torch.complex64)

        # === üî• Enhanced VAD-aware weighting ===
        eps = 1e-8
        beta = 0.5

        # 1. Basic energy normalization
        amp_mean = amp.mean(dim=(1, 2), keepdim=True) + eps
        w_energy = (amp / amp_mean).pow(beta)

        # 2. VAD modulation
        if self.use_vad:
            vad_mask = self.vad(magnitude)  # [B, 1, T, 1]
            vad_weight = vad_mask.squeeze(1)  # [B, T, 1]

            # Combine weights: silence regions get 0.01, speech regions keep original
            w = w_energy * (vad_weight + 0.01)
            w = w.clamp_min(0.01)
        else:
            w = w_energy.clamp_min(0.05)

        # Build tridiagonal system
        main = w.clone().to(torch.complex64)
        main[..., :-1] += (w[..., :-1] * (u[..., :-1].abs()**2)).to(torch.complex64)
        main[..., 1:] += w[..., :-1].to(torch.complex64)

        upper = (w[..., :-1] * u[..., :-1]).to(torch.complex64)
        lower = (w[..., :-1] * torch.conj(u[..., :-1])).to(torch.complex64)

        b = (w * (yprev * v)).to(torch.complex64)

        # Solve
        BT = B * T
        x = self.solver(
            lower.reshape(BT, F-1),
            main.reshape(BT, F),
            upper.reshape(BT, F-1),
            b.reshape(BT, F)
        ).reshape(B, T, F)

        phase = torch.angle(x).unsqueeze(1)
        return phase

    def _solve_integration(self, fpd, bpd):
        """Fallback integration method"""
        return 0.5 * (torch.cumsum(fpd, dim=3) + torch.cumsum(bpd, dim=2))

print("‚úÖ Enhanced Phase Solver ready!")


‚úÖ Enhanced Phase Solver ready!


In [22]:
# ============================================================================
# CELL 7.5: üî• Minimum Phase Initialization (CRITICAL FIX!)
# ============================================================================

def compute_minimum_phase_spectrum(magnitude, n_fft=1024):
    """
    üî• CRITICAL: Compute minimum phase spectrum from magnitude

    This fixes the SNR=-0.78 problem by providing proper phase initialization
    instead of zeros. Uses real cepstrum method from PPSI-Net paper.

    Args:
        magnitude: [F, T] magnitude spectrogram (numpy)
        n_fft: FFT size

    Returns:
        min_phase: [F, T] minimum phase spectrum (numpy)
    """
    F, T = magnitude.shape
    min_phase = np.zeros_like(magnitude)

    for t in range(T):
        mag = magnitude[:, t]

        # Avoid log(0)
        mag = np.maximum(mag, 1e-10)

        # Real cepstrum method for minimum phase
        log_mag = np.log(mag)

        # Create symmetric spectrum for IFFT
        full_log_mag = np.concatenate([log_mag, log_mag[-2:0:-1]])

        # Compute cepstrum
        cepstrum = np.fft.ifft(full_log_mag).real

        # Minimum phase window (causal)
        window = np.zeros(len(cepstrum))
        window[0] = 1
        window[1:len(cepstrum)//2] = 2

        # Apply window to get minimum phase cepstrum
        min_phase_cepstrum = cepstrum * window

        # Transform back to frequency domain
        min_phase_spectrum = np.fft.fft(min_phase_cepstrum)

        # Extract phase
        min_phase[:, t] = np.angle(np.exp(min_phase_spectrum[:F]))

    return min_phase


class MinimumPhaseInitializer(nn.Module):
    """
    üî• CRITICAL: Minimum Phase Initialization Module

    This is THE most important fix for your SNR problem!
    Instead of initializing with zeros, use physically-meaningful minimum phase.
    """

    def __init__(self, n_fft=1024, hop_length=256):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length

    def forward(self, magnitude):
        """
        Args:
            magnitude: [B, 1, T, F] magnitude spectrogram
        Returns:
            initial_complex: [B, 1, T, F] complex STFT with minimum phase
        """
        B, C, T, F = magnitude.shape
        device = magnitude.device

        # üî• FIX: Remove only batch and channel dims, keep T and F
        mag_np = magnitude.squeeze(0).squeeze(0).cpu().numpy()  # [T, F]

        # Ensure we have 2D array
        if mag_np.ndim == 1:
            # If somehow still 1D, add time dimension
            mag_np = mag_np.reshape(1, -1)  # [1, F]

        # mag_np is now [T, F], transpose to [F, T] for processing
        mag_np = mag_np.T  # [F, T]

        # Compute minimum phase
        min_phase = compute_minimum_phase_spectrum(mag_np, self.n_fft)  # [F, T]

        # Convert back to [B, 1, T, F]
        min_phase = torch.from_numpy(min_phase).T  # [T, F]
        min_phase = min_phase.unsqueeze(0).unsqueeze(0)  # [1, 1, T, F]
        min_phase = min_phase.to(device).float()

        # Construct complex STFT: magnitude * exp(j * phase)
        complex_stft = magnitude * torch.exp(1j * min_phase)

        return complex_stft

print("‚úÖ üî• Minimum Phase Initializer ready (CRITICAL FIX!)!")

‚úÖ üî• Minimum Phase Initializer ready (CRITICAL FIX!)!


In [23]:
# ============================================================================
# CELL 8: üî• Enhanced Loss with Frequency Weighting
# ============================================================================

class EnhancedVonMisesLoss(nn.Module):
    """üî• Von Mises Loss with VAD + Frequency Weighting"""

    def __init__(self,
                 mode: str = "vad",
                 beta: float = 0.5,
                 k: float = 1.2,
                 eps: float = 1e-8,
                 use_vad: bool = True,
                 vad_threshold_db: float = -35,
                 use_freq_weighting: bool = True,
                 low_freq_weight: float = 1.5,
                 high_freq_weight: float = 0.8,
                 sample_rate: int = 16000,
                 n_fft: int = 1024):
        super().__init__()
        assert mode in ("mask", "weight", "vad")
        self.mode = mode
        self.beta = beta
        self.k = k
        self.eps = eps
        self.use_vad = use_vad
        self.use_freq_weighting = use_freq_weighting

        if use_vad:
            self.vad = SimpleVAD(energy_threshold_db=vad_threshold_db)
            print(f"    üî• Loss VAD enabled (mode: {mode}, threshold: {vad_threshold_db} dB)")

        # üî• Frequency weighting (emphasize speech frequencies)
        if use_freq_weighting:
            freq_bins = n_fft // 2 + 1
            freqs = np.fft.rfftfreq(n_fft, 1.0 / sample_rate)

            # Perceptual weighting curve
            freq_weights = np.ones(freq_bins)

            # Low frequencies (0-200 Hz): moderate weight
            mask_low = freqs < 200
            freq_weights[mask_low] = low_freq_weight

            # Speech range (200-4000 Hz): HIGH weight
            mask_speech = (freqs >= 200) & (freqs <= 4000)
            freq_weights[mask_speech] = 2.0

            # High frequencies (>4000 Hz): lower weight
            mask_high = freqs > 4000
            freq_weights[mask_high] = high_freq_weight

            self.freq_weights = torch.from_numpy(freq_weights).float()
            print(f"    üî• Frequency weighting enabled (speech emphasis: 200-4000Hz)")

    @staticmethod
    def _align_two(a, b):
        Tm = min(a.shape[2], b.shape[2])
        Fm = min(a.shape[3], b.shape[3])
        if a.shape[2] != Tm or a.shape[3] != Fm:
            a = a[:, :, :Tm, :Fm]
        if b.shape[2] != Tm or b.shape[3] != Fm:
            b = b[:, :, :Tm, :Fm]
        return a, b

    @staticmethod
    def _align_three(a, b, c):
        if c is None:
            a2, b2 = EnhancedVonMisesLoss._align_two(a, b)
            return a2, b2, None
        a2, b2 = EnhancedVonMisesLoss._align_two(a, b)
        Tm, Fm = a2.shape[2], a2.shape[3]
        if c.shape[2] != Tm or c.shape[3] != Fm:
            c = c[:, :, :Tm, :Fm]
        return a2, b2, c

    def _vm_core(self, pred, target):
        return 1.0 - torch.cos(pred - target)

    def _apply_energy(self, vm, mag):
        """üî• Apply VAD + Frequency + Energy weighting"""
        if mag is None:
            return vm.mean()

        B, C, T, F = vm.shape
        device = vm.device

        # Initialize weight as ones
        weight = torch.ones_like(vm)

        # 1. VAD weighting
        if self.use_vad and self.mode == "vad":
            vad_mask = self.vad(mag)  # [B, 1, T, 1]
            weight = weight * vad_mask.expand_as(vm)

        # 2. üî• Frequency weighting (NEW!)
        if self.use_freq_weighting and hasattr(self, 'freq_weights'):
            freq_w = self.freq_weights.to(device).view(1, 1, 1, F)
            weight = weight * freq_w.expand_as(vm)

        # 3. Energy weighting
        amp_mean = mag.mean(dim=(2, 3), keepdim=True) + self.eps
        energy_w = (mag / amp_mean).pow(self.beta).clamp_min(0.05)
        weight = weight * energy_w

        # Compute weighted loss
        masked_vm = vm * weight
        loss = masked_vm.sum() / (weight.sum() + self.eps)

        # Track VAD stats
        if self.use_vad and self.mode == "vad":
            vad_ratio = vad_mask.mean().item()
            if hasattr(self, '_vad_ratio_ema'):
                self._vad_ratio_ema = 0.95 * self._vad_ratio_ema + 0.05 * vad_ratio
            else:
                self._vad_ratio_ema = vad_ratio

        return loss

    def forward(self, *args):
        n = len(args)

        if n in (2, 3):
            pred, target = args[0], args[1]
            mag = args[2] if n == 3 else None
            pred, target, mag = self._align_three(pred, target, mag)
            return self._apply_energy(self._vm_core(pred, target), mag)

        elif n in (4, 5):
            pred_fpd, tgt_fpd, pred_bpd, tgt_bpd = args[:4]
            mag = args[4] if n == 5 else None

            pred_fpd, tgt_fpd, mag_f = self._align_three(pred_fpd, tgt_fpd, mag)
            pred_bpd, tgt_bpd, mag_b = self._align_three(pred_bpd, tgt_bpd, mag)

            loss_f = self._apply_energy(self._vm_core(pred_fpd, tgt_fpd), mag_f)
            loss_b = self._apply_energy(self._vm_core(pred_bpd, tgt_bpd), mag_b)

            return 0.5 * (loss_f + loss_b)

        else:
            raise TypeError(f"Expected 2-5 args, got {n}")

    def get_vad_stats(self):
        """Get VAD statistics"""
        if hasattr(self, '_vad_ratio_ema'):
            return {
                'speech_ratio': self._vad_ratio_ema,
                'silence_ratio': 1.0 - self._vad_ratio_ema
            }
        return None

print("‚úÖ Enhanced Loss ready!")

# ============================================================================
# CELL 8.5: üî• Learning Rate Warmup Scheduler
# ============================================================================

class WarmupCosineScheduler:
    """üî• Learning rate scheduler with warmup + cosine annealing"""

    def __init__(self, optimizer, warmup_epochs, total_epochs, base_lr, min_lr=1e-6):
        self.optimizer = optimizer
        self.warmup_epochs = warmup_epochs
        self.total_epochs = total_epochs
        self.base_lr = base_lr
        self.min_lr = min_lr
        self.current_epoch = 0

    def step(self):
        """Update learning rate for current epoch"""
        if self.current_epoch < self.warmup_epochs:
            # Linear warmup
            lr = self.base_lr * (self.current_epoch + 1) / self.warmup_epochs
        else:
            # Cosine annealing
            progress = (self.current_epoch - self.warmup_epochs) / (self.total_epochs - self.warmup_epochs)
            lr = self.min_lr + (self.base_lr - self.min_lr) * 0.5 * (1 + np.cos(np.pi * progress))

        for param_group in self.optimizer.param_groups:
            param_group['lr'] = lr

        self.current_epoch += 1
        return lr

    def get_last_lr(self):
        """Get current learning rate"""
        return [group['lr'] for group in self.optimizer.param_groups]

print("‚úÖ Warmup Scheduler ready!")

‚úÖ Enhanced Loss ready!
‚úÖ Warmup Scheduler ready!


In [24]:
# ============================================================================
# CELL 9: Training Setup
# ============================================================================

def build_criterion_optimizer_for(model, config=None):
    """üî• Build enhanced criterion and optimizer with all optimizations"""
    if config is None:
        config = Config()

    # üî• Use enhanced loss with all features
    criterion = EnhancedVonMisesLoss(
        mode=config.vad_mode,
        beta=0.5,
        use_vad=config.use_vad,
        vad_threshold_db=config.vad_threshold_db,
        use_freq_weighting=getattr(config, 'use_freq_weighting', False),
        low_freq_weight=getattr(config, 'low_freq_weight', 1.5),
        high_freq_weight=getattr(config, 'high_freq_weight', 0.8),
        sample_rate=config.sample_rate,
        n_fft=config.n_fft
    )

    optimizer = AdamW(
        model.parameters(),
        lr=config.learning_rate,
        weight_decay=config.weight_decay,
        betas=(0.9, 0.999)
    )

    return criterion, optimizer

print("‚úÖ Training setup ready!")


‚úÖ Training setup ready!


In [10]:
# ============================================================================
# CELL 10: Training Loop
# ============================================================================

def _make_loaders_from_config(config):
    train_ds = LibriSpeechSpectrogramDataset(
        root_dir=config.librispeech_root, subset='train',
        n_fft=config.n_fft, hop_length=config.hop_length,
        win_length=config.win_length, sample_rate=config.sample_rate,
        duration=config.duration
    )
    val_ds = LibriSpeechSpectrogramDataset(
        root_dir=config.librispeech_root, subset='val',
        n_fft=config.n_fft, hop_length=config.hop_length,
        win_length=config.win_length, sample_rate=config.sample_rate,
        duration=config.duration
    )
    train_loader = DataLoader(train_ds, batch_size=config.batch_size,
                              shuffle=True, num_workers=6, pin_memory=True)
    val_loader = DataLoader(val_ds, batch_size=config.batch_size,
                           shuffle=False, num_workers=6, pin_memory=True)
    return train_loader, val_loader

def _unpack_batch(batch, device):
    if isinstance(batch, dict):
        log_mag = batch['log_mag'].to(device)
        tgt_fpd = batch.get('tgt_fpd', batch['fpd']).to(device)
        tgt_bpd = batch.get('tgt_bpd', batch['bpd']).to(device)
        mag = batch.get('mag', batch.get('amplitude_abs', None))
        if mag is not None:
            mag = mag.to(device)
        return log_mag, tgt_fpd, tgt_bpd, mag

    if len(batch) == 3:
        log_mag, tgt_fpd, tgt_bpd = [x.to(device) for x in batch]
        return log_mag, tgt_fpd, tgt_bpd, None
    if len(batch) >= 4:
        log_mag, tgt_fpd, tgt_bpd, mag = [x.to(device) for x in batch[:4]]
        return log_mag, tgt_fpd, tgt_bpd, mag
    raise ValueError("Unsupported batch format.")

def train_phase_cnn(config=None):
    """üî• Enhanced training with beautiful progress bars"""
    if config is None:
        config = Config()

    train_loader, val_loader = _make_loaders_from_config(config)

    # Create model
    model = PhaseCNN(
        embed_dim=config.phase_embed_dim,
        body_depth=config.phase_body_depth,
        strided=config.use_strided_inference
    ).to(config.device)

    # üî• Enhanced criterion and optimizer
    criterion, optimizer = build_criterion_optimizer_for(model, config)

    best_val = float('inf')
    history = {"train_loss": [], "val_loss": []}

    print(f"\nüöÄ Starting training with {model.count_parameters():,} parameters...")
    print("="*80)

    for epoch in range(1, config.num_epochs + 1):
        # ========== Training Phase ==========
        model.train()
        run_loss = 0.0
        steps = 0

        # üî• Training progress bar with custom format
        train_pbar = tqdm(
            train_loader,
            desc=f"Epoch {epoch:3d}/{config.num_epochs}",
            ncols=140,
            bar_format='{desc}: {percentage:3.0f}%|{bar:50}| {n_fmt}/{total_fmt} [{elapsed}<{remaining}, {rate_fmt}, {postfix}]'
        )

        for batch in train_pbar:
            log_mag, tgt_fpd, tgt_bpd, mag = _unpack_batch(batch, config.device)
            pred_fpd, pred_bpd = model(log_mag)

            optimizer.zero_grad(set_to_none=True)
            loss = criterion(pred_fpd, tgt_fpd, pred_bpd, tgt_bpd, mag)
            loss.backward()
            optimizer.step()

            current_loss = float(loss.item())
            run_loss += current_loss
            steps += 1

            # Update progress bar with current loss (formatted like the example)
            train_pbar.set_postfix_str(f"loss={current_loss:.4f}")

        train_avg = run_loss / max(1, steps)
        history["train_loss"].append(train_avg)

        # ========== Validation Phase ==========
        model.eval()
        val_loss = 0.0
        val_steps = 0

        with torch.no_grad():
            for batch in val_loader:
                log_mag, tgt_fpd, tgt_bpd, mag = _unpack_batch(batch, config.device)
                pred_fpd, pred_bpd = model(log_mag)
                vloss = criterion(pred_fpd, tgt_fpd, pred_bpd, tgt_bpd, mag)
                val_loss += float(vloss.item())
                val_steps += 1

        val_avg = val_loss / max(1, val_steps)
        history["val_loss"].append(val_avg)

        # ========== Epoch Summary ==========
        # üî• Get VAD stats
        vad_stats = criterion.get_vad_stats()
        vad_str = ""
        if vad_stats:
            vad_str = f" | Speech: {vad_stats['speech_ratio']:.1%}"

        # üî• Beautiful epoch summary
        print(f"\n    Epoch {epoch:3d} | Train: {train_avg:.4f} | Val: {val_avg:.4f}{vad_str}")

        # ========== Save Best Model ==========
        if val_avg < best_val:
            best_val = val_avg
            os.makedirs(f"{config.save_dir}/phase", exist_ok=True)
            torch.save(model.state_dict(), f"{config.save_dir}/phase/best_phase_cnn.pth")
            print(f"        ‚úÖ Best model saved!")

        # Show progress every 5 epochs
        if epoch % 5 == 0 or epoch == config.num_epochs:
            print("")  # Empty line for better readability

    print("="*80)
    return model, history

print("‚úÖ Training loop ready!")


‚úÖ Training loop ready!


In [25]:
# ============================================================================
# CELL 11: Inference
# ============================================================================

class PPSIInference:
    """üî• PPSI Inference with Enhanced VAD Solver + Minimum Phase Init"""

    def __init__(self, phase_cnn_path, config):
        self.device = config.device
        self.config = config

        # Load Phase CNN
        self.phase_cnn = PhaseCNN(
            embed_dim=config.phase_embed_dim,
            body_depth=config.phase_body_depth,
            strided=config.use_strided_inference
        ).to(self.device)

        ckpt = torch.load(phase_cnn_path, map_location=self.device)
        self.phase_cnn.load_state_dict(ckpt)
        self.phase_cnn.eval()

        # üî• Enhanced solver with VAD
        self.solver = EnhancedPhaseSolver(
            use_efficient_solver=config.use_efficient_solver,
            use_vad=config.use_vad,
            vad_threshold_db=config.vad_threshold_db,
            hop_length=config.hop_length,
            n_fft=config.n_fft
        ).to(self.device)

        # üî• CRITICAL: Minimum phase initializer
        self.use_min_phase = getattr(config, 'use_min_phase_init', True)
        if self.use_min_phase:
            self.min_phase_init = MinimumPhaseInitializer(
                n_fft=config.n_fft,
                hop_length=config.hop_length
            )
            print("‚úÖ Inference ready with VAD + üî• Minimum Phase Init!")
        else:
            print("‚úÖ Inference ready with VAD (‚ö†Ô∏è  Minimum Phase Init disabled)")

    @torch.no_grad()
    def reconstruct_audio(self, waveform, return_magnitude=False):
        """
        Reconstruct audio with all optimizations

        Args:
            waveform: Input audio waveform (numpy array)
            return_magnitude: If True, also return original magnitude for visualization

        Returns:
            If return_magnitude=False: reconstructed waveform
            If return_magnitude=True: (reconstructed waveform, original magnitude)
        """
        window = torch.hann_window(self.config.win_length).to(self.device)
        waveform_tensor = torch.from_numpy(waveform).float().to(self.device)

        stft = torch.stft(
            waveform_tensor,
            n_fft=self.config.n_fft,
            hop_length=self.config.hop_length,
            win_length=self.config.win_length,
            window=window,
            return_complex=True,
            center=True,
            normalized=False
        )

        magnitude = torch.abs(stft).T
        log_mag = torch.log1p(magnitude)

        log_mag = log_mag.unsqueeze(0).unsqueeze(0)
        magnitude = magnitude.unsqueeze(0).unsqueeze(0)

        # Stage 1: Predict FPD/BPD
        pred_fpd, pred_bpd = self.phase_cnn(log_mag)

        # üî• Stage 2: Initialize with minimum phase (CRITICAL!)
        if self.use_min_phase:
            # Use minimum phase for first frame
            yprev = self.min_phase_init(magnitude[:, :, 0:1, :])
        else:
            # Old method (causes SNR=-0.78 problem!)
            yprev = torch.zeros(1, 1, 1, magnitude.shape[-1],
                              device=self.device, dtype=torch.complex64)

        # üî• Stage 3: Recursive solve with VAD
        B, C, T, F = pred_fpd.shape
        pred_phase_slices = []

        for t in range(T):
            mag_t = magnitude[:, :, t:t+1, :]
            fpd_t = pred_fpd[:, :, t:t+1, :]
            bpd_t = pred_bpd[:, :, t:t+1, :]

            phase_t = self.solver(mag_t, fpd_t, bpd_t, Y_prev=yprev)
            pred_phase_slices.append(phase_t)

            yprev = (mag_t.squeeze(1) * torch.exp(1j * phase_t.squeeze(1))).unsqueeze(1).to(torch.complex64)

        pred_phase = torch.cat(pred_phase_slices, dim=2)

        # Reconstruct
        amplitude = magnitude.cpu().squeeze().numpy().T
        phase = pred_phase.cpu().squeeze().numpy().T
        stft_recon = amplitude * np.exp(1j * phase)

        stft_recon_t = torch.from_numpy(stft_recon).to(self.device)
        wav_recon = torch.istft(
            stft_recon_t,
            n_fft=self.config.n_fft,
            hop_length=self.config.hop_length,
            win_length=self.config.win_length,
            window=window,
            center=True,
            normalized=False,
            length=waveform_tensor.numel()
        )

        reconstructed = wav_recon.cpu().float().numpy()

        # üî• Return magnitude for visualization if requested
        if return_magnitude:
            # Return original magnitude [F, T] for visualization
            mag_original = magnitude.cpu().squeeze().numpy().T  # [T, F] -> [F, T]
            return reconstructed, mag_original
        else:
            return reconstructed

print("‚úÖ Inference class ready!")

‚úÖ Inference class ready!


In [12]:
# ============================================================================
# CELL 12: Evaluation Metrics
# ============================================================================

def calculate_lsc(original, reconstructed, n_fft=1024, hop_length=256, win_length=1024):
    """Calculate Log-Spectral Convergence"""
    window = torch.hann_window(win_length)

    orig_tensor = torch.from_numpy(original).float()
    stft_orig = torch.stft(orig_tensor, n_fft=n_fft, hop_length=hop_length,
                          win_length=win_length, window=window,
                          return_complex=True, center=True)

    recon_tensor = torch.from_numpy(reconstructed).float()
    stft_recon = torch.stft(recon_tensor, n_fft=n_fft, hop_length=hop_length,
                           win_length=win_length, window=window,
                           return_complex=True, center=True)

    mag_orig = torch.abs(stft_orig)
    mag_recon = torch.abs(stft_recon)

    log_mag_orig = torch.log(mag_orig + 1e-8)
    log_mag_recon = torch.log(mag_recon + 1e-8)

    numerator = torch.norm(log_mag_orig - log_mag_recon, p='fro')
    denominator = torch.norm(log_mag_orig, p='fro')

    lsc = (numerator / (denominator + 1e-8)).item()
    return lsc

def calculate_estoi(original, reconstructed, sample_rate=16000):
    """Calculate ESTOI"""
    from pystoi import stoi

    min_len = min(len(original), len(reconstructed))
    orig = original[:min_len].astype(np.float64)
    recon = reconstructed[:min_len].astype(np.float64)

    orig = orig / (np.abs(orig).max() + 1e-8)
    recon = recon / (np.abs(recon).max() + 1e-8)

    try:
        estoi_score = stoi(orig, recon, sample_rate, extended=True)
    except Exception as e:
        print(f"‚ö†Ô∏è ESTOI failed: {e}")
        estoi_score = -1.0

    return estoi_score

def calculate_pesq(original, reconstructed, sample_rate=16000):
    """Calculate PESQ"""
    from pesq import pesq

    min_len = min(len(original), len(reconstructed))
    orig = original[:min_len]
    recon = reconstructed[:min_len]

    orig = orig / (np.abs(orig).max() + 1e-8)
    recon = recon / (np.abs(recon).max() + 1e-8)

    try:
        pesq_score = pesq(16000, orig, recon, 'wb')
    except Exception as e:
        print(f"‚ö†Ô∏è PESQ failed: {e}")
        pesq_score = -1.0

    return pesq_score

print("‚úÖ Evaluation metrics ready!")

‚úÖ Evaluation metrics ready!


In [13]:
# ============================================================================
# CELL 13: üî• VAD Visualization Tool
# ============================================================================

def visualize_vad_effect(magnitude_spec, sample_rate=16000, hop_length=256,
                        vad_threshold_db=-35):
    """
    üî• Visualize VAD effect on magnitude spectrogram

    Args:
        magnitude_spec: [1, 1, T, F] magnitude spectrogram tensor
        sample_rate: audio sample rate
        hop_length: STFT hop length
        vad_threshold_db: VAD threshold
    """
    vad = SimpleVAD(energy_threshold_db=vad_threshold_db,
                   sample_rate=sample_rate,
                   hop_length=hop_length)

    with torch.no_grad():
        vad_mask = vad(magnitude_spec)  # [1, 1, T, 1]

    # Plot
    fig, axes = plt.subplots(3, 1, figsize=(14, 8))

    # 1. Magnitude spectrogram
    mag_db = 20 * torch.log10(magnitude_spec[0, 0].cpu() + 1e-8).numpy()
    T, F = mag_db.shape
    time_frames = np.arange(T) * hop_length / sample_rate
    freq_bins = np.arange(F) * sample_rate / 1024 / 1000

    im1 = axes[0].imshow(mag_db.T, aspect='auto', origin='lower', cmap='viridis',
                        extent=[time_frames[0], time_frames[-1], freq_bins[0], freq_bins[-1]])
    axes[0].set_title('üéµ Magnitude Spectrogram (dB)', fontweight='bold', fontsize=12)
    axes[0].set_ylabel('Frequency (kHz)', fontsize=10)
    plt.colorbar(im1, ax=axes[0], label='dB')

    # 2. VAD mask overlay
    vad_vis = vad_mask[0, 0, :, 0].cpu().numpy()
    masked_mag = mag_db.copy()
    masked_mag[vad_vis < 0.5, :] = masked_mag.min()

    im2 = axes[1].imshow(masked_mag.T, aspect='auto', origin='lower', cmap='viridis',
                        extent=[time_frames[0], time_frames[-1], freq_bins[0], freq_bins[-1]])
    axes[1].set_title('üî• VAD-Masked Spectrogram (Speech Only)', fontweight='bold', fontsize=12)
    axes[1].set_ylabel('Frequency (kHz)', fontsize=10)
    plt.colorbar(im2, ax=axes[1], label='dB')

    # 3. VAD decision
    axes[2].plot(time_frames, vad_vis, 'b-', linewidth=2, label='VAD Output')
    axes[2].fill_between(time_frames, 0, vad_vis, alpha=0.3, color='blue')
    axes[2].set_title('üé§ VAD Decision (1=Speech, 0=Silence)', fontweight='bold', fontsize=12)
    axes[2].set_xlabel('Time (s)', fontsize=10)
    axes[2].set_ylabel('VAD', fontsize=10)
    axes[2].set_ylim([-0.1, 1.1])
    axes[2].grid(True, alpha=0.3)
    axes[2].legend(loc='upper right')

    # Add stats
    speech_ratio = vad_vis.mean()
    axes[2].text(0.02, 0.95, f'Speech: {speech_ratio:.1%}\nSilence: {1-speech_ratio:.1%}',
                transform=axes[2].transAxes, fontsize=10,
                verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    plt.tight_layout()
    plt.show()

    print(f"\nüìä VAD Statistics:")
    print(f"  ‚Ä¢ Speech ratio: {speech_ratio:.1%}")
    print(f"  ‚Ä¢ Silence ratio: {1-speech_ratio:.1%}")
    print(f"  ‚Ä¢ Threshold: {vad_threshold_db} dB")

print("‚úÖ VAD visualization tool ready!")

‚úÖ VAD visualization tool ready!


In [14]:
# ============================================================================
# CELL 14: Comprehensive Visualization
# ============================================================================

def visualize_reconstruction_comparison(
    original_waveform,
    reconstructed_waveform,
    sample_rate=16000,
    n_fft=1024,
    hop_length=256,
    win_length=1024,
    save_path=None,
    show_plot=True
):
    """Comprehensive visualization of reconstruction quality"""

    print("üìä Generating comprehensive visualization...")

    # Calculate metrics
    lsc = calculate_lsc(original_waveform, reconstructed_waveform, n_fft, hop_length, win_length)
    estoi_score = calculate_estoi(original_waveform, reconstructed_waveform, sample_rate)
    pesq_score = calculate_pesq(original_waveform, reconstructed_waveform, sample_rate)

    # Compute STFTs
    window = torch.hann_window(win_length)

    original_tensor = torch.from_numpy(original_waveform).float()
    stft_original = torch.stft(
        original_tensor, n_fft=n_fft, hop_length=hop_length,
        win_length=win_length, window=window,
        return_complex=True, center=True, normalized=False
    )

    reconstructed_tensor = torch.from_numpy(reconstructed_waveform).float()
    stft_reconstructed = torch.stft(
        reconstructed_tensor, n_fft=n_fft, hop_length=hop_length,
        win_length=win_length, window=window,
        return_complex=True, center=True, normalized=False
    )

    mag_original = torch.abs(stft_original).numpy()
    phase_original = torch.angle(stft_original).numpy()
    mag_reconstructed = torch.abs(stft_reconstructed).numpy()
    phase_reconstructed = torch.angle(stft_reconstructed).numpy()

    mag_original_db = 20 * np.log10(mag_original + 1e-8)
    mag_reconstructed_db = 20 * np.log10(mag_reconstructed + 1e-8)

    mag_diff = np.abs(mag_original_db - mag_reconstructed_db)
    phase_diff = np.abs(np.angle(np.exp(1j * phase_original) / np.exp(1j * phase_reconstructed)))

    times = np.arange(mag_original.shape[1]) * hop_length / sample_rate
    freqs = np.arange(mag_original.shape[0]) * sample_rate / n_fft / 1000

    min_len = min(len(original_waveform), len(reconstructed_waveform))
    original_clip = original_waveform[:min_len]
    reconstructed_clip = reconstructed_waveform[:min_len]

    mse_wav = np.mean((original_clip - reconstructed_clip) ** 2)
    snr = 10 * np.log10(np.var(original_clip) / (mse_wav + 1e-8))
    mag_corr = np.corrcoef(mag_original.flatten(), mag_reconstructed.flatten())[0, 1]

    # Create figure
    fig = plt.figure(figsize=(20, 14))
    gs = fig.add_gridspec(4, 3, hspace=0.35, wspace=0.3)

    # Row 1: Magnitude
    ax1 = fig.add_subplot(gs[0, 0])
    im1 = ax1.imshow(mag_original_db, aspect='auto', origin='lower',
                     cmap='viridis', extent=[times[0], times[-1], freqs[0], freqs[-1]])
    ax1.set_title('Original - Magnitude (dB)', fontweight='bold')
    ax1.set_ylabel('Frequency (kHz)')
    plt.colorbar(im1, ax=ax1, label='dB')

    ax2 = fig.add_subplot(gs[0, 1])
    im2 = ax2.imshow(mag_reconstructed_db, aspect='auto', origin='lower',
                     cmap='viridis', extent=[times[0], times[-1], freqs[0], freqs[-1]])
    ax2.set_title('Reconstructed - Magnitude (dB)', fontweight='bold')
    ax2.set_ylabel('Frequency (kHz)')
    plt.colorbar(im2, ax=ax2, label='dB')

    ax3 = fig.add_subplot(gs[0, 2])
    im3 = ax3.imshow(mag_diff, aspect='auto', origin='lower',
                     cmap='hot', extent=[times[0], times[-1], freqs[0], freqs[-1]])
    ax3.set_title('Magnitude Error', fontweight='bold')
    ax3.set_ylabel('Frequency (kHz)')
    plt.colorbar(im3, ax=ax3, label='Error (dB)')

    # Row 2: Phase
    ax4 = fig.add_subplot(gs[1, 0])
    im4 = ax4.imshow(phase_original, aspect='auto', origin='lower',
                     cmap='twilight', vmin=-np.pi, vmax=np.pi,
                     extent=[times[0], times[-1], freqs[0], freqs[-1]])
    ax4.set_title('Original - Phase', fontweight='bold')
    ax4.set_ylabel('Frequency (kHz)')
    plt.colorbar(im4, ax=ax4, label='rad')

    ax5 = fig.add_subplot(gs[1, 1])
    im5 = ax5.imshow(phase_reconstructed, aspect='auto', origin='lower',
                     cmap='twilight', vmin=-np.pi, vmax=np.pi,
                     extent=[times[0], times[-1], freqs[0], freqs[-1]])
    ax5.set_title('Reconstructed - Phase', fontweight='bold')
    ax5.set_ylabel('Frequency (kHz)')
    plt.colorbar(im5, ax=ax5, label='rad')

    ax6 = fig.add_subplot(gs[1, 2])
    im6 = ax6.imshow(phase_diff, aspect='auto', origin='lower',
                     cmap='hot', vmin=0, vmax=np.pi,
                     extent=[times[0], times[-1], freqs[0], freqs[-1]])
    ax6.set_title('Phase Error', fontweight='bold')
    ax6.set_ylabel('Frequency (kHz)')
    plt.colorbar(im6, ax=ax6, label='rad')

    # Row 3: Waveform
    ax7 = fig.add_subplot(gs[2, :])
    time_axis = np.arange(min_len) / sample_rate
    ax7.plot(time_axis, original_clip, 'b-', alpha=0.7, linewidth=0.8, label='Original')
    ax7.plot(time_axis, reconstructed_clip, 'r-', alpha=0.7, linewidth=0.8, label='Reconstructed')
    ax7.set_title('Waveform Comparison', fontweight='bold')
    ax7.set_xlabel('Time (s)')
    ax7.set_ylabel('Amplitude')
    ax7.legend()
    ax7.grid(True, alpha=0.3)
    ax7.set_xlim([0, min(2.0, time_axis[-1])])

    # Row 4: Metrics
    ax8 = fig.add_subplot(gs[3, :])
    ax8.axis('off')

    pesq_rating = '‚úÖ Excellent' if pesq_score > 3.5 else 'üëç Good' if pesq_score > 3.0 else '‚ö†Ô∏è Fair'
    estoi_rating = '‚úÖ Excellent' if estoi_score > 0.85 else 'üëç Good' if estoi_score > 0.75 else '‚ö†Ô∏è Fair'
    lsc_rating = '‚úÖ Excellent' if lsc < 0.5 else 'üëç Good' if lsc < 1.0 else '‚ö†Ô∏è Fair'

    metrics_text = f"""
    üî• PROFESSIONAL AUDIO QUALITY METRICS (PPSI-Net with VAD)

    PERCEPTUAL QUALITY:
      ‚Ä¢ PESQ (Wideband):  {pesq_score:.3f}  {pesq_rating}
      ‚Ä¢ ESTOI (Extended): {estoi_score:.4f}  {estoi_rating}
      ‚Ä¢ LSC:              {lsc:.4f}  {lsc_rating}

    SIGNAL QUALITY:
      ‚Ä¢ SNR:              {snr:.2f} dB
      ‚Ä¢ Magnitude Corr:   {mag_corr:.4f}
      ‚Ä¢ MSE (Waveform):   {mse_wav:.6f}
    """

    ax8.text(0.05, 0.5, metrics_text, fontsize=11, family='monospace',
             verticalalignment='center',
             bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.3))

    plt.suptitle('üî• PPSI-Net with VAD: Reconstruction Quality Analysis',
                 fontsize=16, fontweight='bold', y=0.98)

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"üíæ Saved to: {save_path}")

    if show_plot:
        plt.show()
    else:
        plt.close()

    print("\n" + "="*70)
    print("üìã QUALITY METRICS SUMMARY")
    print("="*70)
    print(f"PESQ:     {pesq_score:.3f}  {pesq_rating}")
    print(f"ESTOI:    {estoi_score:.4f}  {estoi_rating}")
    print(f"LSC:      {lsc:.4f}  {lsc_rating}")
    print(f"SNR:      {snr:.2f} dB")
    print(f"Mag Corr: {mag_corr:.4f}")
    print("="*70)

    return {
        'pesq': pesq_score,
        'estoi': estoi_score,
        'lsc': lsc,
        'snr': snr,
        'mag_corr': mag_corr
    }

def visualize_training_history(history, save_path=None):
    """
    üî• Visualize training history with gradient monitoring

    Args:
        history: Dictionary with 'train_loss', 'val_loss', 'grad_norm'
        save_path: Path to save the figure
    """
    fig, axes = plt.subplots(2, 1, figsize=(12, 8))

    epochs = range(1, len(history['train_loss']) + 1)

    # Plot 1: Loss curves
    ax1 = axes[0]
    ax1.plot(epochs, history['train_loss'], 'b-', linewidth=2, label='Training Loss', alpha=0.8)
    if 'val_loss' in history and history['val_loss']:
        ax1.plot(epochs, history['val_loss'], 'r-', linewidth=2,
                label='Validation Loss', alpha=0.8, marker='o', markersize=4, markevery=5)

    ax1.set_title('Training & Validation Loss', fontweight='bold', fontsize=14)
    ax1.set_xlabel('Epoch', fontsize=12)
    ax1.set_ylabel('Loss (Von Mises)', fontsize=12)
    ax1.legend(fontsize=11, loc='upper right')
    ax1.grid(True, alpha=0.3)

    # Mark best epoch
    if 'val_loss' in history and history['val_loss']:
        best_epoch = np.argmin(history['val_loss']) + 1
        best_val = np.min(history['val_loss'])
        ax1.axvline(best_epoch, color='g', linestyle='--', alpha=0.5, linewidth=1.5)
        ax1.text(best_epoch, best_val, f'  Best: {best_val:.4f}',
                fontsize=10, color='green', fontweight='bold')

    # Plot 2: Gradient norms
    if 'grad_norm' in history and history['grad_norm']:
        ax2 = axes[1]
        grad_norms = history['grad_norm']
        ax2.plot(epochs, grad_norms, 'g-', linewidth=2, alpha=0.7, label='Avg Gradient Norm')

        # Add gradient clipping threshold line
        max_grad_norm = 1.0  # Default value
        ax2.axhline(max_grad_norm, color='r', linestyle='--', linewidth=1.5,
                   alpha=0.6, label=f'Clip Threshold ({max_grad_norm})')

        # Highlight regions where gradients exceed threshold
        exceeded = np.array(grad_norms) > max_grad_norm
        if np.any(exceeded):
            ax2.fill_between(epochs, 0, np.max(grad_norms),
                           where=exceeded, alpha=0.2, color='red',
                           label='Clipped Region')

        ax2.set_title('üî• Gradient Norm Monitoring', fontweight='bold', fontsize=14)
        ax2.set_xlabel('Epoch', fontsize=12)
        ax2.set_ylabel('Gradient Norm (L2)', fontsize=12)
        ax2.legend(fontsize=11, loc='upper right')
        ax2.grid(True, alpha=0.3)

        # Statistics box
        avg_grad = np.mean(grad_norms)
        max_grad = np.max(grad_norms)
        clip_ratio = np.mean(exceeded)
        stats_text = f'Avg: {avg_grad:.3f}\nMax: {max_grad:.3f}\nClipped: {clip_ratio:.1%}'
        ax2.text(0.02, 0.98, stats_text, transform=ax2.transAxes,
                fontsize=10, verticalalignment='top',
                bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"üíæ Training history saved to: {save_path}")

    plt.show()

print("‚úÖ Enhanced visualization with gradient monitoring ready!")

‚úÖ Enhanced visualization with gradient monitoring ready!


In [None]:
# ============================================================================
# CELL 15: Main Training Execution
# ============================================================================

def main():
    """Main execution - Train the model"""

    print("\n" + "="*70)
    print("üî• PPSI-Net with Enhanced Silence Handling (VAD)")
    print("="*70)
    print("\n‚úÖ Key Features:")
    print("  ‚Ä¢ Voice Activity Detection (VAD)")
    print("  ‚Ä¢ Silence-aware loss weighting")
    print("  ‚Ä¢ VAD-enhanced phase solver")
    print("  ‚Ä¢ Improved training stability")
    print("="*70)

    config = Config()

    if not os.path.exists(config.librispeech_root):
        print("\n‚ö†Ô∏è  Dataset not found!")
        print(f"   Expected at: {config.librispeech_root}")
        return None

    # üî• Train with enhanced VAD
    print("\nüöÄ Starting enhanced training...")
    phase_cnn, history = train_phase_cnn(config)

    print("\n" + "="*70)
    print("‚úÖ Training Complete!")
    print("="*70)
    print(f"üìä Model Parameters: {phase_cnn.count_parameters():,}")
    print(f"üíæ Model saved to: {config.save_dir}/phase/best_phase_cnn.pth")
    print("="*70)

    return {'phase_cnn': phase_cnn, 'config': config, 'history': history}

if __name__ == '__main__':
    results = main()


üî• PPSI-Net with Enhanced Silence Handling (VAD)

‚úÖ Key Features:
  ‚Ä¢ Voice Activity Detection (VAD)
  ‚Ä¢ Silence-aware loss weighting
  ‚Ä¢ VAD-enhanced phase solver
  ‚Ä¢ Improved training stability

üöÄ Starting enhanced training...
  train set: 25685 files
  val set: 2854 files
    üî• Loss VAD enabled (mode: vad, threshold: -5 dB)
    üî• Frequency weighting enabled (speech emphasis: 200-4000Hz)

üöÄ Starting training with 11,724 parameters...


Epoch   1/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch   1 | Train: 0.4012 | Val: 0.3459 | Speech: 97.5%
        ‚úÖ Best model saved!


Epoch   2/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch   2 | Train: 0.3252 | Val: 0.3153 | Speech: 97.5%
        ‚úÖ Best model saved!


Epoch   3/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch   3 | Train: 0.3101 | Val: 0.3136 | Speech: 97.6%
        ‚úÖ Best model saved!


Epoch   4/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch   4 | Train: 0.3050 | Val: 0.3029 | Speech: 97.6%
        ‚úÖ Best model saved!


Epoch   5/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch   5 | Train: 0.3009 | Val: 0.2995 | Speech: 97.8%
        ‚úÖ Best model saved!



Epoch   6/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch   6 | Train: 0.2972 | Val: 0.2966 | Speech: 97.5%
        ‚úÖ Best model saved!


Epoch   7/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch   7 | Train: 0.2921 | Val: 0.2911 | Speech: 97.4%
        ‚úÖ Best model saved!


Epoch   8/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch   8 | Train: 0.2886 | Val: 0.2877 | Speech: 97.8%
        ‚úÖ Best model saved!


Epoch   9/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch   9 | Train: 0.2850 | Val: 0.2841 | Speech: 97.3%
        ‚úÖ Best model saved!


Epoch  10/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  10 | Train: 0.2809 | Val: 0.2846 | Speech: 97.8%



Epoch  11/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  11 | Train: 0.2782 | Val: 0.2771 | Speech: 97.6%
        ‚úÖ Best model saved!


Epoch  12/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  12 | Train: 0.2758 | Val: 0.2766 | Speech: 97.8%
        ‚úÖ Best model saved!


Epoch  13/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  13 | Train: 0.2742 | Val: 0.2737 | Speech: 97.7%
        ‚úÖ Best model saved!


Epoch  14/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  14 | Train: 0.2708 | Val: 0.2806 | Speech: 97.5%


Epoch  15/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  15 | Train: 0.2735 | Val: 0.2725 | Speech: 97.6%
        ‚úÖ Best model saved!



Epoch  16/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  16 | Train: 0.2719 | Val: 0.2747 | Speech: 97.7%


Epoch  17/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  17 | Train: 0.2689 | Val: 0.2717 | Speech: 97.8%
        ‚úÖ Best model saved!


Epoch  18/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  18 | Train: 0.2662 | Val: 0.2695 | Speech: 97.6%
        ‚úÖ Best model saved!


Epoch  19/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  19 | Train: 0.2650 | Val: 0.2643 | Speech: 97.5%
        ‚úÖ Best model saved!


Epoch  20/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  20 | Train: 0.2641 | Val: 0.2637 | Speech: 97.4%
        ‚úÖ Best model saved!



Epoch  21/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  21 | Train: 0.2626 | Val: 0.2621 | Speech: 97.3%
        ‚úÖ Best model saved!


Epoch  22/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  22 | Train: 0.2629 | Val: 0.2626 | Speech: 97.6%


Epoch  23/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  23 | Train: 0.2604 | Val: 0.2636 | Speech: 97.7%


Epoch  24/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  24 | Train: 0.2602 | Val: 0.2638 | Speech: 97.8%


Epoch  25/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  25 | Train: 0.2607 | Val: 0.2616 | Speech: 97.5%
        ‚úÖ Best model saved!



Epoch  26/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  26 | Train: 0.2588 | Val: 0.2584 | Speech: 97.5%
        ‚úÖ Best model saved!


Epoch  27/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  27 | Train: 0.2579 | Val: 0.2587 | Speech: 97.7%


Epoch  28/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  28 | Train: 0.2569 | Val: 0.2603 | Speech: 97.3%


Epoch  29/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  29 | Train: 0.2566 | Val: 0.2587 | Speech: 97.8%


Epoch  30/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  30 | Train: 0.2564 | Val: 0.2626 | Speech: 97.8%



Epoch  31/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  31 | Train: 0.2550 | Val: 0.2635 | Speech: 97.4%


Epoch  32/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  32 | Train: 0.2543 | Val: 0.2553 | Speech: 97.8%
        ‚úÖ Best model saved!


Epoch  33/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  33 | Train: 0.2534 | Val: 0.2584 | Speech: 97.5%


Epoch  34/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  34 | Train: 0.2570 | Val: 0.2579 | Speech: 97.6%


Epoch  35/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  35 | Train: 0.2530 | Val: 0.2532 | Speech: 97.4%
        ‚úÖ Best model saved!



Epoch  36/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  36 | Train: 0.2528 | Val: 0.2689 | Speech: 97.7%


Epoch  37/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  37 | Train: 0.2556 | Val: 0.2618 | Speech: 97.4%


Epoch  38/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  38 | Train: 0.2518 | Val: 0.2572 | Speech: 97.8%


Epoch  39/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  39 | Train: 0.2548 | Val: 0.2546 | Speech: 97.7%


Epoch  40/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  40 | Train: 0.2550 | Val: 0.2553 | Speech: 97.5%



Epoch  41/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  41 | Train: 0.2547 | Val: 0.2558 | Speech: 97.5%


Epoch  42/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  42 | Train: 0.2529 | Val: 0.2530 | Speech: 97.5%
        ‚úÖ Best model saved!


Epoch  43/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  43 | Train: 0.2503 | Val: 0.2522 | Speech: 97.8%
        ‚úÖ Best model saved!


Epoch  44/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  44 | Train: 0.2515 | Val: 0.2502 | Speech: 97.8%
        ‚úÖ Best model saved!


Epoch  45/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  45 | Train: 0.2503 | Val: 0.2504 | Speech: 97.9%



Epoch  46/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  46 | Train: 0.2491 | Val: 0.2488 | Speech: 97.5%
        ‚úÖ Best model saved!


Epoch  47/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  47 | Train: 0.2489 | Val: 0.2498 | Speech: 97.5%


Epoch  48/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  48 | Train: 0.2498 | Val: 0.2497 | Speech: 97.8%


Epoch  49/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  49 | Train: 0.2507 | Val: 0.2505 | Speech: 97.5%


Epoch  50/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  50 | Train: 0.2483 | Val: 0.2489 | Speech: 97.6%



Epoch  51/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  51 | Train: 0.2480 | Val: 0.2485 | Speech: 97.8%
        ‚úÖ Best model saved!


Epoch  52/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  52 | Train: 0.2465 | Val: 0.2486 | Speech: 97.6%


Epoch  53/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  53 | Train: 0.2474 | Val: 0.2497 | Speech: 97.7%


Epoch  54/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  54 | Train: 0.2475 | Val: 0.2481 | Speech: 97.5%
        ‚úÖ Best model saved!


Epoch  55/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  55 | Train: 0.2460 | Val: 0.2546 | Speech: 97.6%



Epoch  56/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  56 | Train: 0.2586 | Val: 0.2534 | Speech: 97.8%


Epoch  57/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  57 | Train: 0.2505 | Val: 0.2476 | Speech: 97.8%
        ‚úÖ Best model saved!


Epoch  58/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  58 | Train: 0.2469 | Val: 0.2487 | Speech: 97.6%


Epoch  59/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  59 | Train: 0.2480 | Val: 0.2467 | Speech: 97.7%
        ‚úÖ Best model saved!


Epoch  60/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  60 | Train: 0.2498 | Val: 0.2472 | Speech: 97.5%



Epoch  61/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  61 | Train: 0.2466 | Val: 0.2515 | Speech: 97.5%


Epoch  62/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  62 | Train: 0.2503 | Val: 0.2512 | Speech: 97.7%


Epoch  63/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  63 | Train: 0.2463 | Val: 0.2601 | Speech: 97.5%


Epoch  64/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  64 | Train: 0.2447 | Val: 0.2479 | Speech: 97.8%


Epoch  65/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  65 | Train: 0.2456 | Val: 0.2502 | Speech: 97.5%



Epoch  66/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  66 | Train: 0.2454 | Val: 0.2443 | Speech: 97.7%
        ‚úÖ Best model saved!


Epoch  67/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  67 | Train: 0.2437 | Val: 0.2424 | Speech: 97.5%
        ‚úÖ Best model saved!


Epoch  68/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  68 | Train: 0.2434 | Val: 0.2442 | Speech: 97.7%


Epoch  69/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  69 | Train: 0.2424 | Val: 0.2465 | Speech: 97.6%


Epoch  70/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  70 | Train: 0.2427 | Val: 0.2450 | Speech: 97.7%



Epoch  71/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  71 | Train: 0.2435 | Val: 0.2450 | Speech: 97.7%


Epoch  72/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  72 | Train: 0.2434 | Val: 0.2438 | Speech: 97.6%


Epoch  73/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  73 | Train: 0.2439 | Val: 0.2429 | Speech: 97.4%


Epoch  74/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  74 | Train: 0.2425 | Val: 0.2417 | Speech: 97.7%
        ‚úÖ Best model saved!


Epoch  75/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  75 | Train: 0.2420 | Val: 0.2413 | Speech: 97.7%
        ‚úÖ Best model saved!



Epoch  76/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  76 | Train: 0.2432 | Val: 0.2424 | Speech: 97.4%


Epoch  77/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  77 | Train: 0.2410 | Val: 0.2425 | Speech: 97.5%


Epoch  78/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  78 | Train: 0.2405 | Val: 0.2457 | Speech: 97.7%


Epoch  79/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  79 | Train: 0.2405 | Val: 0.2429 | Speech: 97.6%


Epoch  80/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  80 | Train: 0.2403 | Val: 0.2454 | Speech: 97.6%



Epoch  81/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  81 | Train: 0.2410 | Val: 0.2398 | Speech: 97.2%
        ‚úÖ Best model saved!


Epoch  82/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  82 | Train: 0.2409 | Val: 0.2414 | Speech: 97.6%


Epoch  83/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  83 | Train: 0.2413 | Val: 0.2451 | Speech: 97.6%


Epoch  84/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  84 | Train: 0.2408 | Val: 0.2423 | Speech: 97.7%


Epoch  85/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  85 | Train: 0.2396 | Val: 0.2407 | Speech: 97.7%



Epoch  86/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  86 | Train: 0.2409 | Val: 0.2442 | Speech: 97.6%


Epoch  87/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  87 | Train: 0.2418 | Val: 0.2430 | Speech: 97.8%


Epoch  88/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  88 | Train: 0.2404 | Val: 0.2430 | Speech: 97.8%


Epoch  89/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  89 | Train: 0.2405 | Val: 0.2415 | Speech: 97.7%


Epoch  90/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  90 | Train: 0.2397 | Val: 0.2418 | Speech: 97.7%



Epoch  91/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  91 | Train: 0.2408 | Val: 0.2395 | Speech: 97.4%
        ‚úÖ Best model saved!


Epoch  92/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  92 | Train: 0.2398 | Val: 0.2403 | Speech: 97.6%


Epoch  93/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  93 | Train: 0.2387 | Val: 0.2412 | Speech: 97.7%


Epoch  94/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  94 | Train: 0.2391 | Val: 0.2420 | Speech: 97.5%


Epoch  95/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  95 | Train: 0.2395 | Val: 0.2377 | Speech: 97.6%
        ‚úÖ Best model saved!



Epoch  96/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  96 | Train: 0.2396 | Val: 0.2376 | Speech: 97.9%
        ‚úÖ Best model saved!


Epoch  97/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  97 | Train: 0.2401 | Val: 0.2403 | Speech: 97.7%


Epoch  98/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  98 | Train: 0.2402 | Val: 0.2419 | Speech: 97.6%


Epoch  99/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch  99 | Train: 0.2391 | Val: 0.2387 | Speech: 97.6%


Epoch 100/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 100 | Train: 0.2383 | Val: 0.2387 | Speech: 97.4%



Epoch 101/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 101 | Train: 0.2381 | Val: 0.2392 | Speech: 97.6%


Epoch 102/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 102 | Train: 0.2379 | Val: 0.2388 | Speech: 97.4%


Epoch 103/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 103 | Train: 0.2387 | Val: 0.2388 | Speech: 97.3%


Epoch 104/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 104 | Train: 0.2379 | Val: 0.2402 | Speech: 97.6%


Epoch 105/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 105 | Train: 0.2385 | Val: 0.2396 | Speech: 97.5%



Epoch 106/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 106 | Train: 0.2388 | Val: 0.2392 | Speech: 97.5%


Epoch 107/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 107 | Train: 0.2387 | Val: 0.2389 | Speech: 97.7%


Epoch 108/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 108 | Train: 0.2378 | Val: 0.2382 | Speech: 97.6%


Epoch 109/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 109 | Train: 0.2366 | Val: 0.2380 | Speech: 97.6%


Epoch 110/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 110 | Train: 0.2368 | Val: 0.2420 | Speech: 97.5%



Epoch 111/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 111 | Train: 0.2372 | Val: 0.2365 | Speech: 97.5%
        ‚úÖ Best model saved!


Epoch 112/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 112 | Train: 0.2373 | Val: 0.2375 | Speech: 97.6%


Epoch 113/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 113 | Train: 0.2369 | Val: 0.2444 | Speech: 97.9%


Epoch 114/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 114 | Train: 0.2373 | Val: 0.2394 | Speech: 97.6%


Epoch 115/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 115 | Train: 0.2375 | Val: 0.2389 | Speech: 97.5%



Epoch 116/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 116 | Train: 0.2374 | Val: 0.2376 | Speech: 97.7%


Epoch 117/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 117 | Train: 0.2375 | Val: 0.2359 | Speech: 97.4%
        ‚úÖ Best model saved!


Epoch 118/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 118 | Train: 0.2378 | Val: 0.2441 | Speech: 97.6%


Epoch 119/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 119 | Train: 0.2392 | Val: 0.2432 | Speech: 97.5%


Epoch 120/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 120 | Train: 0.2373 | Val: 0.2380 | Speech: 97.8%



Epoch 121/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 121 | Train: 0.2369 | Val: 0.2377 | Speech: 97.5%


Epoch 122/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 122 | Train: 0.2358 | Val: 0.2366 | Speech: 97.3%


Epoch 123/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 123 | Train: 0.2369 | Val: 0.2375 | Speech: 97.8%


Epoch 124/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 124 | Train: 0.2357 | Val: 0.2356 | Speech: 97.6%
        ‚úÖ Best model saved!


Epoch 125/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 125 | Train: 0.2371 | Val: 0.2417 | Speech: 97.5%



Epoch 126/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 126 | Train: 0.2369 | Val: 0.2384 | Speech: 97.8%


Epoch 127/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 127 | Train: 0.2360 | Val: 0.2365 | Speech: 97.5%


Epoch 128/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 128 | Train: 0.2364 | Val: 0.2371 | Speech: 97.9%


Epoch 129/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 129 | Train: 0.2392 | Val: 0.2354 | Speech: 97.6%
        ‚úÖ Best model saved!


Epoch 130/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 130 | Train: 0.2361 | Val: 0.2393 | Speech: 97.8%



Epoch 131/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 131 | Train: 0.2350 | Val: 0.2383 | Speech: 97.6%


Epoch 132/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 132 | Train: 0.2360 | Val: 0.2364 | Speech: 97.4%


Epoch 133/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 133 | Train: 0.2359 | Val: 0.2369 | Speech: 97.7%


Epoch 134/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 134 | Train: 0.2352 | Val: 0.2381 | Speech: 97.4%


Epoch 135/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 135 | Train: 0.2361 | Val: 0.2381 | Speech: 97.6%



Epoch 136/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 136 | Train: 0.2369 | Val: 0.2376 | Speech: 97.6%


Epoch 137/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 137 | Train: 0.2366 | Val: 0.2402 | Speech: 97.7%


Epoch 138/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 138 | Train: 0.2366 | Val: 0.2351 | Speech: 97.6%
        ‚úÖ Best model saved!


Epoch 139/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 139 | Train: 0.2352 | Val: 0.2376 | Speech: 97.7%


Epoch 140/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 140 | Train: 0.2368 | Val: 0.2386 | Speech: 97.6%



Epoch 141/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 141 | Train: 0.2359 | Val: 0.2361 | Speech: 97.4%


Epoch 142/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 142 | Train: 0.2350 | Val: 0.2389 | Speech: 97.7%


Epoch 143/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 143 | Train: 0.2354 | Val: 0.2373 | Speech: 97.8%


Epoch 144/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 144 | Train: 0.2351 | Val: 0.2374 | Speech: 97.5%


Epoch 145/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 145 | Train: 0.2352 | Val: 0.2396 | Speech: 97.7%



Epoch 146/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 146 | Train: 0.2364 | Val: 0.2364 | Speech: 97.6%


Epoch 147/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 147 | Train: 0.2345 | Val: 0.2346 | Speech: 97.7%
        ‚úÖ Best model saved!


Epoch 148/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 148 | Train: 0.2357 | Val: 0.2342 | Speech: 97.6%
        ‚úÖ Best model saved!


Epoch 149/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 149 | Train: 0.2342 | Val: 0.2353 | Speech: 97.6%


Epoch 150/150:   0%|                                                  | 0/803 [00:00<?, ?it/s, ]


    Epoch 150 | Train: 0.2349 | Val: 0.2351 | Speech: 97.6%


‚úÖ Training Complete!
üìä Model Parameters: 11,724
üíæ Model saved to: /content/drive/MyDrive/PPSI-Net-Split-v3/checkpoints/phase/best_phase_cnn.pth


In [15]:
# üî• ‰∏¥Êó∂‰øÆÂ§çÔºöÁõ¥Êé•ÈáçÊñ∞ÂÆö‰πâÂáΩÊï∞
def visualize_reconstruction_comparison(
    original_waveform,
    reconstructed_waveform,
    sample_rate=16000,
    n_fft=1024,
    hop_length=256,
    win_length=1024,
    save_path=None,
    show_plot=True,
    original_magnitude=None  # üî• Êñ∞Â¢ûÂèÇÊï∞
):
    """Comprehensive visualization with optional original magnitude"""

    import torch
    import numpy as np
    import matplotlib.pyplot as plt

    print("üìä Generating comprehensive visualization...")

    # Calculate metrics
    lsc = calculate_lsc(original_waveform, reconstructed_waveform, n_fft, hop_length, win_length)
    estoi_score = calculate_estoi(original_waveform, reconstructed_waveform, sample_rate)
    pesq_score = calculate_pesq(original_waveform, reconstructed_waveform, sample_rate)

    # Compute STFTs
    window = torch.hann_window(win_length)

    original_tensor = torch.from_numpy(original_waveform).float()
    stft_original = torch.stft(
        original_tensor, n_fft=n_fft, hop_length=hop_length,
        win_length=win_length, window=window,
        return_complex=True, center=True, normalized=False
    )

    reconstructed_tensor = torch.from_numpy(reconstructed_waveform).float()
    stft_reconstructed = torch.stft(
        reconstructed_tensor, n_fft=n_fft, hop_length=hop_length,
        win_length=win_length, window=window,
        return_complex=True, center=True, normalized=False
    )

    # üî• Use original magnitude if provided
    if original_magnitude is not None:
        mag_original = original_magnitude
        print("   ‚úÖ Using original magnitude from inference")
    else:
        mag_original = torch.abs(stft_original).numpy()
        print("   ‚ö†Ô∏è  Recomputing magnitude from waveform")

    phase_original = torch.angle(stft_original).numpy()
    mag_reconstructed = torch.abs(stft_reconstructed).numpy()
    phase_reconstructed = torch.angle(stft_reconstructed).numpy()

    # Convert to dB
    mag_original_db = 20 * np.log10(mag_original + 1e-8)
    mag_reconstructed_db = 20 * np.log10(mag_reconstructed + 1e-8)

    # Calculate errors
    mag_diff = np.abs(mag_original_db - mag_reconstructed_db)
    phase_diff = np.abs(np.angle(np.exp(1j * phase_original) / np.exp(1j * phase_reconstructed)))

    # Time and frequency axes
    times = np.arange(mag_original.shape[1]) * hop_length / sample_rate
    freqs = np.arange(mag_original.shape[0]) * sample_rate / n_fft / 1000

    min_len = min(len(original_waveform), len(reconstructed_waveform))
    original_clip = original_waveform[:min_len]
    reconstructed_clip = reconstructed_waveform[:min_len]

    mse_wav = np.mean((original_clip - reconstructed_clip) ** 2)
    snr = 10 * np.log10(np.var(original_clip) / (mse_wav + 1e-8))
    mag_corr = np.corrcoef(mag_original.flatten(), mag_reconstructed.flatten())[0, 1]
    mag_error_mean = np.mean(mag_diff)

    # Create figure
    fig = plt.figure(figsize=(20, 14))
    gs = fig.add_gridspec(4, 3, hspace=0.35, wspace=0.3)

    # Row 1: Magnitude
    ax1 = fig.add_subplot(gs[0, 0])
    im1 = ax1.imshow(mag_original_db, aspect='auto', origin='lower',
                     cmap='viridis', extent=[times[0], times[-1], freqs[0], freqs[-1]])
    ax1.set_title('Original - Magnitude (dB)', fontweight='bold')
    ax1.set_ylabel('Frequency (kHz)')
    plt.colorbar(im1, ax=ax1, label='dB')

    ax2 = fig.add_subplot(gs[0, 1])
    im2 = ax2.imshow(mag_reconstructed_db, aspect='auto', origin='lower',
                     cmap='viridis', extent=[times[0], times[-1], freqs[0], freqs[-1]])
    ax2.set_title('Reconstructed - Magnitude (dB)', fontweight='bold')
    ax2.set_ylabel('Frequency (kHz)')
    plt.colorbar(im2, ax=ax2, label='dB')

    ax3 = fig.add_subplot(gs[0, 2])
    im3 = ax3.imshow(mag_diff, aspect='auto', origin='lower',
                     cmap='hot', extent=[times[0], times[-1], freqs[0], freqs[-1]])
    title_str = f'Magnitude Error (mean={mag_error_mean:.2f} dB)'
    if original_magnitude is not None:
        title_str += '\n‚úÖ Using original magnitude'
    ax3.set_title(title_str, fontweight='bold', fontsize=10)
    ax3.set_ylabel('Frequency (kHz)')
    plt.colorbar(im3, ax=ax3, label='Error (dB)')

    # Row 2: Phase
    ax4 = fig.add_subplot(gs[1, 0])
    im4 = ax4.imshow(phase_original, aspect='auto', origin='lower',
                     cmap='twilight', vmin=-np.pi, vmax=np.pi,
                     extent=[times[0], times[-1], freqs[0], freqs[-1]])
    ax4.set_title('Original - Phase', fontweight='bold')
    ax4.set_ylabel('Frequency (kHz)')
    plt.colorbar(im4, ax=ax4, label='rad')

    ax5 = fig.add_subplot(gs[1, 1])
    im5 = ax5.imshow(phase_reconstructed, aspect='auto', origin='lower',
                     cmap='twilight', vmin=-np.pi, vmax=np.pi,
                     extent=[times[0], times[-1], freqs[0], freqs[-1]])
    ax5.set_title('Reconstructed - Phase', fontweight='bold')
    ax5.set_ylabel('Frequency (kHz)')
    plt.colorbar(im5, ax=ax5, label='rad')

    ax6 = fig.add_subplot(gs[1, 2])
    im6 = ax6.imshow(phase_diff, aspect='auto', origin='lower',
                     cmap='hot', vmin=0, vmax=np.pi,
                     extent=[times[0], times[-1], freqs[0], freqs[-1]])
    ax6.set_title('Phase Error', fontweight='bold')
    ax6.set_ylabel('Frequency (kHz)')
    plt.colorbar(im6, ax=ax6, label='rad')

    # Row 3: Waveform
    ax7 = fig.add_subplot(gs[2, :])
    time_axis = np.arange(min_len) / sample_rate
    ax7.plot(time_axis, original_clip, 'b-', alpha=0.7, linewidth=0.8, label='Original')
    ax7.plot(time_axis, reconstructed_clip, 'r-', alpha=0.7, linewidth=0.8, label='Reconstructed')
    ax7.set_title('Waveform Comparison', fontweight='bold')
    ax7.set_xlabel('Time (s)')
    ax7.set_ylabel('Amplitude')
    ax7.legend()
    ax7.grid(True, alpha=0.3)
    ax7.set_xlim([0, min(2.0, time_axis[-1])])

    # Row 4: Metrics
    ax8 = fig.add_subplot(gs[3, :])
    ax8.axis('off')

    pesq_rating = '‚úÖ Excellent' if pesq_score > 3.5 else 'üëç Good' if pesq_score > 3.0 else '‚ö†Ô∏è Fair'
    estoi_rating = '‚úÖ Excellent' if estoi_score > 0.85 else 'üëç Good' if estoi_score > 0.75 else '‚ö†Ô∏è Fair'
    lsc_rating = '‚úÖ Excellent' if lsc < 0.5 else 'üëç Good' if lsc < 1.0 else '‚ö†Ô∏è Fair'

    metrics_text = f"""
    QUALITY METRICS

    PERCEPTUAL:
      ‚Ä¢ PESQ:  {pesq_score:.3f}  {pesq_rating}
      ‚Ä¢ ESTOI: {estoi_score:.4f}  {estoi_rating}
      ‚Ä¢ LSC:   {lsc:.4f}  {lsc_rating}

    SIGNAL:
      ‚Ä¢ SNR:      {snr:.2f} dB
      ‚Ä¢ Mag Corr: {mag_corr:.4f}
      ‚Ä¢ Mag Err:  {mag_error_mean:.2f} dB
    """

    ax8.text(0.05, 0.5, metrics_text, fontsize=11, family='monospace',
             verticalalignment='center',
             bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.3))

    plt.suptitle('PPSI-Net: Reconstruction Quality Analysis',
                 fontsize=16, fontweight='bold', y=0.98)

    if save_path:
        plt.savefig(save_path, dpi=150, bbox_inches='tight')
        print(f"   üíæ Saved to: {save_path}")

    if show_plot:
        plt.show()
    else:
        plt.close()

    print(f"\nüìã PESQ: {pesq_score:.3f} | ESTOI: {estoi_score:.4f} | SNR: {snr:.2f} dB")

    return {
        'pesq': pesq_score,
        'estoi': estoi_score,
        'lsc': lsc,
        'snr': snr,
        'mag_corr': mag_corr
    }

print("‚úÖ Function redefined with original_magnitude parameter!")

‚úÖ Function redefined with original_magnitude parameter!


In [16]:
# ============================================================================
# CELL 16: üî• Inference & Evaluation with VAD
# ============================================================================

def run_inference_and_evaluation():
    """üî• Run inference with VAD-enhanced solver"""

    print("\n" + "="*70)
    print("üî• PPSI-Net Inference & Evaluation (with VAD)")
    print("="*70)

    config = Config()
    model_path = f'{config.save_dir}/phase/best_phase_cnn.pth'

    if not os.path.exists(model_path):
        print(f"‚ùå Model not found at: {model_path}")
        return None

    print(f"‚úÖ Found trained model")

    # üî• Create inferencer with VAD
    print("\nüîß Loading inference model with VAD...")
    inferencer = PPSIInference(model_path, config)

    # Load validation dataset
    print("\nüìä Loading validation dataset...")
    val_dataset = LibriSpeechSpectrogramDataset(
        root_dir=config.librispeech_root,
        subset='val',
        n_fft=config.n_fft,
        hop_length=config.hop_length,
        win_length=config.win_length,
        sample_rate=config.sample_rate,
        duration=config.duration
    )

    vis_dir = f'{PROJECT_ROOT}/logs/visualizations'
    os.makedirs(vis_dir, exist_ok=True)

    # === üî• Test VAD visualization ===
    print("\n" + "="*70)
    print("üî• Part 1: VAD Effect Visualization")
    print("="*70)

    sample = val_dataset[0]
    mag_spec = sample['amplitude_abs'].unsqueeze(0)  # [1, 1, T, F]

    visualize_vad_effect(
        mag_spec,
        sample_rate=config.sample_rate,
        hop_length=config.hop_length,
        vad_threshold_db=config.vad_threshold_db
    )

    # === Single sample reconstruction ===
    print("\n" + "="*70)
    print("üé® Part 2: Sample Reconstruction")
    print("="*70)

    waveform = sample['waveform'].numpy()
    # üî• Get both reconstructed audio AND original magnitude
    reconstructed, mag_original = inferencer.reconstruct_audio(waveform, return_magnitude=True)

    save_path = f'{vis_dir}/sample_0_comparison_vad.png'
    metrics = visualize_reconstruction_comparison(
        waveform,
        reconstructed,
        sample_rate=config.sample_rate,
        n_fft=config.n_fft,
        hop_length=config.hop_length,
        win_length=config.win_length,
        save_path=save_path,
        original_magnitude=mag_original  # üî• Pass original magnitude
    )

    # Audio playback
    print("\nüéß Audio Playback:")
    print("   Original:")
    display(Audio(waveform / (np.abs(waveform).max() + 1e-8), rate=config.sample_rate))
    print("\n   Reconstructed:")
    display(Audio(reconstructed / (np.abs(reconstructed).max() + 1e-8), rate=config.sample_rate))

    # === Batch evaluation ===
    print("\n" + "="*70)
    print("üìä Part 3: Batch Evaluation")
    print("="*70)

    num_eval = 10
    all_metrics = {'pesq': [], 'estoi': [], 'lsc': [], 'snr': []}

    for idx in tqdm(range(num_eval), desc="Evaluating"):
        try:
            sample = val_dataset[idx]
            waveform = sample['waveform'].numpy()
            # üî• Don't need magnitude for batch evaluation (just metrics)
            reconstructed = inferencer.reconstruct_audio(waveform, return_magnitude=False)

            pesq_score = calculate_pesq(waveform, reconstructed)
            estoi_score = calculate_estoi(waveform, reconstructed)
            lsc = calculate_lsc(waveform, reconstructed)

            min_len = min(len(waveform), len(reconstructed))
            mse = np.mean((waveform[:min_len] - reconstructed[:min_len]) ** 2)
            snr = 10 * np.log10(np.var(waveform[:min_len]) / (mse + 1e-8))

            all_metrics['pesq'].append(pesq_score)
            all_metrics['estoi'].append(estoi_score)
            all_metrics['lsc'].append(lsc)
            all_metrics['snr'].append(snr)
        except Exception as e:
            print(f"‚ö†Ô∏è Sample {idx} failed: {e}")

    # Print results
    print("\n" + "="*70)
    print("üìà EVALUATION RESULTS (with VAD)")
    print("="*70)
    print(f"\nüéØ Average Metrics:")
    print(f"  ‚Ä¢ PESQ:  {np.mean(all_metrics['pesq']):.3f} ¬± {np.std(all_metrics['pesq']):.3f}")
    print(f"  ‚Ä¢ ESTOI: {np.mean(all_metrics['estoi']):.4f} ¬± {np.std(all_metrics['estoi']):.4f}")
    print(f"  ‚Ä¢ LSC:   {np.mean(all_metrics['lsc']):.4f} ¬± {np.std(all_metrics['lsc']):.4f}")
    print(f"  ‚Ä¢ SNR:   {np.mean(all_metrics['snr']):.2f} ¬± {np.std(all_metrics['snr']):.2f} dB")
    print("="*70)

    return {
        'inferencer': inferencer,
        'val_dataset': val_dataset,
        'metrics': all_metrics
    }

# Run inference
print("\nüöÄ Running inference with VAD...")
inference_results = run_inference_and_evaluation()

print("\n" + "="*70)
print("üéâ All Done! PPSI-Net with VAD is ready!")
print("="*70)


üöÄ Running inference with VAD...

üî• PPSI-Net Inference & Evaluation (with VAD)
‚úÖ Found trained model

üîß Loading inference model with VAD...
    ‚úÖ Using O(L) tridiagonal solver
    üî• Solver VAD enabled (threshold: -5 dB)


NameError: name 'MinimumPhaseInitializer' is not defined