# Transformer seq2seq Model

- Modified from https://github.com/bentrevett/pytorch-seq2seq
- also useful: https://pytorch.org/tutorials/beginner/torchtext_translation_tutorial.html

In [1]:
%load_ext autoreload
%autoreload 2

In [17]:
import torch
import torch.nn as nn
import torch.optim as optim

print("Torch Version:", torch.__version__)

import torchtext
# from torchtext.legacy.datasets import Multi30k
# from torchtext.legacy.data import Field, BucketIterator
from torchtext.data import Field, BucketIterator

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

# import spacy
import numpy as np
import unicodedata
import re
import numpy as np
import os
import io
import pickle

import random
import math
import time
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from transformer_pyt import Encoder, Decoder, Seq2Seq, Seq2SeqMulti
from transformer_pyt import train, evaluate
from utils import (get_session_data, 
                   build_vocab_from_seqs, 
                   data_process_meta, 
                   data_process_no_meta,
                   epoch_time,
                   get_session_data_test)

Torch Version: 1.5.0


In [3]:
# from torchtext.data.utils import get_tokenizer
from torchtext.data import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

In [4]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [16]:
data_dir = "/recsys_data/RecSys/h_and_m_personalized_fashion_recommendation"
file_name = "hnm_3w_sessionized_orig.txt" # "hnm_big.txt", "hnm_3w_sessionized.txt", "hnm_7w_sessionized.txt"
model_name = file_name.split(".")[0] + ".pt"
seq_file_name = "seq_" + file_name
test_seq_file = "seq_test_" + file_name
colsep = "\t"
include_meta = False

inp_seq_len, tgt_seq_len = 12, 12
BATCH_SIZE = 256
num_examples = None
file_path = os.path.join(data_dir, seq_file_name)
model_path = os.path.join(data_dir, model_name)
test_file_path = os.path.join(data_dir, test_seq_file)

tokenizer = get_tokenizer("basic_english")
# en_tokenizer = get_tokenizer(language='en')

tokens = tokenizer('0924243001 0924243002 0923758001 0918522001 0909370001 0866731001 0751471001 0915529003 0915529005 0448509014 0762846027 0714790020')
tokens

['0924243001',
 '0924243002',
 '0923758001',
 '0918522001',
 '0909370001',
 '0866731001',
 '0751471001',
 '0915529003',
 '0915529005',
 '0448509014',
 '0762846027',
 '0714790020']

Get all the sequence information

In [35]:
inp_file = os.path.join(data_dir, file_name)

all_seqs, prod_dict = get_session_data(inp_file, 
                                       inp_seq_len=inp_seq_len,
                                       tgt_seq_len=tgt_seq_len,
                                       convert_to_integer=False)
print(all_seqs.keys())

364695it [00:01, 210909.77it/s]


Read 48709 user interactions
dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 'prod'])


In [36]:
all_seqs['prod'][0]

(['0921073001'], ['0777148006', '0835801001', '0923134005', '0865929003'])

In [37]:
src_vocab = build_vocab_from_seqs(all_seqs['prod'], tokenizer)
print(f"Total {len(src_vocab)} words in the vocabulary")

if include_meta:
    all_data = data_process_meta(all_seqs, tokenizer, src_vocab)
else:
    all_data = data_process_no_meta(all_seqs, src_vocab)
train_data, val_data = train_test_split(all_data, test_size=0.2)
# test_data = data_process(test_file_path, tokenizer, src_vocab, test_flag=True)
len(all_data), len(train_data), len(val_data)#, len(test_data)

Total 21533 words in the vocabulary


(71460, 57168, 14292)

In [38]:
src_vocab['<unk>'], src_vocab['<pad>'], src_vocab['<bos>'], src_vocab['<eos>']

(0, 1, 2, 3)

In [39]:
src_vocab['921073001'], src_vocab['777148006'], len(src_vocab)

(0, 0, 21533)

In [40]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [41]:
train_data[0]

(tensor([13867,  4365,   822,  5939,   625,  2707,   754,  8904,  2454]),
 tensor([51]))

## Load Product Vector - Separately Created

In [31]:
pv_file = os.path.join(data_dir, "prod_vectors.pkl")
with open(pv_file, "rb") as fr:
    prod_vec, prod_list = pickle.load(fr)
print(prod_vec.shape, len(prod_list))

(105542, 50) 105542


### Create the Embedding Matrix

