导入必要的包

In [1]:
import random
import time
from collections import Counter
from typing import List

import pandas as pd
import torch
import torch.nn as nn
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

from datasets import Dataset, DatasetDict
from src.models.Transformer import TransformerModel

  from .autonotebook import tqdm as notebook_tqdm


ModuleNotFoundError: No module named 'src'

源语言与目标语言

In [None]:
# ---------- Data utilities ----------
SRC_LANG = "zh"
TGT_LANG = "en"

# special tokens
PAD = "<pad>"  # 填充符
BOS = "<sos>"  # 句子开始符
EOS = "<eos>"  # 句子结束符
UNK = "<unk>"  # 未知词符

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

数据导入

In [None]:
random.seed(42)
torch.manual_seed(42)

print("Loading dataset from CSV files...")
dataset = load_csv_data(
    train_path="/home/yhhe/project/Deep_Learning/NLP/datasets/tatoeba/train_zh_en.csv",
    test_path="/home/yhhe/project/Deep_Learning/NLP/datasets/tatoeba/test_zh_en.csv",
)

# 分割验证集
split_dataset = dataset["train"].train_test_split(test_size=0.1, seed=42)
dataset = DatasetDict(
    {
        "train": split_dataset["train"],
        "valid": split_dataset["test"],
        "test": dataset["test"],
    }
)

print("Building vocabularies...")

In [None]:
def tokenize_basic(text: str) -> List[str]:
    return text.lower().strip().split()


def load_csv_data(train_path="train_zh_en.csv", test_path="test_zh_en.csv"):
    """从CSV文件加载数据集"""
    train_df = pd.read_csv(train_path)
    test_df = pd.read_csv(test_path)

    # 转换为Dataset格式
    train_dataset = Dataset.from_pandas(train_df)
    test_dataset = Dataset.from_pandas(test_df)

    # 转换为DatasetDict
    dataset = DatasetDict({"train": train_dataset, "test": test_dataset})

    # 添加translation字段以兼容原有代码
    def add_translation_field(example):
        return {
            "translation": {
                SRC_LANG: example["中文"],  # 根据您的CSV列名调整
                TGT_LANG: example["英文"],
            }
        }

    return dataset.map(add_translation_field)


def build_vocabs(dataset, min_freq=2):
    """dataset: HuggingFace DatasetDict"""
    src_counter, tgt_counter = Counter(), Counter()

    for example in dataset["train"]:
        src_tokens = tokenize_basic(example["translation"][SRC_LANG])
        tgt_tokens = tokenize_basic(example["translation"][TGT_LANG])
        src_counter.update(src_tokens)
        tgt_counter.update(tgt_tokens)

    def make_vocab(counter):
        vocab = {PAD: 0, BOS: 1, EOS: 2, UNK: 3}
        idx = 4
        for word, freq in counter.items():
            if freq >= min_freq and word not in vocab:
                vocab[word] = idx
                idx += 1
        # 建立反查
        itos = {i: w for w, i in vocab.items()}
        return vocab, itos

    src_vocab, src_itos = make_vocab(src_counter)
    tgt_vocab, tgt_itos = make_vocab(tgt_counter)
    return (src_vocab, src_itos), (tgt_vocab, tgt_itos)


def numericalize(vocab, tokens: List[str]) -> List[int]:
    return [vocab[BOS]] + [vocab.get(t, vocab[UNK]) for t in tokens] + [vocab[EOS]]


def collate_fn(batch, src_vocab, tgt_vocab, device):
    src_batch, tgt_batch = [], []
    for example in batch:
        src_tok = tokenize_basic(example["translation"][SRC_LANG])
        tgt_tok = tokenize_basic(example["translation"][TGT_LANG])
        src_idxs = torch.tensor(numericalize(src_vocab, src_tok), dtype=torch.long)
        tgt_idxs = torch.tensor(numericalize(tgt_vocab, tgt_tok), dtype=torch.long)
        src_batch.append(src_idxs)
        tgt_batch.append(tgt_idxs)

    src_batch = pad_sequence(src_batch, padding_value=src_vocab[PAD], batch_first=True)
    tgt_batch = pad_sequence(tgt_batch, padding_value=tgt_vocab[PAD], batch_first=True)

    return src_batch.to(device), tgt_batch.to(device)


