Reference: https://github.com/bentrevett/pytorch-seq2seq/

# 数据准备

设置随机数种子，得到一致的结果

In [1]:
import random
import numpy as np
import torch

seed = 1234

random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.backends.cudnn.deterministic=True

## Datasets

使用opus-100数据集的en-zh子数据集

In [2]:
import datasets
from datasets import load_dataset

ds = load_dataset("Helsinki-NLP/opus-100", "en-zh")

train_data = ds["train"].shuffle(seed=seed).select(range(100000)) # 采样 x 个数据
valid_data = ds["validation"]
test_data = ds["test"]

  from .autonotebook import tqdm as notebook_tqdm


如果提示hub连接失败，可是试试换源

Huggleface镜像源替换环境变量

export HF_ENDPOINT=https://hf-mirror.com

$env:HF_ENDPOINT = "https://hf-mirror.com"

检验dataset是否下载和加载成功

In [3]:
print(ds)
print(train_data[0])

DatasetDict({
    test: Dataset({
        features: ['translation'],
        num_rows: 2000
    })
    train: Dataset({
        features: ['translation'],
        num_rows: 1000000
    })
    validation: Dataset({
        features: ['translation'],
        num_rows: 2000
    })
})
{'translation': {'en': "So they couldn't question you.", 'zh': '为了你不会为了我而撒谎'}}


## Tokenizer

接下来使用spacy进行分词，即将一个句子中的单词和短语分离出来，方便进行相关处理和学习训练。

在分词之前，我们需要下载spacy的相关分析模型。

In [None]:
!python -m spacy download zh_core_web_sm

!python -m spacy download en_core_web_sm

或者使用pip的github连接下载，本地使用pip安装也可，注意安装环境。

pip install https://github.com/explosion/spacy-models/releases/download/zh_core_web_sm-3.7.0/zh_core_web_sm-3.7.0-py3-none-any.whl

pip install https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.0/en_core_web_sm-3.7.0-py3-none-any.whl

加载模型

In [4]:
import spacy

en_nlp = spacy.load("en_core_web_sm")
zh_nlp = spacy.load("zh_core_web_sm")

测试加载结果

In [5]:
test_text1 = "This is amazing!"
test_text2 = "这好棒啊"

test_token1 = [token.text for token in en_nlp.tokenizer(test_text1)]
test_token2 = [token.text for token in zh_nlp.tokenizer(test_text2)]
print(test_token1)
print(test_token2)

['This', 'is', 'amazing', '!']
['这', '好', '棒', '啊']


接下来创建一个函数用于tokenizer，将相应的数据集数据进行分词。

In [6]:
def tokenize_en_zh(example, en_nlp, zh_nlp, max_length, lower, sos_token, eos_token):
    en_tokens = [token.text for token in en_nlp.tokenizer(example['translation']["en"])][:max_length]
    zh_tokens = [token.text for token in zh_nlp.tokenizer(example['translation']["zh"])][:max_length]
    if lower:
        en_tokens = [token.lower() for token in en_tokens]
    en_tokens = [sos_token] + en_tokens + [eos_token]
    zh_tokens = [sos_token] + zh_tokens + [eos_token]
    return {"en_tokens":en_tokens,"zh_tokens":zh_tokens}

max_length = 100
lower = True
sos_token = "<sos>"
eos_token = "<eos>"

fn_kwargs = {
    "en_nlp": en_nlp,
    "zh_nlp": zh_nlp,
    "max_length": max_length,
    "lower": lower,
    "sos_token": sos_token,
    "eos_token": eos_token,
}

