In [1]:
%matplotlib inline


Language Translation with nn.Transformer and torchtext
======================================================

This tutorial shows, how to train a translation model from scratch using
Transformer. We will be using `Multi30k <http://www.statmt.org/wmt16/multimodal-task.html#task1>`__ 
dataset to train a German to English translation model.



Data Sourcing and Processing
----------------------------

`torchtext library <https://pytorch.org/text/stable/>`__ has utilities for creating datasets that can be easily
iterated through for the purposes of creating a language translation
model. In this example, we show how to use torchtext's inbuilt datasets, 
tokenize a raw text sentence, build vocabulary, and numericalize tokens into tensor. We will use
`Multi30k dataset from torchtext library <https://pytorch.org/text/stable/datasets.html#multi30k>`__
that yields a pair of source-target raw sentences. 





In [2]:
import pandas as pd
from sklearn.model_selection import train_test_split

data_dir='./data/news-commentary-v16.en-zh.tsv'
seed=520

data = pd.read_csv(data_dir, sep='\t', encoding='utf-8', names=['src', 'tgt'], nrows=200000)
data = data.dropna(axis=0).reset_index(drop=True)
train_data=data.sample(frac=0.8,random_state=seed)
val_data=data.drop(train_data.index).reset_index(drop=True)
train_data=train_data.reset_index(drop=True)
data

