In [1]:
!pip install nltk tqdm download
!mkdir ./log

Collecting download
  Downloading download-0.3.5-py3-none-any.whl.metadata (3.8 kB)
Downloading download-0.3.5-py3-none-any.whl (8.8 kB)
Installing collected packages: download
Successfully installed download-0.3.5


In [2]:
import re
import os
import torch
import logging
import numpy as np
from download import download
from torch import nn
from tqdm import tqdm
from datetime import datetime
from torch.optim import AdamW
from collections import Counter
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset, DataLoader
from nltk.translate.bleu_score import corpus_bleu
from torch.optim.lr_scheduler import LambdaLR

In [3]:
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():  # 检查 Apple Silicon GPU 是否可用
    device = torch.device("mps")
else:
    device = torch.device("cpu")

print(f'device: {device}')

device: cuda


## # Vocab

In [4]:
class Vocab:
    def __init__(self, special_tokens=None):
        """
        :param special_tokens: 特殊标记列表，默认包含 '<PAD>': 0, '<UNK>':1, '<BOS>':2, '<EOS>':3
        """
        self.word2idx = {}
        self.idx2word = {}
        if special_tokens is None:
            special_tokens = ["<PAD>", "<UNK>", "<BOS>", "<EOS>"]
        for token in special_tokens:
            self.add_word(token)

    def add_word(self, word):
        """
        向词表中添加单词
        :param word: 要添加的单词
        """
        if word not in self.word2idx:
            idx = len(self.word2idx)
            self.word2idx[word] = idx
            self.idx2word[idx] = word

    def build_vocab(self, sentences, min_freq=1):
        """
        基于句子列表构建词表，根据词频排序分配索引
        :param sentences: 句子列表
        :param min_freq: 最小词频限制
        """
        # 统计词频
        word_freq = Counter()
        for sentence in sentences:
            words = self.tokenize(sentence)  # 使用自定义的分词函数
            for word in words:
                word_freq[word] += 1

        # 根据频率过滤词汇并排序
        sorted_words = sorted(word_freq.items(), key=lambda x: x[1], reverse=True)

        # 添加高频词到词表中
        for word, _ in sorted_words:
            if word_freq[word] >= min_freq:
                self.add_word(word)

    def encode(self, sentence, add_special_tokens=True, max_length=None):
        """
        将句子转换为索引序列，并可选择性地添加<BOS>和<EOS>
        :param sentence: 输入句子
        :param add_special_tokens: 是否添加特殊标记
        :param max_length: 句子的最大长度是多少，少了在末尾加入<PAD>，否则截断。
        :return: 索引序列
        """
        tokens = self.tokenize(sentence)
        if add_special_tokens:
            tokens = ["<BOS>"] + tokens + ["<EOS>"]

        arr = [self.word2idx.get(word, self.word2idx["<UNK>"]) for word in tokens]

        if max_length is None:
            return arr
        else:
            if len(arr) < max_length:
                arr += [0] * (max_length - len(arr))
                return arr
            else:
                return arr[: max_length - 1] + [3]

    def decode(self, indices, ignore_special_tokens=False):
        """
        将索引序列转换回句子，并可选择性地忽略特殊标记
        :param indices: 索引序列
        :param ignore_special_tokens: 是否忽略特殊标记
        :return: 句子
        """
        words = []
        for idx in indices:
            word = self.idx2word.get(idx, "<UNK>")
            if ignore_special_tokens and word in ["<PAD>", "<BOS>", "<EOS>"]:
                continue
            words.append(word)
        return words

    @staticmethod
    def load_data(file_path):
        """加载并返回文件中的所有句子"""
        with open(file_path, "r", encoding="utf-8") as file:
            sentences = file.readlines()
        return [sentence.strip() for sentence in sentences]

    def save_vocab(self, path):
        """将词表保存到指定路径"""
        with open(path, "w", encoding="utf-8") as f:
            for word, idx in self.word2idx.items():
                f.write(f"{word}\t{idx}\n")

    def load_vocab(self, path):
        """从指定路径加载词表"""
        self.word2idx = {}
        self.idx2word = {}
        with open(path, "r", encoding="utf-8") as f:
            for line in f:
                word, idx = line.strip().split("\t")
                idx = int(idx)
                self.word2idx[word] = idx
                self.idx2word[idx] = word

    def tokenize(self, sentence):
        """
        分词函数，保留标点符号作为独立的标记
        :param sentence: 输入句子
        :return: 分词后的列表
        """
        # 使用正则表达式分离单词和标点符号
        words_and_punct = re.findall(r"\w+|[^\w\s]", sentence, re.UNICODE)
        return words_and_punct

## # DataLoader

