# Hybrid Vision Transformer for Speech Enhancement - Complete Demo

**Just click 'Run All' and wait for the complete demonstration!**

This notebook will:
1. Install all dependencies
2. Download the dataset automatically
3. Define the complete model architecture
4. Train the model
5. Evaluate performance
6. Demonstrate audio enhancement

**Estimated time**: 30-60 minutes (depending on GPU availability)

## Step 1: Install Dependencies

In [None]:
%%capture
# Install required packages
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install librosa soundfile matplotlib seaborn tqdm einops
!pip install kagglehub

print("✓ All dependencies installed!")

## Step 2: Import Libraries and Setup

In [None]:
import os
import sys
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import librosa
import soundfile as sf
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from IPython.display import Audio, display
from tqdm.notebook import tqdm
from einops import rearrange
import warnings
warnings.filterwarnings('ignore')

# Set style
plt.style.use('seaborn-v0_8-whitegrid')
sns.set_palette('husl')

# Check device
device = 'cuda' if torch.cuda.is_available() else 'cpu'
print(f"PyTorch version: {torch.__version__}")
print(f"Device: {device}")
if device == 'cuda':
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")

# Set random seeds
torch.manual_seed(42)
np.random.seed(42)
if device == 'cuda':
    torch.cuda.manual_seed_all(42)

print("\n✓ Environment setup complete!")

## Step 3: Download Dataset Automatically

In [None]:
import kagglehub
import shutil

data_root = Path('demo_data')
data_root.mkdir(exist_ok=True)

try:
    print("Attempting to download dataset from Kaggle...")
    dataset_options = [
        "saurabhshahane/valentini-noisy-speech-database",
        "muhammadtayyab007/speech-enhancement-dataset",
    ]
    
    dataset_path = None
    for dataset_name in dataset_options:
        try:
            print(f"  Trying {dataset_name}...")
            dataset_path = kagglehub.dataset_download(dataset_name)
            print(f"  ✓ Downloaded to: {dataset_path}")
            break
        except:
            continue
    
    if dataset_path:
        dataset_path = Path(dataset_path)
        audio_files = list(dataset_path.rglob('*.wav'))
        
        if len(audio_files) > 0:
            print(f"\n✓ Found {len(audio_files)} audio files!")
            
            noisy_dir = data_root / 'noisy'
            clean_dir = data_root / 'clean'
            noisy_dir.mkdir(exist_ok=True)
            clean_dir.mkdir(exist_ok=True)
            
            noisy_files = [f for f in audio_files if 'noisy' in str(f).lower()]
            clean_files = [f for f in audio_files if 'clean' in str(f).lower()]
            
            max_files = min(50, len(noisy_files), len(clean_files))
            
            for i in range(max_files):
                shutil.copy(noisy_files[i], noisy_dir / f'sample_{i:03d}.wav')
                shutil.copy(clean_files[i], clean_dir / f'sample_{i:03d}.wav')
            
            print(f"✓ Prepared {max_files} audio pairs")
        else:
            raise Exception("No audio files found")
    else:
        raise Exception("Could not download")
        
except Exception as e:
    print(f"\nKaggle download failed: {e}")
    print("Creating synthetic dataset...\n")
    
    noisy_dir = data_root / 'noisy'
    clean_dir = data_root / 'clean'
    noisy_dir.mkdir(exist_ok=True)
    clean_dir.mkdir(exist_ok=True)
    
    sr = 16000
    duration = 2.0
    num_samples = 30
    
    print(f"Generating {num_samples} synthetic samples...")
    for i in tqdm(range(num_samples), desc="Creating samples"):
        t = np.linspace(0, duration, int(sr * duration))
        fundamental = 200 + np.random.randn() * 50
        
        clean = np.zeros_like(t)
        for harmonic in range(1, 8):
            amplitude = 1.0 / harmonic
            phase = np.random.rand() * 2 * np.pi
            clean += amplitude * np.sin(2 * np.pi * fundamental * harmonic * t + phase)
        
        envelope = 0.5 + 0.5 * np.sin(2 * np.pi * 5 * t)
        clean = clean * envelope
        clean = clean / np.abs(clean).max() * 0.8
        
        noise = np.random.randn(len(clean)) * 0.15
        noisy = clean + noise
        
        sf.write(clean_dir / f'sample_{i:03d}.wav', clean, sr)
        sf.write(noisy_dir / f'sample_{i:03d}.wav', noisy, sr)
    
    print(f"\n✓ Created {num_samples} synthetic audio pairs")

