# Neural Machine Translation with Transformer Using PyTorch
In this notebook we are going to perform machine translation using Transformer. 

Specifically, we are going to train a sequence to sequence model for French-to-English translation.

# Import libraries

In [1]:
!pip install fastBPE sacremoses

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting fastBPE
  Downloading fastBPE-0.1.0.tar.gz (35 kB)
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting sacremoses
  Downloading sacremoses-0.0.53.tar.gz (880 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m880.6/880.6 kB[0m [31m13.0 MB/s[0m eta [36m0:00:00[0m
[?25h  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: fastBPE, sacremoses
  Building wheel for fastBPE (setup.py) ... [?25l[?25hdone
  Created wheel for fastBPE: filename=fastBPE-0.1.0-cp39-cp39-linux_x86_64.whl size=762563 sha256=f33fc86d86d91d146ae92809af45ae57e6ac6ecdcf1edf4e013b6dbb60d8b599
  Stored in directory: /root/.cache/pip/wheels/e1/10/20/0691b69b472ff8530a7e608674d5bd1cbc772f4d6071c8accf
  Building wheel for sacremoses (setup.py) ... [?25l[?25hdone
  Created wheel for sacremoses: filename=sacremoses-0.0.53-py3-none-any.whl size=

In [2]:
import copy
import math
import re
import time
import unicodedata

import fastBPE
import nltk
import numpy as np
import pandas as pd
from tqdm.notebook import tqdm
from sacremoses import MosesDetokenizer, MosesTokenizer
from sklearn.model_selection import train_test_split

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import IterableDataset
from torch.utils.data import DataLoader

In [3]:
!gdown 1Bhpi9gD_3UHZRFmcn7Czkb73jHN8CUAL

Downloading...
From: https://drive.google.com/uc?id=1Bhpi9gD_3UHZRFmcn7Czkb73jHN8CUAL
To: /content/Dataset-fr-eng.txt
  0% 0.00/28.7M [00:00<?, ?B/s]100% 28.7M/28.7M [00:00<00:00, 286MB/s]100% 28.7M/28.7M [00:00<00:00, 285MB/s]


In [4]:
f = open('Dataset-fr-eng.txt', encoding='UTF-8').read().strip().split('\n')

In [5]:
lines = f
# sample size (smaller sample size to reduce computation)
num_examples = 30000 
# creates lists containing each pair
original_word_pairs = [[w for w in l.split('\t')] for l in lines[:num_examples]]
data = pd.DataFrame(original_word_pairs, columns=["eng", "fr","whatever"])
data = data[['eng','fr']]

In [6]:
data.head(10)

Unnamed: 0,eng,fr
0,Go.,Va !
1,Go.,Marche.
2,Go.,Bouge !
3,Hi.,Salut !
4,Hi.,Salut.
5,Run!,Cours !
6,Run!,Courez !
7,Run!,Prenez vos jambes à vos cous !
8,Run!,File !
9,Run!,Filez !


In [7]:
!gdown 1vXaIlIRnx7rm98K9zTNEsrJodNvUg4F-
!unzip tokenizer.zip

Downloading...
From: https://drive.google.com/uc?id=1vXaIlIRnx7rm98K9zTNEsrJodNvUg4F-
To: /content/tokenizer.zip
  0% 0.00/619k [00:00<?, ?B/s]100% 619k/619k [00:00<00:00, 168MB/s]
Archive:  tokenizer.zip
  inflating: bpecodes                
  inflating: dict.en.txt             
  inflating: dict.fr.txt             


In [8]:
class Tokenizer:
    def __init__(self, bpe_file: str, vocab_file: str):        
        self.bpe = fastBPE.fastBPE(bpe_file)
        self.tokenizer = MosesTokenizer('fr')
        self.detokenizer = MosesDetokenizer('en')
        
        tmp = open(vocab_file, encoding='utf-8').read().split('\n')
        tmp = [x.split()[0] for x in tmp[:-1]]
        self.vocab = ['<sos>', '<pad>', '<eos>', '<unk>'] + tmp
        self.t2i = {t : i for i, t in enumerate(self.vocab)}
    
    def encode(self, sent: str, add_eos = False):
        sent = self.tokenizer.tokenize(sent, aggressive_dash_splits=True, return_str=True)
        tokens = self.bpe.apply([sent])[0].split()
        tokens = [self.t2i[t] if t in self.t2i else self.t2i['<unk>'] for t in tokens]
        if add_eos:
            tokens += [self.t2i['<eos>']]
        return tokens
    
    def decode(self, tokens: list):
        sent = [self.vocab[t] for t in tokens]
        sent = ' '.join(sent)
        sent = (sent + ' ').replace('@@ ', '').rstrip()
        sent = self.detokenizer.detokenize(sent.split())
        return sent

In [9]:
class NMT_dataset(IterableDataset):
    def __init__(self, data):
        self.fr_data = data["fr"]
        self.en_data = data["eng"]
        self.src_tokenizer = Tokenizer(*SRC_VOCAB_FILE)
        self.trg_tokenizer = Tokenizer(*TRG_VOCAB_FILE)
    
    def src_line_mapper(self, line):
        line = line.replace('\n', '')
        tokens = self.src_tokenizer.encode(line, True)
        return tokens
    
    def trg_line_mapper(self, line):
        line = line.replace('\n', '')
        tokens = self.trg_tokenizer.encode(line, True)
        return tokens
    
    def __iter__(self):
        mapped_fr_iter = map(self.src_line_mapper, self.fr_data)
        mapped_en_iter = map(self.trg_line_mapper, self.en_data)
        
        fr_en_iter = zip(mapped_fr_iter, mapped_en_iter)
        return fr_en_iter

In [10]:
def pad_batch(batch):
    max_src_len = max(len(x) for x, _ in batch)
    max_trg_len = max(len(y) for _, y in batch)
    max_src_len = min(max_src_len, MAX_SEQ_LEN)
    max_trg_len = min(max_trg_len, MAX_SEQ_LEN)
    
    src_batch = [x[:max_src_len] for x, _ in batch]
    trg_batch = [y[:max_trg_len] for _, y in batch]
    
    src_batch = [np.pad(x, (0, max_src_len - len(x)), constant_values=PAD_ID) for x in src_batch]
    trg_batch = [np.pad(y, (0, max_trg_len - len(y)), constant_values=PAD_ID) for y in trg_batch]
    return torch.tensor(src_batch, dtype=torch.long), torch.tensor(trg_batch, dtype=torch.long)

In [11]:
SRC_VOCAB_FILE = [f'bpecodes', f'dict.fr.txt']
TRG_VOCAB_FILE = [f'bpecodes', f'dict.en.txt']

MAX_SEQ_LEN = 127
PAD_ID = 1
BOUND_ID = 2

# model params
enc_ffn_dims = 1024
dec_ffn_dims = 512
d_model = 256
heads = 8
N = 3
WARM_UP_STEPS = 4000

EPOCHS = 100
BATCH_SIZE = 300
NUM_BATCHES = len(data) // BATCH_SIZE + (1 if len(data) % BATCH_SIZE > 0 else 0)

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
DEVICE

device(type='cuda', index=0)

In [12]:
dataset = NMT_dataset(data)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, collate_fn=pad_batch)

In [13]:
src_vocab_size = len(dataset.src_tokenizer.vocab)
trg_vocab_size = len(dataset.trg_tokenizer.vocab)

In [14]:
class PositionalEncoder(nn.Module):
    def __init__(self, d_model, max_seq_len=256):
        super().__init__()
        self.d_model = d_model

        half_dim = d_model // 2
        pe = torch.tensor(10000) / (torch.arange(1, half_dim + 1, dtype=torch.float) * 2 - 1)
        pe = torch.exp(torch.arange(0, half_dim, dtype=torch.float) * -pe.log())
        pe = torch.arange(0, max_seq_len, dtype=torch.float).unsqueeze(1) * pe.unsqueeze(0)
        pe = torch.cat([torch.sin(pe), torch.cos(pe)], dim=1).view(max_seq_len, -1)
        
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
    
    def forward(self, x):
        length = x.size(1) + 2
        return x + self.pe[:, 2:length]

In [15]:
def attention(q, k, v, d_k, mask=None, dropout=None):
    scores = torch.matmul(q, k.transpose(-2, -1)) / np.sqrt(d_k)
    if mask is not None:
        mask = mask.unsqueeze(1)
        scores = scores.masked_fill(mask == 0, -1e9)
    scores = F.softmax(scores, dim=-1)
    
    if dropout is not None:
        scores = dropout(scores)
        
    output = torch.matmul(scores, v)
    return output

In [16]:
class MultiHeadAttention(nn.Module):
    def __init__(self, heads, d_model, dropout=0.1):
        super().__init__()
        
        self.d_model = d_model
        self.d_k = d_model // heads
        self.h = heads
        
        self.q_linear = nn.Linear(d_model, d_model)
        self.v_linear = nn.Linear(d_model, d_model)
        self.k_linear = nn.Linear(d_model, d_model)

        self.dropout = nn.Dropout(dropout)

        self.out = nn.Linear(d_model, d_model)
    
    def forward(self, q, k, v, mask=None):
        bs = q.size(0)
                
        k = self.k_linear(k).view(bs, -1, self.h, self.d_k)
        q = self.q_linear(q).view(bs, -1, self.h, self.d_k)
        v = self.v_linear(v).view(bs, -1, self.h, self.d_k)
        
        k = k.transpose(1,2)
        q = q.transpose(1,2)
        v = v.transpose(1,2)
        scores = attention(q, k, v, self.d_k, mask, self.dropout)
        
        concat = scores.transpose(1,2).contiguous().view(bs, -1, self.d_model)
        output = self.out(concat)
        return output

In [17]:
class FeedForward(nn.Module):
    def __init__(self, d_model, d_ff, dropout=0.1):
        super().__init__() 
        self.dropout = dropout
        self.linear_1 = nn.Linear(d_model, d_ff)
        self.linear_2 = nn.Linear(d_ff, d_model)

    def forward(self, x):
        x = F.dropout(F.relu(self.linear_1(x)), p=self.dropout, training=self.training)
        x = self.linear_2(x)
        return x


In [18]:
class EncoderLayer(nn.Module):
    def __init__(self, d_model, heads, ffn_dim, dropout=0.1):
        super().__init__()
        self.dropout = dropout
        self.norm_1 = nn.LayerNorm(d_model)
        self.norm_2 = nn.LayerNorm(d_model)
        self.attn = MultiHeadAttention(heads, d_model)
        self.ff = FeedForward(d_model, ffn_dim)
        
    def forward(self, x, mask):
        x = x + F.dropout(self.attn(x, x, x, mask), p=self.dropout, training=self.training)
        x = self.norm_1(x)
        x = x + F.dropout(self.ff(x), p=self.dropout, training=self.training)
        x = self.norm_2(x)
        return x

class DecoderLayer(nn.Module):
    def __init__(self, d_model, heads, ffn_dim, dropout=0.1):
        super().__init__()
        self.dropout = dropout
        self.norm_1 = nn.LayerNorm(d_model)
        self.norm_2 = nn.LayerNorm(d_model)
        self.norm_3 = nn.LayerNorm(d_model)
        self.attn_1 = MultiHeadAttention(heads, d_model)
        self.attn_2 = MultiHeadAttention(heads, d_model)
        self.ff = FeedForward(d_model, ffn_dim)
        
    def forward(self, x, e_outputs, src_mask, trg_mask):
        x = x + F.dropout(self.attn_1(x, x, x, trg_mask), p=self.dropout, training=self.training)
        x = self.norm_1(x)
        x = x + F.dropout(self.attn_2(x, e_outputs, e_outputs, src_mask), p=self.dropout, training=self.training)
        x = self.norm_2(x)
        x = x + F.dropout(self.ff(x), p=self.dropout, training=self.training)
        x = self.norm_3(x)
        return x


In [19]:
class Encoder(nn.Module):
    def __init__(self, vocab_size, d_model, N, heads, ffn_dim):
        super().__init__()
        self.N = N
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=1)
        self.pe = PositionalEncoder(d_model)
        self.layers = nn.ModuleList([EncoderLayer(d_model, heads, ffn_dim) for _ in range(N)])

    def forward(self, src, mask):
        x = self.embed(src) * np.sqrt(self.d_model)
        x = self.pe(x)
        for i in range(self.N):
            x = self.layers[i](x, mask)
        return x

class Decoder(nn.Module):
    def __init__(self, vocab_size, d_model, N, heads, ffn_dim):
        super().__init__()
        self.N = N
        self.d_model = d_model
        self.embed = nn.Embedding(vocab_size, d_model, padding_idx=1)
        self.pe = PositionalEncoder(d_model)
        self.layers = nn.ModuleList([DecoderLayer(d_model, heads, ffn_dim) for _ in range(N)])

    def forward(self, trg, e_outputs, src_mask, trg_mask):
        x = self.embed(trg) * np.sqrt(self.d_model)
        x = self.pe(x)
        for i in range(self.N):
            x = self.layers[i](x, e_outputs, src_mask, trg_mask)
        return x

In [20]:
class Transformer(nn.Module):
    def __init__(self, src_vocab_size, trg_vocab_size, d_model, N, heads, enc_ffn_dim, dec_ffn_dim):
        super().__init__()
        self.encoder = Encoder(src_vocab_size, d_model, N, heads, enc_ffn_dim)
        self.decoder = Decoder(trg_vocab_size, d_model, N, heads, dec_ffn_dim)

    def forward(self, src, trg, src_mask, trg_mask):
        e_outputs = self.encoder(src, src_mask)
        d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
        output = F.linear(d_output, self.decoder.embed.weight)
        return output
    
    def out(self, trg, e_outputs, src_mask, trg_mask):
        d_output = self.decoder(trg, e_outputs, src_mask, trg_mask)
        output = F.linear(d_output, self.decoder.embed.weight)
        return output

# Train

In [21]:
model = Transformer(src_vocab_size, trg_vocab_size, d_model, N, heads, enc_ffn_dims, dec_ffn_dims).to(DEVICE)
for p in model.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)

In [22]:
def nopeak_mask(size):
    np_mask = np.triu(np.ones((1, size, size)), k=1).astype('uint8')
    np_mask = torch.from_numpy(np_mask) == 0
    return np_mask.to(DEVICE)


def create_masks(src, trg):
    src_mask = (src != PAD_ID).unsqueeze(-2)

    if trg is not None:
        trg_mask = (trg != PAD_ID).unsqueeze(-2)
        size = trg.size(1)
        np_mask = nopeak_mask(size).to(DEVICE)
        trg_mask = trg_mask & np_mask
    else:
        trg_mask = None
        
    return src_mask, trg_mask

In [23]:
from torch.optim.lr_scheduler import _LRScheduler


class Scheduler(_LRScheduler):
    def __init__(self, optimizer, dim_embed, warmup_steps, last_epoch=-1):
        self.dim_embed = dim_embed
        self.warmup_steps = warmup_steps
        self.num_param_groups = len(optimizer.param_groups)

        super().__init__(optimizer, last_epoch)
        
    def get_lr(self):
        lr = calc_lr(self._step_count, self.dim_embed, self.warmup_steps)
        return [lr] * self.num_param_groups


def calc_lr(step, dim_embed, warmup_steps):
    return dim_embed**(-0.5) * min(step**(-0.5), step * warmup_steps**(-1.5))

In [24]:
optim = torch.optim.Adam(model.parameters(), betas=(0.9, 0.98), eps=1e-9)
scheduler = Scheduler(optim, d_model, WARM_UP_STEPS)

In [25]:
from tqdm import tqdm

def train_model(epochs):
    model.train()
    
    start = time.time()
    temp = start
    total_loss = 0
    
    for epoch in range(epochs):
        for i, (src, trg) in enumerate(tqdm(dataloader, total=NUM_BATCHES)):

            src = src.to(DEVICE)
            trg = trg.to(DEVICE)
            
            trg_input = F.pad(trg[:, :-1], (1, 0), value=BOUND_ID)
            src_mask, trg_mask = create_masks(src, trg_input)
            preds = model(src, trg_input, src_mask, trg_mask)
            
            optim.zero_grad()
            loss = F.cross_entropy(preds.view(-1, preds.size(-1)), trg.view(-1), ignore_index=PAD_ID)
            loss.backward()
            optim.step()
            
            total_loss += loss
            
        loss_avg = total_loss / NUM_BATCHES
        print(f"{epoch} iter = {i + 1}, loss = {loss_avg}, lr = {scheduler.get_lr()[0]}")
        total_loss = 0
        tokens = np.argmax(F.softmax(preds, -1).data.tolist(), 2)
        print(f'Prediction: {dataset.trg_tokenizer.decode(tokens[0])}')
        print(f'True: {dataset.trg_tokenizer.decode(trg[0].data.tolist())}')

        scheduler.step()

In [26]:
train_model(EPOCHS)

  return torch.tensor(src_batch, dtype=torch.long), torch.tensor(trg_batch, dtype=torch.long)
100%|██████████| 100/100 [00:22<00:00,  4.37it/s]


0 iter = 100, loss = 10.648383140563965, lr = 2.4705294220065465e-07
Prediction: BraBraBrasoftware BraBraBraBulletin Bra
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.53it/s]


1 iter = 100, loss = 10.613430976867676, lr = 4.941058844013093e-07
Prediction: BraBraBraCM BraBrapétroinstructions Bra
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.69it/s]


2 iter = 100, loss = 10.559117317199707, lr = 7.41158826601964e-07
Prediction: <eos> papers BraBra204 <eos> <eos> sentencing.
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.52it/s]


3 iter = 100, loss = 10.490191459655762, lr = 9.882117688026186e-07
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.61it/s]


4 iter = 100, loss = 10.41409969329834, lr = 1.2352647110032732e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.53it/s]


5 iter = 100, loss = 10.339945793151855, lr = 1.482317653203928e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:20<00:00,  4.99it/s]


6 iter = 100, loss = 10.269495964050293, lr = 1.7293705954045826e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.59it/s]


7 iter = 100, loss = 10.197644233703613, lr = 1.976423537605237e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.52it/s]


