In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from util import ManualMFCC, SpeechDataset, SpeechRecognizer

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
# 初始化MFCC提取器
mfcc_extractor = ManualMFCC(sample_rate=8000)

# 创建数据集和数据加载器
VOCAB = {"00": "数字", "01": "语音", "02": "语言", "03": "处理", "04": "中国", "05": "忠告", "06": "北京", "07": "背景", "08": "上海", "09": "商行", "10": "Speech", "11": "Speaker", "12": "Signal", "13": "Sequence", "14": "Processing", "15": "Print", "16": "Project", "17": "File", "18": "Open", "19": "Close"}
dataset = SpeechDataset("", VOCAB, mfcc_extractor)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)

# 初始化模型
model = SpeechRecognizer().to(device)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

找到 2442 个音频文件


In [3]:
num_epochs = 100
for epoch in range(num_epochs):
    total_loss = 0.0
    correct = 0
    total = 0

    for batch_idx, (mfcc, labels) in enumerate(dataloader):
        mfcc = mfcc.permute(0, 2, 1).to(device)  # (batch, time, features)
        labels = labels.squeeze().to(device)

        # 前向传播
        outputs = model(mfcc)
        loss = criterion(outputs, labels)

        # 反向传播和优化
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        # 统计
        total_loss += loss.item()
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

    epoch_loss = total_loss / len(dataloader)
    epoch_acc = 100 * correct / total
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {epoch_loss:.2f}, Accuracy: {epoch_acc:.1f}%")

torch.save(model.state_dict(), "model.pth")

Epoch [1/100], Loss: 2.82, Accuracy: 6.1%
Epoch [2/100], Loss: 2.61, Accuracy: 12.7%
Epoch [3/100], Loss: 2.32, Accuracy: 21.4%
Epoch [4/100], Loss: 1.97, Accuracy: 29.3%
Epoch [5/100], Loss: 1.83, Accuracy: 33.3%
Epoch [6/100], Loss: 1.64, Accuracy: 42.6%
Epoch [7/100], Loss: 1.32, Accuracy: 51.9%
Epoch [8/100], Loss: 1.15, Accuracy: 58.8%
Epoch [9/100], Loss: 1.00, Accuracy: 64.3%
Epoch [10/100], Loss: 0.81, Accuracy: 71.2%
Epoch [11/100], Loss: 0.90, Accuracy: 67.3%
Epoch [12/100], Loss: 0.91, Accuracy: 67.9%
Epoch [13/100], Loss: 0.78, Accuracy: 72.2%
Epoch [14/100], Loss: 0.57, Accuracy: 80.2%
Epoch [15/100], Loss: 0.48, Accuracy: 83.7%
Epoch [16/100], Loss: 0.52, Accuracy: 81.2%
Epoch [17/100], Loss: 0.47, Accuracy: 84.8%
Epoch [18/100], Loss: 0.46, Accuracy: 83.3%
Epoch [19/100], Loss: 0.50, Accuracy: 82.8%
Epoch [20/100], Loss: 0.85, Accuracy: 71.6%
Epoch [21/100], Loss: 0.69, Accuracy: 75.4%
Epoch [22/100], Loss: 0.49, Accuracy: 83.1%
Epoch [23/100], Loss: 0.56, Accuracy: 81.5