In [2]:
# Imports,
import os
import torch
import torchaudio
import random
import librosa
import numpy as np
from torch import nn, optim
from torch.utils.data import Dataset, DataLoader
from torchaudio.transforms import Resample
from tqdm import tqdm
import matplotlib.pyplot as plt

In [3]:
# Configuration
processed_dir = "VocalSet_processed"
splits = ['train', 'val', 'test']
sample_rate = 22050
duration = 4.0  # seconds
data = {}  # To store file paths or audio data

# Access and optionally load files
for split in splits:
    split_path = os.path.join(processed_dir, split)
    file_paths = [os.path.join(split_path, f) for f in os.listdir(split_path) if f.endswith('.wav')]
    
    # Optionally: Load audio files
    audio_data = []
    for path in file_paths:
        try:
            y, _ = librosa.load(path, sr=sample_rate)
            audio_data.append((path, y))  # or just y if you don’t need the path
        except Exception as e:
            print(f"Could not load {path}: {e}")
    
    data[split] = audio_data  # Contains (path, audio_array) for each file

# Example usage
print(f"Train files loaded: {len(data['train'])}")
print(f"Validation files loaded: {len(data['val'])}")
print(f"Test files loaded: {len(data['test'])}")


Train files loaded: 2890
Validation files loaded: 361
Test files loaded: 362


In [2]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import librosa
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import soundfile as sf

# CONFIG
DATA_DIR = 'VocalSet_processed'
SAMPLE_RATE = 22050
DURATION = 4.0
AUDIO_LENGTH = int(SAMPLE_RATE * DURATION)
BATCH_SIZE = 16
EPOCHS = 100
LATENT_DIM = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OUTPUT_DIR = "gan_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ==== Dataset ====
class AudioDataset(Dataset):
    def __init__(self, split_dir):
        self.files = [os.path.join(split_dir, f) for f in os.listdir(split_dir) if f.endswith(".wav")]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        audio, _ = librosa.load(path, sr=SAMPLE_RATE)
        audio = librosa.util.fix_length(audio, size=AUDIO_LENGTH)
        return torch.tensor(audio, dtype=torch.float32)

# ==== Models ====
class Generator(nn.Module):
    def __init__(self, latent_dim, output_size):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 2048),
            nn.BatchNorm1d(2048),
            nn.LeakyReLU(0.2),
            nn.Linear(2048, output_size),
            nn.Tanh()
        )


    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(input_size, 2048),
            nn.LeakyReLU(0.2),
            nn.Linear(2048, 1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 1)
        )

    def forward(self, x):
        return self.model(x)

# ==== DataLoader ====
train_dataset = AudioDataset(os.path.join(DATA_DIR, "train"))
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# ==== Initialize Models ====
generator = Generator(LATENT_DIM, AUDIO_LENGTH).to(DEVICE)
discriminator = Discriminator(AUDIO_LENGTH).to(DEVICE)

# ==== Optimizers and Loss ====
criterion = nn.BCEWithLogitsLoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002)
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002)

# ==== Training Loop ====
for epoch in range(EPOCHS):
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        batch = batch.to(DEVICE)

        real_labels = torch.full((BATCH_SIZE, 1), 0.9).to(DEVICE)  # label smoothing
        fake_labels = torch.zeros(BATCH_SIZE, 1).to(DEVICE)

        # === Train Discriminator ===
        optimizer_D.zero_grad()
        outputs_real = discriminator(batch)
        loss_real = criterion(outputs_real, real_labels)

        z = torch.randn(BATCH_SIZE, LATENT_DIM).to(DEVICE)
        fake_audio = generator(z)
        outputs_fake = discriminator(fake_audio.detach())
        loss_fake = criterion(outputs_fake, fake_labels)

        loss_D = loss_real + loss_fake
        loss_D.backward()
        torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
        optimizer_D.step()

        # === Train Generator ===
        optimizer_G.zero_grad()
        outputs = discriminator(fake_audio)
        loss_G = criterion(outputs, real_labels)
        loss_G.backward()
        optimizer_G.step()

    print(f"Epoch [{epoch+1}/{EPOCHS}] Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}")

    # === Save sample audio ===
    if (epoch + 1) % 10 == 0:
        generator.eval()
        with torch.no_grad():
            z = torch.randn(1, LATENT_DIM).to(DEVICE)
            gen_audio = generator(z).cpu().numpy().flatten()
            sf.write(f"{OUTPUT_DIR}/sample_epoch_{epoch+1}.wav", gen_audio, SAMPLE_RATE)
        generator.train()