8 iter = 100, loss = 10.12109088897705, lr = 2.223476479805892e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.66it/s]


9 iter = 100, loss = 10.036911010742188, lr = 2.4705294220065464e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.38it/s]


10 iter = 100, loss = 9.94424057006836, lr = 2.717582364207201e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.66it/s]


11 iter = 100, loss = 9.842901229858398, lr = 2.964635306407856e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.50it/s]


12 iter = 100, loss = 9.732585906982422, lr = 3.2116882486085104e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.65it/s]


13 iter = 100, loss = 9.61385440826416, lr = 3.4587411908091652e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.55it/s]


14 iter = 100, loss = 9.486661911010742, lr = 3.7057941330098196e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.72it/s]


15 iter = 100, loss = 9.351572036743164, lr = 3.952847075210474e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.60it/s]


16 iter = 100, loss = 9.208643913269043, lr = 4.199900017411129e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.58it/s]


17 iter = 100, loss = 9.058470726013184, lr = 4.446952959611784e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.48it/s]


18 iter = 100, loss = 8.90114688873291, lr = 4.694005901812438e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.51it/s]


19 iter = 100, loss = 8.737567901611328, lr = 4.941058844013093e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.51it/s]


20 iter = 100, loss = 8.568110466003418, lr = 5.188111786213748e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.59it/s]


