<a href="https://colab.research.google.com/github/Champei/mine/blob/main/task2/jupyter.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:

# 1. INSTALL

!pip install torch torchaudio torchvision transformers tqdm matplotlib


In [None]:

# 2. MOUNT GOOGLE DRIVE & VERIFY DATA

from google.colab import drive
drive.mount('/content/drive')
import os

BASE_PATH = '/content/drive/MyDrive/Audios'

print("Subfolders (categories):", os.listdir(BASE_PATH))


In [None]:

# 3. IMPORTS & DATASET CLASS

import torch, torch.nn as nn, torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio, random, matplotlib.pyplot as plt
from tqdm import tqdm
from IPython.display import Audio, display

class TrainAudioSpectrogramDataset(Dataset):
    """Loads .wav files, converts to log-mel spectrograms, returns tensor + one-hot label."""
    def __init__(self, root_dir, categories, max_frames=512, fraction=1.0):
        self.root_dir, self.categories, self.max_frames = root_dir, categories, max_frames
        self.file_list = []
        self.class_to_idx = {cat: i for i, cat in enumerate(categories)}

        for cat in categories:
            path = os.path.join(root_dir, cat)
            files = [os.path.join(path, f) for f in os.listdir(path) if f.endswith('.wav')]
            n = int(len(files) * fraction)
            for f in random.sample(files, n):
                self.file_list.append((f, self.class_to_idx[cat]))

    def __len__(self): return len(self.file_list)

    def __getitem__(self, idx):
        path, label = self.file_list[idx]
        wav, sr = torchaudio.load(path)
        if wav.size(0) > 1: wav = wav.mean(0, keepdim=True)
        mel = torchaudio.transforms.MelSpectrogram(sr, n_fft=1024, hop_length=256, n_mels=128)(wav)
        logmel = torch.log1p(mel)
        _, _, n_frames = logmel.shape
        logmel = F.pad(logmel, (0, max(0, 512 - n_frames)))[:, :, :512]
        y = F.one_hot(torch.tensor(label), num_classes=len(self.categories)).float()
        return logmel, y


In [None]:

# 4. MODELS

class CGAN_Generator(nn.Module):
    def __init__(self, latent_dim, num_classes):
        super().__init__()
        self.latent_dim, self.num_classes = latent_dim, num_classes
        self.num_categories = num_classes
        self.fc = nn.Linear(latent_dim + num_classes, 256 * 8 * 32)
        self.net = nn.Sequential(
            nn.ConvTranspose2d(256,128,4,2,1), nn.BatchNorm2d(128), nn.ReLU(),
            nn.ConvTranspose2d(128,64,4,2,1), nn.BatchNorm2d(64), nn.ReLU(),
            nn.ConvTranspose2d(64,32,4,2,1), nn.BatchNorm2d(32), nn.ReLU(),
            nn.ConvTranspose2d(32,1,4,2,1), nn.ReLU()
        )
    def forward(self,z,y):
        h=torch.cat([z,y],1)
        h=self.fc(h).view(-1,256,8,32)
        return self.net(h)

class CGAN_Discriminator(nn.Module):
    def __init__(self,num_classes):
        super().__init__()
        self.label_emb=nn.Linear(num_classes,128*512)
        self.net=nn.Sequential(
            nn.Conv2d(2,32,4,2,1), nn.LeakyReLU(0.2),
            nn.Conv2d(32,64,4,2,1), nn.BatchNorm2d(64), nn.LeakyReLU(0.2),
            nn.Conv2d(64,128,4,2,1), nn.BatchNorm2d(128), nn.LeakyReLU(0.2),
            nn.Conv2d(128,256,4,2,1), nn.BatchNorm2d(256), nn.LeakyReLU(0.2),
            nn.Conv2d(256,1,(8,32),1,0)
        )
    def forward(self,x,y):
        ymap=self.label_emb(y).view(-1,1,128,512)
        h=torch.cat([x,ymap],1)
        return self.net(h).view(-1,1)


In [None]:
# 5. UTILITIES FOR AUDIO GENERATION & PLAYBACK

def generate_audio_gan(generator, category_idx, num_samples, device, sr=22050):
    generator.eval()
    y = F.one_hot(torch.tensor([category_idx]), num_classes=generator.num_classes).float().to(device)
    z = torch.randn(num_samples, generator.latent_dim, device=device)
    with torch.no_grad():
        logmel = generator(z, y)
    mel = torch.expm1(logmel).squeeze(1)
    invmel = torchaudio.transforms.InverseMelScale(n_stft=513, n_mels=128, sample_rate=sr).to(device)
    spec = invmel(mel)
    griffin = torchaudio.transforms.GriffinLim(1024, hop_length=256, n_iter=32).to(device)
    wav = griffin(spec).cpu()
    return wav