noisy_files = sorted(list((data_root / 'noisy').glob('*.wav')))
clean_files = sorted(list((data_root / 'clean').glob('*.wav')))

print(f"\n{'='*60}")
print(f"Dataset Summary:")
print(f"  Noisy files: {len(noisy_files)}")
print(f"  Clean files: {len(clean_files)}")
print(f"{'='*60}")

## Step 4: Define Hybrid Vision Transformer Architecture

In [None]:
# Model Components

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, dropout=0.1):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.head_dim = embed_dim // num_heads
        self.scale = self.head_dim ** -0.5
        
        self.qkv = nn.Linear(embed_dim, embed_dim * 3)
        self.proj = nn.Linear(embed_dim, embed_dim)
        self.dropout = nn.Dropout(dropout)
    
    def forward(self, x):
        B, N, C = x.shape
        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = attn.softmax(dim=-1)
        attn = self.dropout(attn)
        
        x = (attn @ v).transpose(1, 2).reshape(B, N, C)
        x = self.proj(x)
        x = self.dropout(x)
        return x


class FeedForward(nn.Module):
    def __init__(self, dim, hidden_dim, dropout=0.1):
        super().__init__()
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.GELU(),
            nn.Dropout(dropout),
            nn.Linear(hidden_dim, dim),
            nn.Dropout(dropout)
        )
    
    def forward(self, x):
        return self.net(x)


class TransformerBlock(nn.Module):
    def __init__(self, embed_dim, num_heads, mlp_ratio=4.0, dropout=0.1):
        super().__init__()
        self.norm1 = nn.LayerNorm(embed_dim)
        self.attn = MultiHeadAttention(embed_dim, num_heads, dropout)
        self.norm2 = nn.LayerNorm(embed_dim)
        self.mlp = FeedForward(embed_dim, int(embed_dim * mlp_ratio), dropout)
    
    def forward(self, x):
        x = x + self.attn(self.norm1(x))
        x = x + self.mlp(self.norm2(x))
        return x


class HybridViT(nn.Module):
    """Hybrid Vision Transformer for Speech Enhancement"""
    
    def __init__(
        self,
        input_channels=1,
        output_channels=1,
        encoder_channels=[32, 64, 128],
        embed_dim=256,
        num_heads=4,
        num_layers=3,
        mlp_ratio=4.0,
        patch_size=4,
        dropout=0.1,
    ):
        super().__init__()
        
        # CNN Encoder
        self.encoder = nn.ModuleList()
        in_ch = input_channels
        for out_ch in encoder_channels:
            self.encoder.append(nn.Sequential(
                nn.Conv2d(in_ch, out_ch, 3, padding=1),
                nn.BatchNorm2d(out_ch),
                nn.ReLU(inplace=True),
                nn.MaxPool2d(2) if out_ch != encoder_channels[-1] else nn.Identity()
            ))
            in_ch = out_ch
        
        # Patch Embedding
        self.patch_embed = nn.Conv2d(encoder_channels[-1], embed_dim, patch_size, patch_size)
        
        # Positional Encoding
        self.pos_embed = nn.Parameter(torch.zeros(1, 1000, embed_dim))
        
        # Transformer
        self.transformer = nn.ModuleList([
            TransformerBlock(embed_dim, num_heads, mlp_ratio, dropout)
            for _ in range(num_layers)
        ])
        
        self.norm = nn.LayerNorm(embed_dim)
        
        # Projection back
        self.to_features = nn.Linear(embed_dim, encoder_channels[-1])
        
        # CNN Decoder
        self.decoder = nn.ModuleList([
            nn.Sequential(
                nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv2d(128, 64, 3, padding=1),
                nn.BatchNorm2d(64),
                nn.ReLU(inplace=True)
            ),
            nn.Sequential(
                nn.Upsample(scale_factor=2, mode='nearest'),
                nn.Conv2d(64, 32, 3, padding=1),
                nn.BatchNorm2d(32),
                nn.ReLU(inplace=True)
            ),
            nn.Sequential(
                nn.Conv2d(32, output_channels, 3, padding=1),
                nn.Sigmoid()
            )
        ])
    
    def forward(self, x):
        # Save input shape
        input_shape = x.shape[2:]
        
        # Encoder
        for block in self.encoder:
            x = block(x)
        
        # Patch Embedding
        x = self.patch_embed(x)  # [B, embed_dim, H', W']
        B, C, H, W = x.shape
        x = x.flatten(2).transpose(1, 2)  # [B, N, embed_dim]
        
        # Add positional encoding
        x = x + self.pos_embed[:, :x.size(1), :]
        
        # Transformer
        for block in self.transformer:
            x = block(x)
        
        x = self.norm(x)
        
        # Project back
        x = self.to_features(x)
        x = x.transpose(1, 2).reshape(B, -1, H, W)
        
        # Decoder
        for block in self.decoder:
            x = block(x)
        
        # Resize to input shape
        if x.shape[2:] != input_shape:
            x = F.interpolate(x, size=input_shape, mode='bilinear', align_corners=False)
        
        return x