21 iter = 100, loss = 8.39338207244873, lr = 5.435164728414402e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.49it/s]


22 iter = 100, loss = 8.214130401611328, lr = 5.682217670615057e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.65it/s]


23 iter = 100, loss = 8.030685424804688, lr = 5.929270612815712e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.50it/s]


24 iter = 100, loss = 7.844273567199707, lr = 6.176323555016366e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.62it/s]


25 iter = 100, loss = 7.655529499053955, lr = 6.423376497217021e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.48it/s]


26 iter = 100, loss = 7.465632915496826, lr = 6.670429439417676e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.57it/s]


27 iter = 100, loss = 7.2753214836120605, lr = 6.9174823816183304e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.45it/s]


28 iter = 100, loss = 7.085670471191406, lr = 7.164535323818985e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.64it/s]


29 iter = 100, loss = 6.898202419281006, lr = 7.411588266019639e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.40it/s]


30 iter = 100, loss = 6.713583946228027, lr = 7.658641208220294e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.62it/s]


31 iter = 100, loss = 6.533064842224121, lr = 7.905694150420949e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.45it/s]


32 iter = 100, loss = 6.3578033447265625, lr = 8.152747092621604e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.60it/s]


33 iter = 100, loss = 6.1885504722595215, lr = 8.399800034822258e-06
Prediction: <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.44it/s]


