In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import torchaudio as T
import torchaudio.transforms as TT

# --------------------------
# 分类器模型
# --------------------------
class GenreClassifier(nn.Module):
    def __init__(self, input_channels=80, ndf=64, num_classes=2):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv1d(input_channels, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm1d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm1d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.AdaptiveAvgPool1d(1),
        )
        self.fc = nn.Linear(ndf * 4, num_classes)

    def forward(self, x):
        x = self.model(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# --------------------------
# 数据集类（支持标签）
# --------------------------
class LabeledAudioDataset(Dataset):
    def __init__(self, root_dir, label_dict, split="train",
                 sr=22050, hop_samples=256, n_fft=1024, n_mels=80, target_time_steps=2580):
        self.label_dict = label_dict
        self.all_files = []
        self.labels = []

        for genre_name, label in label_dict.items():
            subfolder = f"{genre_name.lower()}_{split}"
            genre_path = os.path.join(root_dir, genre_name, subfolder)
            files = [os.path.join(genre_path, f) for f in os.listdir(genre_path) if f.endswith(".wav")]
            self.all_files.extend(files)
            self.labels.extend([label] * len(files))

        self.sr = sr
        self.hop_samples = hop_samples
        self.n_fft = n_fft
        self.n_mels = n_mels
        self.target_time_steps = target_time_steps

        self.mel_transform = TT.MelSpectrogram(
            sample_rate=sr,
            win_length=hop_samples * 4,
            hop_length=hop_samples,
            n_fft=n_fft,
            f_min=20.0,
            f_max=sr / 2.0,
            n_mels=n_mels,
            power=1.0,
            normalized=True,
        )

    def __len__(self):
        return len(self.all_files)

    def __getitem__(self, idx):
        file_path = self.all_files[idx]
        label = self.labels[idx]

        audio, sr = T.load(file_path)
        audio = torch.clamp(audio[0], -1.0, 1.0)

        if sr != self.sr:
            raise ValueError(f"Unexpected sampling rate: {sr}")

        with torch.no_grad():
            mel_spec = self.mel_transform(audio)
            mel_spec = 20 * torch.log10(torch.clamp(mel_spec, min=1e-5)) - 20
            mel_spec = torch.clamp((mel_spec + 100) / 100, 0.0, 1.0)

        # Pad/crop to target length
        time_steps = mel_spec.shape[1]
        if time_steps < self.target_time_steps:
            mel_spec = F.pad(mel_spec, (0, self.target_time_steps - time_steps))
        else:
            mel_spec = mel_spec[:, :self.target_time_steps]

        return mel_spec, label

# --------------------------
# 训练函数
# --------------------------
def train_classifier(model, train_loader, val_loader, device, epochs=20, lr=1e-4):
    model.to(device)
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)
    criterion = nn.CrossEntropyLoss()

    for epoch in range(epochs):
        model.train()
        total_loss, correct, total = 0.0, 0, 0
        for x, y in train_loader:
            x, y = x.to(device), y.to(device)
            optimizer.zero_grad()
            preds = model(x)
            loss = criterion(preds, y)
            loss.backward()
            optimizer.step()

            total_loss += loss.item() * x.size(0)
            correct += (preds.argmax(dim=1) == y).sum().item()
            total += x.size(0)

        train_acc = correct / total
        print(f"[Train] Epoch {epoch+1}: Loss={total_loss/total:.4f}, Acc={train_acc:.4f}")

        # Validation
        model.eval()
        correct, total = 0, 0
        with torch.no_grad():
            for x, y in val_loader:
                x, y = x.to(device), y.to(device)
                preds = model(x)
                correct += (preds.argmax(dim=1) == y).sum().item()
                total += x.size(0)

        val_acc = correct / total
        print(f"[Val]   Epoch {epoch+1}: Acc={val_acc:.4f}")

# --------------------------
# 主程序
# --------------------------
if __name__ == "__main__":
    root_dir = "/home/quincy/DATA/music/dataset/wav_fma_split"
    label_dict = {"Pop": 0, "Rock": 1}  # 根据你的类别扩展即可
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    train_dataset = LabeledAudioDataset(root_dir=root_dir, label_dict=label_dict, split="train")
    val_dataset = LabeledAudioDataset(root_dir=root_dir, label_dict=label_dict, split="val")

    train_loader = DataLoader(train_dataset, batch_size=8, shuffle=True, num_workers=4)
    val_loader = DataLoader(val_dataset, batch_size=8, shuffle=False, num_workers=4)

    model = GenreClassifier(input_channels=80, num_classes=len(label_dict))
    train_classifier(model, train_loader, val_loader, device, epochs=100, lr=1e-4)

    # 可选：保存模型
    torch.save(model.state_dict(), "genre_classifier1.pth")
    print("模型已保存为 genre_classifier.pth")


[Train] Epoch 1: Loss=0.6570, Acc=0.6467
[Val]   Epoch 1: Acc=0.6784
[Train] Epoch 2: Loss=0.6235, Acc=0.7139
[Val]   Epoch 2: Acc=0.6935
[Train] Epoch 3: Loss=0.6080, Acc=0.7089
[Val]   Epoch 3: Acc=0.6884
[Train] Epoch 4: Loss=0.5963, Acc=0.7306
[Val]   Epoch 4: Acc=0.6985
[Train] Epoch 5: Loss=0.5866, Acc=0.7261
[Val]   Epoch 5: Acc=0.7085
[Train] Epoch 6: Loss=0.5770, Acc=0.7339
[Val]   Epoch 6: Acc=0.7286
[Train] Epoch 7: Loss=0.5678, Acc=0.7317
[Val]   Epoch 7: Acc=0.6985
[Train] Epoch 8: Loss=0.5619, Acc=0.7372
[Val]   Epoch 8: Acc=0.7136
[Train] Epoch 9: Loss=0.5540, Acc=0.7456
[Val]   Epoch 9: Acc=0.7236
[Train] Epoch 10: Loss=0.5473, Acc=0.7467
[Val]   Epoch 10: Acc=0.7186
[Train] Epoch 11: Loss=0.5387, Acc=0.7572
[Val]   Epoch 11: Acc=0.7136
[Train] Epoch 12: Loss=0.5318, Acc=0.7544
[Val]   Epoch 12: Acc=0.7085
[Train] Epoch 13: Loss=0.5257, Acc=0.7589
[Val]   Epoch 13: Acc=0.7588
[Train] Epoch 14: Loss=0.5229, Acc=0.7611
[Val]   Epoch 14: Acc=0.7538
[Train] Epoch 15: Loss=0