In [42]:
embed_matrix = np.zeros((len(src_vocab), prod_vec.shape[1]))
count = 0
for tok, idx in src_vocab.stoi.items():
    if tok in prod_list:
        list_index = prod_list.index(tok)
        embed_matrix[idx] = prod_vec[list_index]
    else:
        count += 1
print(count)

4


In [52]:
PAD_IDX = src_vocab['<pad>']
BOS_IDX = src_vocab['<bos>']
EOS_IDX = src_vocab['<eos>']

if include_meta:
    def generate_batch(data_batch):
        inp_batch, tgt_batch = [], []
        for (de_item, en_item) in data_batch:
            n = de_item.shape[0]
            before = torch.unsqueeze(torch.tensor([BOS_IDX] * n), 1)
            after = torch.unsqueeze(torch.tensor([EOS_IDX] * n), 1)
            total = torch.cat([before, de_item, after], dim=1)
            total = total.permute(1, 0)
            inp_batch.append(total)
            tgt_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))

        inp_batch = pad_sequence(inp_batch, padding_value=PAD_IDX)
        tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
        return inp_batch, tgt_batch

    # input dimension of each product attribute
    prod_dict_dims = [max(prod_dict[k].values()) for k in range(len(prod_dict))]

else:
    def generate_batch(data_batch):
        de_batch, en_batch = [], []
        for (de_item, en_item) in data_batch:
            de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0))
            en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
        de_batch = pad_sequence(de_batch, padding_value=PAD_IDX)
        en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
        return de_batch, en_batch

train_iterator = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
valid_iterator = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
    

Create multiple encoders - one for each product attribute

In [68]:
SRC_PAD_IDX = 0
TRG_PAD_IDX = 0

if include_meta:
    HID_DIM = 128
    HID_DIM2 = 32
    ENC_LAYERS = 2
    DEC_LAYERS = 2
    ENC_HEADS = 8
    DEC_HEADS = 8
    ENC_PF_DIM = 128
    DEC_PF_DIM = 128
    ENC_DROPOUT = 0.1
    DEC_DROPOUT = 0.1
    OUTPUT_DIM = len(src_vocab)

    encs = torch.nn.ModuleList()
    prod_enc = Encoder(input_dim=len(src_vocab), 
                       hid_dim=HID_DIM, 
                       n_layers=ENC_LAYERS, 
                       n_heads=ENC_HEADS, 
                       pf_dim=ENC_PF_DIM, 
                       dropout=ENC_DROPOUT, 
                       device=device)
    encs.append(prod_enc)
    total_dim = HID_DIM
    for pdim in prod_dict_dims:
        enc_p = Encoder(input_dim=pdim, 
                       hid_dim=HID_DIM2, 
                       n_layers=ENC_LAYERS, 
                       n_heads=ENC_HEADS, 
                       pf_dim=ENC_PF_DIM, 
                       dropout=ENC_DROPOUT, 
                       device=device)
        encs.append(enc_p)
        total_dim += HID_DIM2

    dec = Decoder(OUTPUT_DIM, 
                  HID_DIM, 
                  DEC_LAYERS, 
                  DEC_HEADS, 
                  DEC_PF_DIM, 
                  DEC_DROPOUT, 
                  device)
    model = Seq2SeqMulti(encs, dec, total_dim, HID_DIM, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)    
else:
    INPUT_DIM = len(src_vocab)
    OUTPUT_DIM = len(src_vocab)
    HID_DIM = 256
    ENC_LAYERS = 3
    DEC_LAYERS = 3
    ENC_HEADS = 8
    DEC_HEADS = 8
    ENC_PF_DIM = 512
    DEC_PF_DIM = 512
    ENC_DROPOUT = 0.1
    DEC_DROPOUT = 0.1

    enc = Encoder(INPUT_DIM, 
                  HID_DIM, 
                  ENC_LAYERS, 
                  ENC_HEADS, 
                  ENC_PF_DIM, 
                  ENC_DROPOUT, 
                  device,
                  pretrained=True,
                  weight=torch.FloatTensor(embed_matrix),  # has to be converted
                  ext_embed_dim=embed_matrix.shape[1])

    dec = Decoder(OUTPUT_DIM, 
                  HID_DIM, 
                  DEC_LAYERS, 
                  DEC_HEADS, 
                  DEC_PF_DIM, 
                  DEC_DROPOUT, 
                  device)
    model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)
    

In [69]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 15,064,349 trainable parameters


In [70]:
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)
        
model.apply(initialize_weights);