Unnamed: 0,src,tgt
0,1929 or 1989?,1929年还是1989年?
1,PARIS – As the economic crisis deepens and wid...,巴黎-随着经济危机不断加深和蔓延，整个世界一直在寻找历史上的类似事件希望有助于我们了解目前正...
2,"At the start of the crisis, many people likene...",一开始，很多人把这次危机比作1982年或1973年所发生的情况，这样得类比是令人宽心的，因为...
3,"Today, the mood is much grimmer, with referenc...",如今人们的心情却是沉重多了，许多人开始把这次危机与1929年和1931年相比，即使一些国家政...
4,The tendency is either excessive restraint (Eu...,目前的趋势是，要么是过度的克制（欧洲 ） ， 要么是努力的扩展（美国 ） 。
...,...,...
194690,"In the near future, drugs may even be repositi...",不久之后，药物甚至可以实时再利用，对准每一位患者的独特需要。
194691,New technologies like AI and augmented reality...,人工智能和增强现实等新科技将有助于医生发现疑难杂症的新疗法。
194692,"To realize this future, however, systematic re...",但要实现这一未来，各大制药企业都必须将系统性再利用纳入商业模式。
194693,"Fortunately, there are strong economic incenti...",幸运的是，公司有强力的经济激励这样做。


In [3]:
from torch.utils.data import Dataset
class MyDataset(Dataset):
    def __init__(self,data) -> None:
        super(MyDataset).__init__()
        
        self.data=data

    def __len__(self):
        return len(self.data)

    def __getitem__(self, index):
        return self.data['src'][index],self.data['tgt'][index]

In [4]:
from torchtext.data.utils import get_tokenizer
from torchtext.vocab import build_vocab_from_iterator
#from torchtext.datasets import Multi30k
from typing import Iterable, List
import os
import torch


SRC_LANGUAGE = 'en'
TGT_LANGUAGE = 'zh'
# Place-holders
token_transform = {}
vocab_transform = {}


# Create source and target language tokenizer. Make sure to install the dependencies.
# pip install -U spacy
# python -m spacy download en_core_web_sm
# python -m spacy download de_core_news_sm

token_transform[SRC_LANGUAGE] = get_tokenizer('spacy', language='en_core_web_sm')
token_transform[TGT_LANGUAGE] = get_tokenizer('spacy', language='zh_core_web_sm')

# helper function to yield list of tokens
def yield_tokens(data_iter: Iterable, language: str) -> List[str]:
    language_index = {SRC_LANGUAGE: 0, TGT_LANGUAGE: 1}

    for i in range(len(data_iter)):
        yield token_transform[language](data_iter[i][language_index[language]])

# Define special symbols and indices
UNK_IDX, PAD_IDX, BOS_IDX, EOS_IDX = 0, 1, 2, 3
# Make sure the tokens are in order of their indices to properly insert them in vocab
special_symbols = ['<unk>', '<pad>', '<bos>', '<eos>']
 
data_iter=MyDataset(data)
vocab_transform_dir='./checkpoints/pytorch_vocab.pkl'

if not os.path.exists(vocab_transform_dir):
    for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
        # Create torchtext's Vocab object 
        vocab_transform[ln] = build_vocab_from_iterator(yield_tokens(data_iter, ln),
                                                        min_freq=1,
                                                        specials=special_symbols,
                                                        special_first=True)

    # Set UNK_IDX as the default index. This index is returned when the token is not found. 
    # If not set, it throws RuntimeError when the queried token is not found in the Vocabulary. 
    for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
        vocab_transform[ln].set_default_index(UNK_IDX)
    print('save vocab')
    torch.save(vocab_transform,vocab_transform_dir)
else:
    print('load vocab')
    vocab_transform=torch.load(vocab_transform_dir)

load vocab


Seq2Seq Network using Transformer
---------------------------------

Transformer is a Seq2Seq model introduced in `“Attention is all you
need” <https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf>`__
paper for solving machine translation tasks. 
Below, we will create a Seq2Seq network that uses Transformer. The network
consists of three parts. First part is the embedding layer. This layer converts tensor of input indices
into corresponding tensor of input embeddings. These embedding are further augmented with positional
encodings to provide position information of input tokens to the model. The second part is the 
actual `Transformer <https://pytorch.org/docs/stable/generated/torch.nn.Transformer.html>`__ model. 
Finally, the output of Transformer model is passed through linear layer
that give un-normalized probabilities for each token in the target language. 




In [5]:
from torch import Tensor
import torch
import torch.nn as nn
from torch.nn import Transformer
import math

# helper Module that adds positional encoding to the token embedding to introduce a notion of word order.
class PositionalEncoding(nn.Module):
    def __init__(self,
                 emb_size: int,
                 dropout: float,
                 maxlen: int = 5000):
        super(PositionalEncoding, self).__init__()
        den = torch.exp(- torch.arange(0, emb_size, 2)* math.log(10000) / emb_size)
        pos = torch.arange(0, maxlen).reshape(maxlen, 1)
        pos_embedding = torch.zeros((maxlen, emb_size))
        pos_embedding[:, 0::2] = torch.sin(pos * den)
        pos_embedding[:, 1::2] = torch.cos(pos * den)
        pos_embedding = pos_embedding.unsqueeze(-2)

        self.dropout = nn.Dropout(dropout)
        self.register_buffer('pos_embedding', pos_embedding)

    def forward(self, token_embedding: Tensor):
        return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :])

# helper Module to convert tensor of input indices into corresponding tensor of token embeddings
class TokenEmbedding(nn.Module):
    def __init__(self, vocab_size: int, emb_size):
        super(TokenEmbedding, self).__init__()
        self.embedding = nn.Embedding(vocab_size, emb_size)
        self.emb_size = emb_size

    def forward(self, tokens: Tensor):
        return self.embedding(tokens.long()) * math.sqrt(self.emb_size)