34 iter = 100, loss = 6.026398181915283, lr = 8.646852977022913e-06
Prediction: <eos> <eos> <eos> <eos>. <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.61it/s]


35 iter = 100, loss = 5.873161315917969, lr = 8.893905919223568e-06
Prediction: <eos> <eos> <eos>.. <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.41it/s]


36 iter = 100, loss = 5.731993198394775, lr = 9.140958861424223e-06
Prediction: <eos> <eos> <eos>.. <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.62it/s]


37 iter = 100, loss = 5.604893207550049, lr = 9.388011803624876e-06
Prediction: <eos> <eos> <eos>.. <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.51it/s]


38 iter = 100, loss = 5.492290019989014, lr = 9.63506474582553e-06
Prediction: <eos> <eos>... <eos> <eos> <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.62it/s]


39 iter = 100, loss = 5.394021511077881, lr = 9.882117688026186e-06
Prediction: <eos> <eos>..... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.46it/s]


40 iter = 100, loss = 5.306873798370361, lr = 1.012917063022684e-05
Prediction: <eos> <eos>..... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.61it/s]


41 iter = 100, loss = 5.227205753326416, lr = 1.0376223572427495e-05
Prediction: <eos> <eos>..... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.53it/s]


42 iter = 100, loss = 5.151206970214844, lr = 1.062327651462815e-05
Prediction: <eos> <eos>..... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.65it/s]


