In [14]:
#!/usr/bin/env python
# coding: utf-8

"""
4_evaluation.py
---------------
用途:
1. 加载 2_feature_extraction.py 生成的 feature_index.csv
2. 使用与 3_model_training.py 相同的 BiLSTMChorus 模型结构
3. 读取 ../saved_models/bilstm_chorus_model.pt
4. 输出 Accuracy / Precision / Recall / F1 / AUC 等指标
5. 多阈值 sweeping, 查看 P/R/F1 如何变化

运行:
  python 4_evaluation.py

(可选) 若要进一步对单首歌检测 + ≥15秒合并，可在 detect_chorus 函数中额外实现。
"""

import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import confusion_matrix, precision_recall_fscore_support, roc_curve, auc

FEATURE_FOLDER = "../data/processed/"
MODEL_PATH = "../saved_models/bilstm_chorus_model.pt"

class BiLSTMChorus(nn.Module):
    def __init__(self, input_dim=141, hidden_dim=64):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc   = nn.Linear(hidden_dim*2, 1)  # 二分类 => 输出一个 logit

    def forward(self, x):
        # x => (batch, seq_len=1, input_dim)
        lstm_out, _= self.lstm(x)
        # 取最后时刻
        out= self.fc(lstm_out[:, -1, :])  # (batch,1)
        return out

def main():
    # 1. 加载 feature_index.csv
    df = pd.read_csv(os.path.join(FEATURE_FOLDER, "feature_index.csv"))
    features_list = []
    labels_list   = []

    for idx, row in df.iterrows():
        feat_path = row["feature_path"]
        feats     = np.load(feat_path)
        label     = row["label"]
        features_list.append(feats)
        labels_list.append(label)

    X = np.array(features_list, dtype=np.float32)
    y = np.array(labels_list,   dtype=np.int64)

    print("X shape=", X.shape, ", y shape=", y.shape)
    print("Positive=", sum(y==1), "Negative=", sum(y==0))

    # 2. 构建与训练阶段相同的 BiLSTM
    model= BiLSTMChorus(input_dim=X.shape[1], hidden_dim=64)
    # 加载
    model.load_state_dict(torch.load(MODEL_PATH))
    model.eval()

    # 3. 逐条推理 => 收集 logits
    X_t= torch.from_numpy(X)
    logits_list= []
    batch_size=32
    for start_idx in range(0, len(X_t), batch_size):
        end_idx = start_idx + batch_size
        x_batch= X_t[start_idx:end_idx].unsqueeze(1)  # (b,1, input_dim)
        with torch.no_grad():
            out= model(x_batch)  # shape=(b,1)
        logits_list.append(out.squeeze(1).numpy())

    logits_all = np.concatenate(logits_list, axis=0)  # (N,)
    prob_chorus= 1.0 / (1.0 + np.exp(-logits_all))    # sigmoid => (N,)

    # 4. 默认阈值=0.5 => pred
    pred_05= (prob_chorus>0.5).astype(int)
    acc= (pred_05==y).mean()
    print(f"Accuracy= {acc*100:.2f}%")

    cm= confusion_matrix(y, pred_05)
    print("Confusion Matrix=\n", cm)

    p,r,f,_= precision_recall_fscore_support(y, pred_05, average='binary')
    print(f"Precision={p:.3f}, Recall={r:.3f}, F1={f:.3f}")

    # 5. AUC
    from sklearn.metrics import roc_curve, auc
    fpr,tpr,thresholds= roc_curve(y, prob_chorus, pos_label=1)
    roc_auc= auc(fpr,tpr)
    print(f"ROC AUC= {roc_auc:.3f}")

    # 6. 多阈值 sweeping
    th_candidates= [0.1,0.2,0.3,0.4,0.5,0.6,0.7]
    for th in th_candidates:
        pred_th= (prob_chorus>th).astype(int)
        p2,r2,f2,_= precision_recall_fscore_support(y, pred_th, average='binary')
        print(f"threshold={th:.1f}: P={p2:.3f}, R={r2:.3f}, F1={f2:.3f}")

    print("\n[说明] 若要在单首歌层面合并帧、确保≥15秒，可在 detect_chorus 中处理.")
    print("测试完毕.")

if __name__=="__main__":
    main()
    main()

X shape= (467, 141) , y shape= (467,)
Positive= 82 Negative= 385
Accuracy= 94.65%
Confusion Matrix=
 [[362  23]
 [  2  80]]
Precision=0.777, Recall=0.976, F1=0.865
ROC AUC= 0.987
threshold=0.1: P=0.488, R=1.000, F1=0.656
threshold=0.2: P=0.584, R=0.976, F1=0.731
threshold=0.3: P=0.650, R=0.976, F1=0.780
threshold=0.4: P=0.708, R=0.976, F1=0.821
threshold=0.5: P=0.777, R=0.976, F1=0.865
threshold=0.6: P=0.792, R=0.976, F1=0.874
threshold=0.7: P=0.825, R=0.976, F1=0.894

[说明] 若要在单首歌层面合并帧、确保≥15秒，可在 detect_chorus 中处理.
测试完毕.
X shape= (467, 141) , y shape= (467,)
Positive= 82 Negative= 385
Accuracy= 94.65%
Confusion Matrix=
 [[362  23]
 [  2  80]]
Precision=0.777, Recall=0.976, F1=0.865
ROC AUC= 0.987
threshold=0.1: P=0.488, R=1.000, F1=0.656
threshold=0.2: P=0.584, R=0.976, F1=0.731
threshold=0.3: P=0.650, R=0.976, F1=0.780
threshold=0.4: P=0.708, R=0.976, F1=0.821
threshold=0.5: P=0.777, R=0.976, F1=0.865
threshold=0.6: P=0.792, R=0.976, F1=0.874
threshold=0.7: P=0.825, R=0.976, F1=0.894