# Seq2Seq Network 
class Seq2SeqTransformer(nn.Module):
    def __init__(self,
                 num_encoder_layers: int,
                 num_decoder_layers: int,
                 emb_size: int,
                 nhead: int,
                 src_vocab_size: int,
                 tgt_vocab_size: int,
                 dim_feedforward: int = 512,
                 dropout: float = 0.1):
        super(Seq2SeqTransformer, self).__init__()
        self.transformer = Transformer(d_model=emb_size,
                                       nhead=nhead,
                                       num_encoder_layers=num_encoder_layers,
                                       num_decoder_layers=num_decoder_layers,
                                       dim_feedforward=dim_feedforward,
                                       dropout=dropout)
        self.generator = nn.Linear(emb_size, tgt_vocab_size)
        self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size)
        self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size)
        self.positional_encoding = PositionalEncoding(
            emb_size, dropout=dropout)

    def forward(self,
                src: Tensor,
                trg: Tensor,
                src_mask: Tensor,
                tgt_mask: Tensor,
                src_padding_mask: Tensor,
                tgt_padding_mask: Tensor,
                memory_key_padding_mask: Tensor):
        src_emb = self.positional_encoding(self.src_tok_emb(src))
        tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg))
        outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, 
                                src_padding_mask, tgt_padding_mask, memory_key_padding_mask)
        return self.generator(outs)

    def encode(self, src: Tensor, src_mask: Tensor):
        return self.transformer.encoder(self.positional_encoding(self.src_tok_emb(src)), src_mask)

    def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor):
        return self.transformer.decoder(self.positional_encoding(self.tgt_tok_emb(tgt)), memory, tgt_mask)

    def predict(self, src, tgt):
        device=src.device
        num_tokens=src.shape[0]
        src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool).to(device)
        memory=self.encode(src,src_mask)
        #print(memory.shape)
        tgt_mask = (generate_square_subsequent_mask(tgt.size(0)).type(torch.bool)).to(device)
        out = self.decode(tgt, memory, tgt_mask)
        #print(out.shape)
        out = out.transpose(0, 1)#l*e
        prob = self.generator(out[:, -1])#1*v
        return prob

During training, we need a subsequent word mask that will prevent model to look into
the future words when making predictions. We will also need masks to hide
source and target padding tokens. Below, let's define a function that will take care of both. 




In [6]:
def generate_square_subsequent_mask(sz):
    mask = (torch.triu(torch.ones((sz, sz), device=DEVICE)) == 1).transpose(0, 1)
    mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
    return mask


def create_mask(src, tgt):
    src_seq_len = src.shape[0]
    tgt_seq_len = tgt.shape[0]

    tgt_mask = generate_square_subsequent_mask(tgt_seq_len)
    src_mask = torch.zeros((src_seq_len, src_seq_len),device=DEVICE).type(torch.bool)

    src_padding_mask = (src == PAD_IDX).transpose(0, 1)
    tgt_padding_mask = (tgt == PAD_IDX).transpose(0, 1)
    return src_mask, tgt_mask, src_padding_mask, tgt_padding_mask

Let's now define the parameters of our model and instantiate the same. Below, we also 
define our loss function which is the cross-entropy loss and the optmizer used for training.




In [7]:
#from MyTransformer import Transformer
#from transformer.Models import Transformer
torch.manual_seed(seed)
MAX_SEQ_LEN=1000
SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
EMB_SIZE = 512
NHEAD = 8
FFN_HID_DIM = 512
NUM_LAYERS = 3
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

transformer = Seq2SeqTransformer(NUM_LAYERS, NUM_LAYERS, EMB_SIZE, NHEAD, SRC_VOCAB_SIZE, TGT_VOCAB_SIZE, FFN_HID_DIM)

for p in transformer.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

transformer = transformer.to(DEVICE)

loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
print(DEVICE)

cuda


In [8]:
# #from MyTransformer import Transformer
# from transformer.Models import Transformer
# torch.manual_seed(seed)
# MAX_SEQ_LEN=1000
# SRC_VOCAB_SIZE = len(vocab_transform[SRC_LANGUAGE])
# TGT_VOCAB_SIZE = len(vocab_transform[TGT_LANGUAGE])
# EMB_SIZE = 512
# NHEAD = 8
# FFN_HID_DIM = 1024
# NUM_LAYERS = 3
# DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# transformer=Transformer(
#     n_src_vocab=SRC_VOCAB_SIZE,
#     n_trg_vocab=TGT_VOCAB_SIZE,
#     src_pad_idx=PAD_IDX,
#     trg_pad_idx=PAD_IDX,
#     trg_emb_prj_weight_sharing=True,
#     emb_src_trg_weight_sharing=False,
#     scale_emb_or_prj='emb'
# )

