In [1]:
from tqdm import tqdm
import torch
from datasets import load_dataset, concatenate_datasets, Dataset
from tokenizers import Tokenizer

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [2]:
class AttentionEDM(torch.nn.Module):
    def __init__(
        self,
        en_tokenizer: Tokenizer,
        zh_tokenizer: Tokenizer,
        emb_size: int,
        hidden_size: int,
        num_layers: int,
        dropout: float = 0,
    ):
        super().__init__()
        self.en_tokenizer = en_tokenizer
        self.zh_tokenizer = zh_tokenizer
        input_vs = en_tokenizer.get_vocab_size()
        output_vs = zh_tokenizer.get_vocab_size()
        self.bos_ind = zh_tokenizer.token_to_id("[BOS]")
        self.pad_ind = zh_tokenizer.token_to_id("[PAD]")
        self.eos_ind = zh_tokenizer.token_to_id("[EOS]")
        # 模型主体
        self.enc_embedding = torch.nn.Embedding(input_vs, emb_size)
        self.enc_lstm = torch.nn.LSTM(emb_size, hidden_size, num_layers, dropout=dropout, batch_first=True)
        self.dec_embedding = torch.nn.Embedding(output_vs, hidden_size)
        self.dec_lstm = torch.nn.LSTM(hidden_size, hidden_size, num_layers, dropout=dropout, batch_first=True)
        self.lm = torch.nn.Linear(hidden_size, output_vs)

    def forward(self, source: torch.Tensor, target: torch.Tensor, teacher_forcing_ratio: float = 0.5):
        embed_source: torch.Tensor = self.enc_embedding(source)
        encoder_hiddens = []
        state = None
        for x in embed_source.unbind(1):
            _, (hidden, cell) = self.enc_lstm(x.unsqueeze(1), state)
            state = (hidden, cell)
            encoder_hiddens.append(hidden)
        encoded = torch.stack(encoder_hiddens, dim=-2)
        encoded_trans = torch.transpose(encoded, -2, -1)
        sqrtT: float = encoded_trans.size(-1) ** 0.5
        outputs = []
        x = target[:, 0].unsqueeze(-1)
        for target_t in target.unbind(1):
            embed = self.dec_embedding(x)
            scores = embed.unsqueeze(0) @ encoded_trans
            scores = torch.softmax(scores / sqrtT, dim=-1)
            hidden = scores @ encoded
            dec_output, (hidden, cell) = self.dec_lstm(embed, (hidden.squeeze(-2), cell))
            output = self.lm(dec_output)
            outputs.append(output)
            if teacher_forcing_ratio > 0 and torch.rand(1).item() < teacher_forcing_ratio:
                x = target_t.unsqueeze(1)
            else:
                x = output.argmax(dim=-1)
        return torch.cat(outputs, 1)

    def translate(self, inputs: str, device=DEVICE, max_len_ratio=2.0):
        encoded_inputs = [self.bos_ind] + self.en_tokenizer.encode(inputs).ids + [self.eos_ind]
        source = torch.tensor(encoded_inputs, dtype=torch.long, device=device).unsqueeze(0)
        embed_source: torch.Tensor = self.enc_embedding(source)
        encoder_hiddens = []
        state = None
        for x in embed_source.unbind(1):
            _, (hidden, cell) = self.enc_lstm(x.unsqueeze(1), state)
            state = (hidden, cell)
            encoder_hiddens.append(hidden)
        encoded = torch.stack(encoder_hiddens, dim=-2)
        encoded_trans = torch.transpose(encoded, -2, -1)
        sqrtT: float = encoded_trans.size(-1) ** 0.5
        outs = []
        x = torch.tensor([self.bos_ind], dtype=torch.long, device=device).unsqueeze(0)
        for _ in range(int(len(inputs) * max_len_ratio)):
            embed = self.dec_embedding(x)
            scores = embed.unsqueeze(0) @ encoded_trans
            scores = torch.softmax(scores / sqrtT, dim=-1)
            hidden = scores @ encoded
            dec_output, (hidden, cell) = self.dec_lstm(embed, (hidden.squeeze(-2), cell))
            output = self.lm(dec_output)
            x = output.argmax(dim=-1)
            xi = x[0].item()
            if xi == self.bos_ind:
                continue
            if xi == self.pad_ind:
                continue
            if xi == self.eos_ind:
                break
            outs.append(xi)
        return self.zh_tokenizer.decode(outs, skip_special_tokens=False).replace(" ##", "")

