In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio as T
import torchaudio.transforms as TT
import matplotlib.pyplot as plt
from tqdm import tqdm

# Genre Classifier Model
class GenreClassifier(nn.Module):
    def __init__(self, input_channels=80, ndf=64, num_classes=2):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv1d(input_channels, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm1d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm1d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.AdaptiveAvgPool1d(1),
        )
        self.fc = nn.Linear(ndf * 4, num_classes)

    def forward(self, x):
        x = self.model(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# Dataset Class (with labels)
class LabeledAudioDataset(Dataset):
    def __init__(self, root_dir, label_dict, split="train",
                 sr=22050, hop_samples=256, n_fft=1024, n_mels=80, target_time_steps=2580):
        self.label_dict = label_dict
        self.all_files = []
        self.labels = []

        for genre_name, label in label_dict.items():
            subfolder = f"{genre_name.lower()}_{split}"
            genre_path = os.path.join(root_dir, genre_name, subfolder)
            files = [os.path.join(genre_path, f) for f in os.listdir(genre_path) if f.endswith(".wav")]
            self.all_files.extend(files)
            self.labels.extend([label] * len(files))

        self.sr = sr
        self.hop_samples = hop_samples
        self.n_fft = n_fft
        self.n_mels = n_mels
        self.target_time_steps = target_time_steps

        self.mel_transform = TT.MelSpectrogram(
            sample_rate=sr,
            win_length=hop_samples * 4,
            hop_length=hop_samples,
            n_fft=n_fft,
            f_min=20.0,
            f_max=sr / 2.0,
            n_mels=n_mels,
            power=1.0,
            normalized=True,
        )

    def __len__(self):
        return len(self.all_files)

    def __getitem__(self, idx):
        file_path = self.all_files[idx]
        label = self.labels[idx]

        audio, sr = T.load(file_path)
        audio = torch.clamp(audio[0], -1.0, 1.0)

        if sr != self.sr:
            raise ValueError(f"Unexpected sampling rate: {sr}")

        with torch.no_grad():
            mel_spec = self.mel_transform(audio)
            mel_spec = 20 * torch.log10(torch.clamp(mel_spec, min=1e-5)) - 20
            mel_spec = torch.clamp((mel_spec + 100) / 100, 0.0, 1.0)

        # Pad or crop to target length
        time_steps = mel_spec.shape[1]
        if time_steps < self.target_time_steps:
            mel_spec = F.pad(mel_spec, (0, self.target_time_steps - time_steps))
        else:
            mel_spec = mel_spec[:, :self.target_time_steps]

        return mel_spec, label

# Training Function
def train_classifier(model, train_loader, val_loader, device, epochs=20, lr=1e-4):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    train_losses = []
    val_accuracies = []

    best_val_acc = 0.0  
    best_model_wts = None  

    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0.0, 0, 0
        pbar = tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs}", leave=False)

        for x, y in pbar:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            preds = model(x)
            loss = criterion(preds, y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * x.size(0)
            correct += (preds.argmax(dim=1) == y).sum().item()
            total += x.size(0)

            pbar.set_postfix(loss=total_loss/total, acc=correct/total)

        train_acc = correct / total
        train_loss = total_loss / total
        train_losses.append(train_loss)

        # Validation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                preds = model(x)
                correct += (preds.argmax(dim=1) == y).sum().item()
                total += x.size(0)

        val_acc = correct / total
        val_accuracies.append(val_acc)

        print(f"[Train] Epoch {epoch+1}: Loss={train_loss:.4f}, Acc={train_acc:.4f}")
        print(f"[Val]   Epoch {epoch+1}: Acc={val_acc:.4f}")

        # Save best model based on validation accuracy
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            best_model_wts = model.state_dict()
            torch.save(best_model_wts, "best_genre_classifier.pth")
            print(f"Best model saved with validation accuracy: {val_acc:.4f}")

    # Plot loss and validation accuracy curves
    plt.figure(figsize=(10, 5))
    plt.subplot(1, 2, 1)
    plt.plot(range(epochs), train_losses, label="Train Loss")
    plt.title("Train Loss Curve")
    plt.xlabel("Epochs")
    plt.ylabel("Loss")
    plt.grid(True)

    plt.subplot(1, 2, 2)
    plt.plot(range(epochs), val_accuracies, label="Validation Accuracy")
    plt.title("Validation Accuracy Curve")
    plt.xlabel("Epochs")
    plt.ylabel("Accuracy")
    plt.grid(True)

    plt.tight_layout()
    plt.show()


if __name__ == "__main__":
    root_dir = "/DATA/music/dataset/wav_fma_split"
    label_dict = {"Pop": 0, "Rock": 1}  # Extend this dict according to your genre labels
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_dataset = LabeledAudioDataset(root_dir=root_dir, label_dict=label_dict, split="train")
    val_dataset = LabeledAudioDataset(root_dir=root_dir, label_dict=label_dict, split="val")

    train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, num_workers=4)

    model = GenreClassifier(input_channels=80, num_classes=len(label_dict))
    train_classifier(model, train_loader, val_loader, device, epochs=100, lr=1e-4)