## Model Training

In [71]:
import math

N_EPOCHS = 20
LEARNING_RATE = 1e-04
CLIP = 1

optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

best_valid_loss = float('inf')
best_valid_map = 0
patience, max_patience = 0, 5

print(f"Number of steps per epoch: {len(train_data)//BATCH_SIZE}")
for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP, device)
    valid_loss, valid_map = evaluate(model, valid_iterator, criterion, device)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
    if valid_map > best_valid_map:
        best_valid_map = valid_map
        torch.save(model.state_dict(), model_path)
        patience = 0
    else:
        patience += 1

    if patience == max_patience:
        print("Maximum patience reached ... exiting!")
        break

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s | patience {patience}')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f} | Val. MAP: {valid_map:7.3f}')

0it [00:00, ?it/s]

Number of steps per epoch: 223


224it [00:56,  3.97it/s]
56it [00:05, 11.10it/s]
0it [00:00, ?it/s]

Epoch: 01 | Time: 1m 1s | patience 1
	Train Loss: 4.959 | Train PPL: 142.424
	 Val. Loss: 2.144 |  Val. PPL:   8.534 | Val. MAP:   0.000


224it [00:56,  3.96it/s]
56it [00:05, 11.04it/s]
0it [00:00, ?it/s]

Epoch: 02 | Time: 1m 1s | patience 2
	Train Loss: 2.018 | Train PPL:   7.522
	 Val. Loss: 1.969 |  Val. PPL:   7.164 | Val. MAP:   0.000


224it [00:56,  3.95it/s]
56it [00:05, 10.92it/s]
0it [00:00, ?it/s]

Epoch: 03 | Time: 1m 1s | patience 0
	Train Loss: 1.950 | Train PPL:   7.029
	 Val. Loss: 1.951 |  Val. PPL:   7.039 | Val. MAP:   0.005


224it [00:56,  3.96it/s]
56it [00:05, 10.93it/s]
0it [00:00, ?it/s]

Epoch: 04 | Time: 1m 1s | patience 1
	Train Loss: 1.920 | Train PPL:   6.824
	 Val. Loss: 1.957 |  Val. PPL:   7.079 | Val. MAP:   0.005


224it [00:56,  3.95it/s]
56it [00:05, 10.79it/s]
0it [00:00, ?it/s]

Epoch: 05 | Time: 1m 1s | patience 2
	Train Loss: 1.895 | Train PPL:   6.656
	 Val. Loss: 1.953 |  Val. PPL:   7.053 | Val. MAP:   0.005


224it [00:56,  3.94it/s]
56it [00:05, 10.82it/s]
0it [00:00, ?it/s]

Epoch: 06 | Time: 1m 2s | patience 3
	Train Loss: 1.870 | Train PPL:   6.489
	 Val. Loss: 1.955 |  Val. PPL:   7.064 | Val. MAP:   0.005


224it [00:56,  3.93it/s]
56it [00:05, 10.86it/s]
0it [00:00, ?it/s]

Epoch: 07 | Time: 1m 2s | patience 0
	Train Loss: 1.838 | Train PPL:   6.282
	 Val. Loss: 1.956 |  Val. PPL:   7.070 | Val. MAP:   0.006


224it [00:56,  3.95it/s]
56it [00:05, 10.80it/s]
0it [00:00, ?it/s]

Epoch: 08 | Time: 1m 1s | patience 0
	Train Loss: 1.812 | Train PPL:   6.120
	 Val. Loss: 1.954 |  Val. PPL:   7.060 | Val. MAP:   0.006


224it [00:57,  3.92it/s]
56it [00:05, 10.95it/s]
0it [00:00, ?it/s]

Epoch: 09 | Time: 1m 2s | patience 0
	Train Loss: 1.779 | Train PPL:   5.923
	 Val. Loss: 1.961 |  Val. PPL:   7.105 | Val. MAP:   0.007


224it [00:57,  3.93it/s]
56it [00:05, 10.91it/s]
0it [00:00, ?it/s]

Epoch: 10 | Time: 1m 2s | patience 0
	Train Loss: 1.747 | Train PPL:   5.739
	 Val. Loss: 1.964 |  Val. PPL:   7.127 | Val. MAP:   0.007


224it [00:56,  3.94it/s]
56it [00:05, 10.88it/s]
0it [00:00, ?it/s]

