# Speech Enhancement using Conditional GANs

In [1]:
# System & Utilities
import os
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm.auto import tqdm
import random
random.seed(0)

# Audio I/O & Processing
import torch
torch.manual_seed(0)
import torchaudio
import soundfile as sf
import librosa
import librosa.display
from torchaudio.transforms import MelSpectrogram

# PyTorch Model Building
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader, Sampler
from torch.nn.utils import spectral_norm

# Demo / Deployment
import gradio as gr

  from .autonotebook import tqdm as notebook_tqdm


## Preprocessing Pipeline

In [2]:
# Instantiate the mel spectrogram transform
mel_spec_transform = torchaudio.transforms.MelSpectrogram(sample_rate = 16000,
                                                         n_fft = 512, 
                                                         hop_length = 128, 
                                                         n_mels = 80)

In [3]:
def normalize_minus_one_to_one(data):
    x_min = data.min()
    x_max = data.max()
    normalized_data = 2 * ((data - x_min) / (x_max - x_min)) - 1
    return normalized_data, x_min, x_max

In [4]:
class BucketBatchSampler(Sampler):
    def __init__(self, lengths, batch_size, bucket_size = 1000, shuffle = True):
        self.batch_size = batch_size
        self.shuffle = shuffle

        sorted_indices = []
        buckets = []

        # Build list of (length, index) pairs
        for i in range(len(lengths)):
            sorted_indices.append((lengths[i], i))

        # Sort by length
        sorted_indices.sort()

        # Keep only the indices in order
        for j in range(len(sorted_indices)):
            sorted_indices[j] = sorted_indices[j][1]

        # Break into buckets of size 'bucket_size'
        for i in range(0, len(sorted_indices), bucket_size):
            bucket = sorted_indices[i : i + bucket_size]
            buckets.append(bucket)

        self.buckets = buckets

    def __iter__(self):
        for bucket in self.buckets:
            if self.shuffle:
                random.shuffle(bucket)

            for i in range(0, len(bucket), self.batch_size):
                yield bucket[i : i + self.batch_size]

    
    def __len__(self):
        sum_batches = 0

        for b in self.buckets:
            sum_batches += (len(b) + self.batch_size - 1) // self.batch_size

        return sum_batches

In [5]:
def pad_collate(batch, T_max=512):
    noisys, cleans, stats, names = [], [], [], []

    for noisy, clean, st, name in batch:
        # Truncate
        noisy = noisy[..., :T_max]
        clean = clean[..., :T_max]

        # Pad if shorter
        pad_no = T_max - noisy.shape[-1]
        pad_cl = T_max - clean.shape[-1]

        if pad_no>0:
            noisy = F.pad(noisy, (0, pad_no))
            clean = F.pad(clean, (0, pad_cl))

        noisys.append(noisy)
        cleans.append(clean)

        stats.append(st)
        names.append(name)

    return torch.stack(noisys), torch.stack(cleans), stats, names


