In [None]:
import os
import math
import random
import warnings

import numpy as np
import matplotlib.pyplot as plt

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchaudio
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm



In [None]:
!pip install speechbrain

In [None]:
from google.colab import drive
drive.mount('/content/drive')


Mounted at /content/drive


In [None]:
BASE_PATH = "drive/MyDrive/organized_dataset/"
TRAIN_PATH = os.path.join(BASE_PATH, 'train')

output_base = "drive/MyDrive/FinalOutputs3/"
os.makedirs(os.path.join(output_base, "gan_generated_audio"), exist_ok=True)
os.makedirs(os.path.join(output_base, "gan_spectrogram_plots"), exist_ok=True)
os.makedirs(os.path.join(output_base, "gan_comparison_plots"), exist_ok=True)


In [None]:
SAMPLE_RATE = 22050
N_FFT = 1024
HOP_LENGTH = 256
WIN_LENGTH = 1024
N_MELS = 80
MAX_FRAMES = 352

In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LATENT_DIM = 100
EPOCHS = 300
BATCH_SIZE = 16
LR_G = 1e-4
LR_D = 2e-5
BETAS = (0.0, 0.99)
N_CRITIC = 3
LAMBDA_GP = 4


In [None]:
warnings.filterwarnings("ignore", category=UserWarning, module="torchaudio")

In [None]:
HIFI_SRC = "speechbrain/tts-hifigan-libritts-22050Hz"
HIFI_SAVE_DIR = "pretrained_models/tts-hifigan-libritts-22050Hz"

In [None]:
from speechbrain.inference.vocoders import HIFIGAN
from speechbrain.lobes.models.FastSpeech2 import mel_spectogram


In [None]:
class AudioDataset(Dataset):
    def __init__(self, root_dir, categories, max_frames=MAX_FRAMES):
        self.root_dir = root_dir
        self.categories = categories
        self.max_frames = max_frames
        self.file_list = []
        self.class_to_idx = {cat: i for i, cat in enumerate(categories)}

        for cat_name in categories:
            cat_dir = os.path.join(root_dir, cat_name)
            if not os.path.isdir(cat_dir): continue
            wavs = [os.path.join(cat_dir, f) for f in os.listdir(cat_dir) if f.endswith(".wav")]
            label_idx = self.class_to_idx[cat_name]
            self.file_list.extend([(p, label_idx) for p in wavs])
        self.sample_rate = SAMPLE_RATE
        self.hop_length = HOP_LENGTH
        self.win_length = WIN_LENGTH
        self.n_fft = N_FFT
        self.n_mels = N_MELS

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        path, label = self.file_list[idx]
        wav, sr = torchaudio.load(path)
        if wav.size(0) > 1: wav = wav.mean(dim=0, keepdim=True)
        if sr != self.sample_rate:
            wav = torchaudio.transforms.Resample(sr, self.sample_rate)(wav)
        signal = wav.squeeze(0)

        target_samples = (self.max_frames-1) * self.hop_length
        current_samples = signal.shape[0]


        if current_samples < target_samples:

           signal = F.pad(signal, (0, target_samples - current_samples))
        else:

           signal = signal[:target_samples]
        spectrogram, _ = mel_spectogram(
            audio=signal,
            sample_rate=self.sample_rate,
            hop_length=self.hop_length,
            win_length=self.win_length,
            n_mels=self.n_mels,
            n_fft=self.n_fft,
            f_min=0.0,
            f_max=8000.0,
            power=1,
            normalized=False,
            min_max_energy_norm=False,
            norm="slaney",
            mel_scale="slaney",
            compression=True
        )
        mel_input = spectrogram.unsqueeze(0)

        label_vec = F.one_hot(torch.tensor(label), num_classes=len(self.categories)).float()
        return mel_input.float(), label_vec.float()