# ==== Save models ====
torch.save(generator.state_dict(), os.path.join(OUTPUT_DIR, "generator.pth"))
torch.save(discriminator.state_dict(), os.path.join(OUTPUT_DIR, "discriminator.pth"))
print("Models saved.")


Epoch 1/100: 100%|██████████| 180/180 [17:55<00:00,  5.97s/it]


Epoch [1/100] Loss_D: 2.7504, Loss_G: 82.0034


Epoch 2/100: 100%|██████████| 180/180 [17:07<00:00,  5.71s/it]


Epoch [2/100] Loss_D: 0.4604, Loss_G: 84.0246


Epoch 3/100: 100%|██████████| 180/180 [17:20<00:00,  5.78s/it]


Epoch [3/100] Loss_D: 0.3689, Loss_G: 87.2727


Epoch 4/100: 100%|██████████| 180/180 [17:37<00:00,  5.87s/it]


Epoch [4/100] Loss_D: 0.6819, Loss_G: 37.9613


Epoch 5/100: 100%|██████████| 180/180 [16:53<00:00,  5.63s/it]


Epoch [5/100] Loss_D: 0.4746, Loss_G: 71.3485


Epoch 6/100: 100%|██████████| 180/180 [15:48<00:00,  5.27s/it]


Epoch [6/100] Loss_D: 0.3510, Loss_G: 50.3391


Epoch 7/100: 100%|██████████| 180/180 [15:06<00:00,  5.03s/it]


Epoch [7/100] Loss_D: 0.3363, Loss_G: 33.9157


Epoch 8/100: 100%|██████████| 180/180 [15:02<00:00,  5.02s/it]


Epoch [8/100] Loss_D: 0.3288, Loss_G: 43.1696


Epoch 9/100: 100%|██████████| 180/180 [14:54<00:00,  4.97s/it]


Epoch [9/100] Loss_D: 0.6036, Loss_G: 27.3989


Epoch 10/100: 100%|██████████| 180/180 [14:59<00:00,  5.00s/it]


Epoch [10/100] Loss_D: 0.4364, Loss_G: 24.2004


Epoch 11/100: 100%|██████████| 180/180 [14:50<00:00,  4.95s/it]


Epoch [11/100] Loss_D: 0.3326, Loss_G: 38.5049


Epoch 12/100: 100%|██████████| 180/180 [16:20<00:00,  5.45s/it]


Epoch [12/100] Loss_D: 0.3397, Loss_G: 36.6177


Epoch 13/100: 100%|██████████| 180/180 [17:09<00:00,  5.72s/it]


Epoch [13/100] Loss_D: 0.4921, Loss_G: 26.5129


Epoch 14/100: 100%|██████████| 180/180 [16:59<00:00,  5.66s/it]


Epoch [14/100] Loss_D: 0.3331, Loss_G: 50.8038


Epoch 15/100: 100%|██████████| 180/180 [16:52<00:00,  5.63s/it]


Epoch [15/100] Loss_D: 0.3377, Loss_G: 25.5413


Epoch 16/100: 100%|██████████| 180/180 [16:44<00:00,  5.58s/it]


Epoch [16/100] Loss_D: 0.3448, Loss_G: 30.4880


Epoch 17/100: 100%|██████████| 180/180 [16:46<00:00,  5.59s/it]


Epoch [17/100] Loss_D: 0.3399, Loss_G: 31.6621


Epoch 18/100: 100%|██████████| 180/180 [17:13<00:00,  5.74s/it]


Epoch [18/100] Loss_D: 0.3312, Loss_G: 55.7919


Epoch 19/100: 100%|██████████| 180/180 [15:20<00:00,  5.11s/it]


Epoch [19/100] Loss_D: 0.3308, Loss_G: 35.0517


Epoch 20/100: 100%|██████████| 180/180 [15:34<00:00,  5.19s/it]


Epoch [20/100] Loss_D: 0.3348, Loss_G: 31.7132


Epoch 21/100: 100%|██████████| 180/180 [15:36<00:00,  5.20s/it]