In [6]:
class AudioDataset(Dataset):
    def __init__(self, root, mode = 'train'):
        root = Path(root)
        self.clean_dir = root / 'clean_trainset_wav'
        self.noisy_dir = root / 'noisy_trainset_wav'

        # Ensure directories exist
        assert self.clean_dir.exists(), f"{self.clean_dir} not found."
        assert self.noisy_dir.exists(), f"{self.noisy_dir} not found."

        self.file_list = sorted(p.stem for p in self.clean_dir.glob("*.wav"))

        # Precompute mel-frame lengths for bucketing
        self.lengths = []
        for stem in self.file_list:
            wav, sr = torchaudio.load(self.clean_dir / f"{stem}.wav")
            if sr != 16000:
                wav = torchaudio.functional.resample(wav, sr, 16000)

            # compute mel (without log or norm)
            mel = mel_spec_transform(wav)  # shape [1, 80, T]
            self.lengths.append(mel.shape[-1])

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        basename = self.file_list[idx]

        clean_path = self.clean_dir / f"{basename}.wav"
        noisy_path = self.noisy_dir / f"{basename}.wav"

        # # Load and resample clean audio to 16 kHz
        clean_wav, sr = torchaudio.load(clean_path)
        if sr != 16000:
            clean_wav = torchaudio.functional.resample(clean_wav, sr, 16000)

        # Load and resample noisy audio to 16 kHz
        noisy_wav, sr = torchaudio.load(noisy_path)
        if sr != 16000:
            noisy_wav = torchaudio.functional.resample(noisy_wav, sr, 16000)

        # --- To log-Mel spectrograms ---
        clean_mel = torch.log1p(mel_spec_transform(clean_wav))
        noisy_mel = torch.log1p(mel_spec_transform(noisy_wav))

        # Normalize to roughly [-1,1]
        clean_norm, clean_mn, clean_mx = normalize_minus_one_to_one(clean_mel)
        noisy_norm, noisy_mn, noisy_mx = normalize_minus_one_to_one(noisy_mel)

        freq_pad = 128 - clean_norm.size(1)  # clean_norm shape is [1,80,T]
        if freq_pad > 0:
            # pad dims = (time_left, time_right, freq_top, freq_bottom)
            clean_norm = F.pad(clean_norm, (0, 0, 0, freq_pad))
            noisy_norm = F.pad(noisy_norm, (0, 0, 0, freq_pad))

        return (
            noisy_norm, clean_norm,
            (noisy_mn, noisy_mx, clean_mn, clean_mx),
            basename
        )

In [7]:
train_ds  = AudioDataset('./data', mode = 'train')

sampler = BucketBatchSampler(
    lengths    = train_ds.lengths,
    batch_size = 32,
    bucket_size= 32*10,   
)

train_dl = DataLoader(
    train_ds,
    batch_sampler = sampler,
    collate_fn    = lambda b: pad_collate(b, T_max=512),
    num_workers   = 4,
)

In [8]:
# noisy_batch, clean_batch, stats, names = next(iter(train_dl))
# fig, axes = plt.subplots(4, 2, figsize=(10, 12))

# for i in range(4):
#     mn_no, mx_no, mn_cl, mx_cl = stats[i]

#     noisy_log = (noisy_batch[i, 0] + 1) / 2 * (mx_no - mn_no) + mn_no
#     clean_log = (clean_batch[i, 0] + 1) / 2 * (mx_cl - mn_cl) + mn_cl

#     noisy_db = librosa.power_to_db(np.expm1(noisy_log.cpu().numpy()), ref=np.max)
#     clean_db = librosa.power_to_db(np.expm1(clean_log.cpu().numpy()), ref=np.max)

#     ax = axes[i, 0]
#     librosa.display.specshow(
#         noisy_db, sr=16000, hop_length=128,
#         x_axis='time', y_axis='mel', ax=ax
#     )
#     ax.set_title(f"Noisy ({names[i]})")

#     ax = axes[i, 1]
#     librosa.display.specshow(
#         clean_db, sr=16000, hop_length=128,
#         x_axis='time', y_axis='mel', ax=ax
#     )
#     ax.set_title(f"Clean ({names[i]})")

# fig.colorbar(axes[0,0].get_images()[0], ax=axes[:, :], format="%+2.f dB")
# plt.tight_layout()
# plt.show()

## Discriminator Network

