In [1]:
from pathlib import Path

import librosa
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from pesq import pesq
from pystoi import stoi
import soundfile as sf
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm

In [2]:
torch.manual_seed(777)
np.random.seed(777)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

Using device: cuda


In [51]:
class Config:
    sample_rate = 16000
    n_fft = 512
    hop_length = 128
    win_length = 512

    hidden_dim = 256
    num_layers = 3
    dropout = 0.1

    batch_size = 4
    num_epochs = 10
    learning_rate = 1e-3
    gradient_clip = 5.0

    train_clean_path = "./data/wav_clean_train"
    train_noisy_path = "./data/wav_noisy_train"
    test_clean_path = "./data/wav_clean_test"
    test_noisy_path = "./data/wav_noisy_test"
    checkpoint_path = "./checkpoints"
    results_path = "./results"


config = Config()

In [52]:
class VoiceBankDataset(Dataset):
    def __init__(self, clean_dir, noisy_dir, max_len=4*16000, mode='train'):
        self.clean_dir = Path(clean_dir)
        self.noisy_dir = Path(noisy_dir)
        self.max_len = max_len
        self.mode = mode
        
        self.clean_files = sorted(list(self.clean_dir.glob("*.wav")))
        self.noisy_files = sorted(list(self.noisy_dir.glob("*.wav")))
        
        if mode == 'test':
            self.file_pairs = []
            clean_dict = {f.name: f for f in self.clean_files}
            for noisy_file in self.noisy_files:
                clean_file = clean_dict.get(noisy_file.name)
                if clean_file:
                    self.file_pairs.append((clean_file, noisy_file))
        else:
            assert len(self.clean_files) == len(self.noisy_files), \
                f"Mismatch: {len(self.clean_files)} clean vs {len(self.noisy_files)} noisy"
            self.file_pairs = list(zip(self.clean_files, self.noisy_files))
        
        print(f"Loaded {len(self.file_pairs)} file pairs for {mode} mode")
    
    def __len__(self):
        return len(self.file_pairs)
    
    def __getitem__(self, idx):
        clean_file, noisy_file = self.file_pairs[idx]
        
        clean_wav, sr = torchaudio.load(clean_file)
        noisy_wav, _ = torchaudio.load(noisy_file)
        
        if clean_wav.shape[0] > 1:
            clean_wav = clean_wav.mean(dim=0, keepdim=True)
        if noisy_wav.shape[0] > 1:
            noisy_wav = noisy_wav.mean(dim=0, keepdim=True)
        
        if self.mode == 'train' and self.max_len:
            if clean_wav.shape[1] > self.max_len:
                start = torch.randint(0, clean_wav.shape[1] - self.max_len, (1,))
                clean_wav = clean_wav[:, start:start+self.max_len]
                noisy_wav = noisy_wav[:, start:start+self.max_len]
            else:
                pad_len = self.max_len - clean_wav.shape[1]
                clean_wav = F.pad(clean_wav, (0, pad_len))
                noisy_wav = F.pad(noisy_wav, (0, pad_len))
        
        return {
            'noisy': noisy_wav.squeeze(0),
            'clean': clean_wav.squeeze(0),
            'filename': clean_file.name,
            'clean_path': str(clean_file),
            'noisy_path': str(noisy_file)
        }

In [53]:
class STFT(nn.Module):
    def __init__(self, n_fft=512, hop_length=128, win_length=512):
        super().__init__()
        self.n_fft = n_fft
        self.hop_length = hop_length
        self.win_length = win_length
        self.window = torch.hann_window(win_length)

    def forward(self, x):
        self.window = self.window.to(x.device)
        spec = torch.stft(
            x,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=self.window,
            return_complex=True,
        )
        return spec

    def inverse(self, spec):
        self.window = self.window.to(spec.device)
        wav = torch.istft(
            spec,
            n_fft=self.n_fft,
            hop_length=self.hop_length,
            win_length=self.win_length,
            window=self.window,
        )
        return wav