In [5]:
# 定义自定义数据集类
class TranslationDataset(Dataset):
    def __init__(self, src_sentences, trg_sentences, src_vocab, trg_vocab, max_length=50):
        """
        初始化翻译数据集
        :param src_sentences: 源语言句子列表
        :param trg_sentences: 目标语言句子列表
        :param src_vocab: 源语言词表 (Vocab 对象)
        :param trg_vocab: 目标语言词表 (Vocab 对象)
        :param max_length: 最大序列长度
        """
        self.src_sentences = src_sentences
        self.trg_sentences = trg_sentences
        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab
        self.max_length = max_length

    def __len__(self):
        return len(self.src_sentences)

    def __getitem__(self, idx):
        src_sentence = self.src_sentences[idx]
        trg_sentence = self.trg_sentences[idx]

        # 编码源语言和目标语言句子
        src_encoded = self.src_vocab.encode(src_sentence, add_special_tokens=True, max_length=self.max_length)
        trg_encoded = self.trg_vocab.encode(trg_sentence, add_special_tokens=True, max_length=self.max_length)

        return src_encoded, trg_encoded  # 返回元组 (src, trg)


# 数据加载函数
def load_datasets(
    train_path, valid_path, test_path, src_vocab, trg_vocab, task, max_length=50, batch_size=32, drop_last=False
):
    """
    加载训练集、验证集和测试集
    :param train_path: 训练集路径
    :param valid_path: 验证集路径
    :param test_path: 测试集路径
    :param src_vocab: 源语言词表 (Vocab 对象)
    :param trg_vocab: 目标语言词表 (Vocab 对象)
    :param task: 任务，en->de or de->en
    :param max_length: 最大序列长度
    :param batch_size: 批量大小
    :param drop_last: 是否丢弃最后一个不完整的批次
    :return: 训练集、验证集和测试集的 DataLoader
    """
    # 加载句子
    train_src = Vocab.load_data(os.path.join(train_path, "train.de"))
    train_trg = Vocab.load_data(os.path.join(train_path, "train.en"))

    valid_src = Vocab.load_data(os.path.join(valid_path, "val.de"))
    valid_trg = Vocab.load_data(os.path.join(valid_path, "val.en"))

    test_src = Vocab.load_data(os.path.join(test_path, "test2016.de"))
    test_trg = Vocab.load_data(os.path.join(test_path, "test2016.en"))

    if task != 'de->en':
        train_src, train_trg = train_trg, train_src
        valid_src, valid_trg = valid_trg, valid_src
        test_src, test_trg, test_trg, test_src

    # 创建数据集
    train_dataset = TranslationDataset(train_src, train_trg, src_vocab, trg_vocab, max_length=max_length)
    valid_dataset = TranslationDataset(valid_src, valid_trg, src_vocab, trg_vocab, max_length=max_length)
    test_dataset = TranslationDataset(test_src, test_trg, src_vocab, trg_vocab, max_length=max_length)

    # 创建 DataLoader
    train_loader = DataLoader(
        train_dataset,
        batch_size=batch_size,
        shuffle=True,
        collate_fn=lambda batch: list(zip(*batch)),
        drop_last=drop_last,  # 控制是否丢弃最后一个批次
    )
    valid_loader = DataLoader(
        valid_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=lambda batch: list(zip(*batch)),
        drop_last=drop_last,  # 控制是否丢弃最后一个批次
    )
    test_loader = DataLoader(
        test_dataset,
        batch_size=batch_size,
        shuffle=False,
        collate_fn=lambda batch: list(zip(*batch)),
        drop_last=drop_last,  # 控制是否丢弃最后一个批次
    )

    return train_loader, valid_loader, test_loader


'''
1. 数据集下载与解压
2. 构建德文和英文的词表
3. 构建DataLoader
'''
url = "https://modelscope.cn/api/v1/datasets/SelinaRR/Multi30K/repo?Revision=master&FilePath=Multi30K.zip"
datasets_path = "./datasets/"
train_path = os.path.join(datasets_path, "train/")
valid_path = os.path.join(datasets_path, "valid/")
test_path = os.path.join(datasets_path, "test/")

if not os.path.exists(datasets_path):
    download(url, "./", kind="zip", replace=True)
    print("Dataset downloaded and extracted.")
else:
    print("Dataset is already downloaded.")


config = {
    "task": "de->en",
    "max_length": 32,
    "batch_size": 128
}


de_sentences = Vocab.load_data(os.path.join(train_path, "train.de"))
en_sentences = Vocab.load_data(os.path.join(train_path, "train.en"))

de_vocab = Vocab()
de_vocab.build_vocab(de_sentences)

en_vocab = Vocab()
en_vocab.build_vocab(en_sentences)

print("德文词表大小:", len(de_vocab.word2idx))
print("英文词表大小:", len(en_vocab.word2idx))


# 词表
src_vocab, trg_vocab = (de_vocab, en_vocab) if config['task'] == 'de->en' else (en_vocab, de_vocab)


train_loader, valid_loader, test_loader = load_datasets(
    train_path=train_path,
    valid_path=valid_path,
    test_path=test_path,
    src_vocab=src_vocab,
    trg_vocab=trg_vocab,
    task=config['task'],
    max_length=config["max_length"],
    batch_size=config["batch_size"],
    drop_last=True
)

Downloading data from https://modelscope.cn/api/v1/datasets/SelinaRR/Multi30K/repo?Revision=master&FilePath=Multi30K.zip (1 byte)