In [9]:
class SpecPatchDiscriminator(nn.Module):
    def __init__(self, in_channels=1, base_features=64):
        super().__init__()
        # Since we concatenate two audios, the first conv sees in_channels*2
        self.model = nn.Sequential(
        # → (in_channels*2) x H x W
        spectral_norm(nn.Conv2d(in_channels * 2, base_features, kernel_size=4, stride=2, padding=1, bias=False)),
        nn.BatchNorm2d(base_features),
        nn.LeakyReLU(0.2, inplace=True),
        # → base_features x H/2 x W/2

        spectral_norm(nn.Conv2d(base_features, base_features*2, kernel_size=4, stride=2, padding=1, bias=False)),
        nn.BatchNorm2d(base_features*2),
        nn.LeakyReLU(0.2, inplace=True),
        # → (base_features*2) x H/4 x W/4

        spectral_norm(nn.Conv2d(base_features*2, base_features*4, kernel_size=4, stride=2, padding=1, bias=False)),
        nn.BatchNorm2d(base_features*4),
        nn.LeakyReLU(0.2, inplace=True),
        # → (base_features*4) x H/8 x W/8

        spectral_norm(nn.Conv2d(base_features*4, base_features*8, kernel_size=4, stride=1, padding=1, bias=False)),
        nn.BatchNorm2d(base_features*8),
        nn.LeakyReLU(0.2, inplace=True),
        # → (base_features*4) x (H/8 - 1) x (W/8 - 1)

        # final “patch” conv
        spectral_norm(nn.Conv2d(base_features*8, 1, kernel_size=4, stride=1, padding=1, bias=False)),
        # → 1 x (H/8 - 2) x (W/8 - 2)
        )

    def forward(self, spec_input, spec_target):
        # spec_input and spec_target: [B, 1, H, W]
        x = torch.cat([spec_input, spec_target], dim=1)  # → [B, 2, H, W]
        return self.model(x)

## Generator Network

In [10]:
class SpecUNetGenerator(nn.Module):
    def __init__(self, in_channels=1, out_channels=1, features=64):
        super().__init__()
        # --- ENCODER (downsampling) ---
        self.enc1 = nn.Sequential(
            nn.Conv2d(in_channels, features, 4, 2, 1, bias=False),
            nn.LeakyReLU(0.2, inplace=True)
        )                                   #  H→H/2
        self.enc2 = nn.Sequential(
            nn.Conv2d(features, features*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features*2),
            nn.LeakyReLU(0.2, inplace=True)
        )                                   #  H/2→H/4
        self.enc3 = nn.Sequential(
            nn.Conv2d(features*2, features*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features*4),
            nn.LeakyReLU(0.2, inplace=True)
        )                                   #  H/4→H/8
        self.enc4 = nn.Sequential(
            nn.Conv2d(features*4, features*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features*8),
            nn.LeakyReLU(0.2, inplace=True)
        )                                   #  H/8→H/16
        self.enc5 = nn.Sequential(
            nn.Conv2d(features*8, features*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features*8),
            nn.LeakyReLU(0.2, inplace=True)
        )                                   #  H/16→H/32
        self.enc6 = nn.Sequential(
            nn.Conv2d(features*8, features*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features*8),
            nn.LeakyReLU(0.2, inplace=True)
        )                                   #  H/32→H/64

        # --- DECODER (upsampling) ---
        self.dec1 = nn.Sequential(
            nn.ConvTranspose2d(features*8, features*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features*8),
            nn.Dropout(0.5),
            nn.ReLU(inplace=True)
        )                                   #  H/256→H/128
        self.dec2 = nn.Sequential(
            nn.ConvTranspose2d(features*8*2, features*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features*8),
            nn.Dropout(0.5),
            nn.ReLU(inplace=True)
        )                                   #  H/128→H/64
        self.dec3 = nn.Sequential(
            nn.ConvTranspose2d(features*8*2, features*8, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features*8),
            nn.Dropout(0.5),
            nn.ReLU(inplace=True)
        )                                   #  H/64→H/32
        self.dec4 = nn.Sequential(
            nn.ConvTranspose2d(features*12, features*4, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features*4),
            nn.ReLU(inplace=True)
        )                                   #  H/32→H/16
        self.dec5 = nn.Sequential(
            nn.ConvTranspose2d(features*6, features*2, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features*2),
            nn.ReLU(inplace=True)
        )                                   #  H/16→H/8
        self.dec6 = nn.Sequential(
            nn.ConvTranspose2d(features*3, features, 4, 2, 1, bias=False),
            nn.BatchNorm2d(features),
            nn.ReLU(inplace=True)
        )                                   #  H/8→H/4

        self.final = nn.Sequential(
            nn.ConvTranspose2d(features, out_channels, kernel_size = 1, bias=False),
        )                                   #  H/2→H

    def forward(self, x):
        # Encoder
        e1 = self.enc1(x)
        e2 = self.enc2(e1)
        e3 = self.enc3(e2)
        e4 = self.enc4(e3)
        e5 = self.enc5(e4)
        e6 = self.enc6(e5)

        # Decoder with skip connections
        d1 = self.dec1(e6); d1 = torch.cat([d1, e5], dim=1)
        d2 = self.dec2(d1); d2 = torch.cat([d2, e4], dim=1)
        d3 = self.dec3(d2); d3 = torch.cat([d3, e3], dim=1)
        d4 = self.dec4(d3); d4 = torch.cat([d4, e2], dim=1)
        d5 = self.dec5(d4); d5 = torch.cat([d5, e1], dim=1)
        x = self.dec6(d5)
        return self.final(x)