# ---------- Mask helpers ----------
def make_src_key_padding_mask(src, pad_idx):
    return src == pad_idx


def make_tgt_masks(tgt, pad_idx):
    B, T = tgt.size()
    subsequent = torch.triu(torch.ones((T, T), dtype=torch.bool), diagonal=1)
    pad_mask = tgt == pad_idx
    return subsequent.to(tgt.device), pad_mask.to(tgt.device)


# ---------- Training and evaluation ----------
def train_epoch(
    model, dataloader, optimizer, criterion, src_pad_idx, tgt_pad_idx, device, clip=1.0
):
    model.train()
    total_loss = 0
    for src_batch, tgt_batch in dataloader:
        tgt_input = tgt_batch[:, :-1]
        tgt_target = tgt_batch[:, 1:]
        src_key_pad_mask = make_src_key_padding_mask(src_batch, src_pad_idx)
        memory_mask = src_key_pad_mask.unsqueeze(1).expand(-1, tgt_input.size(1), -1)

        subsequent_mask, tgt_key_pad_mask = make_tgt_masks(tgt_input, tgt_pad_idx)
        causal = subsequent_mask.unsqueeze(0).expand(tgt_input.size(0), -1, -1)
        tgt_mask = causal | tgt_key_pad_mask.unsqueeze(2)

        optimizer.zero_grad()
        logits = model(
            src_batch,
            tgt_input,
            src_mask=None,
            tgt_mask=tgt_mask,
            memory_mask=memory_mask,
        )
        loss = criterion(logits.view(-1, logits.size(-1)), tgt_target.reshape(-1))
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        total_loss += loss.item()
    return total_loss / len(dataloader)