In [54]:
class xLSTMBlock(nn.Module):
    """
    Упрощенная реализация xLSTM блока для бейзлайна
    """

    def __init__(self, input_dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.lstm = nn.LSTM(
            input_dim,
            hidden_dim,
            num_layers=1,
            batch_first=True,
            dropout=0,
            bidirectional=True,
        )
        self.norm = nn.LayerNorm(hidden_dim * 2)
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(hidden_dim * 2, input_dim)

    def forward(self, x):
        B, T, n_freq, C = x.shape

        # Применяем LSTM по временной оси
        x_reshaped = x.permute(0, 2, 1, 3).reshape(B * n_freq, T, C)
        out, _ = self.lstm(x_reshaped)
        out = self.norm(out)
        out = self.dropout(out)
        out = self.proj(out)
        out = out.reshape(B, n_freq, T, C).permute(0, 2, 1, 3)

        return out + x  # Residual connection

In [55]:
class MambaBlock(nn.Module):
    """
    Упрощенная реализация Mamba блока для бейзлайна
    """

    def __init__(self, input_dim, hidden_dim, dropout=0.1):
        super().__init__()
        # Используем Conv1D как аппроксимацию SSM
        self.conv = nn.Conv1d(
            input_dim,
            hidden_dim,
            kernel_size=3,
            padding=1,
            groups=input_dim // 4 if input_dim >= 4 else 1,
        )
        self.norm = nn.LayerNorm(hidden_dim)
        self.dropout = nn.Dropout(dropout)
        self.proj = nn.Linear(hidden_dim, input_dim)
        self.activation = nn.SiLU()

    def forward(self, x):
        B, T, n_freq, C = x.shape

        # Применяем по частотной оси
        x_reshaped = x.permute(0, 1, 3, 2).reshape(B * T, C, n_freq)
        out = self.conv(x_reshaped)
        out = out.permute(0, 2, 1).reshape(B, T, n_freq, -1)
        out = self.norm(out)
        out = self.activation(out)
        out = self.dropout(out)
        out = self.proj(out)

        return out + x  # Residual connection

In [None]:
class HybridBlock(nn.Module):
    """
    Комбинация xLSTM (для времени) и Mamba (для частоты)
    """
    def __init__(self, input_dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.xlstm = xLSTMBlock(input_dim, hidden_dim, dropout)
        self.mamba = MambaBlock(input_dim, hidden_dim, dropout)
        self.norm = nn.LayerNorm(input_dim)
        
    def forward(self, x):
        x = self.xlstm(x)
        x = self.mamba(x)
        x = self.norm(x)
        return x

In [57]:
class Encoder(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers, dropout=0.1):
        super().__init__()
        self.input_proj = nn.Linear(input_dim, hidden_dim)
        self.layers = nn.ModuleList([
            HybridBlock(hidden_dim, hidden_dim, dropout)
            for _ in range(num_layers)
        ])
        
    def forward(self, x):
        x = self.input_proj(x)
        for layer in self.layers:
            x = layer(x)
        return x

In [58]:
class MagnitudeDecoder(nn.Module):
    def __init__(self, hidden_dim, output_dim, num_layers, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList(
            [HybridBlock(hidden_dim, hidden_dim, dropout) for _ in range(num_layers)]
        )
        self.output_proj = nn.Sequential(nn.Linear(hidden_dim, output_dim), nn.ReLU())

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.output_proj(x)
        return x

In [59]:
class PhaseDecoder(nn.Module):
    def __init__(self, hidden_dim, output_dim, num_layers, dropout=0.1):
        super().__init__()
        self.layers = nn.ModuleList(
            [HybridBlock(hidden_dim, hidden_dim, dropout) for _ in range(num_layers)]
        )
        self.output_proj = nn.Sequential(nn.Linear(hidden_dim, output_dim), nn.Tanh())

    def forward(self, x):
        for layer in self.layers:
            x = layer(x)
        x = self.output_proj(x)
        return x * np.pi  # Масштабируем в [-pi, pi]

In [None]:
class MPSENet(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.stft = STFT(
            n_fft=config.n_fft,
            hop_length=config.hop_length,
            win_length=config.win_length,
        )

        freq_dim = config.n_fft // 2 + 1

        self.encoder = Encoder(
            input_dim=2,  # magnitude + phase
            hidden_dim=config.hidden_dim,
            num_layers=config.num_layers,
            dropout=config.dropout,
        )

        # Параллельные декодеры
        self.magnitude_decoder = MagnitudeDecoder(
            hidden_dim=config.hidden_dim,
            output_dim=1,
            num_layers=config.num_layers // 2,
            dropout=config.dropout,
        )

        self.phase_decoder = PhaseDecoder(
            hidden_dim=config.hidden_dim,
            output_dim=1,
            num_layers=config.num_layers // 2,
            dropout=config.dropout,
        )

    def forward(self, noisy_wav):
        original_length = noisy_wav.shape[-1]
        
        # STFT
        noisy_spec = self.stft(noisy_wav)
        noisy_mag = torch.abs(noisy_spec)
        noisy_phase = torch.angle(noisy_spec)
        

        encoder_input = torch.stack([noisy_mag, noisy_phase], dim=-1)
        encoder_input = encoder_input.permute(0, 2, 1, 3)
        encoded = self.encoder(encoder_input)

        enhanced_mag = self.magnitude_decoder(encoded).squeeze(-1).permute(0, 2, 1)
        enhanced_phase = self.phase_decoder(encoded).squeeze(-1).permute(0, 2, 1)

        enhanced_spec = torch.polar(enhanced_mag, enhanced_phase)
        
        # iSTFT
        enhanced_wav = self.stft.inverse(enhanced_spec)

        if enhanced_wav.shape[-1] > original_length:
            enhanced_wav = enhanced_wav[..., :original_length]
        elif enhanced_wav.shape[-1] < original_length:
            pad_len = original_length - enhanced_wav.shape[-1]
            enhanced_wav = F.pad(enhanced_wav, (0, pad_len))
        
        return enhanced_wav, enhanced_mag, enhanced_phase, noisy_mag, noisy_phase

In [61]:
class CombinedLoss(nn.Module):
    def __init__(self, alpha=0.7, beta=0.2, gamma=0.1):
        super().__init__()
        self.alpha = alpha  # вес для time-domain loss
        self.beta = beta  # вес для magnitude loss
        self.gamma = gamma  # вес для phase loss

    def forward(
        self, pred_wav, clean_wav, pred_mag, clean_mag, pred_phase, clean_phase
    ):
        # Time domain loss
        time_loss = F.l1_loss(pred_wav, clean_wav)

        # Magnitude loss
        mag_loss = F.mse_loss(pred_mag, clean_mag)

        # Phase loss
        phase_diff = torch.cos(pred_phase - clean_phase)
        phase_loss = 1 - phase_diff.mean()

        total_loss = (
            self.alpha * time_loss + self.beta * mag_loss + self.gamma * phase_loss
        )

        return total_loss, {
            "time_loss": time_loss.item(),
            "mag_loss": mag_loss.item(),
            "phase_loss": phase_loss.item(),
        }

In [None]:
class URGENTMetrics:
    """
    Метрики из URGENT Challenge 2026
    Репозиторий: https://github.com/urgent-challenge/urgent2026_challenge_track1
    """

    @staticmethod
    def compute_pesq(clean_path, enhanced_path, sr=16000):
        """PESQ - Perceptual Evaluation of Speech Quality"""
        try:
            clean, _ = sf.read(clean_path)
            enhanced, _ = sf.read(enhanced_path)

            min_len = min(len(clean), len(enhanced))
            clean = clean[:min_len]
            enhanced = enhanced[:min_len]

            score = pesq(sr, clean, enhanced, "wb")
            return score
        except Exception as e:
            print(f"Error computing PESQ: {e}")
            return 0.0

    @staticmethod
    def compute_stoi(clean_path, enhanced_path, sr=16000):
        """STOI - Short-Time Objective Intelligibility"""
        try:
            clean, _ = sf.read(clean_path)
            enhanced, _ = sf.read(enhanced_path)

            min_len = min(len(clean), len(enhanced))
            clean = clean[:min_len]
            enhanced = enhanced[:min_len]

            score = stoi(clean, enhanced, sr, extended=False)
            return score
        except Exception as e:
            print(f"Error computing STOI: {e}")
            return 0.0

    @staticmethod
    def compute_estoi(clean_path, enhanced_path, sr=16000):
        """Extended STOI"""
        try:
            clean, _ = sf.read(clean_path)
            enhanced, _ = sf.read(enhanced_path)

            min_len = min(len(clean), len(enhanced))
            clean = clean[:min_len]
            enhanced = enhanced[:min_len]

            score = stoi(clean, enhanced, sr, extended=True)
            return score
        except Exception as e:
            print(f"Error computing eSTOI: {e}")
            return 0.0

    @staticmethod
    def compute_si_sdr(clean_path, enhanced_path):
        """SI-SDR - Scale-Invariant Signal-to-Distortion Ratio"""
        try:
            clean, _ = sf.read(clean_path)
            enhanced, _ = sf.read(enhanced_path)

            min_len = min(len(clean), len(enhanced))
            clean = clean[:min_len]
            enhanced = enhanced[:min_len]

            alpha = np.dot(enhanced, clean) / np.dot(clean, clean)
            s_target = alpha * clean
            e_noise = enhanced - s_target

            si_sdr = 10 * np.log10(np.sum(s_target**2) / np.sum(e_noise**2))
            return si_sdr
        except Exception as e:
            print(f"Error computing SI-SDR: {e}")
            return 0.0

    @staticmethod
    def compute_all_metrics(clean_path, enhanced_path, sr=16000):
        """Вычисляет все метрики URGENT Challenge"""
        return {
            "PESQ": URGENTMetrics.compute_pesq(clean_path, enhanced_path, sr),
            "STOI": URGENTMetrics.compute_stoi(clean_path, enhanced_path, sr),
            "eSTOI": URGENTMetrics.compute_estoi(clean_path, enhanced_path, sr),
            "SI-SDR": URGENTMetrics.compute_si_sdr(clean_path, enhanced_path),
        }

In [63]:
def train_epoch(model, train_loader, optimizer, criterion, device):
    model.train()
    total_loss = 0
    loss_components = {'time_loss': 0, 'mag_loss': 0, 'phase_loss': 0}
    
    pbar = tqdm(train_loader, desc="Training")
    for batch in pbar:
        noisy = batch['noisy'].to(device)
        clean = batch['clean'].to(device)
        
        optimizer.zero_grad()
        enhanced_wav, pred_mag, pred_phase, _, _ = model(noisy)
        
        clean_spec = model.stft(clean)
        clean_mag = torch.abs(clean_spec)
        clean_phase = torch.angle(clean_spec)
        
        loss, loss_dict = criterion(
            enhanced_wav, clean, 
            pred_mag, clean_mag, 
            pred_phase, clean_phase
        )
        
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), config.gradient_clip)
        optimizer.step()
        
        total_loss += loss.item()
        for key in loss_components:
            loss_components[key] += loss_dict[key]
        
        pbar.set_postfix({'loss': f"{loss.item():.4f}"})
    
    avg_loss = total_loss / len(train_loader)
    avg_components = {k: v / len(train_loader) for k, v in loss_components.items()}
    
    return avg_loss, avg_components

In [64]:
def validate(model, val_loader, criterion, device, save_audio=False):
    model.eval()
    total_loss = 0
    metrics = {'PESQ': [], 'STOI': [], 'eSTOI': [], 'SI-SDR': []}
    
    enhanced_dir = Path(config.results_path) / "enhanced_audio"
    if save_audio:
        enhanced_dir.mkdir(parents=True, exist_ok=True)
    
    with torch.no_grad():
        pbar = tqdm(val_loader, desc="Validation")
        for batch in pbar:
            noisy = batch['noisy'].to(device)
            clean = batch['clean'].to(device)
            filename = batch['filename'][0]
            clean_path = batch['clean_path'][0]
            
            enhanced_wav, pred_mag, pred_phase, _, _ = model(noisy)
            
            clean_spec = model.stft(clean)
            clean_mag = torch.abs(clean_spec)
            clean_phase = torch.angle(clean_spec)
            
            loss, _ = criterion(
                enhanced_wav, clean,
                pred_mag, clean_mag,
                pred_phase, clean_phase
            )
            
            total_loss += loss.item()
            
            # Сохраняем enhanced audio
            if save_audio:
                enhanced_path = enhanced_dir / filename
            else:
                enhanced_path = "/tmp/" + filename
            
            torchaudio.save(
                str(enhanced_path),
                enhanced_wav[0].cpu().unsqueeze(0),
                config.sample_rate
            )

            batch_metrics = URGENTMetrics.compute_all_metrics(
                clean_path,
                str(enhanced_path),
                sr=config.sample_rate
            )
            
            for key, value in batch_metrics.items():
                metrics[key].append(value)
            
            # Удаляем временный файл
            if not save_audio and Path(enhanced_path).exists():
                Path(enhanced_path).unlink()
            
            pbar.set_postfix({'PESQ': f"{np.mean(metrics['PESQ']):.3f}"})
    
    avg_loss = total_loss / len(val_loader)
    avg_metrics = {k: np.mean(v) for k, v in metrics.items()}
    
    return avg_loss, avg_metrics

In [65]:
def train_model():
    Path(config.checkpoint_path).mkdir(parents=True, exist_ok=True)
    Path(config.results_path).mkdir(parents=True, exist_ok=True)
    
    print("Loading datasets...")
    train_dataset = VoiceBankDataset(
        config.train_clean_path,
        config.train_noisy_path,
        mode='train'
    )
    test_dataset = VoiceBankDataset(
        config.test_clean_path,
        config.test_noisy_path,
        max_len=None,
        mode='test'
    )
    
    train_loader = DataLoader(
        train_dataset,
        batch_size=config.batch_size,
        shuffle=True,
        num_workers=4,
        pin_memory=True
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=1,
        shuffle=False,
        num_workers=2
    )
    
    print("\nInitializing MP-SENet model...")
    model = MPSENet(config).to(device)
    total_params = sum(p.numel() for p in model.parameters())
    print(f"Total parameters: {total_params / 1e6:.2f}M")
    
    optimizer = torch.optim.AdamW(model.parameters(), lr=config.learning_rate, weight_decay=0.01)
    scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
        optimizer, mode='max', factor=0.5, patience=5
    )
    criterion = CombinedLoss()
    
    history = {
        'train_loss': [],
        'val_loss': [],
        'val_pesq': [],
        'val_stoi': [],
        'val_estoi': [],
        'val_sisdr': []
    }
    
    best_pesq = 0.0
    
    print("\nStarting training...")
    print("="*70)
    
    for epoch in range(config.num_epochs):
        print(f"\nEpoch {epoch+1}/{config.num_epochs}")
        print("-" * 70)

        train_loss, train_components = train_epoch(
            model, train_loader, optimizer, criterion, device
        )
        print(f"Train Loss: {train_loss:.4f}")
        print(f"Time: {train_components['time_loss']:.4f}, "
              f"Mag: {train_components['mag_loss']:.4f}, "
              f"Phase: {train_components['phase_loss']:.4f}")

        val_loss, val_metrics = validate(model, test_loader, criterion, device, save_audio=False)
        print(f"Val Loss: {val_loss:.4f}")
        print(f"Val Metrics:")
        print(f"PESQ:   {val_metrics['PESQ']:.4f}")
        print(f"STOI:   {val_metrics['STOI']:.4f}")
        print(f"eSTOI:  {val_metrics['eSTOI']:.4f}")
        print(f"SI-SDR: {val_metrics['SI-SDR']:.2f} dB")

        scheduler.step(val_metrics['PESQ'])

        history['train_loss'].append(train_loss)
        history['val_loss'].append(val_loss)
        history['val_pesq'].append(val_metrics['PESQ'])
        history['val_stoi'].append(val_metrics['STOI'])
        history['val_estoi'].append(val_metrics['eSTOI'])
        history['val_sisdr'].append(val_metrics['SI-SDR'])

        if val_metrics['PESQ'] > best_pesq:
            best_pesq = val_metrics['PESQ']
            torch.save({
                'epoch': epoch,
                'model_state_dict': model.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'pesq': best_pesq,
                'metrics': val_metrics,
                'config': config
            }, f"{config.checkpoint_path}/best_model.pt")
            print(f"Saved best model with PESQ: {best_pesq:.4f}")
            
            if best_pesq >= 3.25:
                print(f"Target PESQ {trainer_ner} achieved!")
    
    pd.DataFrame(history).to_csv(f"{config.results_path}/training_history.csv", index=False)
    print("\n" + "="*70)
    print("Training completed!")
    
    return model, history

In [66]:
print("=" * 70)
print("MP-SENet with xLSTM+Mamba for Speech Enhancement")
print("=" * 70)

print("\n[1] Training model...")
model, history = train_model()

MP-SENet with xLSTM+Mamba for Speech Enhancement

[1] Training model...
Loading datasets...
Loaded 11572 file pairs for train mode
Loaded 824 file pairs for test mode

Initializing MP-SENet model...
Total parameters: 6.28M

Starting training...

Epoch 1/10
----------------------------------------------------------------------


Training: 100%|██████████| 2893/2893 [23:43<00:00,  2.03it/s, loss=0.0687]


Train Loss: 0.1051
Time: 0.0155, Mag: 0.0636, Phase: 0.8150


Validation: 100%|██████████| 824/824 [03:11<00:00,  4.31it/s, PESQ=1.791]


Val Loss: 0.0649
Val Metrics:
PESQ:   1.7912
STOI:   0.8896
eSTOI:  0.7329
SI-SDR: 15.43 dB
Saved best model with PESQ: 1.7912

Epoch 2/10
----------------------------------------------------------------------


Training: 100%|██████████| 2893/2893 [23:42<00:00,  2.03it/s, loss=0.0471]


Train Loss: 0.0572
Time: 0.0079, Mag: 0.0379, Phase: 0.4403


Validation: 100%|██████████| 824/824 [03:09<00:00,  4.35it/s, PESQ=2.000]


Val Loss: 0.0634
Val Metrics:
PESQ:   2.0001
STOI:   0.9024
eSTOI:  0.7687
SI-SDR: 17.27 dB
Saved best model with PESQ: 2.0001

Epoch 3/10
----------------------------------------------------------------------


Training: 100%|██████████| 2893/2893 [23:43<00:00,  2.03it/s, loss=0.0664]


Train Loss: 0.0557
Time: 0.0074, Mag: 0.0327, Phase: 0.4397


Validation: 100%|██████████| 824/824 [03:06<00:00,  4.41it/s, PESQ=2.043]


Val Loss: 0.0633
Val Metrics:
PESQ:   2.0430
STOI:   0.9039
eSTOI:  0.7690
SI-SDR: 15.95 dB
Saved best model with PESQ: 2.0430

Epoch 4/10
----------------------------------------------------------------------


Training: 100%|██████████| 2893/2893 [23:42<00:00,  2.03it/s, loss=0.0674]


Train Loss: 0.0549
Time: 0.0071, Mag: 0.0300, Phase: 0.4390


Validation: 100%|██████████| 824/824 [03:08<00:00,  4.38it/s, PESQ=1.900]


Val Loss: 0.0671
Val Metrics:
PESQ:   1.8998
STOI:   0.8942
eSTOI:  0.7405
SI-SDR: 15.22 dB

Epoch 5/10
----------------------------------------------------------------------


Training: 100%|██████████| 2893/2893 [23:43<00:00,  2.03it/s, loss=0.0631]


Train Loss: 0.0544
Time: 0.0070, Mag: 0.0284, Phase: 0.4385


Validation: 100%|██████████| 824/824 [03:08<00:00,  4.36it/s, PESQ=2.176]


Val Loss: 0.0619
Val Metrics:
PESQ:   2.1756
STOI:   0.9178
eSTOI:  0.7964
SI-SDR: 17.92 dB
Saved best model with PESQ: 2.1756

Epoch 6/10
----------------------------------------------------------------------


Training: 100%|██████████| 2893/2893 [23:42<00:00,  2.03it/s, loss=0.0667]


Train Loss: 0.0537
Time: 0.0067, Mag: 0.0262, Phase: 0.4379


Validation: 100%|██████████| 824/824 [03:05<00:00,  4.44it/s, PESQ=2.243]


Val Loss: 0.0620
Val Metrics:
PESQ:   2.2428
STOI:   0.9202
eSTOI:  0.7971
SI-SDR: 17.37 dB
Saved best model with PESQ: 2.2428

Epoch 7/10
----------------------------------------------------------------------


Training:  25%|██▍       | 711/2893 [05:50<17:54,  2.03it/s, loss=0.0651]


KeyboardInterrupt: 

In [31]:
def plot_spectrograms(noisy_wav, enhanced_wav, clean_wav, sr=16000, save_path=None):
    """
    Визуализирует спектрограммы noisy, enhanced и clean
    """
    fig, axes = plt.subplots(3, 1, figsize=(15, 12))

    wavs = [noisy_wav, enhanced_wav, clean_wav]
    titles = ["Noisy", "Enhanced", "Clean"]

    for i, (wav, title) in enumerate(zip(wavs, titles)):
        if isinstance(wav, torch.Tensor):
            wav = wav.cpu().numpy()

        D = librosa.stft(wav, n_fft=512, hop_length=128, win_length=512)
        D_db = librosa.amplitude_to_db(np.abs(D), ref=np.max)

        img = librosa.display.specshow(
            D_db,
            sr=sr,
            hop_length=128,
            x_axis="time",
            y_axis="hz",
            ax=axes[i],
            cmap="viridis",
        )
        axes[i].set_title(f"{title} Spectrogram", fontsize=14)
        axes[i].set_ylabel("Frequency (Hz)")
        fig.colorbar(img, ax=axes[i], format="%+2.0f dB")

    axes[-1].set_xlabel("Time (s)")
    plt.tight_layout()

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches="tight")
    plt.show()