Epoch: 11 | Time: 1m 2s | patience 0
	Train Loss: 1.717 | Train PPL:   5.566
	 Val. Loss: 1.966 |  Val. PPL:   7.144 | Val. MAP:   0.009


224it [00:57,  3.92it/s]
56it [00:05, 10.92it/s]
0it [00:00, ?it/s]

Epoch: 12 | Time: 1m 2s | patience 0
	Train Loss: 1.686 | Train PPL:   5.399
	 Val. Loss: 1.967 |  Val. PPL:   7.151 | Val. MAP:   0.010


224it [00:57,  3.92it/s]
56it [00:05, 10.87it/s]
0it [00:00, ?it/s]

Epoch: 13 | Time: 1m 2s | patience 0
	Train Loss: 1.656 | Train PPL:   5.236
	 Val. Loss: 1.965 |  Val. PPL:   7.135 | Val. MAP:   0.014


224it [00:57,  3.92it/s]
56it [00:05, 10.91it/s]
0it [00:00, ?it/s]

Epoch: 14 | Time: 1m 2s | patience 0
	Train Loss: 1.628 | Train PPL:   5.092
	 Val. Loss: 1.977 |  Val. PPL:   7.220 | Val. MAP:   0.017


224it [00:56,  3.94it/s]
56it [00:05, 10.64it/s]
0it [00:00, ?it/s]

Epoch: 15 | Time: 1m 2s | patience 1
	Train Loss: 1.600 | Train PPL:   4.955
	 Val. Loss: 1.986 |  Val. PPL:   7.284 | Val. MAP:   0.016


224it [00:56,  3.94it/s]
56it [00:05, 10.91it/s]
0it [00:00, ?it/s]

Epoch: 16 | Time: 1m 2s | patience 0
	Train Loss: 1.572 | Train PPL:   4.816
	 Val. Loss: 1.990 |  Val. PPL:   7.315 | Val. MAP:   0.025


224it [00:57,  3.93it/s]
56it [00:05, 10.66it/s]
0it [00:00, ?it/s]

Epoch: 17 | Time: 1m 2s | patience 0
	Train Loss: 1.545 | Train PPL:   4.687
	 Val. Loss: 1.994 |  Val. PPL:   7.346 | Val. MAP:   0.033


224it [00:57,  3.93it/s]
56it [00:05, 10.82it/s]
0it [00:00, ?it/s]

Epoch: 18 | Time: 1m 2s | patience 1
	Train Loss: 1.518 | Train PPL:   4.562
	 Val. Loss: 2.007 |  Val. PPL:   7.444 | Val. MAP:   0.031


224it [00:56,  3.93it/s]
56it [00:05, 10.75it/s]
0it [00:00, ?it/s]

Epoch: 19 | Time: 1m 2s | patience 0
	Train Loss: 1.492 | Train PPL:   4.444
	 Val. Loss: 2.015 |  Val. PPL:   7.503 | Val. MAP:   0.037


224it [00:57,  3.92it/s]
56it [00:05, 10.82it/s]


Epoch: 20 | Time: 1m 2s | patience 0
	Train Loss: 1.464 | Train PPL:   4.324
	 Val. Loss: 2.017 |  Val. PPL:   7.514 | Val. MAP:   0.049


In [26]:
file_name

'hnm_7w_sessionized.txt'

### Get the test data - last session

In [52]:
test_seqs = get_session_data_test(inp_file, prod_dict, inp_seq_len=inp_seq_len)
print(len(test_seqs), test_seqs.keys())

1332519it [00:05, 227087.95it/s]


Read 144202 user interactions
10 dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 'prod'])


In [57]:
test_data = data_process_meta(test_seqs, tokenizer, src_vocab, test_flag=True)
len(test_data)

144202

In [61]:
def generate_batch_test(data_batch):
    inp_batch = []
    for de_item in data_batch:
        n = de_item.shape[0]
        before = torch.unsqueeze(torch.tensor([BOS_IDX] * n), 1)
        after = torch.unsqueeze(torch.tensor([EOS_IDX] * n), 1)
        total = torch.cat([before, de_item, after], dim=1)
        total = total.permute(1, 0)
        inp_batch.append(total)
    
#     print(torch.cat(inp_batch, dim=0).shape)
    inp_batch = pad_sequence(inp_batch, padding_value=PAD_IDX)
    return inp_batch

test_iterator = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=generate_batch_test)

In [63]:
model.eval()

