In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torch.optim.lr_scheduler import CosineAnnealingLR
from util import ManualMFCC, SpeechDataset, SpeechRecognizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# 初始化模型
model = SpeechRecognizer().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

In [3]:
# 初始化MFCC提取器
mfcc_extractor = ManualMFCC(sample_rate=8000)

# 创建数据集和数据加载器
VOCAB = {"00": "数字", "01": "语音", "02": "语言", "03": "处理", "04": "中国", "05": "忠告", "06": "北京", "07": "背景", "08": "上海", "09": "商行", "10": "Speech", "11": "Speaker", "12": "Signal", "13": "Sequence", "14": "Processing", "15": "Print", "16": "Project", "17": "File", "18": "Open", "19": "Close"}
dataset = SpeechDataset("", VOCAB, mfcc_extractor)
dataloader = DataLoader(dataset, batch_size=16, shuffle=True)

训练集: 2400 个音频文件


In [4]:
num_epochs = 100
scheduler = CosineAnnealingLR(optimizer, T_max=num_epochs, eta_min=1e-5)
for epoch in range(num_epochs):
    total_loss, correct, total = 0.0, 0, 0

    for batch_idx, (mfcc, labels) in enumerate(dataloader):
        mfcc = mfcc.permute(0, 2, 1).to(device)  # (batch, time, features)
        labels = labels.squeeze().to(device)

        # 前向传播
        outputs = model(mfcc)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 统计
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    # 更新学习率
    scheduler.step()
    
    epoch_loss = total_loss / len(dataloader)
    epoch_acc = 100 * correct / total
    current_lr = optimizer.param_groups[0]['lr']
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss*100:.2f}, Accuracy: {epoch_acc:.3f}%")

torch.save(model.state_dict(), "model.pth")

Epoch [1/200], Loss: 285.49, Accuracy: 8.292%
Epoch [2/200], Loss: 229.36, Accuracy: 23.792%
Epoch [3/200], Loss: 189.81, Accuracy: 32.917%
Epoch [4/200], Loss: 164.47, Accuracy: 40.208%
Epoch [5/200], Loss: 144.81, Accuracy: 46.500%
Epoch [6/200], Loss: 128.61, Accuracy: 52.792%
Epoch [7/200], Loss: 106.24, Accuracy: 61.625%
Epoch [8/200], Loss: 86.56, Accuracy: 67.667%
Epoch [9/200], Loss: 71.65, Accuracy: 75.375%
Epoch [10/200], Loss: 61.32, Accuracy: 78.500%
Epoch [11/200], Loss: 58.92, Accuracy: 79.667%
Epoch [12/200], Loss: 50.94, Accuracy: 82.125%
Epoch [13/200], Loss: 42.96, Accuracy: 84.375%
Epoch [14/200], Loss: 38.13, Accuracy: 87.333%
Epoch [15/200], Loss: 40.15, Accuracy: 86.500%
Epoch [16/200], Loss: 35.70, Accuracy: 87.750%
Epoch [17/200], Loss: 27.08, Accuracy: 91.208%
Epoch [18/200], Loss: 23.52, Accuracy: 91.833%
Epoch [19/200], Loss: 22.68, Accuracy: 92.208%
Epoch [20/200], Loss: 19.93, Accuracy: 93.250%
Epoch [21/200], Loss: 22.54, Accuracy: 92.708%
Epoch [22/200], 