In [32]:
def evaluate_test_set(model, test_loader, num_samples=5):
    model.eval()
    
    results = []
    sample_count = 0
    
    enhanced_dir = Path(config.results_path) / "audio_samples"
    enhanced_dir.mkdir(parents=True, exist_ok=True)
    
    print("Evaluating test set with URGENT Challenge metrics...")
    print("="*70)
    
    with torch.no_grad():
        for batch in tqdm(test_loader):
            noisy = batch['noisy'].to(device)
            clean = batch['clean'].to(device)
            filename = batch['filename'][0]
            clean_path = batch['clean_path'][0]
            noisy_path = batch['noisy_path'][0]

            enhanced_wav, _, _, _, _ = model(noisy)

            enhanced_path = enhanced_dir / filename
            torchaudio.save(
                str(enhanced_path),
                enhanced_wav[0].cpu().unsqueeze(0),
                config.sample_rate
            )

            metrics = URGENTMetrics.compute_all_metrics(
                clean_path,
                str(enhanced_path),
                sr=config.sample_rate
            )
            
            results.append({
                'filename': filename,
                **metrics
            })

            if sample_count < num_samples:
                import shutil
                shutil.copy(noisy_path, enhanced_dir / f"{Path(filename).stem}_noisy.wav")
                shutil.copy(clean_path, enhanced_dir / f"{Path(filename).stem}_clean.wav")

                plot_spectrograms(
                    noisy[0].cpu(),
                    enhanced_wav[0].cpu(),
                    clean[0].cpu(),
                    sr=config.sample_rate,
                    save_path=str(enhanced_dir / f"{Path(filename).stem}_spectrogram.png")
                )
                
                sample_count += 1
 
    results_df = pd.DataFrame(results)
    results_df.to_csv(f"{config.results_path}/urgent_test_results.csv", index=False)

    print("\n" + "="*70)
    print("URGENT CHALLENGE METRICS - Test Set Results")
    print("Model: MP-SENet with xLSTM + Mamba (Official Implementations)")
    print("="*70)
    print(f"PESQ:   {results_df['PESQ'].mean():.4f} ± {results_df['PESQ'].std():.4f}  (Target: ≥3.25)")
    print(f"STOI:   {results_df['STOI'].mean():.4f} ± {results_df['STOI'].std():.4f}")
    print(f"eSTOI:  {results_df['eSTOI'].mean():.4f} ± {results_df['eSTOI'].std():.4f}")
    print(f"SI-SDR: {results_df['SI-SDR'].mean():.2f} ± {results_df['SI-SDR'].std():.2f} dB")
    print("="*70)
    print(f"Files passing PESQ ≥ 3.25: {(results_df['PESQ'] >= 3.25).sum()}/{len(results_df)} "
          f"({(results_df['PESQ'] >= 3.25).sum() / len(results_df) * 100:.1f}%)")
    print(f"Best PESQ: {results_df['PESQ'].max():.4f}")
    print(f"Worst PESQ: {results_df['PESQ'].min():.4f}")
    print(f"Median PESQ: {results_df['PESQ'].median():.4f}")
    print("="*70)
    
    return results_df