epoch_loss = 0
all_maps = []
with torch.no_grad():

    for _, src in tqdm(enumerate(test_iterator)):

        src = src.to(device)
        # make the batch dimension first
        if src.dim() == 3:
            src = src.permute(1, 0, 2)
        else:
            src = src.permute(1, 0)

        output, _ = model(src, trg[:, :-1])
        prediction = torch.argmax(output, axis=-1)
        print(prediction.shape)
        sys.exit()


0it [00:00, ?it/s]


NameError: name 'trg' is not defined

In [90]:
def translate_sentence(src_tensor, model, vocab, device, max_len = 50):
    """
    inp: (b, s, f), batch_size X sequence length X feature dimension
    """
    src_tensor = src_tensor.unsqueeze(0).to(device)  # add the batch dimension
    
    model.eval()
    src_mask = model.make_src_mask(src_tensor[:, :, 0])
    
    with torch.no_grad():
        enc_src = []
        for ii in range(model.num_features):
            enc_src.append(model.encoders[ii](src_tensor[:, :, ii], src_mask))
        enc_src = torch.cat(enc_src, dim=-1)
        enc_src = model.linear(enc_src)

    trg_indexes = [vocab.stoi['<bos>']]

    for i in range(max_len):
        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
        trg_mask = model.make_trg_mask(trg_tensor)
        with torch.no_grad():
            output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
        pred_token = output.argmax(2)[:,-1].item()
        trg_indexes.append(pred_token)
        if pred_token == vocab.stoi['<eos>']:
            break
    
    trg_tokens = [vocab.itos[i] for i in trg_indexes]
    return trg_tokens[1:], attention

In [94]:
pred, _ = translate_sentence(src[4, :, :], model, src_vocab, device, max_len = tgt_seq_len)
pred

['220', '<eos>']

# Only Product Sequence

In [7]:
src_vocab = build_vocab_from_file(file_path, tokenizer)
all_data = data_process(file_path, tokenizer, src_vocab)
train_data, val_data = train_test_split(all_data, test_size=0.2)
test_data = data_process(test_file_path, tokenizer, src_vocab, test_flag=True)
len(all_data), len(train_data), len(val_data), len(test_data)

(71460, 57168, 14292, 48709)

In [23]:
INPUT_DIM = len(src_vocab)
OUTPUT_DIM = len(src_vocab)
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

enc = Encoder(INPUT_DIM, 
              HID_DIM, 
              ENC_LAYERS, 
              ENC_HEADS, 
              ENC_PF_DIM, 
              ENC_DROPOUT, 
              device)

dec = Decoder(OUTPUT_DIM, 
              HID_DIM, 
              DEC_LAYERS, 
              DEC_HEADS, 
              DEC_PF_DIM, 
              DEC_DROPOUT, 
              device)

In [10]:
PAD_IDX = src_vocab['<pad>']
BOS_IDX = src_vocab['<bos>']
EOS_IDX = src_vocab['<eos>']

def generate_batch(data_batch):
    de_batch, en_batch = [], []
    for (de_item, en_item) in data_batch:
        de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0))
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
    de_batch = pad_sequence(de_batch, padding_value=PAD_IDX)
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
    return de_batch, en_batch

train_iterator = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
valid_iterator = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
test_iterator = DataLoader(test_data, batch_size=BATCH_SIZE,
                       shuffle=False, collate_fn=generate_batch)

In [24]:
SRC_PAD_IDX = 0
TRG_PAD_IDX = 0

model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)

In [34]:
def predict(sentence, src_vocab, trg_vocab, model, device, max_len = tgt_seq_len):
    
    model.eval()

    tokens = tokenizer(sentence)
    tokens = [src_vocab['<bos>']] + tokens + [src_vocab['<eos>']]
    src_indexes = [src_vocab.stoi[token] for token in tokens]
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
    src_mask = model.make_src_mask(src_tensor)
    
    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)

    trg_indexes = [trg_vocab.stoi['<bos>']]
    for i in range(max_len):

        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
        trg_mask = model.make_trg_mask(trg_tensor)
        
        with torch.no_grad():
            output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
        
        pred_token = output.argmax(2)[:,-1].item()
        trg_indexes.append(pred_token)
        if pred_token == trg_vocab.stoi['<eos>']:
            break
    
    trg_tokens = [trg_vocab.itos[i] for i in trg_indexes]
    
    return trg_tokens[1:], attention

In [37]:
pred, _ = predict('13112 16042 3871 35', src_vocab, src_vocab, model, device)
pred

['1566', '<eos>']