# Speech Enhancement using Conditional GANs

In [9]:
# System & Utilities
import os
import numpy as np
import pathlib
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

# Audio I/O & Processing
import torch
import torchaudio
import soundfile as sf
from torchaudio.transforms import MelSpectrogram, GriffinLim

# PyTorch Model Building
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader

# ASR for WER Evaluation
from transformers import Wav2Vec2ForCTC, Wav2Vec2Tokenizer
from jiwer import wer

# Demo / Deployment
import gradio as gr

## Preprocessing Pipeline

In [10]:
# Instantiate the mel spectrogram transform
mel_spec_transform = torchaudio.transforms.MelSpectrogram(sample_rate = 16000,
                                                         n_fft = 512, 
                                                         hop_length = 128, 
                                                         n_mels = 80)

In [11]:
def normalize_minus_one_to_one(data):
    x_min = np.min(data)
    x_max = np.max(data)
    normalized_data = 2 * ((data - x_min) / (x_max - x_min)) - 1
    return normalized_data

In [12]:
class AudioDataset(Dataset):
    def __init__(self, root, mode = 'train'):
        root = Path(root)
        self.clean_dir = data / 'clean_trainset_wav'
        self.noisy_dir = data / 'noisy_trainset_wav'

        # Ensure directories exist
        assert self.clean_dir.exists(), f"{self.clean_dir} not found."
        assert self.noisy_dir.exists(), f"{self.noisy_dir} not found."

        self.file_list = sorted(p.stem for p in self.clean_dir.glob("*.wav"))

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        basename = self.file_list[idx]

        clean_path = self.clean_dir / f"{basename}.wav"
        noiay_path = self.noisy_dir / f"{basename}.wav"

        # # Load and resample clean audio to 16 kHz
        clean_wav, sr = torchaudio.load(clean_path)
        if sr != 16000:
            clean_wav = torchaudio.functional.resample(clean_wav, sr, 16000)

        # Load and resample noisy audio to 16 kHz
        noisy_wav, sr = torchaudio.load(noisy_path)
        if sr != 16000:
            noisy_wav = torchaudio.functional.resample(noisy_wav, sr, 16000)

        # --- To log-Mel spectrograms ---
        clean_mel = torch.log1p(mel_spec_transform(clean_wav))
        noisy_mel = torch.log1p(mel_spec_transform(noisy_wav))

        # Normalize to roughly [-1,1]
        clean_mel_normalized = normalize_minus_one_to_one(clean_mel)
        noisy_mel_normalized = normalize_minus_one_to_one(noisy_mel)

        return noisy_mel, clean_mel, basename

In [None]:
train_ds  = Pix2PixDataset('./data', mode = 'train')
train_dl = DataLoader(train_ds, batch_size = 1, shuffle = True, num_workers = 0, pin_memory = True)