43 iter = 100, loss = 5.077214241027832, lr = 1.0870329456828805e-05
Prediction: <eos> <eos>..... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.54it/s]


44 iter = 100, loss = 5.004464149475098, lr = 1.111738239902946e-05
Prediction: <eos> <eos>..... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.51it/s]


45 iter = 100, loss = 4.933346271514893, lr = 1.1364435341230114e-05
Prediction: <eos>...... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.51it/s]


46 iter = 100, loss = 4.86358118057251, lr = 1.161148828343077e-05
Prediction: <eos> <eos> <eos>.... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.58it/s]


47 iter = 100, loss = 4.796123504638672, lr = 1.1858541225631424e-05
Prediction: <eos> <eos> <eos>.... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.56it/s]


48 iter = 100, loss = 4.730810642242432, lr = 1.2105594167832077e-05
Prediction: <eos> <eos> <eos>.... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.52it/s]


49 iter = 100, loss = 4.667760372161865, lr = 1.2352647110032732e-05
Prediction: <eos> <eos> <eos>.. <eos>. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.43it/s]


50 iter = 100, loss = 4.605944633483887, lr = 1.2599700052233387e-05
Prediction: <eos> <eos>..... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.63it/s]


51 iter = 100, loss = 4.546443462371826, lr = 1.2846752994434042e-05
Prediction: <eos> <eos> <eos>.... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.53it/s]