Epoch [21/100] Loss_D: 0.3759, Loss_G: 38.1981


Epoch 22/100: 100%|██████████| 180/180 [15:54<00:00,  5.30s/it]


Epoch [22/100] Loss_D: 0.3486, Loss_G: 36.9954


Epoch 23/100: 100%|██████████| 180/180 [15:49<00:00,  5.28s/it]


Epoch [23/100] Loss_D: 0.3389, Loss_G: 20.1535


Epoch 24/100: 100%|██████████| 180/180 [15:34<00:00,  5.19s/it]


Epoch [24/100] Loss_D: 0.4897, Loss_G: 52.1396


Epoch 25/100: 100%|██████████| 180/180 [15:29<00:00,  5.17s/it]


Epoch [25/100] Loss_D: 0.3370, Loss_G: 34.2022


Epoch 26/100: 100%|██████████| 180/180 [15:28<00:00,  5.16s/it]


Epoch [26/100] Loss_D: 0.3623, Loss_G: 33.4491


Epoch 27/100: 100%|██████████| 180/180 [15:28<00:00,  5.16s/it]


Epoch [27/100] Loss_D: 0.3262, Loss_G: 28.0464


Epoch 28/100: 100%|██████████| 180/180 [15:28<00:00,  5.16s/it]


Epoch [28/100] Loss_D: 0.3279, Loss_G: 33.7441


Epoch 29/100: 100%|██████████| 180/180 [15:26<00:00,  5.15s/it]


Epoch [29/100] Loss_D: 0.3298, Loss_G: 24.1679


Epoch 30/100: 100%|██████████| 180/180 [15:21<00:00,  5.12s/it]


Epoch [30/100] Loss_D: 0.3266, Loss_G: 22.3231


Epoch 31/100: 100%|██████████| 180/180 [15:13<00:00,  5.08s/it]


Epoch [31/100] Loss_D: 0.3271, Loss_G: 30.2627


Epoch 32/100: 100%|██████████| 180/180 [15:09<00:00,  5.05s/it]


Epoch [32/100] Loss_D: 0.3271, Loss_G: 45.3598


Epoch 33/100: 100%|██████████| 180/180 [15:16<00:00,  5.09s/it]


Epoch [33/100] Loss_D: 0.3325, Loss_G: 18.4624


Epoch 34/100: 100%|██████████| 180/180 [16:44<00:00,  5.58s/it]


Epoch [34/100] Loss_D: 0.3363, Loss_G: 25.4705


Epoch 35/100: 100%|██████████| 180/180 [16:45<00:00,  5.59s/it]


Epoch [35/100] Loss_D: 0.3278, Loss_G: 19.0532


Epoch 36/100: 100%|██████████| 180/180 [17:38<00:00,  5.88s/it]


Epoch [36/100] Loss_D: 0.3270, Loss_G: 17.0098


Epoch 37/100: 100%|██████████| 180/180 [18:11<00:00,  6.06s/it]


Epoch [37/100] Loss_D: 0.3280, Loss_G: 24.8119


Epoch 38/100: 100%|██████████| 180/180 [18:54<00:00,  6.30s/it]


Epoch [38/100] Loss_D: 0.3279, Loss_G: 30.3532


Epoch 39/100: 100%|██████████| 180/180 [17:52<00:00,  5.96s/it]


Epoch [39/100] Loss_D: 0.3299, Loss_G: 36.0546


Epoch 40/100: 100%|██████████| 180/180 [15:48<00:00,  5.27s/it]


Epoch [40/100] Loss_D: 0.3254, Loss_G: 28.9557


Epoch 41/100: 100%|██████████| 180/180 [17:02<00:00,  5.68s/it]


Epoch [41/100] Loss_D: 0.3266, Loss_G: 36.5267


Epoch 42/100: 100%|██████████| 180/180 [19:46<00:00,  6.59s/it]


Epoch [42/100] Loss_D: 0.3273, Loss_G: 28.1142


Epoch 43/100: 100%|██████████| 180/180 [22:20<00:00,  7.45s/it]


Epoch [43/100] Loss_D: 0.3340, Loss_G: 20.0828


Epoch 44/100: 100%|██████████| 180/180 [18:15<00:00,  6.08s/it]