train_data = train_data.map(tokenize_en_zh, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(tokenize_en_zh, fn_kwargs=fn_kwargs)
test_data = test_data.map(tokenize_en_zh, fn_kwargs=fn_kwargs)

Map: 100%|██████████| 100000/100000 [01:30<00:00, 1109.30 examples/s]
Map: 100%|██████████| 2000/2000 [00:01<00:00, 1025.20 examples/s]
Map: 100%|██████████| 2000/2000 [00:02<00:00, 963.36 examples/s]


测试一下分词结果。

In [7]:
print(train_data)
print(train_data[0])

Dataset({
    features: ['translation', 'en_tokens', 'zh_tokens'],
    num_rows: 100000
})
{'translation': {'en': "So they couldn't question you.", 'zh': '为了你不会为了我而撒谎'}, 'en_tokens': ['<sos>', 'so', 'they', 'could', "n't", 'question', 'you', '.', '<eos>'], 'zh_tokens': ['<sos>', '为了', '你', '不', '会', '为了', '我', '而', '撒谎', '<eos>']}


## Vocabularies

接下来开始构建词表，将每个单词用一个对应的索引编号来表示。

In [8]:
import torchtext.vocab

min_freq = 2 # 出现次数少于这个的不建立索引
# 特殊词元
unk_token = "<unk>"
pad_token = "<pad>"
sos_token = "<sos>"
eos_token = "<eos>"

special_tokens = {
    unk_token,
    pad_token,
    sos_token,
    eos_token,
}

en_vocab = torchtext.vocab.build_vocab_from_iterator(
    train_data["en_tokens"],
    min_freq=min_freq,
    specials=special_tokens,
    # max_tokens=20000, # 限制词表长度，一般不用设置
)

zh_vocab = torchtext.vocab.build_vocab_from_iterator(
    train_data["zh_tokens"],
    min_freq=min_freq,
    specials=special_tokens,
    # max_tokens=60000, # 限制词表长度，一般不用设置
)

# 处理默认返回结果
en_vocab.set_default_index(en_vocab[unk_token])
zh_vocab.set_default_index(zh_vocab[unk_token])



查看词表建立结果

In [9]:
print(en_vocab.get_itos()[:10])
print(zh_vocab.get_itos()[:10])
print(len(en_vocab))
print(len(zh_vocab))

['<eos>', '<sos>', '<unk>', '<pad>', 'the', ',', '.', 'of', 'and', 'to']
['<eos>', '<sos>', '<unk>', '<pad>', '的', '，', '。', '和', '在', '、']
23384
29412


接下来创建一个对数据集进行numericalize编码的函数。

In [10]:
def numericalize_en_zh(example, en_vocab, zh_vocab):
    en_ids = en_vocab.lookup_indices(example["en_tokens"])
    zh_ids = zh_vocab.lookup_indices(example["zh_tokens"])
    return {"en_ids": en_ids, "zh_ids": zh_ids}

fn_kwargs = {"en_vocab": en_vocab, "zh_vocab": zh_vocab}
train_data = train_data.map(numericalize_en_zh, fn_kwargs=fn_kwargs)
valid_data = valid_data.map(numericalize_en_zh, fn_kwargs=fn_kwargs)
test_data = test_data.map(numericalize_en_zh, fn_kwargs=fn_kwargs)

Map: 100%|██████████| 100000/100000 [00:10<00:00, 9417.99 examples/s]
Map: 100%|██████████| 2000/2000 [00:00<00:00, 7967.01 examples/s]
Map: 100%|██████████| 2000/2000 [00:00<00:00, 7632.27 examples/s]


查看numericalize结果

In [11]:
train_data[0]

{'translation': {'en': "So they couldn't question you.", 'zh': '为了你不会为了我而撒谎'},
 'en_tokens': ['<sos>',
  'so',
  'they',
  'could',
  "n't",
  'question',
  'you',
  '.',
  '<eos>'],
 'zh_tokens': ['<sos>', '为了', '你', '不', '会', '为了', '我', '而', '撒谎', '<eos>'],
 'en_ids': [1, 82, 61, 150, 46, 442, 16, 6, 0],
 'zh_ids': [1, 354, 14, 16, 60, 354, 10, 89, 5518, 0]}

将ids使用with_format转换为pytorch的tensor类型

In [12]:
data_type = "torch"
format_columns = ["en_ids","zh_ids"]

train_data = train_data.with_format(
    type=data_type, columns=format_columns, output_all_columns=True
)

valid_data = valid_data.with_format(
    type=data_type, columns=format_columns, output_all_columns=True
)

test_data = test_data.with_format(
    type=data_type, columns=format_columns, output_all_columns=True
)

检查结果

In [13]:
print(type(train_data[0]["en_ids"]))
print(train_data[0])

<class 'torch.Tensor'>
{'en_ids': tensor([  1,  82,  61, 150,  46, 442,  16,   6,   0]), 'zh_ids': tensor([   1,  354,   14,   16,   60,  354,   10,   89, 5518,    0]), 'translation': {'en': "So they couldn't question you.", 'zh': '为了你不会为了我而撒谎'}, 'en_tokens': ['<sos>', 'so', 'they', 'could', "n't", 'question', 'you', '.', '<eos>'], 'zh_tokens': ['<sos>', '为了', '你', '不', '会', '为了', '我', '而', '撒谎', '<eos>']}


## DataLoader

最后一步将数据装入pytorch的DataLoader中

collate_fn 接收一个batch将其中的en_ids和zh_ids进行padding

In [14]:
from torch import nn

def get_collate_fn(pad_index):
    def collate_fn(batch):
        batch_en_ids = [example["en_ids"] for example in batch]
        batch_zh_ids = [example["zh_ids"] for example in batch]
        batch_en_ids = nn.utils.rnn.pad_sequence(batch_en_ids, padding_value=pad_index)
        batch_zh_ids = nn.utils.rnn.pad_sequence(batch_zh_ids, padding_value=pad_index)
        batch = {
            "en_ids": batch_en_ids,
            "zh_ids": batch_zh_ids,
        }
        return batch

    return collate_fn

接下来创建dataloader

In [15]:
import torch.utils.data

def get_data_loader(dataset, batch_size, pad_index, shuffle=False):
    collate_fn = get_collate_fn(pad_index)
    data_loader = torch.utils.data.DataLoader(
        dataset=dataset,
        batch_size=batch_size,
        collate_fn=collate_fn,
        shuffle=shuffle
    )
    return data_loader

In [24]:
batch_size = 128
pad_index = en_vocab[pad_token]

train_data_loader = get_data_loader(train_data, batch_size, pad_index, shuffle=True)
valid_data_loader = get_data_loader(valid_data, batch_size, pad_index)
test_data_loader = get_data_loader(test_data, batch_size, pad_index)

# 创建模型

## Encoder

In [16]:
from torch import nn

class Encoder(nn.Module):
    def __init__(self, input_dim, embedding_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.hidden_dim=hidden_dim
        self.n_layers=n_layers

        self.embedding = nn.Embedding(input_dim, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers, dropout=dropout)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # x = [length, batch]
        embedded = self.dropout(self.embedding(x))
        # embedded = [length, batch, embedding dim]
        outputs, (hidden, cell) = self.rnn(embedded)
        # outputs = [length, batch, hidden_dim * n directions] # n directions 单向LSTM为1双向为2
        # hidden,cell = [n layers * n directions, batch, hidden dim]
        return hidden, cell

## Decoder

In [17]:
from torch import nn

class Decoder(nn.Module):
    def __init__(self, output_dim, embedding_dim, hidden_dim, n_layers, dropout):
        super().__init__()
        self.output_dim=output_dim
        self.hidden_dim=hidden_dim
        self.n_layers=n_layers

        self.embedding = nn.Embedding(output_dim, embedding_dim)
        self.rnn = nn.LSTM(embedding_dim, hidden_dim, n_layers, dropout=dropout)
        self.fc_out = nn.Linear(hidden_dim, output_dim)
        self.dropout = nn.Dropout(dropout)

    def forward(self, x, hidden, cell):
        # x = [batch]
        # hidden,cell = [n layers * n directions, batch, hidden dim]
        x = x.unsqueeze(0)
        # x = [1, batch]
        embedded = self.dropout(self.embedding(x))
        # embedded = [1, batch, embedding dim]
        output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
        # output = [seq length, batch, hidden dim * n directions]
        # hidden/cell = [n layers * n directions, batch, hidden dim]
        # seq length & n directions = 1 here
        prediction = self.fc_out(output.squeeze(0))
        # prediction = [batch, output dim]
        return prediction, hidden, cell

## Seq2Seq

In [18]:
from torch import nn
import random

class Seq2Seq(nn.Module):
    def __init__(self, encoder, decoder, device):
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.device = device
        assert (
            encoder.hidden_dim == decoder.hidden_dim
            ), "Hidden dimensions of encoder and decoder must be equal!"
        assert (
            encoder.n_layers == decoder.n_layers
            ), "Encoder and decoder must have equal number of layers!"
        
    def forward(self, src, trg, teacher_forcing_ratio):
        # src = [src length, batch]
        # trg = [trg length, batch]
        # if teacher_forcing_ratio is 0.75 we user ground-truth inputs 75% of the time
        batch_size = trg.shape[1]
        trg_length = trg.shape[0]
        trg_vocab_size = self.decoder.output_dim
        # tensor to store decoder outputs
        outputs = torch.zeros(trg_length, batch_size, trg_vocab_size).to(self.device)
        # last hidden state of encoder
        hidden, cell = self.encoder(src)
        # first input to decoder is <sos>
        input = trg[0, :]
        for t in range(1, trg_length):
            output, hidden, cell = self.decoder(input, hidden, cell)
            # place predictions in a tensor
            outputs[t] = output
            # decide if we are going to use teacher forcing or not
            teacher_force = random.random() < teacher_forcing_ratio
            # get higgest predicted token from predictions
            top1 = output.argmax(1)
            # if teacher_force use actual as input else top1
            input = trg[t] if teacher_force else top1
        return outputs

# 训练

## 初始化模型

训练之前需要先初始化模型和加载数据集（已完成）。

In [19]:
input_dim = len(en_vocab)
output_dim = len(zh_vocab)
encoder_embedding_dim = 256
decoder_embedding_dim = 256
hidden_dim = 512
n_layers = 2
encoder_dropout = 0.5
decoder_dropout = 0.5
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

encoder = Encoder(
    input_dim,
    encoder_embedding_dim,
    hidden_dim,
    n_layers,
    encoder_dropout
)

decoder = Decoder(
    output_dim,
    decoder_embedding_dim,
    hidden_dim,
    n_layers,
    decoder_dropout
)

model = Seq2Seq(encoder, decoder, device).to(device)

均匀分布初始化权重

使用apply的时候，这个函数将会在每个模块和子模块中调用，对每个模块使用nn.init.uniform_进行均匀采样

In [20]:
def init_weights(m):
    for name, param in m.named_parameters():
        nn.init.uniform_(param.data, -0.08, 0.08)

model.apply(init_weights)

Seq2Seq(
  (encoder): Encoder(
    (embedding): Embedding(23384, 256)
    (rnn): LSTM(256, 512, num_layers=2, dropout=0.5)
    (dropout): Dropout(p=0.5, inplace=False)
  )
  (decoder): Decoder(
    (embedding): Embedding(29412, 256)
    (rnn): LSTM(256, 512, num_layers=2, dropout=0.5)
    (fc_out): Linear(in_features=512, out_features=29412, bias=True)
    (dropout): Dropout(p=0.5, inplace=False)
  )
)

查看模型参数个数

In [21]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f"The mode has {count_parameters(model):,} trainable parameters")

The mode has 35,960,548 trainable parameters


## 优化器

In [22]:
import torch.optim

optimizer = torch.optim.Adam(model.parameters())

## 损失函数

In [25]:
criterion = nn.CrossEntropyLoss(ignore_index=pad_index)

## 训练

训练用函数

In [26]:
def train_fn(model, data_loader, optimizer, criterion, clip, teacher_forcing_ratio, device):
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(data_loader):
        src = batch["en_ids"].to(device)
        trg = batch["zh_ids"].to(device)
        # src/trg = [length, batch]
        optimizer.zero_grad()
        output = model(src, trg, teacher_forcing_ratio)
        # output = [length, batch, vocab size]
        output_dim = output.shape[-1]
        output = output[1:].view(-1, output_dim)
        # output = [(length-1)*batch, vocab size]
        trg = trg[1:].view(-1)
        # trg = [(length-1)*batch]
        loss = criterion(output, trg)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
        optimizer.step()
        epoch_loss += loss.item()
    return epoch_loss / len(data_loader)

evaluate用函数

In [27]:
def evaluate_fn(model, data_loader, criterion, device):
    model.eval()
    epoch_loss = 0
    with torch.no_grad():
        for i, batch in enumerate(data_loader):
            src = batch["en_ids"].to(device)
            trg = batch["zh_ids"].to(device)
            output = model(src, trg, 0) # turnoff teacher forcing
            output_dim = output.shape[-1]
            output = output[1:].view(-1, output_dim)
            trg = trg[1:].view(-1)
            loss = criterion(output, trg)
            epoch_loss += loss.item()
    return epoch_loss/len(data_loader)

开始训练

In [28]:
import tqdm

n_epochs = 20
clip = 1.0
teacher_forcing_ratio = 0.5

best_valid_loss = float("inf")

for epoch in tqdm.tqdm(range(n_epochs)):
    train_loss = train_fn(
        model,
        train_data_loader,
        optimizer,
        criterion,
        clip,
        teacher_forcing_ratio,
        device
    )
    
    valid_loss = evaluate_fn(
        model,
        valid_data_loader,
        criterion,
        device
    )
    if valid_loss < best_valid_loss:
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), "./model/s2s-enzh-model.pt")
    print(f"\tTrain Loss: {train_loss:7.3f} | Train PPL: {np.exp(train_loss):7.3f}")
    print(f"\tValid Loss: {valid_loss:7.3f} | Valid PPL: {np.exp(valid_loss):7.3f}")

  5%|▌         | 1/20 [24:00<7:36:07, 1440.38s/it]

	Train Loss:   6.667 | Train PPL: 785.712
	Valid Loss:   6.594 | Valid PPL: 730.956


 10%|█         | 2/20 [47:59<7:11:58, 1439.91s/it]

	Train Loss:   6.056 | Train PPL: 426.514
	Valid Loss:   6.443 | Valid PPL: 628.249


 15%|█▌        | 3/20 [1:11:56<6:47:31, 1438.32s/it]

	Train Loss:   5.760 | Train PPL: 317.385
	Valid Loss:   6.391 | Valid PPL: 596.191


 20%|██        | 4/20 [1:35:57<6:23:48, 1439.30s/it]

	Train Loss:   5.554 | Train PPL: 258.144
	Valid Loss:   6.273 | Valid PPL: 529.931


 25%|██▌       | 5/20 [2:00:29<6:02:46, 1451.12s/it]

	Train Loss:   5.392 | Train PPL: 219.718
	Valid Loss:   6.210 | Valid PPL: 497.898


 30%|███       | 6/20 [2:25:10<5:40:59, 1461.39s/it]

	Train Loss:   5.241 | Train PPL: 188.834
	Valid Loss:   6.155 | Valid PPL: 471.056


 35%|███▌      | 7/20 [2:49:27<5:16:18, 1459.85s/it]

	Train Loss:   5.116 | Train PPL: 166.673
	Valid Loss:   6.103 | Valid PPL: 447.229


 40%|████      | 8/20 [3:15:35<4:58:52, 1494.40s/it]

	Train Loss:   5.000 | Train PPL: 148.407
	Valid Loss:   6.084 | Valid PPL: 438.979


 45%|████▌     | 9/20 [3:41:22<4:36:58, 1510.79s/it]

	Train Loss:   4.896 | Train PPL: 133.725
	Valid Loss:   6.087 | Valid PPL: 439.964


 50%|█████     | 10/20 [4:06:59<4:13:09, 1518.97s/it]

	Train Loss:   4.810 | Train PPL: 122.781
	Valid Loss:   6.082 | Valid PPL: 438.076


 55%|█████▌    | 11/20 [4:31:19<3:45:06, 1500.72s/it]

	Train Loss:   4.726 | Train PPL: 112.858
	Valid Loss:   6.026 | Valid PPL: 414.110


 60%|██████    | 12/20 [4:54:32<3:15:43, 1467.93s/it]

	Train Loss:   4.648 | Train PPL: 104.374
	Valid Loss:   6.037 | Valid PPL: 418.483


 65%|██████▌   | 13/20 [5:17:54<2:48:57, 1448.20s/it]

	Train Loss:   4.568 | Train PPL:  96.335
	Valid Loss:   6.081 | Valid PPL: 437.270


 70%|███████   | 14/20 [5:41:23<2:23:38, 1436.34s/it]

	Train Loss:   4.506 | Train PPL:  90.532
	Valid Loss:   6.066 | Valid PPL: 430.855


 75%|███████▌  | 15/20 [6:04:45<1:58:48, 1425.78s/it]

	Train Loss:   4.440 | Train PPL:  84.774
	Valid Loss:   6.038 | Valid PPL: 419.192


 80%|████████  | 16/20 [6:29:24<1:36:07, 1441.88s/it]

	Train Loss:   4.384 | Train PPL:  80.181
	Valid Loss:   6.084 | Valid PPL: 438.971


 85%|████████▌ | 17/20 [6:53:14<1:11:55, 1438.39s/it]

	Train Loss:   4.329 | Train PPL:  75.862
	Valid Loss:   6.106 | Valid PPL: 448.764


 90%|█████████ | 18/20 [7:16:45<47:40, 1430.20s/it]  

	Train Loss:   4.269 | Train PPL:  71.452
	Valid Loss:   6.136 | Valid PPL: 462.366


 95%|█████████▌| 19/20 [7:40:18<23:44, 1424.84s/it]

	Train Loss:   4.228 | Train PPL:  68.547
	Valid Loss:   6.143 | Valid PPL: 465.648