In [33]:
def plot_training_history(history_df):
    fig, axes = plt.subplots(2, 3, figsize=(18, 10))
    fig.suptitle('MP-SENet Training History (xLSTM + Mamba)', fontsize=16, fontweight='bold')
    
    # Loss
    axes[0, 0].plot(history_df['train_loss'], label='Train Loss', linewidth=2)
    axes[0, 0].plot(history_df['val_loss'], label='Val Loss', linewidth=2)
    axes[0, 0].set_xlabel('Epoch')
    axes[0, 0].set_ylabel('Loss')
    axes[0, 0].set_title('Training and Validation Loss')
    axes[0, 0].legend()
    axes[0, 0].grid(True, alpha=0.3)
    
    # PESQ
    axes[0, 1].plot(history_df['val_pesq'], label='PESQ', color='green', linewidth=2)
    axes[0, 1].axhline(y=3.25, color='r', linestyle='--', label='Target (3.25)', linewidth=2)
    axes[0, 1].fill_between(range(len(history_df)), 3.25, history_df['val_pesq'], 
                            where=(history_df['val_pesq'] >= 3.25), alpha=0.3, color='green')
    axes[0, 1].set_xlabel('Epoch')
    axes[0, 1].set_ylabel('PESQ')
    axes[0, 1].set_title('Validation PESQ')
    axes[0, 1].legend()
    axes[0, 1].grid(True, alpha=0.3)
    
    # STOI
    axes[0, 2].plot(history_df['val_stoi'], label='STOI', color='orange', linewidth=2)
    axes[0, 2].set_xlabel('Epoch')
    axes[0, 2].set_ylabel('STOI')
    axes[0, 2].set_title('Validation STOI')
    axes[0, 2].legend()
    axes[0, 2].grid(True, alpha=0.3)
    
    # eSTOI
    axes[1, 0].plot(history_df['val_estoi'], label='eSTOI', color='purple', linewidth=2)
    axes[1, 0].set_xlabel('Epoch')
    axes[1, 0].set_ylabel('eSTOI')
    axes[1, 0].set_title('Validation eSTOI')
    axes[1, 0].legend()
    axes[1, 0].grid(True, alpha=0.3)
    
    # SI-SDR
    axes[1, 1].plot(history_df['val_sisdr'], label='SI-SDR', color='brown', linewidth=2)
    axes[1, 1].set_xlabel('Epoch')
    axes[1, 1].set_ylabel('SI-SDR (dB)')
    axes[1, 1].set_title('Validation SI-SDR')
    axes[1, 1].legend()
    axes[1, 1].grid(True, alpha=0.3)
    
    # All metrics (normalized)
    pesq_norm = (history_df['val_pesq'] - 1) / 3.5
    axes[1, 2].plot(pesq_norm, label='PESQ (norm)', linewidth=2)
    axes[1, 2].plot(history_df['val_stoi'], label='STOI', linewidth=2)
    axes[1, 2].plot(history_df['val_estoi'], label='eSTOI', linewidth=2)
    axes[1, 2].set_xlabel('Epoch')
    axes[1, 2].set_ylabel('Score (normalized)')
    axes[1, 2].set_title('All Metrics Comparison')
    axes[1, 2].legend()
    axes[1, 2].grid(True, alpha=0.3)
    
    plt.tight_layout()
    plt.savefig(f"{config.results_path}/training_curves.png", dpi=300, bbox_inches='tight')
    plt.show()