Epoch [44/100] Loss_D: 0.3260, Loss_G: 33.2990


Epoch 45/100:  26%|██▌       | 47/180 [05:05<14:24,  6.50s/it]


KeyboardInterrupt: 

In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import torchaudio
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import soundfile as sf
from torch.cuda.amp import autocast, GradScaler

# === CONFIG ===
DATA_DIR = 'VocalSet_processed'
SAMPLE_RATE = 16000  # Reduced sample rate
DURATION = 2.0  # Reduced duration to 2 seconds
AUDIO_LENGTH = int(SAMPLE_RATE * DURATION)
BATCH_SIZE = 32
EPOCHS = 50
LATENT_DIM = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OUTPUT_DIR = "gan_outputs_improved"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# === Dataset ===
class AudioDataset(Dataset):
    def __init__(self, split_dir):
        self.files = [os.path.join(split_dir, f) for f in os.listdir(split_dir) if f.endswith(".wav")]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        waveform, sr = torchaudio.load(path)
        waveform = torch.mean(waveform, dim=0)  # mono
        if sr != SAMPLE_RATE:
            waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=SAMPLE_RATE)
        waveform = torch.nn.functional.pad(waveform, (0, max(0, AUDIO_LENGTH - waveform.size(0))))
        waveform = waveform[:AUDIO_LENGTH]
        return waveform

# === Generator ===
class Generator(nn.Module):
    def __init__(self, latent_dim):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 256 * 64),
            nn.ReLU(True),
            nn.Unflatten(1, (256, 64)),
            nn.ConvTranspose1d(256, 128, 25, stride=4, padding=11, output_padding=1),
            nn.BatchNorm1d(128),
            nn.ReLU(True),
            nn.ConvTranspose1d(128, 64, 25, stride=4, padding=11, output_padding=1),
            nn.BatchNorm1d(64),
            nn.ReLU(True),
            nn.ConvTranspose1d(64, 1, 25, stride=2, padding=12, output_padding=1),
            nn.Tanh()
        )

    def forward(self, z):
        x = self.model(z)
        x = x.squeeze(1)
        if x.size(1) < AUDIO_LENGTH:
            x = torch.nn.functional.pad(x, (0, AUDIO_LENGTH - x.size(1)))
        return x[:, :AUDIO_LENGTH]

# === Discriminator ===
class Discriminator(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv1d(1, 32, 25, stride=4, padding=11),  # Reduced number of channels
            nn.LeakyReLU(0.2),
            nn.Conv1d(32, 64, 25, stride=4, padding=11),
            nn.LeakyReLU(0.2),
            nn.Conv1d(64, 128, 25, stride=4, padding=11),
            nn.LeakyReLU(0.2),
        )
        
        # Compute output size dynamically
        with torch.no_grad():
            dummy_input = torch.randn(1, 1, AUDIO_LENGTH)
            out = self.features(dummy_input)
            self.flattened_size = out.view(1, -1).size(1)

        self.classifier = nn.Linear(self.flattened_size, 1)

    def forward(self, x):
        x = x.unsqueeze(1)  # [B, 1, AUDIO_LENGTH]
        x = self.features(x)
        x = x.view(x.size(0), -1)
        return self.classifier(x)

# === Gradient Penalty ===
def gradient_penalty(D, real, fake):
    batch_size = real.size(0)
    alpha = torch.rand(batch_size, 1).to(DEVICE)
    alpha = alpha.expand_as(real)

    interpolates = alpha * real + (1 - alpha) * fake
    interpolates = interpolates.requires_grad_(True)

    d_interpolates = D(interpolates)
    ones = torch.ones_like(d_interpolates).to(DEVICE)

    gradients = torch.autograd.grad(
        outputs=d_interpolates,
        inputs=interpolates,
        grad_outputs=ones,
        create_graph=True,
        retain_graph=True,
        only_inputs=True
    )[0]

    gradients = gradients.view(batch_size, -1)
    penalty = ((gradients.norm(2, dim=1) - 1) ** 2).mean()
    return penalty

# === DataLoader ===
train_dataset = AudioDataset(os.path.join(DATA_DIR, "train"))
train_loader = DataLoader(
    train_dataset, 
    batch_size=BATCH_SIZE, 
    shuffle=True, 
    drop_last=True, 
    num_workers=8,  # Increased number of workers
    pin_memory=True  # Enable memory pinning for faster transfer to GPU
)