file_sizes: 1.37MB [00:01, 1.13MB/s]
Extracting zip file...
Successfully downloaded / unzipped to ./
Dataset downloaded and extracted.
德文词表大小: 18487
英文词表大小: 10829


## # Model

In [6]:

'''
B: batch_size
L: max_len
H: hidden size of K, Q, V after mapping from input
D: embedding_size
'''
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():  # 检查 Apple Silicon GPU 是否可用
    device = torch.device("mps")
else:
    device = torch.device("cpu")


# 旋转位置编码
class PositionalEncoding(nn.Module):

    def __init__(self, embedding_size, p=0., max_len=100):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=p)

        pe = torch.zeros(max_len, embedding_size)  # (L, D)

        position = torch.arange(0, max_len, dtype=torch.float32).unsqueeze(1)  # (L,) -> (L, 1)
        div_term = torch.exp(torch.arange(0, embedding_size, 2, dtype=torch.float32) * (-torch.log(torch.tensor(10000.0)) / embedding_size))  # (D//2, )

        # (L, 1) * (D/2, ) -> (L, D/2) * 2 -> (L, D//2)
        pe[:, 0::2] = torch.sin(position * div_term)  # col列号, row是行号：pe[row] = sin(i / 10000^{col/D})
        pe[:, 1::2] = torch.cos(position * div_term)  # col奇数, row是行号：pe[row] = cos(i / 10000^{col-1/D})

        pe = pe.unsqueeze(0)  # (L, D) -> (1, L, D)
        self.register_buffer('pe', pe)

    def forward(self, x):
        '''
        x: (B, L, D) -> x + pos
        '''

        _, seq_len, _ = x.size()

        pos_encoding = self.pe[:, :seq_len, :]  # (1, L, D)
        x = x + pos_encoding  # (B, L, D)
        return self.dropout(x)


# 自注意力机制
class SelfAttention(nn.Module):
    def __init__(self, p=0.):
        super(SelfAttention, self).__init__()
        self.softmax = nn.Softmax(dim=-1)
        self.dropout = nn.Dropout(p=p)

    def forward(self, q: torch.Tensor, k: torch.Tensor, v, mask=None):
        '''
        q: (B, [N], Lq, D)
        k: (B, [N], Lk, D)
        v: (B, [N], Lk, D)
        mask: (B, [N], Lq, Hk), fill: -inf

        -> res: (B, [N], Lq, D), atten: (B, [N], Lk, Lk)
        '''
        embedding_size = q.size(dim=-1)
        d = torch.sqrt(torch.tensor(embedding_size, dtype=torch.float16))

        attention = torch.matmul(q, k.transpose(-1, -2) / d)  # (B, [N], Lq, D) @ (B, [N], D, Lk) / (1,) -> (B, [N], Lq, Lk)

        if mask is not None:
            attention = attention.masked_fill(mask, -torch.inf)  # mask: (B, 1, Lq, Lk) -> (B, [N], Lq, Lk), set True -> -inf

        attention = self.softmax(attention)  # softmax on dim of Hk
        attention = self.dropout(attention)

        res = torch.matmul(attention, v)  # (B, [N], Lq, Lk) @ (B, [N], Lk, D) -> (B, [N], Lq, D), 这里就是对Lk所在的维度进行加权平均

        return res, attention


# 多头注意力
class MultiHeadAttention(nn.Module):
    def __init__(self, embedding_size, new_embedding_size, head_size, p=0.):
        super(MultiHeadAttention, self).__init__()

        self.n_head = head_size
        self.new_embedding_size = new_embedding_size
        self.linear_q = nn.Linear(embedding_size, self.new_embedding_size * self.n_head)
        self.linear_k = nn.Linear(embedding_size, self.new_embedding_size * self.n_head)
        self.linear_v = nn.Linear(embedding_size, self.new_embedding_size * self.n_head)
        self.linear_o = nn.Linear(self.new_embedding_size * self.n_head, embedding_size)
        self.attention = SelfAttention(p=p)

    def forward(self, q: torch.Tensor, k, v, mask: torch.Tensor):
        '''
        q: (B, Lq, D)
        k: (B, Lk, D)
        v: (B, Lk, D)
        mask: (Lq, Lk), fill: -inf

        -> res: (B, Lq, D), atten: (B, N, Lq, Lk)
        '''
        batch_size = q.size(dim=0)

        Q = self.linear_q(q).view(batch_size, -1, self.n_head, self.new_embedding_size).transpose(1, 2)  # (B, Lq, N * D') -> (B, Lq, N, D') -> (B, N, Lq, D')
        K = self.linear_k(k).view(batch_size, -1, self.n_head, self.new_embedding_size).transpose(1, 2)  # (B, Lk, N * D') -> (B, Lk, N, D') -> (B, N, Lk, D')
        V = self.linear_v(v).view(batch_size, -1, self.n_head, self.new_embedding_size).transpose(1, 2)  # (B, Lk, N * D') -> (B, Lk, N, D') -> (B, N, Lk, D')

        mask = mask.unsqueeze(1)  # (B, 1, Lq, Lk)
        mask = mask.expand((-1, self.n_head, -1, -1))  # (B, N, Lq, Lk)

        res, atten = self.attention(Q, K, V, mask)  # (B, N, Lq, D')

        res = res.transpose(1, 2).reshape(batch_size, -1, self.new_embedding_size * self.n_head)  # (B, N, Lq, D') -> (B, Lq, N, D') -> (B, Lq, N * D')
        res = self.linear_o(res)  # (B, Lq, N * D') -> (B, Lq, D)

        return res, atten