In [None]:
class Generator(nn.Module):
    def __init__(self, latent_dim, num_categories):
        super().__init__()

        self.latent_dim = latent_dim
        self.num_categories = num_categories

        self.unflatten_shape = (256, 5, 22)
        flat_size = 256 * 5 * 22
        self.fc = nn.Linear(latent_dim + num_categories, flat_size)


        self.net = nn.Sequential(
            nn.ConvTranspose2d(256, 128, kernel_size=4, stride=2, padding=1),  # 5->10, 22->44
            nn.GroupNorm(8, 128),
            nn.LeakyReLU(0.2,inplace = True),

            nn.ConvTranspose2d(128, 64, kernel_size=4, stride=2, padding=1),   # 10->20, 44->88
            nn.GroupNorm(8, 64),
            nn.LeakyReLU(0.2,inplace = True),

            nn.ConvTranspose2d(64, 32, kernel_size=4, stride=2, padding=1),    # 20->40, 88->176
            nn.GroupNorm(8, 32),
            nn.LeakyReLU(0.2,inplace = True),

            nn.ConvTranspose2d(32, 1, kernel_size=4, stride=2, padding=1),     # 40->80, 176->35
        )

    def forward(self, z, y):

        h = torch.cat([z, y], dim=1)
        h = self.fc(h)
        h = h.view(-1, *self.unflatten_shape)
        out = self.net(h)
        out = torch.clamp(out, min=-15.0, max=1.0)          # [B, 1, 80, 352]
        return out



In [None]:

from torch.nn.utils import spectral_norm

class Discriminator(nn.Module):
    def __init__(self, num_categories, spec_shape=(80, 352)):
        super().__init__()
        H, W = spec_shape  # H=80, W=352

        self.label_embedding = nn.Linear(num_categories, H * W)


        self.net = nn.Sequential(

            spectral_norm(nn.Conv2d(2, 32, kernel_size=4, stride=2, padding=1)),
            nn.LeakyReLU(0.2, inplace=True),


            spectral_norm(nn.Conv2d(32, 64, kernel_size=4, stride=2, padding=1)),
            nn.GroupNorm(8, 64),
            nn.LeakyReLU(0.2, inplace=True),


            spectral_norm(nn.Conv2d(64, 128, kernel_size=4, stride=2, padding=1)),
            nn.GroupNorm(8, 128),
            nn.LeakyReLU(0.2, inplace=True),


            spectral_norm(nn.Conv2d(128, 256, kernel_size=4, stride=2, padding=1)),
            nn.GroupNorm(8, 256),
            nn.LeakyReLU(0.2, inplace=True),


            spectral_norm(nn.Conv2d(256, 1, kernel_size=(5, 22), stride=1, padding=0)),
        )

    def forward(self, spec, y):

        B = spec.size(0)

        label_map = self.label_embedding(y).view(B, 1, *spec.shape[2:])  # [B,1,80,352]
        label_map = label_map * 0.1  # keep conditional channel small vs spec energy

        x = torch.cat([spec, label_map], dim=1)      # [B,2,80,352]
        out = self.net(x)                            # [B,1,1,1]
        return out.view(-1, 1)                       # [B,1]


In [None]:
def weights_init(m):
    if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d, nn.Linear)):
        nn.init.normal_(m.weight, 0.0, 0.02)
        if m.bias is not None:
            nn.init.constant_(m.bias, 0)

In [None]:

hifi_gan = HIFIGAN.from_hparams(source=HIFI_SRC, savedir=HIFI_SAVE_DIR, run_opts={"device": DEVICE})
hifi_gan.to(DEVICE).eval()


In [None]:
def generated_mel_to_waveform(mel_tensor, hifi_gan_model, device=DEVICE):

    if mel_tensor.dim() == 4 and mel_tensor.shape[1] == 1:
        mel = mel_tensor.squeeze(1)
    elif  mel_tensor.dim() == 3 and mel_tensor.shape[0] == 1:
        mel = mel_tensor
    else: raise ValueError(f"Unexpected mel shape: {mel_tensor.shape}. ")


    mel = mel.to(device, dtype=torch.float32)


    with torch.no_grad():
        wav = hifi_gan_model.decode_batch(mel)
    if isinstance(wav, (tuple, list)):
        wav = wav[0]


    wav = wav.squeeze(1).detach().cpu()

    if wav.dim() == 1:
        wav = wav.unsqueeze(0)

    return wav

In [None]:
def compute_gradient_penalty(D, real_samples, fake_samples, labels, device, lambda_gp = LAMBDA_GP):
    batch_size = real_samples.size(0)
    epsilon = torch.rand(batch_size, 1, 1, 1, device=device)
    interpolates = (epsilon * real_samples + (1 - epsilon) * fake_samples).requires_grad_(True)
    d_interpolates = D(interpolates, labels)
    ones = torch.ones_like(d_interpolates, device=device)

    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=ones,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    gradients = gradients.view(batch_size, -1)
    grad_norm = gradients.norm(2, dim=1)
    gp = lambda_gp * ((grad_norm - 1) ** 2).mean()
    return gp