# === Model Initialization ===
generator = Generator(LATENT_DIM).to(DEVICE)
discriminator = Discriminator().to(DEVICE)

# === Optimizers ===
optimizer_G = optim.Adam(generator.parameters(), lr=1e-4, betas=(0.5, 0.9))
optimizer_D = optim.Adam(discriminator.parameters(), lr=1e-4, betas=(0.5, 0.9))

# === Mixed Precision Training ===
scaler = GradScaler()

# === Training Loop ===
lambda_gp = 10
n_critic = 1  # Reduced from 2 to 1 to speed up training

for epoch in range(EPOCHS):
    for i, batch in enumerate(tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}")):
        real_audio = batch.to(DEVICE)

        # Train Discriminator
        z = torch.randn(BATCH_SIZE, LATENT_DIM).to(DEVICE)
        fake_audio = generator(z).detach()

        for _ in range(n_critic):
            with autocast():
                d_real = discriminator(real_audio)
                d_fake = discriminator(fake_audio)
                gp = gradient_penalty(discriminator, real_audio, fake_audio)
                loss_D = -(torch.mean(d_real) - torch.mean(d_fake)) + lambda_gp * gp

            scaler.scale(loss_D).backward()
            scaler.step(optimizer_D)
            scaler.update()

        # Train Generator
        z = torch.randn(BATCH_SIZE, LATENT_DIM).to(DEVICE)
        fake_audio = generator(z)

        with autocast():
            loss_G = -torch.mean(discriminator(fake_audio))

        scaler.scale(loss_G).backward()
        scaler.step(optimizer_G)
        scaler.update()

    print(f"Epoch {epoch+1}/{EPOCHS}, Loss_D: {loss_D.item():.4f}, Loss_G: {loss_G.item():.4f}")

    if (epoch + 1) % 10 == 0:
        generator.eval()
        with torch.no_grad():
            z = torch.randn(1, LATENT_DIM).to(DEVICE)
            gen_audio = generator(z).cpu().numpy().flatten()
            sf.write(f"{OUTPUT_DIR}/sample_epoch_{epoch+1}.wav", gen_audio, SAMPLE_RATE)
        generator.train()

# === Save Models ===
torch.save(generator.state_dict(), os.path.join(OUTPUT_DIR, "generator.pth"))
torch.save(discriminator.state_dict(), os.path.join(OUTPUT_DIR, "discriminator.pth"))
print("Models saved.")


  scaler = GradScaler()


In [None]:
import os
import torch
import torch.nn as nn
import torch.optim as optim
import librosa
import numpy as np
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
import soundfile as sf
import random

# CONFIG
DATA_DIR = 'VocalSet_processed'
SAMPLE_RATE = 22050
DURATION = 4.0
AUDIO_LENGTH = int(SAMPLE_RATE * DURATION)
BATCH_SIZE = 16
EPOCHS = 100
LATENT_DIM = 100
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
OUTPUT_DIR = "gan_outputs_improved"
os.makedirs(OUTPUT_DIR, exist_ok=True)

# ==== Dataset ====
class AudioDataset(Dataset):
    def __init__(self, split_dir):
        self.files = [os.path.join(split_dir, f) for f in os.listdir(split_dir) if f.endswith(".wav")]

    def __len__(self):
        return len(self.files)

    def __getitem__(self, idx):
        path = self.files[idx]
        audio, _ = librosa.load(path, sr=SAMPLE_RATE)
        audio = librosa.util.fix_length(audio, size=AUDIO_LENGTH)
        return torch.tensor(audio, dtype=torch.float32)

# ==== Models ====
class Generator(nn.Module):
    def __init__(self, latent_dim, output_size):
        super().__init__()
        self.model = nn.Sequential(
            nn.Linear(latent_dim, 1024),
            nn.BatchNorm1d(1024),
            nn.LeakyReLU(0.2),
            nn.Linear(1024, 2048),
            nn.BatchNorm1d(2048),
            nn.LeakyReLU(0.2),
            nn.Linear(2048, output_size),
            nn.Tanh()
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.model:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

    def forward(self, z):
        return self.model(z)

class Discriminator(nn.Module):
    def __init__(self, input_size):
        super().__init__()
        self.model = nn.Sequential(
            nn.utils.spectral_norm(nn.Linear(input_size, 2048)),
            nn.LeakyReLU(0.2),
            nn.utils.spectral_norm(nn.Linear(2048, 1024)),
            nn.LeakyReLU(0.2),
            nn.utils.spectral_norm(nn.Linear(1024, 1))
        )
        self._init_weights()

    def _init_weights(self):
        for m in self.model:
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)

    def forward(self, x):
        return self.model(x)