print("✓ Model architecture defined!")

## Step 5: Create Dataset Loader

In [None]:
from torch.utils.data import Dataset, DataLoader, random_split

class AudioDataset(Dataset):
    def __init__(self, noisy_dir, clean_dir, sr=16000, n_fft=512, hop_length=128):
        self.noisy_files = sorted(list(Path(noisy_dir).glob('*.wav')))
        self.clean_files = sorted(list(Path(clean_dir).glob('*.wav')))
        self.sr = sr
        self.n_fft = n_fft
        self.hop_length = hop_length
    
    def __len__(self):
        return len(self.noisy_files)
    
    def __getitem__(self, idx):
        noisy, _ = librosa.load(self.noisy_files[idx], sr=self.sr)
        clean, _ = librosa.load(self.clean_files[idx], sr=self.sr)
        
        min_len = min(len(noisy), len(clean))
        noisy = noisy[:min_len]
        clean = clean[:min_len]
        
        noisy_stft = librosa.stft(noisy, n_fft=self.n_fft, hop_length=self.hop_length)
        clean_stft = librosa.stft(clean, n_fft=self.n_fft, hop_length=self.hop_length)
        
        noisy_mag = np.abs(noisy_stft)
        clean_mag = np.abs(clean_stft)
        
        # Use log-scale normalization for better audio processing
        noisy_mag_db = librosa.amplitude_to_db(noisy_mag, ref=np.max)
        clean_mag_db = librosa.amplitude_to_db(clean_mag, ref=np.max)
        
        # Normalize to [0, 1] range
        noisy_mag_norm = (noisy_mag_db + 80) / 80  # Assumes -80dB to 0dB range
        clean_mag_norm = (clean_mag_db + 80) / 80
        
        noisy_mag_norm = np.clip(noisy_mag_norm, 0, 1)
        clean_mag_norm = np.clip(clean_mag_norm, 0, 1)
        
        noisy_mag_norm = torch.from_numpy(noisy_mag_norm).float().unsqueeze(0)
        clean_mag_norm = torch.from_numpy(clean_mag_norm).float().unsqueeze(0)
        
        return noisy_mag_norm, clean_mag_norm

dataset = AudioDataset(data_root / 'noisy', data_root / 'clean')
train_size = int(0.8 * len(dataset))
val_size = len(dataset) - train_size
train_dataset, val_dataset = random_split(dataset, [train_size, val_size])

train_loader = DataLoader(train_dataset, batch_size=4, shuffle=True, num_workers=0)
val_loader = DataLoader(val_dataset, batch_size=4, shuffle=False, num_workers=0)

