In [3]:
import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn

MODEL_PATH = "../saved_models/mlp_chorus_model.pt"
FEATURE_FOLDER = "../data/processed/"

# 同一个 MLP 结构
class MLP(nn.Module):
    def __init__(self, input_dim=13, hidden_dim=16, output_dim=2):
        super().__init__()
        self.fc1 = nn.Linear(input_dim, hidden_dim)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(hidden_dim, output_dim)
    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

# 1. 加载模型
model = MLP(input_dim=13, hidden_dim=16, output_dim=2)
model.load_state_dict(torch.load(MODEL_PATH))
model.eval()

# 2. 读取 feature_index.csv
df = pd.read_csv(os.path.join(FEATURE_FOLDER, "feature_index.csv"))
print("feature_index.csv 内容:")
print(df.head())

X_list = []
y_list = []

# 3. 从 CSV 读取真正的标签 label
for idx, row in df.iterrows():
    feat_path = row["feature_path"]
    feats = np.load(feat_path)

    # 原先 "label = 0 if idx==0 else 1" 仅是个示例.
    # 你需要用 row["label"]（或 CSV 中实际列名）
    label = row["label"]   # <-- 关键改动
    X_list.append(feats)
    y_list.append(label)

X = np.array(X_list, dtype=np.float32)
y = np.array(y_list, dtype=np.int64)

print("X shape=", X.shape, "y shape=", y.shape)

# 4. 预测
X_t = torch.from_numpy(X)
with torch.no_grad():
    output = model(X_t)
pred = output.argmax(dim=1).numpy()

# 5. 计算准确率
acc = (pred == y).mean()
print("Prediction:", pred, "\nLabel:     ", y)
print(f"Accuracy = {acc*100:.2f}%")

feature_index.csv 内容:
                                        feature_path  label
0  ../data/processed/Taylor Swift - Cruel Summer....      0
1  ../data/processed/Taylor Swift - Cruel Summer....      0
2  ../data/processed/Taylor Swift - Cruel Summer....      0
3  ../data/processed/Taylor Swift - Cruel Summer....      0
4  ../data/processed/Taylor Swift - Cruel Summer....      0
X shape= (425, 13) y shape= (425,)
Prediction: [0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 1 0 0 0 1 0 1 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0
 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0 0