<a href="https://colab.research.google.com/github/Song20011219/song/blob/main/train.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import torch
from sklearn.metrics import accuracy_score
from tools import wer

def train_epoch(model, criterion, optimizer, dataloader, device, epoch, logger, log_interval, writer):
    model.train()
    losses = []
    all_label = []
    all_pred = []

    for batch_idx, data in enumerate(dataloader):
        # 获取输入和标签
        inputs, labels = data['data'].to(device), data['label'].to(device)

        optimizer.zero_grad()
        # 前向传播
        outputs = model(inputs)
        if isinstance(outputs, list):
            outputs = outputs[0]

        # 计算损失
        loss = criterion(outputs, labels.squeeze())
        losses.append(loss.item())

        # 计算准确率
        prediction = torch.max(outputs, 1)[1]
        all_label.extend(labels.squeeze())
        all_pred.extend(prediction)
        score = accuracy_score(labels.squeeze().cpu().data.squeeze().numpy(), prediction.cpu().data.squeeze().numpy())

        # 反向传播 & 优化
        loss.backward()
        optimizer.step()

        if (batch_idx + 1) % log_interval == 0:
            logger.info("epoch {:3d} | iteration {:5d} | Loss {:.6f} | Acc {:.2f}%".format(epoch+1, batch_idx+1, loss.item(), score*100))

    # 计算平均损失和准确率
    training_loss = sum(losses)/len(losses)
    all_label = torch.stack(all_label, dim=0)
    all_pred = torch.stack(all_pred, dim=0)
    training_acc = accuracy_score(all_label.squeeze().cpu().data.squeeze().numpy(), all_pred.cpu().data.squeeze().numpy())
    # 记录
    writer.add_scalars('Loss', {'train': training_loss}, epoch+1)
    writer.add_scalars('Accuracy', {'train': training_acc}, epoch+1)
    logger.info("第 {} 轮平均训练损失: {:.6f} | 准确率: {:.2f}%".format(epoch+1, training_loss, training_acc*100))


def train_seq2seq(model, criterion, optimizer, clip, dataloader, device, epoch, logger, log_interval, writer):
    model.train()
    losses = []
    all_trg = []
    all_pred = []
    all_wer = []

    for batch_idx, (imgs, target) in enumerate(dataloader):
        imgs = imgs.to(device)
        target = target.to(device)

        optimizer.zero_grad()
        # 前向传播
        outputs = model(imgs, target)

        # 目标：(batch_size, trg长度)
        # outputs：(trg长度, batch_size, 输出维度)
        # 跳过sos
        output_dim = outputs.shape[-1]
        outputs = outputs[1:].view(-1, output_dim)
        target = target.permute(1,0)[1:].reshape(-1)

        # 计算损失
        loss = criterion(outputs, target)
        losses.append(loss.item())

        # 计算准确率
        prediction = torch.max(outputs, 1)[1]
        score = accuracy_score(target.cpu().data.squeeze().numpy(), prediction.cpu().data.squeeze().numpy())
        all_trg.extend(target)
        all_pred.extend(prediction)

        # 计算WER
        # prediction: ((trg长度-1)*batch_size)
        # target: ((trg长度-1)*batch_size)
        batch_size = imgs.shape[0]
        prediction = prediction.view(-1, batch_size).permute(1,0).tolist()
        target = target.view(-1, batch_size).permute(1,0).tolist()
        wers = []
        for i in range(batch_size):
            # 添加掩码（去除填充、sos、eos）
            prediction[i] = [item for item in prediction[i] if item not in [0,1,2]]
            target[i] = [item for item in target[i] if item not in [0,1,2]]
            wers.append(wer(target[i], prediction[i]))
        all_wer.extend(wers)

        # 反向传播 & 优化
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()

        if (batch_idx + 1) % log_interval == 0:
            logger.info("第 {} 轮 | 迭代 {:5d} | 损失 {:.6f} | 准确率 {:.2f}% | WER {:.2f}%".format(epoch+1, batch_idx+1, loss.item(), score*100, sum(wers)/len(wers)))

    # 计算平均损失、准确率和WER
    training_loss = sum(losses)/len(losses)
    all_trg = torch.stack(all_trg, dim=0)
    all_pred = torch.stack(all_pred, dim=0)
    training_acc = accuracy_score(all_trg.cpu().data.squeeze().numpy(), all_pred.cpu().data.squeeze().numpy())
    training_wer = sum(all_wer)/len(all_wer)
    # 记录
    writer.add_scalars('Loss', {'train': training_loss}, epoch+1)
    writer.add_scalars('Accuracy', {'train': training_acc}, epoch+1)
    writer.add_scalars('WER', {'train': training_wer}, epoch+1)
    logger.info("第 {} 轮平均训练损失: {:.6f} | 准确率: {:.2f}% | WER {:.2f}%".format(epoch+1, training_loss, training_acc*100, training_wer))