100%|██████████| 20/20 [8:03:37<00:00, 1450.87s/it]

	Train Loss:   4.176 | Train PPL:  65.131
	Valid Loss:   6.143 | Valid PPL: 465.423





# 评估模型

首先测试模型的loss

In [29]:
model.load_state_dict(torch.load("./model/s2s-enzh-model.pt"))
test_loss = evaluate_fn(model, test_data_loader, criterion, device)
print(f"\tTest Loss: {test_loss:7.3f} | Test PPL: {np.exp(test_loss):7.3f}")

	Test Loss:   5.983 | Test PPL: 396.545


接下来评估模型的BELU

首先是翻译用函数

In [30]:
def translate_sentence(sentence, model, en_nlp, zh_nlp, en_vocab, zh_vocab, 
                       lower, sos_token, eos_token, device, max_output_length=25,):
    model.eval()
    with torch.no_grad():
        if isinstance(sentence, str):
            tokens = [token.text for token in en_nlp.tokenizer(sentence)]
        else:
            tokens = [token for token in sentence]
        tokens = [sos_token] + tokens + [eos_token]
        ids = en_vocab.lookup_indices(tokens)
        tensor = torch.LongTensor(ids).unsqueeze(-1).to(device)
        hidden, cell = model.encoder(tensor)
        inputs = zh_vocab.lookup_indices([sos_token])
        for _ in range(max_output_length):
            inputs_tensor = torch.LongTensor([inputs[-1]]).to(device)
            output, hidden, cell = model.decoder(inputs_tensor, hidden, cell)
            predicted_token = output.argmax(-1).item()
            inputs.append(predicted_token)
            if predicted_token == zh_vocab[eos_token]:
                break
        tokens = zh_vocab.lookup_tokens(inputs)
    return tokens