# transformer = transformer.to(DEVICE)

# loss_fn = torch.nn.CrossEntropyLoss(ignore_index=PAD_IDX)

# optimizer = torch.optim.Adam(transformer.parameters(), lr=0.0001, betas=(0.9, 0.98), eps=1e-9)
# print(DEVICE)

Collation
---------

As seen in the ``Data Sourcing and Processing`` section, our data iterator yields a pair of raw strings. 
We need to convert these string pairs into the batched tensors that can be processed by our ``Seq2Seq`` network 
defined previously. Below we define our collate function that convert batch of raw strings into batch tensors that
can be fed directly into our model.   




In [9]:
from torch.nn.utils.rnn import pad_sequence

# helper function to club together sequential operations
def sequential_transforms(*transforms):
    def func(txt_input):
        for transform in transforms:
            txt_input = transform(txt_input)
        return txt_input
    return func

# function to add BOS/EOS and create tensor for input sequence indices
def tensor_transform(token_ids: List[int]):
    return torch.cat((torch.tensor([BOS_IDX]), 
                      torch.tensor(token_ids), 
                      torch.tensor([EOS_IDX])))

# src and tgt language text transforms to convert raw strings into tensors indices
text_transform = {}
for ln in [SRC_LANGUAGE, TGT_LANGUAGE]:
    text_transform[ln] = sequential_transforms(token_transform[ln], #Tokenization
                                               vocab_transform[ln], #Numericalization
                                               tensor_transform) # Add BOS/EOS and create tensor


# function to collate data samples into batch tesors
def collate_fn(batch):
    src_batch, tgt_batch = [], []
    for src_sample, tgt_sample in batch:
        src_batch.append(text_transform[SRC_LANGUAGE](src_sample.rstrip("\n")))
        tgt_batch.append(text_transform[TGT_LANGUAGE](tgt_sample.rstrip("\n")))

    src_batch = pad_sequence(src_batch, batch_first=True, padding_value=PAD_IDX)
    tgt_batch = pad_sequence(tgt_batch, batch_first=True, padding_value=PAD_IDX)
    return src_batch, tgt_batch

Let's define training and evaluation loop that will be called for each 
epoch.




In [10]:
from torch.utils.data import DataLoader
from tqdm import tqdm
BATCH_SIZE = 32
def train_epoch(model, optimizer, epoch):
    model.train()
    losses = 0
    train_iter = MyDataset(train_data)
    train_dataloader = DataLoader(train_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)
    
    for src, tgt in tqdm(train_dataloader, desc='Epoch: {}'.format(epoch)):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        #print(tgt.shape)
        #tgt_input = tgt[:-1, :]
        tgt_input = tgt[:, :-1]

        #src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        #logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        logits=model(src,tgt_input)

        optimizer.zero_grad()

        #tgt_out = tgt[1:, :]
        tgt_out = tgt[:, 1:]
        #loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        #print(tgt_out.reshape(-1).shape,logits.shape)
        loss = loss_fn(logits, tgt_out.reshape(-1))
        loss.backward()

        optimizer.step()
        losses += loss.item()

    return losses / len(train_dataloader)


def evaluate(model):
    model.eval()
    losses = 0

    val_iter = MyDataset(val_data)
    val_dataloader = DataLoader(val_iter, batch_size=BATCH_SIZE, collate_fn=collate_fn)

    for src, tgt in tqdm(val_dataloader):
        src = src.to(DEVICE)
        tgt = tgt.to(DEVICE)
        #tgt_input = tgt[:-1, :]
        tgt_input = tgt[:, :-1]

        #src_mask, tgt_mask, src_padding_mask, tgt_padding_mask = create_mask(src, tgt_input)

        #logits = model(src, tgt_input, src_mask, tgt_mask,src_padding_mask, tgt_padding_mask, src_padding_mask)
        logits=model(src,tgt_input)
        
        #tgt_out = tgt[1:, :]
        tgt_out = tgt[:, 1:]
        #loss = loss_fn(logits.reshape(-1, logits.shape[-1]), tgt_out.reshape(-1))
        loss = loss_fn(logits, tgt_out.reshape(-1))
        losses += loss.item()

    return losses / len(val_dataloader)