# 位置前馈神经网络
class FFN(nn.Module):
    def __init__(self, hidden_size, embedding_size, p=0.):
        super(FFN, self).__init__()

        self.linear1 = nn.Linear(embedding_size, hidden_size)
        self.linear2 = nn.Linear(hidden_size, embedding_size)
        self.dropout = nn.Dropout(p=p)
        self.relu = nn.ReLU()

    def forward(self, x):
        '''
        x: (B, L, D)

        -> (B, L, D)
        '''

        res = self.linear1(x)
        res = self.relu(res)
        res = self.dropout(res)
        res = self.linear2(res)
        return res


# Add & Norm
class AddNorm(nn.Module):
    def __init__(self, embedding_size, p=0.):
        super(AddNorm, self).__init__()

        self.LN = nn.LayerNorm((embedding_size,))
        self.dropout = nn.Dropout(p=p)

    def forward(self, x, fx):
        '''
        x: (B, L, D)

        -> (B, L, D)
        '''

        fx = self.dropout(fx)
        return self.LN(x + fx)


# Mask = <PAD> + SeeForward
class Mask:
    '''
    where "True" is the place should be masked.

    seq_q: (B, Lq), seq after vocab encode for q.
    seq_k: (B, Lk), seq after vocab encode for k.

    -> (B, Lq, Lk), mask
    '''

    def get_padding_mask(self, seq_q, seq_k, who_is_pad=0):
        '''
        <PAD> mask

        seq_q: (B, Lq), seq after vocab encode for q.
        seq_k: (B, Lk), seq after vocab encode for k.

        -> (B, Lq, Lk), mask
        '''

        batch_size, Lq = seq_q.size()
        batch_size, Lk = seq_k.size()

        pad_mask = (seq_k == who_is_pad)  # (B, Lk)
        pad_mask = pad_mask.unsqueeze(1).expand(batch_size, Lq, Lk)  # (B, Lk) -> (B, 1, Lk) -> (B, Lq, Lk)

        return pad_mask

    def get_causal_mask(self, seq_q, seq_k):
        '''
        causal mask

        seq_q: (B, Lq), seq after vocab encode for q.
        seq_k: (B, Lk), seq after vocab encode for k.

        -> (B, Lq, Lk), mask
        '''

        B, Lq = seq_q.size()
        _, Lk = seq_k.size()

        mask = ~torch.tril(torch.ones(Lq, Lk)).bool()

        mask = mask.unsqueeze(0).expand(B, -1, -1)

        return mask


'''
Encoder
'''


class EncoderBlock(nn.Module):
    def __init__(self, embedding_size, head_size, ffn_size, p=0.):
        super(EncoderBlock, self).__init__()

        new_embedding_size = embedding_size // head_size

        if new_embedding_size * head_size != embedding_size:
            raise ValueError(f'make sure embedding_size % head_size == 0, but get: {embedding_size} and {head_size}.')

        self.encoder_self_attention = MultiHeadAttention(embedding_size, new_embedding_size, head_size, p)
        self.ffn = FFN(ffn_size, embedding_size, p)

        self.AN1 = AddNorm(embedding_size, p)
        self.AN2 = AddNorm(embedding_size, p)

    def forward(self, encoder_input, encoder_self_mask):
        '''
        encoder_input: (B, L) ~int ~raw_sentence
        encoder_self_mask: (B, L, L) ~bool

        -> (B, L, D) ~float
        '''

        # fx branch
        x = encoder_input

        fx, _ = self.encoder_self_attention(q=x, k=x, v=x, mask=encoder_self_mask)
        out1 = self.AN1(x, fx)

        x = out1

        fx = self.ffn(x)
        res = self.AN2(x, fx)

        return res


class Encoder(nn.Module):
    def __init__(self, src_vocab_size, embedding_size, head_size, ffn_size, num_blocks, p=0.):
        super(Encoder, self).__init__()

        self.embed = nn.Embedding(src_vocab_size, embedding_size)
        self.pos_embed = PositionalEncoding(embedding_size, p)

        self.blocks = nn.ModuleList([
            EncoderBlock(
                embedding_size=embedding_size,
                head_size=head_size,
                ffn_size=ffn_size,
                p=p
            ) for _ in range(num_blocks)
        ])

        self.scaling = torch.sqrt(torch.tensor(embedding_size))

    def forward(self, encoder_input, src_who_is_pad):
        '''
        encoder_input: (B, L) ~int ~raw_sentence
        src_who_is_pad: (,) ~int

        -> (B, L, D) ~float
        '''

        embeded_encoder_input = self.embed(encoder_input)
        embeded_encoder_input = self.pos_embed(embeded_encoder_input)

        encoder_self_mask = Mask().get_padding_mask(seq_q=encoder_input,
                                                    seq_k=encoder_input,
                                                    who_is_pad=src_who_is_pad)

        encoder_output = embeded_encoder_input
        for block in self.blocks:
            encoder_output = block(encoder_output, encoder_self_mask)

        return encoder_output