In [34]:
def analyze_model_performance(model):
    import time
    
    total_params = sum(p.numel() for p in model.parameters())
    trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
    
    print("\n" + "="*70)
    print("MODEL ARCHITECTURE ANALYSIS")
    print("="*70)
    print(f"Architecture: MP-SENet with xLSTM + Mamba")
    print(f"\nTotal parameters: {total_params / 1e6:.2f}M")
    print(f"Trainable parameters: {trainable_params / 1e6:.2f}M")
    print(f"Model size (FP32): {total_params * 4 / 1024**2:.2f} MB")
    print(f"Model size (FP16): {total_params * 2 / 1024**2:.2f} MB")

    model.eval()
    test_input = torch.randn(1, config.sample_rate * 4).to(device)  # 4 seconds

    with torch.no_grad():
        for _ in range(10):
            _ = model(test_input)

    num_runs = 100
    start_time = time.time()
    with torch.no_grad():
        for _ in range(num_runs):
            _ = model(test_input)
    
    if device.type == 'cuda':
        torch.cuda.synchronize()
    
    end_time = time.time()
    avg_time = (end_time - start_time) / num_runs
    rtf = avg_time / 4.0  # Real-Time Factor for 4s audio
    
    print("\nINFERENCE SPEED ANALYSIS")
    print(f"Average inference time: {avg_time*1000:.2f} ms (for 4s audio)")
    print(f"Throughput: {4/avg_time:.2f} seconds audio per second")
    print(f"Real-Time Factor (RTF): {rtf:.4f}")
    print(f"Can process in real-time: {'Yes' if rtf < 1.0 else 'No'}")
    
    if rtf < 1.0:
        print(f"\n  Processing speed: {1/rtf:.2f}x faster than real-time")
    else:
        print(f"\n  Processing speed: {rtf:.2f}x slower than real-time")
    
    print("="*70)