52 iter = 100, loss = 4.490264415740967, lr = 1.3093805936634696e-05
Prediction: <eos> <eos> <eos>.... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.58it/s]


53 iter = 100, loss = 4.4340620040893555, lr = 1.3340858878835351e-05
Prediction: <eos> <eos> <eos>.... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.50it/s]


54 iter = 100, loss = 4.381743431091309, lr = 1.3587911821036006e-05
Prediction: <eos> <eos> '.... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.56it/s]


55 iter = 100, loss = 4.328509330749512, lr = 1.3834964763236661e-05
Prediction: <eos> <eos> ''... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.57it/s]


56 iter = 100, loss = 4.278448104858398, lr = 1.4082017705437316e-05
Prediction: <eos> I I.... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.57it/s]


57 iter = 100, loss = 4.230011463165283, lr = 1.432907064763797e-05
Prediction: <eos> I ''... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.59it/s]


58 iter = 100, loss = 4.180129528045654, lr = 1.4576123589838624e-05
Prediction: I I ''... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.54it/s]


59 iter = 100, loss = 4.133031368255615, lr = 1.4823176532039278e-05
Prediction: I I I you... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.60it/s]


60 iter = 100, loss = 4.0855255126953125, lr = 1.5070229474239933e-05
Prediction: I I 'you... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.51it/s]


61 iter = 100, loss = 4.040093421936035, lr = 1.5317282416440588e-05
Prediction: I I I you... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.58it/s]


62 iter = 100, loss = 3.9943957328796387, lr = 1.5564335358641243e-05
Prediction: I I I you... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.48it/s]


63 iter = 100, loss = 3.9503467082977295, lr = 1.5811388300841898e-05
Prediction: I I I you... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.57it/s]


64 iter = 100, loss = 3.907180070877075, lr = 1.6058441243042552e-05
Prediction: I I I you... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.48it/s]


65 iter = 100, loss = 3.864348888397217, lr = 1.6305494185243207e-05
Prediction: I I you you... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.61it/s]


66 iter = 100, loss = 3.8198864459991455, lr = 1.6552547127443862e-05
Prediction: I I I you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.45it/s]


67 iter = 100, loss = 3.777959108352661, lr = 1.6799600069644517e-05
Prediction: I I you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.62it/s]


68 iter = 100, loss = 3.735535144805908, lr = 1.7046653011845172e-05
Prediction: I I you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.48it/s]


69 iter = 100, loss = 3.693073272705078, lr = 1.7293705954045827e-05
Prediction: I I you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.58it/s]


70 iter = 100, loss = 3.6539952754974365, lr = 1.754075889624648e-05
Prediction: I I you you... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.40it/s]