'''
Decoder
'''


class DecoderBlock(nn.Module):
    def __init__(self, embedding_size, head_size, ffn_size, p=0.):
        super(DecoderBlock, self).__init__()

        one_head_embedding_size = embedding_size // head_size
        if one_head_embedding_size * head_size != embedding_size:
            raise ValueError(f'make sure embedding_size % head_size == 0, but get: {embedding_size} and {head_size}.')

        self.decoder_self_attention = MultiHeadAttention(embedding_size, one_head_embedding_size, head_size, p)
        self.decoder_encoder_attention = MultiHeadAttention(embedding_size, one_head_embedding_size, head_size, p)

        self.ffn = FFN(ffn_size, embedding_size, p)

        self.AN1 = AddNorm(embedding_size, p)
        self.AN2 = AddNorm(embedding_size, p)
        self.AN3 = AddNorm(embedding_size, p)

    def forward(self, decoder_input, encoder_output, decoder_self_mask, decoder_encoder_mask):
        '''
        decoder_input: (B, L) ~int ~raw_sentence
        encoder_output: (B, L, D) ~float
        decoder_self_mask: (B, L, L) ~bool
        decoder_encoder_mask: (B, L, L) ~bool

        -> (B, L, D) ~float
        '''

        x = decoder_input

        fx, _ = self.decoder_self_attention(q=x, k=x, v=x, mask=decoder_self_mask)
        out1 = self.AN1(x, fx)

        x = out1
        fx, _ = self.decoder_encoder_attention(q=x, k=encoder_output, v=encoder_output, mask=decoder_encoder_mask)
        out2 = self.AN2(x, fx)

        x = out2

        fx = self.ffn(x)
        res = self.AN3(x, fx)

        return res


class Decoder(nn.Module):
    def __init__(self, trg_vocab_size, embedding_size, head_size, ffn_size, num_blocks, p=0.):
        super(Decoder, self).__init__()

        self.embed = nn.Embedding(trg_vocab_size, embedding_size)
        self.pos_embed = PositionalEncoding(embedding_size, p)

        self.blocks = nn.ModuleList([
            DecoderBlock(
                embedding_size=embedding_size,
                head_size=head_size,
                ffn_size=ffn_size,
                p=p
            ) for _ in range(num_blocks)
        ])

        self.scaling = torch.sqrt(torch.tensor(embedding_size))
        self.linear_out = nn.Linear(embedding_size, trg_vocab_size)

    def forward(self, decoder_input, encoder_input, encoder_output, src_who_is_pad, trg_who_is_pad):
        '''
        decoder_input: (B, L) ~int ~raw_sentence
        encoder_input: (B, L) ~int ~raw_sentence
        encoder_output: (B, L, D) ~float
        src_who_is_pad, trg_who_is_pad: (,) ~int

        -> (B, L, trg_vocab_size) ~float
        '''

        embeded_decoder_input = self.embed(decoder_input)
        embeded_decoder_input = self.pos_embed(embeded_decoder_input)

        decoder_self_padding_mask = Mask().get_padding_mask(decoder_input, decoder_input, trg_who_is_pad)
        decoder_self_causal_mask = Mask().get_causal_mask(decoder_input, decoder_input)
        decoder_self_mask = decoder_self_padding_mask.to(device) | decoder_self_causal_mask.to(device)  # can not use 'or' here

        decoder_encoder_padding_mask = Mask().get_padding_mask(decoder_input, encoder_input, src_who_is_pad)
        decoder_encoder_mask = decoder_encoder_padding_mask

        decoder_output = embeded_decoder_input
        for block in self.blocks:
            decoder_output = block(decoder_output, encoder_output, decoder_self_mask, decoder_encoder_mask)

        decoder_output = self.linear_out(decoder_output)

        return decoder_output


'''
Transformer
'''


class Transformer(nn.Module):
    def __init__(self, encoder, decoder):
        super(Transformer, self).__init__()

        self.encoder = encoder
        self.decoder = decoder

    def forward(self, encoder_input, decoder_input, src_who_is_pad, trg_who_is_pad):
        '''
        encoder_input: (B, L) ~int ~raw_sentence
        decoder_input: (B, L) ~int ~raw_sentence
        src_who_is_pad, trg_who_is_pad: (,) ~int
        '''
        encoder_output = self.encoder(encoder_input, src_who_is_pad)

        decoder_output = self.decoder(decoder_input, encoder_input, encoder_output, src_who_is_pad, trg_who_is_pad)

        logits = decoder_output.view((-1, decoder_output.shape[-1]))

        return logits