In [35]:
def comparative_analysis(results_df):
    fig = plt.figure(figsize=(16, 12))
    gs = fig.add_gridspec(3, 3, hspace=0.3, wspace=0.3)
    fig.suptitle('URGENT Challenge Metrics Analysis (xLSTM + Mamba)', 
                 fontsize=16, fontweight='bold')
    
    # PESQ distribution
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.hist(results_df['PESQ'], bins=20, edgecolor='black', alpha=0.7, color='green')
    ax1.axvline(results_df['PESQ'].mean(), color='r', linestyle='--', linewidth=2,
                label=f"Mean: {results_df['PESQ'].mean():.4f}")
    ax1.axvline(3.25, color='b', linestyle='--', linewidth=2, label='Target: 3.25')
    ax1.set_xlabel('PESQ Score')
    ax1.set_ylabel('Frequency')
    ax1.set_title('PESQ Distribution')
    ax1.legend()
    ax1.grid(True, alpha=0.3)
    
    # STOI distribution
    ax2 = fig.add_subplot(gs[0, 1])
    ax2.hist(results_df['STOI'], bins=20, edgecolor='black', alpha=0.7, color='orange')
    ax2.axvline(results_df['STOI'].mean(), color='r', linestyle='--', linewidth=2,
                label=f"Mean: {results_df['STOI'].mean():.4f}")
    ax2.set_xlabel('STOI Score')
    ax2.set_ylabel('Frequency')
    ax2.set_title('STOI Distribution')
    ax2.legend()
    ax2.grid(True, alpha=0.3)
    
    # eSTOI distribution
    ax3 = fig.add_subplot(gs[0, 2])
    ax3.hist(results_df['eSTOI'], bins=20, edgecolor='black', alpha=0.7, color='purple')
    ax3.axvline(results_df['eSTOI'].mean(), color='r', linestyle='--', linewidth=2,
                label=f"Mean: {results_df['eSTOI'].mean():.4f}")
    ax3.set_xlabel('eSTOI Score')
    ax3.set_ylabel('Frequency')
    ax3.set_title('eSTOI Distribution')
    ax3.legend()
    ax3.grid(True, alpha=0.3)
    
    # SI-SDR distribution
    ax4 = fig.add_subplot(gs[1, 0])
    ax4.hist(results_df['SI-SDR'], bins=20, edgecolor='black', alpha=0.7, color='brown')
    ax4.axvline(results_df['SI-SDR'].mean(), color='r', linestyle='--', linewidth=2,
                label=f"Mean: {results_df['SI-SDR'].mean():.2f} dB")
    ax4.set_xlabel('SI-SDR (dB)')
    ax4.set_ylabel('Frequency')
    ax4.set_title('SI-SDR Distribution')
    ax4.legend()
    ax4.grid(True, alpha=0.3)
    
    # Correlation: PESQ vs STOI
    ax5 = fig.add_subplot(gs[1, 1])
    ax5.scatter(results_df['PESQ'], results_df['STOI'], alpha=0.6, c='blue')
    ax5.set_xlabel('PESQ')
    ax5.set_ylabel('STOI')
    ax5.set_title('PESQ vs STOI Correlation')
    ax5.grid(True, alpha=0.3)
    
    # Correlation: PESQ vs SI-SDR
    ax6 = fig.add_subplot(gs[1, 2])
    ax6.scatter(results_df['PESQ'], results_df['SI-SDR'], alpha=0.6, c='red')
    ax6.set_xlabel('PESQ')
    ax6.set_ylabel('SI-SDR (dB)')
    ax6.set_title('PESQ vs SI-SDR Correlation')
    ax6.grid(True, alpha=0.3)
    
    # Box plots
    ax7 = fig.add_subplot(gs[2, :])
    box_data = [results_df['PESQ'], results_df['STOI'], 
                results_df['eSTOI'], results_df['SI-SDR']/10]  # Scale SI-SDR for visualization
    bp = ax7.boxplot(box_data, labels=['PESQ', 'STOI', 'eSTOI', 'SI-SDR/10'],
                     patch_artist=True, showmeans=True)
    colors = ['green', 'orange', 'purple', 'brown']
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)
    ax7.set_ylabel('Score')
    ax7.set_title('Metrics Box Plot Comparison')
    ax7.grid(True, alpha=0.3, axis='y')
    ax7.axhline(y=3.25, color='r', linestyle='--', alpha=0.5, label='PESQ Target')
    ax7.legend()
    
    plt.savefig(f"{config.results_path}/metrics_analysis.png", dpi=300, bbox_inches='tight')
    plt.show()
    
    # Detailed statistics
    print("\nDETAILED STATISTICS")
    print("="*70)
    print(results_df[['PESQ', 'STOI', 'eSTOI', 'SI-SDR']].describe())
    print("="*70)