def evaluate(model, dataloader, criterion, src_pad_idx, tgt_pad_idx, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for src_batch, tgt_batch in dataloader:
            tgt_input = tgt_batch[:, :-1]
            tgt_target = tgt_batch[:, 1:]
            src_key_pad_mask = make_src_key_padding_mask(src_batch, src_pad_idx)
            memory_mask = src_key_pad_mask.unsqueeze(1).expand(
                -1, tgt_input.size(1), -1
            )

            subsequent_mask, tgt_key_pad_mask = make_tgt_masks(tgt_input, tgt_pad_idx)
            causal = subsequent_mask.unsqueeze(0).expand(tgt_input.size(0), -1, -1)
            tgt_mask = causal | tgt_key_pad_mask.unsqueeze(2)

            logits = model(
                src_batch,
                tgt_input,
                src_mask=None,
                tgt_mask=tgt_mask,
                memory_mask=memory_mask,
            )
            loss = criterion(logits.view(-1, logits.size(-1)), tgt_target.reshape(-1))
            total_loss += loss.item()
    return total_loss / len(dataloader)


def greedy_decode(
    model, src_sentence_tensor, src_vocab, tgt_vocab, tgt_itos, max_len=50, device="cpu"
):
    model.eval()
    pad_idx = src_vocab[PAD]
    src_key_pad_mask = make_src_key_padding_mask(src_sentence_tensor, pad_idx)
    memory_mask = src_key_pad_mask.unsqueeze(1)

    with torch.no_grad():
        memory = model.encode(src_sentence_tensor.to(device), src_mask=None)
        ys = torch.tensor([[tgt_vocab[BOS]]], dtype=torch.long, device=device)
        for i in range(max_len - 1):
            subsequent_mask, tgt_key_pad_mask = make_tgt_masks(ys, tgt_vocab[PAD])
            causal = subsequent_mask.unsqueeze(0).expand(ys.size(0), -1, -1)
            tgt_mask = causal | tgt_key_pad_mask.unsqueeze(2)
            out = model.decode(
                ys, memory, tgt_mask=tgt_mask, memory_mask=memory_mask.to(device)
            )
            prob = model.output_proj(out[:, -1, :])
            next_word = torch.argmax(prob, dim=-1).item()
            ys = torch.cat(
                [ys, torch.tensor([[next_word]], dtype=torch.long, device=device)],
                dim=1,
            )
            if next_word == tgt_vocab[EOS]:
                break
    return [tgt_itos[i] if i in tgt_itos else "<unk>" for i in ys.squeeze(0).tolist()]


# ---------- Main script ----------
def transformer_train(device_str="cuda" if torch.cuda.is_available() else "cpu"):
    (src_vocab, src_itos), (tgt_vocab, tgt_itos) = build_vocabs(dataset, min_freq=2)
    print(f"Vocab sizes -> SRC: {len(src_vocab)}, TGT: {len(tgt_vocab)}")

    # Hyperparams
    NUM_LAYERS = 2
    EMBED_DIM = 256
    NUM_HEADS = 4
    FF_DIM = 512
    BATCH_SIZE = 64
    N_EPOCHS = 10
    LR = 1e-3

    src_pad_idx = src_vocab[PAD]
    tgt_pad_idx = tgt_vocab[PAD]

    train_loader = DataLoader(
        dataset["train"],
        batch_size=BATCH_SIZE,
        shuffle=True,
        collate_fn=lambda batch: collate_fn(batch, src_vocab, tgt_vocab, device),
    )
    valid_loader = DataLoader(
        dataset["valid"],
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=lambda batch: collate_fn(batch, src_vocab, tgt_vocab, device),
    )
    test_loader = DataLoader(
        dataset["test"],
        batch_size=BATCH_SIZE,
        shuffle=False,
        collate_fn=lambda batch: collate_fn(batch, src_vocab, tgt_vocab, device),
    )

    model = TransformerModel(
        len(src_vocab),
        len(tgt_vocab),
        NUM_LAYERS,
        EMBED_DIM,
        NUM_HEADS,
        FF_DIM,
        max_len=100,
        dropout=0.1,
        pad_idx=src_pad_idx,
    ).to(device)

    criterion = nn.CrossEntropyLoss(ignore_index=tgt_pad_idx)
    optimizer = torch.optim.Adam(model.parameters(), lr=LR)

    best_valid_loss = float("inf")
    for epoch in range(1, N_EPOCHS + 1):
        start_time = time.time()
        train_loss = train_epoch(
            model, train_loader, optimizer, criterion, src_pad_idx, tgt_pad_idx, device
        )
        valid_loss = evaluate(
            model, valid_loader, criterion, src_pad_idx, tgt_pad_idx, device
        )
        end_time = time.time()
        print(
            f"Epoch: {epoch:02} | Train Loss: {train_loss:.4f} | Val Loss: {valid_loss:.4f} | Time: {(end_time-start_time):.2f}s"
        )

        if valid_loss < best_valid_loss:
            best_valid_loss = valid_loss
            torch.save(model.state_dict(), "best_transformer_mt.pt")
            print("\tSaved best model.")

    # Test greedy decode
    model.load_state_dict(torch.load("best_transformer_mt.pt", map_location=device))
    model.to(device)
    test_sample = dataset["test"].select(range(3))
    for example in test_sample:
        src_sent = example["translation"][SRC_LANG]
        tgt_sent = example["translation"][TGT_LANG]
        src_tok = tokenize_basic(src_sent)
        src_idx = torch.tensor(
            [
                [src_vocab[BOS]]
                + [src_vocab.get(t, src_vocab[UNK]) for t in src_tok]
                + [src_vocab[EOS]]
            ],
            dtype=torch.long,
        ).to(device)
        pred_tokens = greedy_decode(
            model, src_idx, src_vocab, tgt_vocab, tgt_itos, max_len=50, device=device
        )
        if EOS in pred_tokens:
            pred_tokens = pred_tokens[1 : pred_tokens.index(EOS)]
        else:
            pred_tokens = pred_tokens[1:]
        print("SRC:", src_sent)
        print("REF:", tgt_sent)
        print("PRED:", " ".join(pred_tokens))
        print("-" * 40)


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--device",
        type=str,
        default="cuda" if torch.cuda.is_available() else "cpu",
        help="Device to use, e.g. 'cpu', 'cuda', 'cuda:0'",
    )
    args = parser.parse_args()
    transformer_train(args.device)