## # Training

In [7]:
# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(levelname)s - %(message)s',
    handlers=[
        logging.FileHandler(f'log/training_{datetime.now().strftime("%Y%m%d_%H%M")}.log'),
        logging.StreamHandler()
    ]
)


class TransformerTrainer:
    def __init__(self, config):
        self.config = config

        if torch.cuda.is_available():
            device = torch.device("cuda")
        elif torch.backends.mps.is_available():
            device = torch.device("mps")
        else:
            device = torch.device("cpu")

        self.device = device

        self.model = Transformer(
            encoder=Encoder(
                src_vocab_size=len(src_vocab.word2idx),
                embedding_size=config['d_model'],
                head_size=config['n_head'],
                ffn_size=config['ffn_size'],
                num_blocks=config['num_blocks'],
                p=config['dropout']
            ),
            decoder=Decoder(
                trg_vocab_size=len(trg_vocab.word2idx),
                embedding_size=config['d_model'],
                head_size=config['n_head'],
                ffn_size=config['ffn_size'],
                num_blocks=config['num_blocks'],
                p=config['dropout']
            )
        ).to(self.device)

        self.optimizer = AdamW(
            self.model.parameters(),
            lr=config['lr'],
            betas=(0.9, 0.98),
            eps=1e-9,
            weight_decay=config['weight_decay']
        )

        self.lr_scheduler = LambdaLR(
            self.optimizer,
            lr_lambda=lambda step: min(
                (step + 1) ** -0.5,
                (step + 1) * (config['warmup_steps'] ** -1.5)
            )
        )

        # 混合精度训练
        self.scaler = torch.amp.GradScaler(enabled=config['use_amp'])

        # 损失函数（忽略padding）
        self.criterion = nn.CrossEntropyLoss(ignore_index=0)

        os.makedirs(config['output_dir'], exist_ok=True)

    def _prepare_batch(self, batch):
        """处理原始批次数据"""
        src_batch, trg_batch = batch
        src_tensor = torch.tensor(src_batch).to(self.device)
        trg_tensor = torch.tensor(trg_batch).to(self.device)

        # 生成decoder输入输出
        trg_input = trg_tensor[:, :-1]  # 移除最后一个token
        trg_output = trg_tensor[:, 1:]  # 移除第一个token

        # 生成mask
        src_pad_mask = (src_tensor == 0).to(self.device)
        trg_pad_mask = (trg_input == 0).to(self.device)

        return {
            'src': src_tensor,
            'trg_input': trg_input,
            'trg_output': trg_output,
            'src_pad_mask': src_pad_mask,
            'trg_pad_mask': trg_pad_mask
        }

    def train_epoch(self, train_loader):
        self.model.train()
        total_loss = 0
        progress_bar = tqdm(train_loader, desc="Training", leave=False)

        for batch in progress_bar:
            prepared_batch = self._prepare_batch(batch)

            self.optimizer.zero_grad()

            with torch.cuda.amp.autocast(enabled=self.config['use_amp']):
                logits = self.model(
                    encoder_input=prepared_batch['src'],
                    decoder_input=prepared_batch['trg_input'],
                    src_who_is_pad=0,
                    trg_who_is_pad=0
                )

                loss = self.criterion(
                    logits.view(-1, len(trg_vocab.idx2word)),
                    prepared_batch['trg_output'].contiguous().view(-1)
                )

            self.scaler.scale(loss).backward()
            self.scaler.unscale_(self.optimizer)
            torch.nn.utils.clip_grad_norm_(
                self.model.parameters(),
                self.config['max_grad_norm']
            )

            self.scaler.step(self.optimizer)
            self.scaler.update()
            self.lr_scheduler.step()

            total_loss += loss.item()
            progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})

        return total_loss / len(train_loader)

    @torch.no_grad()
    def validate(self, val_loader):
        self.model.eval()
        total_loss = 0

        for batch in val_loader:
            prepared_batch = self._prepare_batch(batch)

            logits = self.model(
                encoder_input=prepared_batch['src'],
                decoder_input=prepared_batch['trg_input'],
                src_who_is_pad=0,
                trg_who_is_pad=0
            )

            loss = self.criterion(
                logits.view(-1, len(trg_vocab.idx2word)),
                prepared_batch['trg_output'].contiguous().view(-1)
            )
            total_loss += loss.item()

        return total_loss / len(val_loader)

    def save_checkpoint(self, epoch, is_best=False):
        state = {
            'epoch': epoch,
            'model_state': self.model.state_dict(),
            'optimizer_state': self.optimizer.state_dict(),
            'scaler_state': self.scaler.state_dict(),
            'config': self.config
        }

        filename = f"checkpoint_epoch_{epoch}.pt" if not is_best else "best_model.pt"
        torch.save(state, os.path.join(self.config['output_dir'], filename))
        logging.info(f"Checkpoint saved: {filename}")

    def train(self, train_loader, val_loader=None):
        best_loss = float('inf')

        for epoch in range(1, self.config['num_epochs'] + 1):
            logging.info(f"Epoch {epoch}/{self.config['num_epochs']}")
            print(f"Epoch {epoch}/{self.config['num_epochs']}")

            train_loss = self.train_epoch(train_loader)
            logging.info(f"Train Loss: {train_loss:.4f}")
            print(f"Train Loss: {train_loss:.4f}")

            if val_loader:
                val_loss = self.validate(val_loader)
                logging.info(f"Val Loss: {val_loss:.4f}")
                print(f"Val Loss: {val_loss:.4f}")

                if val_loss < best_loss:
                    best_loss = val_loss
                    self.save_checkpoint(epoch, is_best=True)

            if epoch % self.config['save_interval'] == 0:
                self.save_checkpoint(epoch)