In [23]:
print("=" * 70)
print("MP-SENet with xLSTM+Mamba for Speech Enhancement")
print("=" * 70)

print("\n[1] Training model...")
model, history = train_model()

MP-SENet with xLSTM+Mamba for Speech Enhancement

[1] Training model...
Loading datasets...
Loaded 11572 file pairs for train mode
Loaded 824 file pairs for test mode

Initializing MP-SENet model...
Total parameters: 6.28M

Starting training...

Epoch 1/1
----------------------------------------------------------------------


Training: 100%|██████████| 2893/2893 [23:44<00:00,  2.03it/s, loss=0.1668]


Train Loss: 0.1605
Time: 0.0232, Mag: 0.0720, Phase: 1.2979


  time_loss = F.l1_loss(pred_wav, clean_wav)
Validation:   0%|          | 0/824 [00:00<?, ?it/s]


RuntimeError: The size of tensor a (27776) must match the size of tensor b (27861) at non-singleton dimension 1

In [None]:
# Визуализация истории обучения
print("\n[2] Plotting training history...")
history_df = pd.read_csv(f"{config.results_path}/training_history.csv")
plot_training_history(history_df)

# Загружаем лучшую модель
print("\n[3] Loading best model...")
checkpoint = torch.load(f"{config.checkpoint_path}/best_model.pt")
model.load_state_dict(checkpoint["model_state_dict"])
model.to(device)

# Анализ производительности
print("\n[4] Analyzing model performance...")
analyze_model_performance(model)

# Оценка на тестовой выборке
print("\n[5] Evaluating on test set...")
test_dataset = VoiceBankDataset(
    config.test_clean_path, config.test_noisy_path, max_len=None
)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False)
results_df = evaluate_test_set(model, test_loader, num_samples=5)

# Сравнительный анализ
print("\n[6] Performing comparative analysis...")
comparative_analysis(results_df)

print("\n" + "=" * 70)
print("All tasks completed successfully!")
print(f"Results saved to: {config.results_path}")
print("=" * 70)