print(f"✓ Dataset: {len(train_dataset)} train, {len(val_dataset)} val")

## Step 6: Visualize Sample

In [None]:
sample_noisy, sample_clean = dataset[0]

fig, axes = plt.subplots(1, 2, figsize=(14, 4))
im1 = axes[0].imshow(librosa.amplitude_to_db(sample_noisy.squeeze()), aspect='auto', origin='lower', cmap='viridis')
axes[0].set_title('Noisy Spectrogram', fontweight='bold')
plt.colorbar(im1, ax=axes[0])

im2 = axes[1].imshow(librosa.amplitude_to_db(sample_clean.squeeze()), aspect='auto', origin='lower', cmap='viridis')
axes[1].set_title('Clean Spectrogram', fontweight='bold')
plt.colorbar(im2, ax=axes[1])
plt.tight_layout()
plt.show()

## Step 7: Create Model

In [None]:
model = HybridViT(
    input_channels=1,
    output_channels=1,
    encoder_channels=[32, 64, 128],
    embed_dim=256,
    num_heads=4,
    num_layers=3,
    mlp_ratio=4.0,
    patch_size=4,
    dropout=0.1,
).to(device)

total_params = sum(p.numel() for p in model.parameters())
print(f"Model Parameters: {total_params:,}")

# Test
with torch.no_grad():
    test_input = torch.randn(1, 1, 257, 100).to(device)
    test_output = model(test_input)
    print(f"Test: {test_input.shape} → {test_output.shape}")
    print("✓ Model ready!")

## Step 8: Train Model

In [None]:
num_epochs = 50
criterion = nn.L1Loss()
optimizer = torch.optim.AdamW(model.parameters(), lr=0.001, weight_decay=0.01)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=num_epochs)

train_losses = []
val_losses = []
best_val_loss = float('inf')

print(f"Starting training for {num_epochs} epochs...\n")

for epoch in range(num_epochs):
    # Train
    model.train()
    train_loss = 0.0
    
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}', leave=False)
    for noisy, clean in pbar:
        noisy, clean = noisy.to(device), clean.to(device)
        optimizer.zero_grad()
        output = model(noisy)
        loss = criterion(output, clean)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
        optimizer.step()
        train_loss += loss.item()
        pbar.set_postfix({'loss': f'{loss.item():.4f}'})
    
    train_loss /= len(train_loader)
    train_losses.append(train_loss)
    
    # Validate
    model.eval()
    val_loss = 0.0
    with torch.no_grad():
        for noisy, clean in val_loader:
            noisy, clean = noisy.to(device), clean.to(device)
            output = model(noisy)
            loss = criterion(output, clean)
            val_loss += loss.item()
    
    val_loss /= len(val_loader)
    val_losses.append(val_loss)
    scheduler.step()
    
    print(f"Epoch {epoch+1}/{num_epochs} - Train: {train_loss:.4f}, Val: {val_loss:.4f}")
    
    if val_loss < best_val_loss:
        best_val_loss = val_loss
        torch.save(model.state_dict(), 'best_model.pth')

print(f"\n✓ Training complete! Best val loss: {best_val_loss:.4f}")

## Step 9: Plot Training

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(train_losses, label='Train', linewidth=2, marker='o', markersize=4)
plt.plot(val_losses, label='Val', linewidth=2, marker='s', markersize=4)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Progress', fontweight='bold')
plt.legend()
plt.grid(True, alpha=0.3)
plt.show()

print(f"Improvement: {((train_losses[0]-train_losses[-1])/train_losses[0]*100):.1f}%")

## Step 10: Test Enhancement

In [None]:
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

test_noisy_file = noisy_files[0]
test_clean_file = clean_files[0]

noisy_audio, sr = librosa.load(test_noisy_file, sr=16000)
clean_audio, sr = librosa.load(test_clean_file, sr=16000)

# Enhance
noisy_stft = librosa.stft(noisy_audio, n_fft=512, hop_length=128)
noisy_mag = np.abs(noisy_stft)
noisy_phase = np.angle(noisy_stft)