In [None]:
def train_gan(generator, discriminator, dataloader, device, categories, epochs, latent_dim ,output_base,vocoder,lr_g,lr_d ,real_samples_per_cat):

    opt_G = torch.optim.Adam(generator.parameters(), lr=lr_g, betas=BETAS)
    opt_D = torch.optim.Adam(discriminator.parameters(), lr=lr_d, betas=BETAS)


    #  checkpoint paths
    ckpt_path = os.path.join(output_base, "checkpoint_last.pt")


    start_epoch = 1
    if os.path.exists(ckpt_path):
        ckpt = torch.load(ckpt_path, map_location=device)
        generator.load_state_dict(ckpt['gen'])
        discriminator.load_state_dict(ckpt['disc'])
        opt_G.load_state_dict(ckpt['optG'])
        opt_D.load_state_dict(ckpt['optD'])
        start_epoch = ckpt.get('epoch', 1) + 1
        print("Resumed from checkpoint epoch", start_epoch-1)

    for epoch in range(start_epoch, epochs + 1):
        loop = tqdm(dataloader, desc=f"Epoch {epoch}/{epochs}")
        for batch_i, (real_specs, labels) in enumerate(loop):
            real_specs, labels = real_specs.to(device), labels.to(device)
            B = real_specs.size(0)

            # Train discriminator N_CRITIC times
            for _ in range(N_CRITIC):
                z = torch.randn(B, latent_dim, device=device)
                fake_specs = generator(z, labels)

                opt_D.zero_grad()
                real_valid = discriminator(real_specs, labels)
                fake_valid = discriminator(fake_specs.detach(), labels)

                loss_D = -(real_valid.mean() - fake_valid.mean())
                gp = compute_gradient_penalty(discriminator, real_specs, fake_specs.detach(), labels, device)
                (loss_D + gp).backward()
                opt_D.step()

            # Train generator
            opt_G.zero_grad()
            z = torch.randn(B, latent_dim, device=device)
            fake_specs = generator(z, labels)
            fake_valid = discriminator(fake_specs, labels)
            loss_G = -fake_valid.mean()

            lambda_l1 = 1.0
            l1_term = lambda_l1 * F.l1_loss(fake_specs, real_specs)
            loss_G = loss_G + l1_term

            loss_G.backward()
            opt_G.step()
            loop.set_postfix({
            'loss_D': f"{loss_D.item():.2f}",
             'loss_G': f"{loss_G.item():.2f}",
             'GP': f"{gp.item():.2f}",
            'L1': f"{l1_term.item():.2f}",
             'D_real': f"{real_valid.mean().item():.2f}",
              'D_fake': f"{fake_valid.mean().item():.2f}"
              })

        #  save checkpoint
        torch.save({
            'epoch': epoch,
            'gen': generator.state_dict(),
            'disc': discriminator.state_dict(),
            'optG': opt_G.state_dict(),
            'optD': opt_D.state_dict()
        }, ckpt_path)


        if epoch%10 == 0:
          generator.eval()
          with torch.no_grad():

            num_cats = len(categories)
            torch.manual_seed(42)
            z_fixed = torch.randn(num_cats, latent_dim, device=device)
            y_fixed = F.one_hot(torch.arange(num_cats), num_classes=num_cats).float().to(device)
            specs_gen = generator(z_fixed, y_fixed)

            for i, cat in enumerate(categories):
               spec_gen = specs_gen[i].squeeze().cpu()  # (80, 352)
               plt.figure(figsize=(6,3))
               plt.imshow(spec_gen.numpy(), aspect='auto', origin='lower', cmap='viridis')
               plt.title(f"{cat} - epoch {epoch}")
               plt.axis('off')
               plt.colorbar()
               plt.tight_layout()
               plt.savefig(os.path.join(output_base, "gan_spectrogram_plots", f"{cat}_epoch{epoch}.png"))
               plt.close()



               wav = generated_mel_to_waveform(specs_gen[i].unsqueeze(0), vocoder, device=device)  # (1,1,80,352) -> (1, time)
               out_path = os.path.join(output_base, "gan_generated_audio", f"{cat}_epoch{epoch}.wav")
               torchaudio.save(out_path, wav.unsqueeze(0) if wav.dim() == 1 else wav, SAMPLE_RATE)
               print("[SAVED]", out_path)



               print(f"\n Category: {cat}, Epoch: {epoch}")
               print("spec_gen shape", spec_gen.shape, "min/max/mean/std:", spec_gen.min().item(), spec_gen.max().item(), spec_gen.mean().item(), spec_gen.std().item())

                   # Comparison plot
            for i, cat in enumerate(categories):
                    real_path = real_samples_per_cat.get(i)
                    if real_path is not None:

                        real_wav, sr = torchaudio.load(real_path)
                        if real_wav.size(0) > 1: real_wav = real_wav.mean(dim=0, keepdim=True)
                        if sr != SAMPLE_RATE: real_wav = torchaudio.transforms.Resample(sr, SAMPLE_RATE)(real_wav)
                        real_signal = real_wav.squeeze(0)


                        target_samples = ( MAX_FRAMES - 1) * HOP_LENGTH
                        if real_signal.shape[0] < target_samples:
                            real_signal = F.pad(real_signal, (0, target_samples - real_signal.shape[0]))
                        else:
                            real_signal = real_signal[:target_samples]


                        mel_real, _ = mel_spectogram(
                            audio=real_signal,
                            sample_rate=SAMPLE_RATE,
                            hop_length=HOP_LENGTH,
                            win_length=WIN_LENGTH,
                            n_mels=N_MELS,
                            n_fft=N_FFT,
                            f_min=0.0, f_max=8000.0,
                            power=1,
                            normalized=False,
                            min_max_energy_norm=False,
                            norm="slaney",
                            mel_scale="slaney",
                            compression=True
                        )
                        logmel_real = mel_real.squeeze().numpy()

                        spec_gen_np = specs_gen[i].squeeze().cpu().numpy()

                        fig, ax = plt.subplots(1, 2, figsize=(10, 4))
                        im0 = ax[0].imshow(logmel_real, aspect='auto', origin='lower', cmap='viridis')
                        ax[0].set_title(f"Real {cat}")
                        ax[0].axis('off')
                        plt.colorbar(im0, ax=ax[0])

                        im1 = ax[1].imshow(spec_gen_np, aspect='auto', origin='lower', cmap='viridis')
                        ax[1].set_title(f"Gen {cat} (epoch {epoch})")
                        ax[1].axis('off')
                        plt.colorbar(im1, ax=ax[1])

                        plt.tight_layout()
                        comp_path = os.path.join(output_base, "gan_comparison_plots", f"{cat}_cmp_epoch{epoch}.png")
                        plt.savefig(comp_path)
                        plt.close()

            generator.train()