# ==== DataLoader ====
train_dataset = AudioDataset(os.path.join(DATA_DIR, "train"))
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)

# ==== Initialize Models ====
generator = Generator(LATENT_DIM, AUDIO_LENGTH).to(DEVICE)
discriminator = Discriminator(AUDIO_LENGTH).to(DEVICE)

# ==== Optimizers and Loss ====
criterion = nn.BCEWithLogitsLoss()
optimizer_G = optim.Adam(generator.parameters(), lr=0.0002, betas=(0.5, 0.999))
optimizer_D = optim.Adam(discriminator.parameters(), lr=0.0002, betas=(0.5, 0.999))

# Learning Rate Schedulers
scheduler_G = optim.lr_scheduler.StepLR(optimizer_G, step_size=30, gamma=0.5)
scheduler_D = optim.lr_scheduler.StepLR(optimizer_D, step_size=30, gamma=0.5)

# ==== Training Loop ====
for epoch in range(EPOCHS):
    total_loss_D = 0.0
    total_loss_G = 0.0

    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}/{EPOCHS}"):
        batch = batch.to(DEVICE)

        # Label smoothing and flipping
        real_labels = torch.full((BATCH_SIZE, 1), 0.9).to(DEVICE)
        fake_labels = torch.zeros((BATCH_SIZE, 1)).to(DEVICE)

        if random.random() < 0.05:  # Randomly flip labels (5% chance)
            real_labels, fake_labels = fake_labels, real_labels

        # === Train Discriminator ===
        optimizer_D.zero_grad()
        outputs_real = discriminator(batch)
        loss_real = criterion(outputs_real, real_labels)

        z = torch.randn(BATCH_SIZE, LATENT_DIM).to(DEVICE)
        fake_audio = generator(z)
        outputs_fake = discriminator(fake_audio.detach())
        loss_fake = criterion(outputs_fake, fake_labels)

        loss_D = loss_real + loss_fake
        loss_D.backward()
        torch.nn.utils.clip_grad_norm_(discriminator.parameters(), 1.0)
        optimizer_D.step()

        # === Train Generator ===
        optimizer_G.zero_grad()
        outputs = discriminator(fake_audio)
        loss_G = criterion(outputs, real_labels)
        loss_G.backward()
        torch.nn.utils.clip_grad_norm_(generator.parameters(), 1.0)
        optimizer_G.step()

        total_loss_D += loss_D.item()
        total_loss_G += loss_G.item()

    scheduler_G.step()
    scheduler_D.step()

    avg_loss_D = total_loss_D / len(train_loader)
    avg_loss_G = total_loss_G / len(train_loader)

    print(f"Epoch [{epoch+1}/{EPOCHS}] Avg Loss_D: {avg_loss_D:.4f}, Avg Loss_G: {avg_loss_G:.4f}")

    # === Save sample audio ===
    if (epoch + 1) % 10 == 0:
        generator.eval()
        with torch.no_grad():
            z = torch.randn(1, LATENT_DIM).to(DEVICE)
            gen_audio = generator(z).cpu().numpy().flatten()
            sf.write(f"{OUTPUT_DIR}/sample_epoch_{epoch+1}.wav", gen_audio, SAMPLE_RATE)
        generator.train()

# ==== Save models ====
torch.save(generator.state_dict(), os.path.join(OUTPUT_DIR, "generator.pth"))
torch.save(discriminator.state_dict(), os.path.join(OUTPUT_DIR, "discriminator.pth"))
print("Models saved.")


Epoch 1/100: 100%|██████████| 180/180 [49:19<00:00, 16.44s/it]


Epoch [1/100] Avg Loss_D: 0.9690, Avg Loss_G: 6.4710


Epoch 2/100: 100%|██████████| 180/180 [38:38<00:00, 12.88s/it]