## Discriminator Training 

In [11]:
def train_discriminator(discriminator, generator, noisy, clean, opt_d):
    discriminator.train()
    
    # Clear discriminator gradients
    opt_d.zero_grad()

    # ——— Real pairs ———
    # D(noisy, real) should predict “real” → target=1
    real_preds = discriminator(noisy, clean)
    real_targets = torch.full_like(real_preds, 0.9)
    real_loss = F.binary_cross_entropy_with_logits(real_preds, real_targets)
    real_score = real_preds.mean().item()

    # ——— Fake pairs ———
    # Generate fake images
    # G(noisy) → fake; detach so G’s grad isn’t updated here
    fake_audios = generator(noisy).detach()
    fake_preds = discriminator(noisy, fake_audios)
    fake_targets = torch.zeros_like(fake_preds)
    fake_loss    = F.binary_cross_entropy_with_logits(fake_preds, fake_targets)
    fake_score   = fake_preds.mean().item() 


    # Update discriminator weights
    loss = real_loss + fake_loss
    loss.backward()
    opt_d.step()
    
    return loss.item(), real_score, fake_score

## Generator Training

In [12]:
def train_generator(discriminator, generator, noisy, clean, opt_g, lambda_L1 = 100):
    generator.train()

    # Clear generator gradients
    opt_g.zero_grad()

    # 1) Adverserial Loss
    fake_audio = generator(noisy)

    # Try to fool the discriminator
    preds = discriminator(noisy, fake_audio)
    targets = torch.ones_like(preds)
    adv_loss = F.binary_cross_entropy_with_logits(preds, targets)

    # 2) L1 recontruction loss
    l1_loss = F.l1_loss(fake_audio, clean)

    total_loss = adv_loss + (lambda_L1 * l1_loss)

    # Update generator weights
    total_loss.backward()
    opt_g.step()

    return total_loss.item(), adv_loss.item(), l1_loss.item()



## Saving Generated Samples

In [14]:
# Denormalize
def denorm(normed, mn, mx):
    """
    Inverts the above: takes a tensor in [-1,1] back to [mn,mx].
    """
    return (normed + 1) / 2 * (mx - mn) + mn