测试翻译函数

In [31]:
sentence = test_data[0]['translation']['en']
expected_translation = test_data[0]['translation']['zh']

print(sentence)
print(expected_translation)

translation = translate_sentence(sentence, model, en_nlp, zh_nlp, en_vocab, zh_vocab,
                                 lower, sos_token, eos_token, device)
print(translation)

The Global Programme of Action Coordination Office, with the financial support of Belgium, is currently assisting Egypt, Nigeria, United Republic of Tanzania, Sri Lanka and Yemen to develop pilot national programmes of action for the protection of the marine environment from land-based activities.
9. 《全球行动纲领》协调处得到比利时的财政支持，目前正帮助埃及、尼日利亚、坦桑尼亚联合共和国、斯里兰卡和也门制订试行的保护海洋环境免受陆地活动影响的国家行动方案。
['<sos>', '<unk>', '.', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>', '<unk>']


接下来将test_data进行翻译

In [32]:
translations = [translate_sentence(example['translation']["en"], model, en_nlp, zh_nlp, en_vocab, zh_vocab,
                                 lower, sos_token, eos_token, device)
                for example in tqdm.tqdm(test_data)]

predictions = ["".join(translation[1:-1]) for translation in translations]
references = [example['translation']["zh"] for example in test_data]

100%|██████████| 2000/2000 [00:26<00:00, 75.52it/s]


查看预测和参考内容

In [33]:
print(predictions[0])
print(references[0])

<unk>.<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk>
9. 《全球行动纲领》协调处得到比利时的财政支持，目前正帮助埃及、尼日利亚、坦桑尼亚联合共和国、斯里兰卡和也门制订试行的保护海洋环境免受陆地活动影响的国家行动方案。


将结果tokenize

In [34]:
def get_tokenizer_fn(nlp, lower):
    def tokenizer_fn(s):
        tokens = [token.text for token in nlp.tokenizer(s)]
        if lower:
            tokens = [token.lower() for token in tokens]
        return tokens
    return tokenizer_fn

tokenizer_fn = get_tokenizer_fn(zh_nlp, lower)

测试函数

In [35]:
print(tokenizer_fn(predictions[0]))
print(tokenizer_fn(references[0]))

['<un', 'k>', '.', '<unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><unk><un', 'k>']
['9', '.', '《', '全球', '行动', '纲领', '》', '协调', '处', '得到', '比利时', '的', '财政', '支持', '，', '目前', '正', '帮助', '埃及', '、', '尼日利亚', '、', '坦桑尼亚', '联合', '共和国', '、', '斯里兰卡', '和', '也门', '制订', '试行', '的', '保护', '海洋', '环境', '免受', '陆地', '活动', '影响', '的', '国家', '行动', '方案', '。']


计算bleu

In [36]:
import evaluate

bleu = evaluate.load("bleu")
results = bleu.compute(
    predictions=predictions, references=references, tokenizer=tokenizer_fn
)

查看结果BLEU

In [37]:
print(results)

{'bleu': 0.002740053664609093, 'precisions': [0.2146313838550247, 0.013195639701663799, 0.0032295569047926624, 0.0012392477037469019], 'brevity_penalty': 0.26555402230057806, 'length_ratio': 0.4299342614931716, 'translation_length': 19424, 'reference_length': 45179}