Epoch [2/100] Avg Loss_D: 0.8485, Avg Loss_G: 6.8569


Epoch 3/100: 100%|██████████| 180/180 [34:34<00:00, 11.52s/it]


Epoch [3/100] Avg Loss_D: 0.9552, Avg Loss_G: 5.0825


Epoch 4/100: 100%|██████████| 180/180 [29:56<00:00,  9.98s/it]


Epoch [4/100] Avg Loss_D: 0.6479, Avg Loss_G: 5.7771


Epoch 5/100: 100%|██████████| 180/180 [29:25<00:00,  9.81s/it]


Epoch [5/100] Avg Loss_D: 0.7573, Avg Loss_G: 5.2758


Epoch 6/100: 100%|██████████| 180/180 [29:23<00:00,  9.80s/it]


Epoch [6/100] Avg Loss_D: 0.7926, Avg Loss_G: 4.9114


Epoch 7/100: 100%|██████████| 180/180 [33:51<00:00, 11.28s/it]


Epoch [7/100] Avg Loss_D: 0.8733, Avg Loss_G: 4.3537


Epoch 8/100: 100%|██████████| 180/180 [34:22<00:00, 11.46s/it]


Epoch [8/100] Avg Loss_D: 0.9979, Avg Loss_G: 3.5559


Epoch 9/100: 100%|██████████| 180/180 [30:43<00:00, 10.24s/it]


Epoch [9/100] Avg Loss_D: 0.9340, Avg Loss_G: 3.7159


Epoch 10/100: 100%|██████████| 180/180 [29:45<00:00,  9.92s/it]


Epoch [10/100] Avg Loss_D: 1.0730, Avg Loss_G: 2.8634


Epoch 11/100: 100%|██████████| 180/180 [29:26<00:00,  9.81s/it]


Epoch [11/100] Avg Loss_D: 1.0297, Avg Loss_G: 2.1685


Epoch 12/100: 100%|██████████| 180/180 [29:29<00:00,  9.83s/it]


Epoch [12/100] Avg Loss_D: 1.0233, Avg Loss_G: 1.4096


Epoch 13/100: 100%|██████████| 180/180 [32:34<00:00, 10.86s/it]


Epoch [13/100] Avg Loss_D: 1.0192, Avg Loss_G: 1.2347


Epoch 14/100: 100%|██████████| 180/180 [34:15<00:00, 11.42s/it]


Epoch [14/100] Avg Loss_D: 1.0012, Avg Loss_G: 1.2079


Epoch 15/100: 100%|██████████| 180/180 [34:12<00:00, 11.40s/it]


Epoch [15/100] Avg Loss_D: 1.0163, Avg Loss_G: 1.1785


Epoch 16/100: 100%|██████████| 180/180 [35:42<00:00, 11.90s/it]


Epoch [16/100] Avg Loss_D: 1.0275, Avg Loss_G: 1.1771


Epoch 17/100: 100%|██████████| 180/180 [34:58<00:00, 11.66s/it]


Epoch [17/100] Avg Loss_D: 0.9838, Avg Loss_G: 1.1915


Epoch 18/100: 100%|██████████| 180/180 [29:21<00:00,  9.79s/it]


Epoch [18/100] Avg Loss_D: 1.0122, Avg Loss_G: 1.1947


Epoch 19/100: 100%|██████████| 180/180 [31:30<00:00, 10.50s/it]


Epoch [19/100] Avg Loss_D: 1.0219, Avg Loss_G: 1.1943


Epoch 20/100: 100%|██████████| 180/180 [35:34<00:00, 11.86s/it]


Epoch [20/100] Avg Loss_D: 1.0030, Avg Loss_G: 1.2005


Epoch 21/100: 100%|██████████| 180/180 [29:48<00:00,  9.93s/it]


Epoch [21/100] Avg Loss_D: 0.9935, Avg Loss_G: 1.2110


Epoch 22/100: 100%|██████████| 180/180 [28:52<00:00,  9.62s/it]


Epoch [22/100] Avg Loss_D: 1.0246, Avg Loss_G: 1.2165


Epoch 23/100: 100%|██████████| 180/180 [28:43<00:00,  9.58s/it]


Epoch [23/100] Avg Loss_D: 1.0212, Avg Loss_G: 1.2163