In [None]:

if __name__ == "__main__":

    SEED = 42
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)

    if not os.path.isdir(TRAIN_PATH):
        raise RuntimeError(f"TRAIN_PATH '{TRAIN_PATH}' not found.")
    train_categories = sorted([d for d in os.listdir(TRAIN_PATH) if os.path.isdir(os.path.join(TRAIN_PATH, d))])
    print(f"Found {len(train_categories)} categories: {train_categories}")

    dataset = AudioDataset(TRAIN_PATH, train_categories)
    dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2, drop_last=True)

    # Cache one real path per category for fast comparisons
    real_samples_per_cat = {}
    for path, label in dataset.file_list:
     if label not in real_samples_per_cat:
      real_samples_per_cat[label] = path



    sample, _ = dataset[0]
    print("Sample stats:")
    print(f"min: {sample.min().item():.3f}, max: {sample.max().item():.3f}, mean: {sample.mean().item():.3f}, std: {sample.std().item():.3f}")

    gen = Generator(LATENT_DIM, len(train_categories)).to(DEVICE)
    disc = Discriminator(len(train_categories)).to(DEVICE)
    gen.apply(weights_init)
    disc.apply(weights_init)

    z = torch.randn(2, LATENT_DIM).to(DEVICE)
    y = F.one_hot(torch.tensor([0, 1 % len(train_categories)]), num_classes=len(train_categories)).float().to(DEVICE)
    fake = gen(z, y)
    print("Generator output shape:", fake.shape)
    out_d = disc(fake, y)
    print("Discriminator output shape:", out_d.shape)

    train_gan(gen, disc, dataloader, DEVICE, train_categories, EPOCHS, LATENT_DIM, output_base,vocoder=hifi_gan, lr_g = LR_G,lr_d = LR_D,real_samples_per_cat=real_samples_per_cat)