def play_and_save(wav, sr, name):
    import torchaudio
    torchaudio.save(name, wav.squeeze(0), sr)
    print("Saved:", name)
    display(Audio(wav.numpy().squeeze(), rate=sr))


In [None]:
# 6. TRAINING FUNCTION

def train_gan(generator, discriminator, dataloader, device, categories, epochs, lr, latent_dim):

    optimizer_G = torch.optim.Adam(generator.parameters(), lr=lr, betas=(0.5, 0.999))
    optimizer_D = torch.optim.Adam(discriminator.parameters(), lr=lr, betas=(0.5, 0.999))
    criterion = nn.BCEWithLogitsLoss()

    os.makedirs("gan_spectrogram_plots", exist_ok=True)

    for epoch in range(1, epochs + 1):
        loop = tqdm(dataloader, desc=f"Epoch {epoch}/{epochs}", leave=True)

        for real_specs, labels in loop:
            real_specs = real_specs.to(device)
            labels = labels.to(device)
            batch_size = real_specs.size(0)

            real_labels = torch.ones(batch_size, 1, device=device)
            fake_labels = torch.zeros(batch_size, 1, device=device)

            # Train Discriminator

            optimizer_D.zero_grad()

            real_out = discriminator(real_specs, labels)
            loss_D_real = criterion(real_out, real_labels)

            z = torch.randn(batch_size, latent_dim, device=device)
            fake_specs = generator(z, labels)

            fake_out = discriminator(fake_specs.detach(), labels)
            loss_D_fake = criterion(fake_out, fake_labels)

            loss_D = loss_D_real + loss_D_fake
            loss_D.backward()
            optimizer_D.step()

            # Train Generator

            optimizer_G.zero_grad()

            fake_out = discriminator(fake_specs, labels)
            loss_G = criterion(fake_out, real_labels)

            loss_G.backward()
            optimizer_G.step()

            loop.set_postfix(lossD=loss_D.item(), lossG=loss_G.item())

        print(f"\nSample generation after epoch {epoch}")

        generator.eval()

        for cat_idx, cat_name in enumerate(categories):
            y_cond = F.one_hot(torch.tensor([cat_idx]), num_classes=generator.num_categories).float().to(device)
            z_sample = torch.randn(1, generator.latent_dim).to(device)

            with torch.no_grad():
                spec_log = generator(z_sample, y_cond).squeeze().cpu().numpy()

            plt.figure(figsize=(6,4))
            plt.imshow(spec_log, aspect='auto', origin='lower', cmap='viridis')
            plt.title(f"{cat_name} (Epoch {epoch})")
            plt.axis('off')
            plt.savefig(f"gan_spectrogram_plots/{cat_name}_ep{epoch}.png")
            plt.close()

        generator.train()



In [None]:
!apt install ffmpeg -y


In [None]:
import os

base_dir = '/content/drive/MyDrive/Audios'

for root, dirs, files in os.walk(base_dir):
    for f in files:
        if f.lower().endswith('.mp3'):
            mp3_path = os.path.join(root, f)
            wav_path = os.path.splitext(mp3_path)[0] + '.wav'

            # Only convert if WAV doesn't exist yet
            if not os.path.exists(wav_path):
                print(f"Converting: {mp3_path}")
                !ffmpeg -y -i "{mp3_path}" -ar 22050 -ac 1 "{wav_path}"


In [None]:
for cat in os.listdir(base_dir):
    cat_path = os.path.join(base_dir, cat)
    if os.path.isdir(cat_path):
        wavs = [f for f in os.listdir(cat_path) if f.endswith('.wav')]
        print(f"{cat}: {len(wavs)} wav files")


In [None]:
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
LATENT_DIM = 100
EPOCHS = 5
BATCH = 8
LR = 2e-4

In [None]:
train_path = '/content/drive/MyDrive/Audios'
cats = sorted([d for d in os.listdir(train_path) if os.path.isdir(os.path.join(train_path, d))])
print("Categories:", cats)

ds = TrainAudioSpectrogramDataset(train_path, cats)
dl = DataLoader(ds, batch_size=BATCH, shuffle=True, num_workers=2)

G = CGAN_Generator(LATENT_DIM, len(cats)).to(DEVICE)
D = CGAN_Discriminator(len(cats)).to(DEVICE)

train_gan(G, D, dl, DEVICE, cats, EPOCHS, LR, LATENT_DIM)