# Apply same normalization as training
noisy_mag_db = librosa.amplitude_to_db(noisy_mag, ref=np.max)
noisy_mag_norm = (noisy_mag_db + 80) / 80
noisy_mag_norm = np.clip(noisy_mag_norm, 0, 1)

input_tensor = torch.from_numpy(noisy_mag_norm).float().unsqueeze(0).unsqueeze(0).to(device)

with torch.no_grad():
    enhanced_mag_norm = model(input_tensor).squeeze().cpu().numpy()

# Denormalize: reverse the normalization
enhanced_mag_db = (enhanced_mag_norm * 80) - 80
enhanced_mag = librosa.db_to_amplitude(enhanced_mag_db)

# Reconstruct with original phase
enhanced_stft = enhanced_mag * np.exp(1j * noisy_phase)
enhanced_audio = librosa.istft(enhanced_stft, hop_length=128, length=len(noisy_audio))

print("✓ Enhancement complete!")

## Step 11: Visualize Results

In [None]:
clean_stft = librosa.stft(clean_audio, n_fft=512, hop_length=128)
clean_mag = np.abs(clean_stft)

fig, axes = plt.subplots(3, 2, figsize=(15, 10))

axes[0,0].plot(noisy_audio[:sr], linewidth=0.5)
axes[0,0].set_title('Noisy Waveform', fontweight='bold')
axes[0,0].grid(True, alpha=0.3)

im1 = axes[0,1].imshow(librosa.amplitude_to_db(noisy_mag), aspect='auto', origin='lower', cmap='viridis')
axes[0,1].set_title('Noisy Spectrogram', fontweight='bold')
plt.colorbar(im1, ax=axes[0,1])

axes[1,0].plot(clean_audio[:sr], linewidth=0.5)
axes[1,0].set_title('Clean Waveform', fontweight='bold')
axes[1,0].grid(True, alpha=0.3)

im2 = axes[1,1].imshow(librosa.amplitude_to_db(clean_mag), aspect='auto', origin='lower', cmap='viridis')
axes[1,1].set_title('Clean Spectrogram', fontweight='bold')
plt.colorbar(im2, ax=axes[1,1])

axes[2,0].plot(enhanced_audio[:sr], linewidth=0.5)
axes[2,0].set_title('Enhanced Waveform', fontweight='bold')
axes[2,0].grid(True, alpha=0.3)

im3 = axes[2,1].imshow(librosa.amplitude_to_db(enhanced_mag), aspect='auto', origin='lower', cmap='viridis')
axes[2,1].set_title('Enhanced Spectrogram', fontweight='bold')
plt.colorbar(im3, ax=axes[2,1])

plt.tight_layout()
plt.show()

## Step 12: Audio Playback

In [None]:
print("Noisy Audio:")
display(Audio(noisy_audio, rate=sr))

print("\nClean Audio:")
display(Audio(clean_audio, rate=sr))

print("\nEnhanced Audio:")
display(Audio(enhanced_audio, rate=sr))

## Step 13: Compute Metrics

In [None]:
def compute_snr(clean, enhanced):
    noise = enhanced - clean
    signal_power = np.mean(clean ** 2)
    noise_power = np.mean(noise ** 2)
    return 10 * np.log10(signal_power / (noise_power + 1e-10))

min_len = min(len(clean_audio), len(enhanced_audio), len(noisy_audio))
clean_audio = clean_audio[:min_len]
enhanced_audio = enhanced_audio[:min_len]
noisy_audio = noisy_audio[:min_len]

snr_noisy = compute_snr(clean_audio, noisy_audio)
snr_enhanced = compute_snr(clean_audio, enhanced_audio)
improvement = snr_enhanced - snr_noisy

print("="*60)
print("RESULTS")
print("="*60)
print(f"SNR Noisy:      {snr_noisy:8.2f} dB")
print(f"SNR Enhanced:   {snr_enhanced:8.2f} dB")
print(f"Improvement:    {improvement:8.2f} dB")
print("="*60)
print("\n✓ Demo Complete!")