In [15]:
def save_audio_samples(
    index,
    noisy_batch,
    clean_batch,
    generator,
    denorm,
    stats,              
    sample_rate=16000,
    sample_dir="audio_samples",
    show=True
):
    os.makedirs(sample_dir, exist_ok=True)
    was_training = generator.training
    generator.eval()

    with torch.no_grad():
        fake_batch = generator(noisy_batch.to(next(generator.parameters()).device))

    if was_training:
        generator.train()

    for i, (noisy_mel, clean_mel, (nmn, nmx, cmn, cmx)) in enumerate(zip(noisy_batch, clean_batch, stats)):
        fake_denorm = denorm(fake_batch[i], cmn, cmx)
        lin_mel  = np.expm1(fake_denorm.cpu().numpy())
        fake_wav = librosa.feature.inverse.mel_to_audio(
            lin_mel, sr=sample_rate, hop_length=128, n_fft=512, n_iter=32
        )

        prefix = f"{index:04d}_{i}"
        sf.write(os.path.join(sample_dir, f"{prefix}_denoised.wav"), fake_wav, sample_rate)

        if show:
            db = librosa.power_to_db(lin_mel, ref=np.max)
            plt.figure(figsize=(6,2))
            librosa.display.specshow(db, sr=sample_rate, hop_length=128, y_axis="mel", x_axis="time")
            plt.title(f"Denoised ({prefix})")
            plt.colorbar(format="%+2.0f dB")
            plt.show()

    print(f"Saved audio samples for batch index {index} → `{sample_dir}`")

## Full Training Loop with ASR-Aware Training

In [16]:
def fit(
    discriminator: nn.Module,
    generator:     nn.Module,
    train_dl,
    fixed_noisy,  
    fixed_clean, 
    fixed_stats,        
    denorm,                
    device,
    epochs     = 200,
    lr         = 2e-4,
    lambda_L1  = 100,
    start_idx  = 1
):

    # Optimizers
    opt_d = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    opt_g = torch.optim.Adam(generator.parameters(),     lr=lr, betas=(0.5, 0.999))

    # History
    losses_d, losses_g = [], []
    real_scores, fake_scores = [], []

    for epoch in range(start_idx, start_idx + epochs):
        sum_d = sum_g = 0.0
        sum_real = sum_fake = 0.0
        batches = 0

        pbar = tqdm(train_dl, desc=f"Epoch {epoch}/{start_idx+epochs-1}")
        for noisy, clean, stats, names in pbar:
            noisy = noisy.to(device)
            clean = clean.to(device)

            # ——— Train D ———
            opt_d.zero_grad()
            d_loss, real_s, fake_s = train_discriminator(
                discriminator, generator,
                noisy, clean,
                opt_d
            )

            # ——— Train G ———
            opt_g.zero_grad()
            g_loss, adv_loss, l1_loss = train_generator(
                discriminator, generator,
                noisy, clean,
                opt_g,
                lambda_L1
            )

            # Accumulate stats
            sum_d    += d_loss
            sum_g    += g_loss
            sum_real += real_s
            sum_fake += fake_s
            batches  += 1

        # Compute averages
        avg_d    = sum_d    / batches
        avg_g    = sum_g    / batches
        avg_real = sum_real / batches
        avg_fake = sum_fake / batches

        # Record losses & scores
        losses_d.append(avg_d)
        losses_g.append(avg_g)
        real_scores.append(avg_real)
        fake_scores.append(avg_fake)

        # Log losses & scores
        print(
            f"Epoch [{epoch}]  "
            f"loss_g: {avg_g:.4f}, loss_d: {avg_d:.4f}, "
            f"real_score: {avg_real:.4f}, fake_score: {avg_fake:.4f}"
        )

        # Generate & Save fixed-noisy samples
        save_audio_samples(
            index       = epoch,
            noisy_batch = fixed_noisy.to(device),
            clean_batch = fixed_clean.to(device),
            generator   = generator,
            denorm      = denorm,
            stats       = fixed_stats,
            show        = False
            )

        if epoch % 5 == 0:
            torch.save(generator.state_dict(), f"checkpoint_gen_epoch{epoch}.pth")

    return losses_g, losses_d, real_scores, fake_scores

In [17]:
# Weight-initialisation helper
def init_weights(m):
    """DCGAN‐style weight init: N(0, 0.02) for Conv / BN layers."""
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.BatchNorm2d)):
        nn.init.normal_(m.weight, 0.0, 0.02)

In [18]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

generator = SpecUNetGenerator()
discriminator = SpecPatchDiscriminator()

generator.apply(init_weights) 
discriminator.apply(init_weights) 