In [8]:
# 加载训练配置
config = {
    "d_model": 512,
    "n_head": 8,
    "ffn_size": 2048,
    "num_blocks": 6,
    "dropout": 0.1,
    "num_epochs": 20,
    "lr": 0.005,
    "weight_decay": 0.01,
    "warmup_steps": 4000,
    "max_grad_norm": 1.0,
    "output_dir": "./ckpts",
    "save_interval": 5,
    "use_amp": True
}

# 初始化训练器
trainer = TransformerTrainer(config)

try:
    trainer.train(train_loader, valid_loader)
except KeyboardInterrupt:
    logging.info("Training interrupted. Saving checkpoint...")
    trainer.save_checkpoint(epoch='interrupted')
except Exception as e:
    logging.error(f"Training failed: {str(e)}")
    raise

Epoch 1/20


  with torch.cuda.amp.autocast(enabled=self.config['use_amp']):
                                                                        

Train Loss: 8.0778
Val Loss: 6.7698
Epoch 2/20


                                                                        

Train Loss: 6.2368
Val Loss: 5.5060
Epoch 3/20


                                                                        

Train Loss: 5.1223
Val Loss: 4.5634
Epoch 4/20


                                                                        

Train Loss: 4.3753
Val Loss: 3.9801
Epoch 5/20


                                                                        

Train Loss: 3.9249
Val Loss: 3.5942
Epoch 6/20


                                                                        

Train Loss: 3.6047
Val Loss: 3.3150
Epoch 7/20


                                                                        

Train Loss: 3.3390
Val Loss: 3.0749
Epoch 8/20


                                                                        

Train Loss: 3.1117
Val Loss: 2.8909
Epoch 9/20


                                                                        

Train Loss: 2.9123
Val Loss: 2.7255
Epoch 10/20


                                                                        

Train Loss: 2.7363
Val Loss: 2.5890
Epoch 11/20


                                                                        

Train Loss: 2.5833
Val Loss: 2.4877
Epoch 12/20


                                                                        

Train Loss: 2.4444
Val Loss: 2.3877
Epoch 13/20


                                                                        

Train Loss: 2.3160
Val Loss: 2.2937
Epoch 14/20


                                                                        

Train Loss: 2.1969
Val Loss: 2.2342
Epoch 15/20


                                                                        

Train Loss: 2.0871
Val Loss: 2.1601
Epoch 16/20


                                                                        

Train Loss: 1.9821
Val Loss: 2.1062
Epoch 17/20


                                                                        

Train Loss: 1.8810
Val Loss: 2.0488
Epoch 18/20


                                                                        

Train Loss: 1.7832
Val Loss: 2.0077
Epoch 19/20


                                                                        

Train Loss: 1.6854
Val Loss: 1.9783
Epoch 20/20


                                                                        

Train Loss: 1.5947
Val Loss: 1.9475


## # Inference

