<a href="https://colab.research.google.com/github/MayankKhoria2007/Decibel-Duel-solution/blob/main/AUDIOGENERATIONCGAN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import files
uploaded = files.upload()


In [None]:
!mkdir -p ~/.kaggle
!cp kaggle.json ~/.kaggle/
!chmod 600 ~/.kaggle/kaggle.json


In [None]:
!kaggle datasets download -d mayankkhoria/frequencytrain


In [None]:
!unzip /content/frequencytrain.zip -d /content/frequencytrain



In [None]:
# ============================
# INSTALLS
# ============================
!pip install torch torchaudio matplotlib tqdm soundfile

# ============================
# IMPORTS
# ============================
import os
import random
from pathlib import Path
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio
from torchaudio.transforms import MelSpectrogram, TimeMasking, FrequencyMasking
from tqdm import tqdm
import matplotlib.pyplot as plt
import soundfile as sf
from IPython.display import Audio, display


In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# ============================
# CONFIG
# ============================
DATASET_TRAIN_PATH = "/content/frequencytrain/train"
OUTPUT_DIR = "/content/gan_outputs"

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

N_MELS = 128
MAX_FRAMES = 512
LATENT_DIM = 100
BATCH_SIZE = 8
EPOCHS = 100
LR = 2e-4
SAMPLE_RATE = 22050
SAMPLES_PER_CLASS = 2

os.makedirs(OUTPUT_DIR, exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "audio"), exist_ok=True)
os.makedirs(os.path.join(OUTPUT_DIR, "specs"), exist_ok=True)


In [None]:
# ============================
# DATASET CLASS
# ============================
class TrainAudioSpectrogramDataset(Dataset):
    def __init__(self, root_dir, categories, max_frames=512, n_mels=80):
        self.root_dir = root_dir
        self.categories = categories
        self.max_frames = max_frames
        self.n_mels = n_mels
        self.file_list = []
        self.class_to_idx = {cat: i for i, cat in enumerate(categories)}

        for cat_name in self.categories:
            cat_dir = os.path.join(root_dir, cat_name)
            if not os.path.isdir(cat_dir):
                continue
            files = [
                os.path.join(cat_dir, f)
                for f in os.listdir(cat_dir)
                if f.lower().endswith(".wav")
            ]
            for f in files:
                self.file_list.append((f, self.class_to_idx[cat_name]))

        self.mel_transform = MelSpectrogram(
            sample_rate=SAMPLE_RATE,
            n_fft=1024,
            hop_length=256,
            n_mels=self.n_mels
        )
        self.time_mask = TimeMasking(time_mask_param=40)
        self.freq_mask = FrequencyMasking(freq_mask_param=12)

    def __len__(self):
        return len(self.file_list)

    def __getitem__(self, idx):
        path, label = self.file_list[idx]
        wav, sr = torchaudio.load(path)

        if sr != SAMPLE_RATE:
            wav = torchaudio.functional.resample(wav, sr, SAMPLE_RATE)

        if wav.size(0) > 1:
            wav = wav.mean(dim=0, keepdim=True)

        mel = self.mel_transform(wav)
        log_mel = torch.log1p(mel)

        _, _, frames = log_mel.shape
        if frames < self.max_frames:
            log_mel = F.pad(log_mel, (0, self.max_frames - frames))
        else:
            log_mel = log_mel[:, :, :self.max_frames]

        log_mel = self.freq_mask(log_mel)
        log_mel = self.time_mask(log_mel)

        label_vec = F.one_hot(torch.tensor(label), num_classes=len(self.categories)).float()
        return log_mel, label_vec


In [None]:
# ============================
# GENERATOR
# ============================
class CGAN_Generator(nn.Module):
    def __init__(self, latent_dim, num_categories, spec_shape=(128, 512)):
        super().__init__()
        H, W = spec_shape
        self.fc = nn.Linear(latent_dim + num_categories, 256 * 8 * 16)
        self.unflatten_shape = (256, 8, 16)

        self.net = nn.Sequential(
            nn.ConvTranspose2d(256, 128, 4, 2, 1),
            nn.BatchNorm2d(128), nn.ReLU(True),

            nn.ConvTranspose2d(128, 64, 4, 2, 1),
            nn.BatchNorm2d(64), nn.ReLU(True),

            nn.ConvTranspose2d(64, 32, 4, 2, 1),
            nn.BatchNorm2d(32), nn.ReLU(True),

            nn.ConvTranspose2d(32, 16, 4, 2, 1),
            nn.BatchNorm2d(16), nn.ReLU(True),

            nn.ConvTranspose2d(16, 1, kernel_size=(1, 2), stride=(1, 2)),
            nn.ReLU()
        )

    def forward(self, z, y):
        h = torch.cat([z, y], dim=1)
        h = self.fc(h)
        h = h.view(-1, *self.unflatten_shape)
        return self.net(h)