Now we have all the ingredients to train our model. Let's do it!




In [11]:
# from timeit import default_timer as timer
# NUM_EPOCHS = 20
# best_loss=1e7
# #BESTMODEL='./checkpoints/pytorch_bestmodel.pkl'
# BESTMODEL='./checkpoints/test_bestmodel.pkl'
# for epoch in range(1, NUM_EPOCHS+1):
#     start_time = timer()
#     train_loss = train_epoch(transformer, optimizer, epoch)
#     end_time = timer()
#     val_loss = evaluate(transformer)
#     if val_loss<best_loss:
#         best_loss=val_loss
#         torch.save(transformer.state_dict(),BESTMODEL)
#     print((f"Epoch: {epoch}, Train loss: {train_loss:.3f}, Val loss: {val_loss:.3f}, "f"Epoch time = {(end_time - start_time):.3f}s"))



In [12]:
# function to generate output sequence using greedy algorithm 
def greedy_decode(model, src, src_mask, max_len, start_symbol):
    src = src.to(DEVICE)
    src_mask=src_mask.to(DEVICE)

    memory = model.encode(src, src_mask)
    ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
    for i in range(max_len-1):
        memory = memory.to(DEVICE)
        tgt_mask = (generate_square_subsequent_mask(ys.size(0)).type(torch.bool)).to(DEVICE)
        out = model.decode(ys, memory, tgt_mask)
        #print('out',out.shape)
        out = out.transpose(0, 1)#l*e
        prob = model.generator(out[:, -1])#1*v
        _, next_word = torch.max(prob, dim=1)
        next_word = next_word.item()

        ys = torch.cat([ys,torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=0)
        if next_word == EOS_IDX:
            break
    return ys


# actual function to translate input sentence into target language
def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1)
    num_tokens = src.shape[0]
    src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool)
    #print('src',src.shape)
    tgt_tokens = greedy_decode(model,  src, src_mask, max_len=num_tokens + 10, start_symbol=BOS_IDX).flatten()
    return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")

In [13]:
# # function to generate output sequence using greedy algorithm 
# def greedy_decode(model, src, max_len, start_symbol):
#     src = src.to(DEVICE)

#     memory = model.encode(src)
#     ys = torch.ones(1, 1).fill_(start_symbol).type(torch.long).to(DEVICE)
#     for i in range(max_len-1):
#         memory = memory.to(DEVICE)
#         out = model.decode(ys, memory)
#         prob = model.trg_word_prj(out[:, -1])
#         _, next_word = torch.max(prob, dim=1)
#         next_word = next_word.item()

#         ys = torch.cat([ys,torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1)
#         if next_word == EOS_IDX:
#             break
#     return ys


# # actual function to translate input sentence into target language
# def translate(model: torch.nn.Module, src_sentence: str):
#     model.eval()
#     src = text_transform[SRC_LANGUAGE](src_sentence).view(1, -1)#l*1
#     num_tokens = src.shape[-1]
#     tgt_tokens = greedy_decode(model,  src, max_len=num_tokens + 10, start_symbol=BOS_IDX).flatten()
#     return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")

In [14]:
# BESTMODEL='./checkpoints/test_bestmodel.pkl'
# transformer.load_state_dict(torch.load(BESTMODEL))
# src=" i like apple"
# print(translate(transformer, src))

In [15]:
#BESTMODEL='./checkpoints/my_bestmodel.pkl'
BESTMODEL='./checkpoints/pytorch_bestmodel.pkl'
transformer.load_state_dict(torch.load(BESTMODEL))
src="I like apple"
print(translate(transformer, src))

 我 喜欢 苹果 