In [9]:
class Translator:
    def __init__(self, model_path, src_vocab, trg_vocab):
        if torch.cuda.is_available():
            device = torch.device("cuda")
        elif torch.backends.mps.is_available():  # 检查 Apple Silicon GPU 是否可用
            device = torch.device("mps")
        else:
            device = torch.device("cpu")

        self.device = device

        self.src_vocab = src_vocab
        self.trg_vocab = trg_vocab

        # 加载模型
        checkpoint = torch.load(model_path, map_location=self.device)
        self.config = checkpoint['config']

        self.model = Transformer(
            encoder=Encoder(
                src_vocab_size=len(src_vocab.word2idx),
                embedding_size=self.config['d_model'],
                head_size=self.config['n_head'],
                ffn_size=self.config['ffn_size'],
                num_blocks=self.config['num_blocks'],
                p=0.0  # 推理时关闭dropout
            ),
            decoder=Decoder(
                trg_vocab_size=len(trg_vocab.word2idx),
                embedding_size=self.config['d_model'],
                head_size=self.config['n_head'],
                ffn_size=self.config['ffn_size'],
                num_blocks=self.config['num_blocks'],
                p=0.0
            )
        ).to(self.device)

        self.model.load_state_dict(checkpoint['model_state'])
        self.model.eval()

    def _prepare_input(self, src_seq):
        """处理输入序列"""
        src_tensor = torch.tensor([2] + src_seq + [3]).to(self.device)

        # 添加batch维度并填充
        src_tensor = src_tensor.unsqueeze(0)  # [1, seq_len]
        src_pad_mask = (src_tensor == 0)
        return src_tensor, src_pad_mask

    def translate(self, src_seq, max_length=50):
        """使用贪心算法进行翻译"""
        src_tensor, src_pad_mask = self._prepare_input(src_seq)

        # 初始化decoder输入
        decoder_input = torch.tensor([[2]]).to(self.device)

        # 自回归生成
        for _ in range(max_length):
            with torch.no_grad():
                logits = self.model(
                    encoder_input=src_tensor,
                    decoder_input=decoder_input,
                    src_who_is_pad=0,
                    trg_who_is_pad=0
                )

            # 获取最后一个token的预测
            # print(logits)
            next_token = logits[-1, :].argmax(-1)
            # print(next_token)
            # print()
            # print(f'decoder input: {decoder_input}')
            decoder_input = torch.cat(
                [decoder_input, next_token.unsqueeze(0).unsqueeze(0)], dim=-1
            )

            # 遇到EOS则停止，同时拼接EOS
            if next_token.item() == 3:
                decoder_input = torch.cat(
                    [decoder_input, torch.tensor(3).unsqueeze(0).unsqueeze(0).to(self.device)], dim=-1
                )
                break

        # 转换为token列表
        output_tokens = decoder_input[0].cpu().tolist()

        # 去除特殊token并解码
        filtered = [
            t for t in output_tokens
            if t not in {2, 3, 0}
        ]

        return self.trg_vocab.decode(filtered)

    def calculate_bleu(self, test_loader_):
        """计算整个测试集的BLEU分数"""
        references = []
        hypotheses = []

        for batch in tqdm(test_loader_, desc="Calculating BLEU"):
            src_batch, trg_batch = batch

            # 处理每个样本
            for src_seq, trg_seq in zip(src_batch, trg_batch):
                # 解码参考翻译
                ref = [self.trg_vocab.decode([t for t in trg_seq if t not in {0}])]

                # 生成模型翻译
                hyp = self.translate(src_seq)

                references.append(ref)
                hypotheses.append(hyp)

        # 计算corpus BLEU
        return corpus_bleu(references, hypotheses)

In [10]:
# 初始化翻译器
translator = Translator(
    model_path='ckpts/best_model.pt',
    src_vocab=src_vocab,
    trg_vocab=trg_vocab
)

  checkpoint = torch.load(model_path, map_location=self.device)


### # 1. Some Examples 

In [11]:
# 示例翻译 - 德语到英语，包含标准答案
example_src_with_refs = [
    ("Die Katze ist auf dem Tisch .", "The cat is on the table."),
    ("Sie liest jeden Tag ein Buch .", "She reads a book every day."),
    ("Der Hund spielt im Garten .", "The dog is playing in the garden."),
    ("Wir gehen ins Kino heute Abend .", "We are going to the cinema this evening."),
    ("Das Wetter ist sehr schön heute .", "The weather is very nice today."),
    ("Er hat eine rote Jacke an .", "He is wearing a red jacket."),
    ("Ich habe Hunger .", "I am hungry."),
    ("Es gibt viele Bücher im Regal .", "There are many books on the shelf."),
    ("Kannst du mir helfen, bitte ?", "Can you help me, please?"),
    ("Morgen werde ich einkaufen gehen .", "Tomorrow I will go shopping.")
]

print("\nExample Translations (German -> English):")
for src, ref_translation in example_src_with_refs:
    # 编码源句子
    src_encoded = src_vocab.encode(src, add_special_tokens=False)

    # 生成翻译
    translation = translator.translate(src_encoded)
    translation = " ".join(translation)

    # 打印双语对照结果，包括标准答案
    print()
    print(f"Source (German)                : {src}")
    print(f"Reference Translation (English): {ref_translation}")
    print(f"Model Translation (English)    : {translation}")


Example Translations (German -> English):

Source (German)                : Die Katze ist auf dem Tisch .
Reference Translation (English): The cat is on the table.
Model Translation (English)    : The cat is on the table .

Source (German)                : Sie liest jeden Tag ein Buch .
Reference Translation (English): She reads a book every day.
Model Translation (English)    : They are reading a book on the day .

Source (German)                : Der Hund spielt im Garten .
Reference Translation (English): The dog is playing in the garden.
Model Translation (English)    : The dog is playing in the garden .

Source (German)                : Wir gehen ins Kino heute Abend .
Reference Translation (English): We are going to the cinema this evening.
Model Translation (English)    : We are walking in the air at the ocean .

Source (German)                : Das Wetter ist sehr schön heute .
Reference Translation (English): The weather is very nice today.
Model Translation (English)    : Th

### # 2. BLEU

In [12]:
# 计算 BLEU 分数
if True:
    bleu_score = translator.calculate_bleu(test_loader)
    print(f"\nBLEU Score: {bleu_score:.4f}")

Calculating BLEU: 100%|██████████| 7/7 [02:20<00:00, 20.00s/it]


BLEU Score: 0.2966