Epoch 24/100: 100%|██████████| 180/180 [28:57<00:00,  9.65s/it]


Epoch [24/100] Avg Loss_D: 1.0338, Avg Loss_G: 1.2221


Epoch 25/100: 100%|██████████| 180/180 [28:51<00:00,  9.62s/it]


Epoch [25/100] Avg Loss_D: 0.9877, Avg Loss_G: 1.2237


Epoch 26/100: 100%|██████████| 180/180 [29:06<00:00,  9.70s/it]


Epoch [26/100] Avg Loss_D: 1.0109, Avg Loss_G: 1.2283


Epoch 27/100: 100%|██████████| 180/180 [28:56<00:00,  9.64s/it]


Epoch [27/100] Avg Loss_D: 0.9908, Avg Loss_G: 1.2423


Epoch 28/100: 100%|██████████| 180/180 [28:59<00:00,  9.66s/it]


Epoch [28/100] Avg Loss_D: 0.9947, Avg Loss_G: 1.2442


Epoch 29/100: 100%|██████████| 180/180 [28:49<00:00,  9.61s/it]


Epoch [29/100] Avg Loss_D: 1.0256, Avg Loss_G: 1.2272


Epoch 30/100: 100%|██████████| 180/180 [28:58<00:00,  9.66s/it]


Epoch [30/100] Avg Loss_D: 0.9784, Avg Loss_G: 1.2391


Epoch 31/100: 100%|██████████| 180/180 [28:54<00:00,  9.64s/it]


Epoch [31/100] Avg Loss_D: 1.1015, Avg Loss_G: 1.0193


Epoch 32/100: 100%|██████████| 180/180 [28:52<00:00,  9.62s/it]


Epoch [32/100] Avg Loss_D: 1.1094, Avg Loss_G: 1.0201


Epoch 33/100: 100%|██████████| 180/180 [28:49<00:00,  9.61s/it]


Epoch [33/100] Avg Loss_D: 1.1055, Avg Loss_G: 1.0209


Epoch 34/100: 100%|██████████| 180/180 [29:06<00:00,  9.70s/it]


Epoch [34/100] Avg Loss_D: 1.1135, Avg Loss_G: 1.0173


Epoch 35/100: 100%|██████████| 180/180 [28:58<00:00,  9.66s/it]


Epoch [35/100] Avg Loss_D: 1.1215, Avg Loss_G: 1.0024


Epoch 36/100: 100%|██████████| 180/180 [29:02<00:00,  9.68s/it]


Epoch [36/100] Avg Loss_D: 1.1263, Avg Loss_G: 1.0020


Epoch 37/100: 100%|██████████| 180/180 [29:40<00:00,  9.89s/it]


Epoch [37/100] Avg Loss_D: 1.1187, Avg Loss_G: 1.0090


Epoch 38/100: 100%|██████████| 180/180 [34:01<00:00, 11.34s/it]


Epoch [38/100] Avg Loss_D: 1.1226, Avg Loss_G: 0.9964


Epoch 39/100: 100%|██████████| 180/180 [34:00<00:00, 11.34s/it]


Epoch [39/100] Avg Loss_D: 1.1147, Avg Loss_G: 1.0080


Epoch 40/100: 100%|██████████| 180/180 [36:02<00:00, 12.02s/it]


Epoch [40/100] Avg Loss_D: 1.1307, Avg Loss_G: 0.9879


Epoch 41/100: 100%|██████████| 180/180 [36:20<00:00, 12.11s/it]


Epoch [41/100] Avg Loss_D: 1.1185, Avg Loss_G: 0.9937


Epoch 42/100: 100%|██████████| 180/180 [34:29<00:00, 11.50s/it]


Epoch [42/100] Avg Loss_D: 1.1357, Avg Loss_G: 0.9853


Epoch 43/100: 100%|██████████| 180/180 [35:47<00:00, 11.93s/it]


Epoch [43/100] Avg Loss_D: 1.1377, Avg Loss_G: 0.9835


Epoch 44/100: 100%|██████████| 180/180 [1:03:02<00:00, 21.02s/it]


Epoch [44/100] Avg Loss_D: 1.1266, Avg Loss_G: 0.9822


Epoch 45/100:  68%|██████▊   | 122/180 [37:51<11:52, 12.29s/it] 