In [83]:
model = torch.load("translator.pth", weights_only=False)
assert isinstance(model, AttentionEDM)

model.to(DEVICE)

print("模型参数量 ", sum(p.numel() for p in model.parameters()))

模型参数量  10153384


In [89]:
print(model.translate("I love You"))
print(model.translate("All roads lead to Rome"))

我爱你
所有的道路都变成了罗马


In [76]:
def dataset_processing(dataset: Dataset):
    def chinese_rule(code: int):
        return code < 128 or 0x4E00 <= code <= 0x9FFF or 0x3000 <= code <= 0x303F or 0xFF01 <= code <= 0xFF5E

    def english_rule(code: int):
        return code < 128 or 0x3000 <= code <= 0x303F or 0xFF01 <= code <= 0xFF5E

    en_list = []
    zh_list = []
    for data in tqdm(dataset, desc="pre-processing"):
        en = data["english"]
        zh = data["non_english"]
        if all([english_rule(ord(c)) for c in en]) and all([chinese_rule(ord(c)) for c in zh]):
            en_list.append(en)
            zh_list.append(zh)

    print(f"训练文本长度: {len(en_list)}")

    en_dataset = [[2] + model.en_tokenizer.encode(en).ids + [3] for en in en_list]
    zh_dataset = [[2] + model.zh_tokenizer.encode(zh).ids + [3] for zh in zh_list]
    return en_dataset, zh_dataset

In [None]:
wikimatrix = load_dataset("sentence-transformers/parallel-sentences-wikimatrix", "en-zh")["train"]
opensubtitles = load_dataset("sentence-transformers/parallel-sentences-opensubtitles", "en-zh_cn")["train"]
talks = load_dataset("sentence-transformers/parallel-sentences-talks", "en-zh-cn")["train"]
tatoeba = load_dataset("sentence-transformers/parallel-sentences-tatoeba", "en-zh")["train"]
# ccmatrix = load_dataset("sentence-transformers/parallel-sentences-ccmatrix", "en-zh")["train"]

dataset = concatenate_datasets([wikimatrix, opensubtitles, talks, tatoeba])

In [None]:
en_dataset, zh_dataset = dataset_processing(dataset)

In [None]:
from torch.nn.utils.rnn import pad_sequence


def batch_processing(dataset: list[list[int]], batch_size: int, device=DEVICE):
    return [
        pad_sequence([torch.tensor(x, dtype=torch.long, device=device) for x in dataset[i : i + batch_size]], batch_first=True)
        for i in range(0, len(dataset), batch_size)
    ]

In [None]:
en_dataset = batch_processing(en_dataset, 32)
zh_dataset = batch_processing(zh_dataset, 32)

In [None]:
train_dataset = list(zip(en_dataset, zh_dataset))

In [None]:
import random

optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

# 训练循环
train_dataset_length = len(train_dataset)

start_rate = 1.2
epoch_rate = 0.1
step_rate = epoch_rate / train_dataset_length

num_epochs = 20
lossi = []

for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    random.shuffle(train_dataset)
    for step, (en_batch, zh_batch) in enumerate(tqdm(train_dataset, desc=f"Epoch {epoch+1}/{num_epochs}")):
        tf_rate = start_rate - epoch * epoch_rate - step * step_rate
        outputs = model.forward(en_batch, zh_batch, tf_rate)
        loss = torch.nn.functional.cross_entropy(outputs.transpose(-2, -1), zh_batch)
        model.zero_grad()
        loss.backward()
        optimizer.step()
        loss = loss.item()
        lossi.append(loss)
        total_loss += loss
    torch.save(model, f"translator-epoch{epoch}.pth")
    print(f"Average Loss {total_loss/train_dataset_length:.4f}")

print(model.translate("hello world"))

In [None]:
print(model.translate("site packages"))