In [16]:
src="I like apple"
src = text_transform[SRC_LANGUAGE](src).view(-1, 1).to(DEVICE)
src=src.repeat(1,3)
print(src.shape)
tgt=torch.ones(1, 1).fill_(BOS_IDX).type(torch.long).to(DEVICE)
tgt=tgt.repeat(1,3)
transformer.predict(src, tgt).shape

torch.Size([5, 3])


torch.Size([3, 122260])

In [17]:
x=torch.Tensor([1,3,4,5])
print(x)
y=torch.Tensor([0,1,1,0]).bool()
x[y]

tensor([1., 3., 4., 5.])


tensor([3., 4.])

In [20]:
def beam_search(model, inputs, topk, maxlen, min_ends=1, min_len=1):
    """
    beam search解码
    说明：这里的topk即beam size；
    返回：最优解码序列。
    """
    output_ids, output_scores = torch.full((1,1),BOS_IDX).to(DEVICE), torch.zeros(1,1).to(DEVICE)
    for step in range(maxlen):
        
        scores= model.predict(inputs, output_ids)  # 计算当前得分
        
        if step == 0:  # 第1步预测后将输入重复topk次
            inputs = inputs.repeat(1, topk)
        scores = output_scores.reshape(-1,1) + scores  # 综合累积得分
        
        indices = torch.topk(scores.reshape(1,-1),topk,dim=1,largest=True)[1]  # 仅保留topk
        
        indices_1 = (indices // scores.shape[1]).squeeze()  # 行索引
        indices_2 = indices % scores.shape[1] # 列索引
        
        output_ids = torch.cat([output_ids[:,indices_1], indices_2],dim=0)  # 更新输出
        #print(output_ids)
        output_scores = scores[indices_1,indices_2.squeeze()]
        #print(output_scores)# 更新得分
        is_end = output_ids[-1, :] == EOS_IDX  # 标记是否以end标记结束
        end_counts = (output_ids == EOS_IDX).sum(0)  # 统计出现的end标记
        #print(is_end.shape,end_counts.shape)
        
        if output_ids.shape[0] >= min_len:  # 最短长度判断
            best = output_scores.argmax()  # 得分最大的那个
            if is_end[best] and end_counts[best] >= min_ends:  # 如果已经终止
                return output_ids[:,best]  # 直接输出
            else:  # 否则，只保留未完成部分
                flag = ~is_end | (end_counts < min_ends)  # 标记未完成序列
                if not flag.all():  # 如果有已完成的
                    inputs = inputs[:,flag] # 扔掉已完成序列
                    output_ids = output_ids[:,flag]  # 扔掉已完成序列
                    output_scores = output_scores[flag]  # 扔掉已完成序列
                    end_counts = end_counts[flag]  # 扔掉已完成end计数
                    topk = flag.sum()  # topk相应变化
    # 达到长度直接输出
    return output_ids[:,output_scores.argmax()]

In [23]:
def translate(model: torch.nn.Module, src_sentence: str):
    model.eval()
    src = text_transform[SRC_LANGUAGE](src_sentence).view(-1, 1).to(DEVICE)
    num_tokens = src.shape[0]
    tgt_tokens = beam_search(model,src,topk=3,maxlen=num_tokens+10)
    return " ".join(vocab_transform[TGT_LANGUAGE].lookup_tokens(list(tgt_tokens.cpu().numpy()))).replace("<bos>", "").replace("<eos>", "")

In [34]:
#BESTMODEL='./checkpoints/my_bestmodel.pkl'
BESTMODEL='./checkpoints/pytorch_bestmodel.pkl'
transformer.load_state_dict(torch.load(BESTMODEL))
src="I miss you."
print(translate(transformer, src))

 我 忽略 了 你 。 


References
----------

1. Attention is all you need paper.
   https://papers.nips.cc/paper/2017/file/3f5ee243547dee91fbd053c1c4a845aa-Paper.pdf
2. The annotated transformer. https://nlp.seas.harvard.edu/2018/04/03/attention.html#positional-encoding

