In [7]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

class GenreClassifier(nn.Module):
    def __init__(self, input_channels=80, ndf=64, num_classes=2):
        super().__init__()
        self.model = nn.Sequential(
            nn.Conv1d(input_channels, ndf, kernel_size=4, stride=2, padding=1),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(ndf, ndf * 2, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm1d(ndf * 2),
            nn.LeakyReLU(0.2, inplace=True),

            nn.Conv1d(ndf * 2, ndf * 4, kernel_size=4, stride=2, padding=1),
            nn.InstanceNorm1d(ndf * 4),
            nn.LeakyReLU(0.2, inplace=True),

            nn.AdaptiveAvgPool1d(1),
        )
        self.fc = nn.Linear(ndf * 4, num_classes)

    def forward(self, x):
        x = self.model(x)
        x = x.view(x.size(0), -1)
        return self.fc(x)

# 分类器加载函数
def load_model(model_path, num_classes, device):
    model = GenreClassifier(input_channels=80, num_classes=num_classes)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

# 加载 .npy 的 mel 特征
def load_mel_npy(file_path, target_time_steps=2580):
    mel = np.load(file_path)
    mel = torch.tensor(mel).float()

    # 去除额外的维度
    mel = mel.squeeze(0)  # 将 [1, 1, 80, 2580] 转换为 [1, 80, 2580]
    mel = mel.squeeze(0)  # 将 [1, 80, 2580] 转换为 [80, 2580]

    # 如果时间步长不足，填充；如果太长，裁剪
    if mel.shape[1] < target_time_steps:
        mel = F.pad(mel, (0, target_time_steps - mel.shape[1]))
    elif mel.shape[1] > target_time_steps:
        mel = mel[:, :target_time_steps]

    return mel.unsqueeze(0)  # [1, 80, T] (添加 batch 维度)

# 遍历评估 .npy 文件夹
def evaluate_mel_folder(folder_path, model, label_map, expected_label_id, device):
    files = [f for f in os.listdir(folder_path) if f.endswith(".npy")]
    correct = 0
    total = len(files)

    # 遍历所有文件，逐个进行评估
    for file_name in files:
        mel = load_mel_npy(os.path.join(folder_path, file_name)).to(device)
        print(f"Evaluating file: {file_name}, Mel shape: {mel.shape}")

        with torch.no_grad():
            output = model(mel)
            pred = output.argmax(dim=1).item()

        # 打印每个文件的预测结果
        predicted_label = label_map[pred]
        print(f"Prediction: {predicted_label} (Expected: {label_map[expected_label_id]})")

        if pred == expected_label_id:
            correct += 1

    # 打印最终的准确率结果
    acc = correct / total if total > 0 else 0.0
    print(f"Evaluation complete. Total: {total} samples, Accuracy: {acc:.4f} ({label_map[expected_label_id]})")

    return acc

if __name__ == "__main__":
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    label_map = {0: "Pop", 1: "Rock"}
    inv_label_map = {v: k for k, v in label_map.items()}

    # 加载模型
    model_path = "best_genre_classifier.pth"
    model = load_model(model_path, num_classes=2, device=device)

    # 指定转换后的 .npy 目录
    converted_mel_dir = "/home/quincy/DATA/music/converted_features"
    target_label = "Rock"

    # 评估
    evaluate_mel_folder(
        folder_path=converted_mel_dir,
        model=model,
        label_map=label_map,
        expected_label_id=inv_label_map[target_label],
        device=device
    )


  model.load_state_dict(torch.load(model_path, map_location=device))


Evaluating file: 027802.wav_converted.npy, Mel shape: torch.Size([1, 80, 2580])
Prediction: Rock (Expected: Rock)
Evaluating file: 048454.wav_converted.npy, Mel shape: torch.Size([1, 80, 2580])
Prediction: Pop (Expected: Rock)
Evaluating file: 122500.wav_converted.npy, Mel shape: torch.Size([1, 80, 2580])
Prediction: Rock (Expected: Rock)
Evaluating file: 113808.wav_converted.npy, Mel shape: torch.Size([1, 80, 2580])
Prediction: Rock (Expected: Rock)
Evaluating file: 090587.wav_converted.npy, Mel shape: torch.Size([1, 80, 2580])
Prediction: Rock (Expected: Rock)
Evaluating file: 126400.wav_converted.npy, Mel shape: torch.Size([1, 80, 2580])
Prediction: Rock (Expected: Rock)
Evaluating file: 051267.wav_converted.npy, Mel shape: torch.Size([1, 80, 2580])
Prediction: Rock (Expected: Rock)
Evaluating file: 035549.wav_converted.npy, Mel shape: torch.Size([1, 80, 2580])
Prediction: Rock (Expected: Rock)
Evaluating file: 149078.wav_converted.npy, Mel shape: torch.Size([1, 80, 2580])
Predictio