In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from util import ManualMFCC, SpeechDataset, SpeechRecognizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# 初始化模型
model = SpeechRecognizer().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [None]:
# 初始化MFCC提取器
mfcc_extractor = ManualMFCC(sample_rate=8000)

# 创建数据集和数据加载器
VOCAB = {"00": "数字", "01": "语音", "02": "语言", "03": "处理", "04": "中国", "05": "忠告", "06": "北京", "07": "背景", "08": "上海", "09": "商行", "10": "Speech", "11": "Speaker", "12": "Signal", "13": "Sequence", "14": "Processing", "15": "Print", "16": "Project", "17": "File", "18": "Open", "19": "Close"}
dataset = SpeechDataset("", VOCAB, mfcc_extractor)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

训练集: 1822 个音频文件


In [None]:
num_epochs = 200
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-5)
for epoch in range(num_epochs):
    total_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (mfcc, labels) in enumerate(dataloader):
        mfcc = mfcc.permute(0, 2, 1).to(device)  # (batch, time, features)
        labels = labels.squeeze().to(device)

        # 前向传播
        outputs = model(mfcc)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 统计
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    # 更新学习率
    scheduler.step()
    
    epoch_loss = total_loss / len(dataloader)
    epoch_acc = 100 * correct / total
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss*100:.2f}, Accuracy: {epoch_acc:.3f}%")

torch.save(model.state_dict(), "model.pth")

ValueError: Expected input batch_size (1) to match target batch_size (0).