In [10]:
#!/usr/bin/env python
# coding: utf-8

"""
3_model_training.py
-------------------
用途:
1. 读取 2_feature_extraction.py 生成的 feature_index.csv
2. 构建 BiLSTMChorus (示例) 或简化 MLP, 并使用 Weighted BCE Loss
3. 训练后保存 model.pt

运行:
  python 3_model_training.py

输出:
  ../saved_models/bilstm_chorus_model.pt
"""

import os
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.model_selection import train_test_split

FEATURE_FOLDER = "../data/processed/"
MODEL_PATH = "../saved_models/bilstm_chorus_model.pt"

class BiLSTMChorus(nn.Module):
    def __init__(self, input_dim=141, hidden_dim=64):
        super().__init__()
        self.lstm = nn.LSTM(input_dim, hidden_dim, batch_first=True, bidirectional=True)
        self.fc = nn.Linear(hidden_dim * 2, 1)  # 二分类 => 输出一个logit

    def forward(self, x):
        # x shape => (batch, seq_len=1, input_dim)
        lstm_out, _ = self.lstm(x)  # (batch, 1, hidden_dim*2)
        out = self.fc(lstm_out[:, -1, :])  # (batch, 1)
        return out

def main():
    df = pd.read_csv(os.path.join(FEATURE_FOLDER, "feature_index.csv"))
    print("读取 feature_index.csv =>")
    print(df.head())

    features_list = []
    labels_list = []

    for idx, row in df.iterrows():
        feat_path = row["feature_path"]
        feats = np.load(feat_path)
        label = row["label"]
        features_list.append(feats)
        labels_list.append(label)

    X = np.array(features_list, dtype=np.float32)  # shape=(N, 3*feature_dim)
    y = np.array(labels_list, dtype=np.int64)

    print("X shape=", X.shape, ", y shape=", y.shape)
    print("正例数=", sum(y == 1), ", 负例数=", sum(y == 0))

    # 拆分训练/验证集
    X_train, X_val, y_train, y_val = train_test_split(
        X, y, test_size=0.2, random_state=42, stratify=y
    )
    print("Train size=", X_train.shape, "Val size=", X_val.shape)

    input_dim = X.shape[1]  # 例如 141
    hidden_dim = 64
    model = BiLSTMChorus(input_dim, hidden_dim)
    print(model)

    # Weighted BCEWithLogitsLoss (提高对副歌(正例)的召回)
    pos_count = sum(y_train == 1)
    neg_count = sum(y_train == 0)
    if pos_count == 0:
        pos_weight_val = 1.0
    else:
        pos_weight_val = float(neg_count) / float(pos_count)
    pos_weight_tensor = torch.tensor([pos_weight_val], dtype=torch.float32)
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight_tensor)

    optimizer = optim.Adam(model.parameters(), lr=1e-3)

    X_train_t = torch.from_numpy(X_train)
    y_train_t = torch.from_numpy(y_train).float()  # 用于BCE, 故转为float
    X_val_t = torch.from_numpy(X_val)
    y_val_t = torch.from_numpy(y_val).float()

    def train_one_epoch(model, X_data, y_data, batch_size=32):
        model.train()
        total_loss = 0.0
        indices = np.arange(len(X_data))
        np.random.shuffle(indices)

        for start_idx in range(0, len(X_data), batch_size):
            end_idx = start_idx + batch_size
            batch_idx = indices[start_idx:end_idx]

            x_batch = X_data[batch_idx].unsqueeze(1)  # (b,1,input_dim)
            y_batch = y_data[batch_idx].unsqueeze(1)  # (b,1)

            optimizer.zero_grad()
            logits = model(x_batch)       # shape=(b,1)
            loss = criterion(logits, y_batch)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        avg_loss = total_loss / (len(X_data) / batch_size)
        return avg_loss

    def evaluate(model, X_data, y_data, batch_size=32):
        model.eval()
        total_loss = 0.0
        correct = 0
        total = 0

        with torch.no_grad():
            indices = np.arange(len(X_data))
            for start_idx in range(0, len(X_data), batch_size):
                end_idx = start_idx + batch_size
                batch_idx = indices[start_idx:end_idx]

                x_batch = X_data[batch_idx].unsqueeze(1)
                y_batch = y_data[batch_idx].unsqueeze(1)

                logits = model(x_batch)
                loss = criterion(logits, y_batch)
                total_loss += loss.item()

                # 预测: sigmoid>0.5 => 1
                pred = (torch.sigmoid(logits) > 0.5).float()
                correct += (pred == y_batch).sum().item()
                total += y_batch.numel()

        avg_loss = total_loss / (len(X_data) / batch_size)
        acc = correct / total
        return avg_loss, acc

    epochs = 30
    best_val_acc = 0.0
    for ep in range(1, epochs + 1):
        tr_loss = train_one_epoch(model, X_train_t, y_train_t, batch_size=32)
        val_loss, val_acc = evaluate(model, X_val_t, y_val_t, batch_size=32)

        # 若val_acc更好, 保存模型
        if val_acc > best_val_acc:
            best_val_acc = val_acc
            torch.save(model.state_dict(), MODEL_PATH)

        print(f"Epoch {ep} => train_loss={tr_loss:.4f}, val_loss={val_loss:.4f}, val_acc={val_acc:.4f}")

    print("最优 val_acc=", best_val_acc)
    print("已保存模型 =>", MODEL_PATH)

if __name__ == "__main__":
    main()

读取 feature_index.csv =>
                                        feature_path  label  \
0  ../data/processed/Taylor Swift - Cruel Summer....      0   
1  ../data/processed/Taylor Swift - Cruel Summer....      0   
2  ../data/processed/Taylor Swift - Cruel Summer....      0   
3  ../data/processed/Taylor Swift - Cruel Summer....      0   
4  ../data/processed/Taylor Swift - Cruel Summer....      0   

                          filename  
0  Taylor Swift - Cruel Summer.mp3  
1  Taylor Swift - Cruel Summer.mp3  
2  Taylor Swift - Cruel Summer.mp3  
3  Taylor Swift - Cruel Summer.mp3  
4  Taylor Swift - Cruel Summer.mp3  
X shape= (467, 141) , y shape= (467,)
正例数= 82 , 负例数= 385
Train size= (373, 141) Val size= (94, 141)
BiLSTMChorus(
  (lstm): LSTM(141, 64, batch_first=True, bidirectional=True)
  (fc): Linear(in_features=128, out_features=1, bias=True)
)
Epoch 1 => train_loss=1.1376, val_loss=1.1109, val_acc=0.4681
Epoch 2 => train_loss=1.0694, val_loss=1.0587, val_acc=0.4894
Epoch 3 => tra