71 iter = 100, loss = 3.6126482486724854, lr = 1.7787811838447136e-05
Prediction: I I you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.62it/s]


72 iter = 100, loss = 3.5712573528289795, lr = 1.803486478064779e-05
Prediction: I I you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.42it/s]


73 iter = 100, loss = 3.5320773124694824, lr = 1.8281917722848446e-05
Prediction: I I I you... <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.60it/s]


74 iter = 100, loss = 3.4927620887756348, lr = 1.8528970665049097e-05
Prediction: I I you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.46it/s]


75 iter = 100, loss = 3.453685760498047, lr = 1.8776023607249752e-05
Prediction: I I you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.61it/s]


76 iter = 100, loss = 3.4143905639648438, lr = 1.9023076549450407e-05
Prediction: I I you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.48it/s]


77 iter = 100, loss = 3.377269744873047, lr = 1.927012949165106e-05
Prediction: I I you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.63it/s]


78 iter = 100, loss = 3.3400394916534424, lr = 1.9517182433851716e-05
Prediction: I I you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.48it/s]


79 iter = 100, loss = 3.3016984462738037, lr = 1.976423537605237e-05
Prediction: I I you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.46it/s]


80 iter = 100, loss = 3.262112855911255, lr = 2.0011288318253026e-05
Prediction: I I you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.48it/s]


81 iter = 100, loss = 3.2272136211395264, lr = 2.025834126045368e-05
Prediction: I I you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.61it/s]


82 iter = 100, loss = 3.1933064460754395, lr = 2.0505394202654336e-05
Prediction: I m a you you you. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.46it/s]


83 iter = 100, loss = 3.1595897674560547, lr = 2.075244714485499e-05
Prediction: I I you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.63it/s]


84 iter = 100, loss = 3.124579906463623, lr = 2.0999500087055645e-05
Prediction: I m to you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.45it/s]


85 iter = 100, loss = 3.0876402854919434, lr = 2.12465530292563e-05
Prediction: I m a you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.61it/s]


86 iter = 100, loss = 3.051502227783203, lr = 2.1493605971456955e-05
Prediction: I m you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.46it/s]


87 iter = 100, loss = 3.018120288848877, lr = 2.174065891365761e-05
Prediction: I m you you you you. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.62it/s]


88 iter = 100, loss = 2.98427677154541, lr = 2.1987711855858265e-05
Prediction: I m you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.42it/s]


89 iter = 100, loss = 2.950727701187134, lr = 2.223476479805892e-05
Prediction: I m you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.65it/s]


90 iter = 100, loss = 2.916313648223877, lr = 2.2481817740259574e-05
Prediction: I m you you you you. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.42it/s]


91 iter = 100, loss = 2.8824706077575684, lr = 2.272887068246023e-05
Prediction: I m you you you you. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.60it/s]


92 iter = 100, loss = 2.852226495742798, lr = 2.2975923624660884e-05
Prediction: I m you you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.48it/s]


93 iter = 100, loss = 2.818902015686035, lr = 2.322297656686154e-05
Prediction: I m a you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.61it/s]


94 iter = 100, loss = 2.7890281677246094, lr = 2.3470029509062193e-05
Prediction: I' m go you you you. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.49it/s]


95 iter = 100, loss = 2.7568533420562744, lr = 2.3717082451262848e-05
Prediction: I' m go you you you. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.58it/s]


96 iter = 100, loss = 2.7265191078186035, lr = 2.39641353934635e-05
Prediction: I' m you you you you. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.46it/s]


97 iter = 100, loss = 2.6932079792022705, lr = 2.4211188335664154e-05
Prediction: I m go you you you. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:17<00:00,  5.63it/s]


98 iter = 100, loss = 2.6669790744781494, lr = 2.445824127786481e-05
Prediction: I' m go you you you. <eos> <eos>
True: I' m letting you go. <eos> <pad>


100%|██████████| 100/100 [00:18<00:00,  5.48it/s]


99 iter = 100, loss = 2.637155055999756, lr = 2.4705294220065464e-05
Prediction: I m go you you.. <eos> <eos>
True: I' m letting you go. <eos> <pad>


# Beam search

In [27]:
import math