In [None]:
# ============================
# DISCRIMINATOR
# ============================
class CGAN_Discriminator(nn.Module):
    def __init__(self, num_categories, spec_shape=(128, 512)):
        super().__init__()
        H, W = spec_shape

        self.label_embedding = nn.Linear(num_categories, H * W)

        self.net = nn.Sequential(
            nn.Conv2d(2, 32, 4, 2, 1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv2d(32, 64, 4, 2, 1),
            nn.BatchNorm2d(64), nn.LeakyReLU(0.2),

            nn.Conv2d(64, 128, 4, 2, 1),
            nn.BatchNorm2d(128), nn.LeakyReLU(0.2),

            nn.Conv2d(128, 256, 4, 2, 1),
            nn.BatchNorm2d(256), nn.LeakyReLU(0.2),

            nn.Conv2d(256, 1, kernel_size=(8, 32), stride=1, padding=0)
        )

    def forward(self, spec, y):
        label_map = self.label_embedding(y).view(-1, 1,128,512).to(spec.device)
        h = torch.cat([spec, label_map], dim=1)
        logits = self.net(h)
        return logits.view(-1, 1)


In [None]:
# ============================
# HIFI-GAN LOADING
# ============================
print("Loading HiFi-GAN from PyTorch Hub...")
hifigan = torch.hub.load("bshall/hifigan:main", "hifigan", pretrained=True)
hifigan = hifigan.to(DEVICE).eval()
print("HiFi-GAN loaded.")

def mel_to_audio_hifi(log_spec_tensor):
    mel = torch.expm1(log_spec_tensor.squeeze(1))
    mel = mel.to(DEVICE)
    with torch.no_grad():
        wav = hifigan(mel)
    return wav.cpu()


In [None]:
# ============================
# DATASET & LOADER
# ============================
train_categories = sorted(
    [d for d in os.listdir(DATASET_TRAIN_PATH)
     if os.path.isdir(os.path.join(DATASET_TRAIN_PATH, d))]
)
NUM_CATEGORIES = len(train_categories)

print("Classes:", train_categories)

dataset = TrainAudioSpectrogramDataset(
    DATASET_TRAIN_PATH,
    train_categories,
    max_frames=MAX_FRAMES,
    n_mels=N_MELS
)

loader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=2)


Classes: ['dog_bark', 'drilling', 'engine_idling', 'siren', 'street_music']


In [None]:
# ============================
# MODELS & OPTIMIZERS
# ============================
generator = CGAN_Generator(LATENT_DIM, NUM_CATEGORIES).to(DEVICE)
discriminator = CGAN_Discriminator(NUM_CATEGORIES).to(DEVICE)

optG = torch.optim.Adam(generator.parameters(), lr=LR, betas=(0.5, 0.999))
optD = torch.optim.Adam(discriminator.parameters(), lr=LR, betas=(0.5, 0.999))

criterion = nn.BCEWithLogitsLoss()


In [None]:
# ============================
# TRAINING LOOP
# ============================
for epoch in range(1, EPOCHS+1):
    generator.train(); discriminator.train()
    loop = tqdm(loader, desc=f"Epoch {epoch}/{EPOCHS}")

    for real_specs, labels in loop:
        real_specs = real_specs.to(DEVICE)
        labels = labels.to(DEVICE)
        b = real_specs.size(0)

        real = torch.ones(b,1, device=DEVICE)
        fake = torch.zeros(b,1, device=DEVICE)

        # --- Train D ---
        optD.zero_grad()

        real_out = discriminator(real_specs, labels)
        loss_real = criterion(real_out, real)

        z = torch.randn(b, LATENT_DIM, device=DEVICE)
        fake_specs = generator(z, labels)

        fake_out = discriminator(fake_specs.detach(), labels)
        loss_fake = criterion(fake_out, fake)

        loss_D = loss_real + loss_fake
        loss_D.backward()
        optD.step()

        # --- Train G ---
        optG.zero_grad()
        out = discriminator(fake_specs, labels)
        loss_G = criterion(out, real)
        loss_G.backward()
        optG.step()

        loop.set_postfix(D=loss_D.item(), G=loss_G.item())

    # ============================
    # GENERATE AND SAVE SAMPLES
    # ============================
    print("\nGenerating samples...")
    generator.eval()

    for cat_idx, cat_name in enumerate(train_categories):
        for i in range(SAMPLES_PER_CLASS):
            z = torch.randn(1, LATENT_DIM, device=DEVICE)
            y = F.one_hot(torch.tensor([cat_idx]), NUM_CATEGORIES).float().to(DEVICE)
            with torch.no_grad():
                spec = generator(z, y)

            # save spectrogram
            spec_np = spec.squeeze().cpu().numpy()
            plt.imshow(spec_np, aspect='auto', origin='lower')
            plt.title(f"{cat_name} ep{epoch} s{i}")
            plt.axis('off')
            plt.savefig(f"{OUTPUT_DIR}/specs/{cat_name}_ep{epoch}_s{i}.png",
                        bbox_inches='tight', pad_inches=0)
            plt.close()

            # save audio
            wav = mel_to_audio_hifi(spec)
            sf.write(f"{OUTPUT_DIR}/audio/{cat_name}_ep{epoch}_s{i}.wav",
                     wav.squeeze().numpy(),
                     SAMPLE_RATE)

    print(f"Epoch {epoch} done.\n")