discriminator = discriminator.to(device)
generator     = generator.to(device)

fixed_noisy, fixed_clean, fixed_stats, _= next(iter(train_dl))

fixed_noisy = fixed_noisy.to(device)
fixed_clean = fixed_clean.to(device)

history = fit(
    discriminator=discriminator,
    generator=generator,
    train_dl=train_dl,
    fixed_noisy=fixed_noisy,
    fixed_clean=fixed_clean,
    fixed_stats=fixed_stats,
    denorm=denorm,
    device=device,    
    epochs=200,
    lr=2e-4,
    lambda_L1=100,
    start_idx=1
)

losses_g, losses_d, real_scores, fake_scores = history

## Checkpointing 

In [None]:
# Save the model checkpoints 
torch.save(generator.state_dict(), 'G.pth')
torch.save(discriminator.state_dict(), 'D.pth')

In [None]:
from IPython.display import Audio, display

epochs = [1, 5, 10, 50, 100, 150, 200]

for e in epochs:
    tag = f"{e:04d}_0"
    print(f"Epoch {e}:")
    display(Audio(f"audio_samples/{tag}_denoised.wav", rate=16000))
    print()

## Plotting Loss of Generator & Discriminator

In [None]:
epochs_range = list(range(1, len(losses_d) + 1))

In [None]:
plt.clf()
plt.plot(epochs_range, losses_d, label="Discriminator")
plt.plot(epochs_range, losses_g, label="Generator")
plt.xlabel("Epoch")
plt.ylabel("Losses")
plt.title("Training Loss per Epoch")
plt.legend()
plt.grid(True)
plt.show()

## Plotting Real & Fake Scores

In [None]:
plt.clf()
plt.plot(epochs_range, real_scores, label="Real Score")
plt.plot(epochs_range, fake_scores, label="Fake Score")
plt.xlabel("Epoch")
plt.ylabel("Scores")
plt.title("Real vs Fake Scores per Epoch")
plt.legend()
plt.grid(True)
plt.show()

## User Interface

In [None]:
def enhance_audio(audio_path):
    # ensure the generator is in eval mode for inference
    generator.eval()

    wav, sr = librosa.load(audio_path, sr=16000)
    wav_t = torch.from_numpy(wav).unsqueeze(0).to(device)
    mel   = torch.log1p(mel_spec_transform(wav_t))          # [1,80,T']
    mel_norm, mn, mx = normalize_minus_one_to_one(mel)      # [1,80,T']

    # Pad frequency axis to 128 bins
    freq_pad = 128 - mel_norm.size(1)
    if freq_pad > 0:
        mel_norm = F.pad(mel_norm, (0, 0, 0, freq_pad))     # [1,128,T']

    # Truncate/pad to T_max=512
    mel_norm = mel_norm[..., :512]
    if mel_norm.shape[-1] < 512:
        pad = 512 - mel_norm.shape[-1]
        mel_norm = F.pad(mel_norm, (0, pad))

    mel_in = mel_norm.unsqueeze(0)                         # [1,1,128,512]

    with torch.no_grad():
        fake = generator(mel_in)                           # [1,1,128,512]

    fake = fake.squeeze(0).squeeze(0).cpu()                # [128,512]
    fake = denorm(fake, mn, mx)                            # back to log-Mel

    lin_mel = np.expm1(fake.numpy())
    denoised = librosa.feature.inverse.mel_to_audio(
        lin_mel, sr=sr, hop_length=128, n_fft=512, n_iter=32
    )

    return sr, denoised


In [None]:
iface = gr.Interface(
    fn=enhance_audio,
    inputs=gr.Audio(source="upload", type="filepath", label="Noisy Audio"),
    outputs=gr.Audio(type="numpy", label="Denoised Audio"),
    title="Speech Enhancement GAN",
    description="Upload a WAV file sampled at 16 kHz and get back the denoised audio."
)

In [None]:
# Launch
iface.launch()