def init_vars(sentence, model, K):
    src_mask = (sentence != PAD_ID)
    e_output = model.encoder(sentence, src_mask)
            
    out = model.out(torch.LongTensor([[BOUND_ID]]).to(DEVICE), 
                    e_output, src_mask, 
                    nopeak_mask(1))
    out = F.softmax(out, -1)
    
    probs, ix = out[:, -1].data.topk(K)
    log_scores = torch.Tensor([math.log(prob) for prob in probs.data[0]]).unsqueeze(0).to(DEVICE)
    
    outputs = torch.zeros(K, MAX_SEQ_LEN).long().to(DEVICE)
    outputs[:, 0] = BOUND_ID
    outputs[:, 1] = ix[0]
    
    e_outputs = torch.zeros(K, e_output.size(-2),e_output.size(-1)).to(DEVICE)
    e_outputs[:, :] = e_output[0]
    
    return outputs, e_outputs, log_scores


def k_best_outputs(outputs, out, log_scores, i, k):
    probs, ix = out[:, -1].data.topk(k)
    log_probs = torch.Tensor([math.log(p) for p in probs.data.view(-1)]).view(k, -1).to(DEVICE) + log_scores.transpose(0,1)
    log_probs = log_probs.to(DEVICE)
    
    k_probs, k_ix = log_probs.view(-1).topk(k)
    
    row = k_ix // k
    col = k_ix % k

    outputs[:, :i] = outputs[row, :i]
    outputs[:, i] = ix[row, col]

    log_scores = k_probs.unsqueeze(0)
    
    return outputs, log_scores


def beam_search(sentence, model, K = 10):    
    outputs, e_outputs, log_scores = init_vars(sentence, model, K)
    src_mask = (sentence != PAD_ID).to(DEVICE)
    ind = None
    for i in range(2, MAX_SEQ_LEN):
    
        trg_mask = nopeak_mask(i)

        out = model.out(outputs[:,:i], e_outputs, src_mask, trg_mask)

        out = F.softmax(out, dim=-1)
    
        outputs, log_scores = k_best_outputs(outputs, out, log_scores, i, K)
        
        sentence_lengths = torch.zeros(len(outputs), dtype=torch.long).to(DEVICE)
        
        ones = (outputs==BOUND_ID).nonzero()
        for vec in ones:
            i = vec[0]
            if sentence_lengths[i] == 0:
                sentence_lengths[i] = vec[1]
        num_finished_sentences = len([s for s in sentence_lengths if s > 0])
        if num_finished_sentences == K:
            alpha = 0.6
            div = 1/(sentence_lengths.type_as(log_scores)**alpha)
            _, ind = torch.max(log_scores * div, 1)
            ind = ind.data[0]
            break
        
    if ind is None:
        length = (outputs[0] == BOUND_ID).nonzero()
        if len(length) != 0:
            length = length[1]
        else:
            length = 10
        res = outputs[0][1:length].data.tolist()
        return dataset.trg_tokenizer.decode(res)
    else:
        length = (outputs[ind] == BOUND_ID).nonzero()[1]
        res = outputs[ind][1:length].data.tolist()
        return dataset.trg_tokenizer.decode(res)

# Inference

In [28]:
model.eval()

Transformer(
  (encoder): Encoder(
    (embed): Embedding(43807, 256, padding_idx=1)
    (pe): PositionalEncoder()
    (layers): ModuleList(
      (0-2): 3 x EncoderLayer(
        (norm_1): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (norm_2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
        (attn): MultiHeadAttention(
          (q_linear): Linear(in_features=256, out_features=256, bias=True)
          (v_linear): Linear(in_features=256, out_features=256, bias=True)
          (k_linear): Linear(in_features=256, out_features=256, bias=True)
          (dropout): Dropout(p=0.1, inplace=False)
          (out): Linear(in_features=256, out_features=256, bias=True)
        )
        (ff): FeedForward(
          (linear_1): Linear(in_features=256, out_features=1024, bias=True)
          (linear_2): Linear(in_features=1024, out_features=256, bias=True)
        )
      )
    )
  )
  (decoder): Decoder(
    (embed): Embedding(43771, 256, padding_idx=1)
    (pe): Pos

In [29]:
def translate(sent):
    sent = torch.LongTensor([dataset.src_line_mapper(sent)]).to(DEVICE)
    return beam_search(sent, model, 5)

In [30]:
translate('j adore les fleurs')

'I I I I I I like like like like like like books.'

Transformer model is trickier to tune than LSTM...

3 bonus points to the first three students who tune this notebook to make Transformer correctly translate most of training set in under 30 min training in Colab. Send me the colab link in telegram.