# Transformer seq2seq Model

- Modified from https://github.com/bentrevett/pytorch-seq2seq
- also useful: https://pytorch.org/tutorials/beginner/torchtext_translation_tutorial.html

In [1]:
%load_ext autoreload
%autoreload 2

In [6]:
import torch
import torch.nn as nn
import torch.optim as optim

print("Torch Version:", torch.__version__)

import torchtext
# from torchtext.legacy.datasets import Multi30k
# from torchtext.legacy.data import Field, BucketIterator
from torchtext.data import Field, BucketIterator

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker

# import spacy
import numpy as np
import unicodedata
import re
import numpy as np
import os
import io
import pickle

import random
import math
import time
from sklearn.model_selection import train_test_split
from tqdm import tqdm

from transformer_pyt import Encoder, Decoder, Seq2Seq, Seq2SeqMulti, Seq2Class, SimpleDecoder
from transformer_pyt import train, evaluate
from utils import (get_session_data, 
                   build_vocab_from_seqs, 
                   data_process_meta, 
                   data_process_no_meta,
                   epoch_time,
                   get_session_data_test)

Torch Version: 1.5.0


In [3]:
# from torchtext.data.utils import get_tokenizer
from torchtext.data import get_tokenizer
from collections import Counter
from torchtext.vocab import Vocab

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

In [4]:
SEED = 1234

random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True

In [20]:
data_dir = "/recsys_data/RecSys/h_and_m_personalized_fashion_recommendation"
file_name = "hnm_3w_sessionized_orig.txt" # "hnm_big.txt", "hnm_3w_sessionized.txt", "hnm_7w_sessionized.txt"
seq_file_name = "seq_" + file_name
test_seq_file = "seq_test_" + file_name
colsep = "\t"
include_meta = False
model_type = "seq2class"  # "seq2seq", "seq2class"
model_name = file_name.split(".")[0] + "_" + model_type + ".pt"

inp_seq_len, tgt_seq_len = 12, 12
BATCH_SIZE = 256
num_examples = None
file_path = os.path.join(data_dir, seq_file_name)
model_path = os.path.join(data_dir, model_name)
test_file_path = os.path.join(data_dir, test_seq_file)

tokenizer = get_tokenizer("basic_english")
# en_tokenizer = get_tokenizer(language='en')

tokens = tokenizer('0924243001 0924243002 0923758001 0918522001 0909370001 0866731001 0751471001 0915529003 0915529005 0448509014 0762846027 0714790020')
tokens

['0924243001',
 '0924243002',
 '0923758001',
 '0918522001',
 '0909370001',
 '0866731001',
 '0751471001',
 '0915529003',
 '0915529005',
 '0448509014',
 '0762846027',
 '0714790020']

Get all the sequence information

In [7]:
inp_file = os.path.join(data_dir, file_name)

all_seqs, prod_dict = get_session_data(inp_file, 
                                       inp_seq_len=inp_seq_len,
                                       tgt_seq_len=tgt_seq_len,
                                       convert_to_integer=False)
print(all_seqs.keys())

364695it [00:02, 163209.13it/s]


Read 48709 user interactions
dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 'prod'])


In [8]:
all_seqs['prod'][0]

(['0921073001'], ['0777148006', '0835801001', '0923134005', '0865929003'])

In [9]:
src_vocab = build_vocab_from_seqs(all_seqs['prod'], tokenizer, extra=["<pad>", "<bos>", "<eos>"])
print(f"Total {len(src_vocab)} words in the vocabulary")

if include_meta:
    all_data = data_process_meta(all_seqs, tokenizer, src_vocab)
else:
    all_data = data_process_no_meta(all_seqs, src_vocab)
train_data, val_data = train_test_split(all_data, test_size=0.2)
# test_data = data_process(test_file_path, tokenizer, src_vocab, test_flag=True)
len(all_data), len(train_data), len(val_data)#, len(test_data)

Total 21532 words in the vocabulary


(71460, 57168, 14292)

In [10]:
src_vocab['<unk>'], src_vocab['<pad>'], src_vocab['<bos>'], src_vocab['<eos>']

(None, 0, 1, 2)

In [11]:
src_vocab['0921073001'], src_vocab['0777148006'], len(src_vocab)

(980, 12650, 21532)

In [12]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device

device(type='cuda')

In [13]:
train_data[0]

(tensor([12909,  7398]),
 tensor([  52,    5, 1069,   25,   25,  208, 3076, 3076,   15,  169]))

## Load Product Vector - Separately Created

In [14]:
pv_file = os.path.join(data_dir, "prod_vectors.pkl")
with open(pv_file, "rb") as fr:
    prod_vec, prod_list = pickle.load(fr)
print(prod_vec.shape, len(prod_list))

(105542, 50) 105542


### Create the Embedding Matrix

In [15]:
embed_matrix = np.zeros((len(src_vocab), prod_vec.shape[1]))
count = 0
for tok, idx in src_vocab.stoi.items():
    if tok in prod_list:
        list_index = prod_list.index(tok)
        embed_matrix[idx] = prod_vec[list_index]
    else:
        count += 1
print(count)

3


In [16]:
PAD_IDX = src_vocab['<pad>']
BOS_IDX = src_vocab['<bos>']
EOS_IDX = src_vocab['<eos>']

if include_meta:
    def generate_batch(data_batch):
        inp_batch, tgt_batch = [], []
        for (de_item, en_item) in data_batch:
            n = de_item.shape[0]
            before = torch.unsqueeze(torch.tensor([BOS_IDX] * n), 1)
            after = torch.unsqueeze(torch.tensor([EOS_IDX] * n), 1)
            total = torch.cat([before, de_item, after], dim=1)
            total = total.permute(1, 0)
            inp_batch.append(total)
            tgt_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))

        inp_batch = pad_sequence(inp_batch, padding_value=PAD_IDX)
        tgt_batch = pad_sequence(tgt_batch, padding_value=PAD_IDX)
        return inp_batch, tgt_batch

    # input dimension of each product attribute
    prod_dict_dims = [max(prod_dict[k].values()) for k in range(len(prod_dict))]

else:
    def generate_batch(data_batch):
        de_batch, en_batch = [], []
        for (de_item, en_item) in data_batch:
            de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0))
            en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
        de_batch = pad_sequence(de_batch, padding_value=PAD_IDX)
        en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
        return de_batch, en_batch

train_iterator = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
valid_iterator = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
    

In [17]:
src, trg = next(iter(train_iterator))
print(src.shape, trg.shape)
print(src)
print(trg)

torch.Size([14, 256]) torch.Size([14, 256])
tensor([[   1,    1,    1,  ...,    1,    1,    1],
        [4499,   92, 5340,  ...,  269,  204, 1452],
        [2932,  105, 2700,  ..., 7063,   51,    2],
        ...,
        [   0,    0,    0,  ...,    0,    0,    0],
        [   0,    0,    0,  ...,    0,    0,    0],
        [   0,    0,    0,  ...,    0,    0,    0]])
tensor([[    1,     1,     1,  ...,     1,     1,     1],
        [  223,   212,   130,  ...,    63, 10590,   380],
        [ 1196,   468,    77,  ...,   125,  7063,     2],
        ...,
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0],
        [    0,     0,     0,  ...,     0,     0,     0]])


Create multiple encoders - one for each product attribute

In [118]:
SRC_PAD_IDX = 0
TRG_PAD_IDX = 0

if include_meta:
    HID_DIM = 128
    HID_DIM2 = 32
    ENC_LAYERS = 2
    DEC_LAYERS = 2
    ENC_HEADS = 8
    DEC_HEADS = 8
    ENC_PF_DIM = 128
    DEC_PF_DIM = 128
    ENC_DROPOUT = 0.1
    DEC_DROPOUT = 0.1
    OUTPUT_DIM = len(src_vocab)

    encs = torch.nn.ModuleList()
    prod_enc = Encoder(input_dim=len(src_vocab), 
                       hid_dim=HID_DIM, 
                       n_layers=ENC_LAYERS, 
                       n_heads=ENC_HEADS, 
                       pf_dim=ENC_PF_DIM, 
                       dropout=ENC_DROPOUT, 
                       device=device)
    encs.append(prod_enc)
    total_dim = HID_DIM
    for pdim in prod_dict_dims:
        enc_p = Encoder(input_dim=pdim, 
                       hid_dim=HID_DIM2, 
                       n_layers=ENC_LAYERS, 
                       n_heads=ENC_HEADS, 
                       pf_dim=ENC_PF_DIM, 
                       dropout=ENC_DROPOUT, 
                       device=device)
        encs.append(enc_p)
        total_dim += HID_DIM2

    dec = Decoder(OUTPUT_DIM, 
                  HID_DIM, 
                  DEC_LAYERS, 
                  DEC_HEADS, 
                  DEC_PF_DIM, 
                  DEC_DROPOUT, 
                  device)
    model = Seq2SeqMulti(encs, dec, total_dim, HID_DIM, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)    
else:
    INPUT_DIM = len(src_vocab)
    OUTPUT_DIM = len(src_vocab)
    HID_DIM = 256
    ENC_LAYERS = 3
    DEC_LAYERS = 3
    ENC_HEADS = 8
    DEC_HEADS = 8
    ENC_PF_DIM = 512
    DEC_PF_DIM = 512
    ENC_DROPOUT = 0.1
    DEC_DROPOUT = 0.1

    enc = Encoder(INPUT_DIM, 
                  HID_DIM, 
                  ENC_LAYERS, 
                  ENC_HEADS, 
                  ENC_PF_DIM, 
                  ENC_DROPOUT, 
                  device,
                  pretrained=True,
                  weight=torch.FloatTensor(embed_matrix),  # has to be converted
                  ext_embed_dim=embed_matrix.shape[1])

    if model_type == "seq2class":
        dec = SimpleDecoder(OUTPUT_DIM, 
                            HID_DIM,
                            inp_seq_len+2,
                            tgt_seq_len+2,
                            DEC_DROPOUT)
        model = Seq2Class(enc, dec, SRC_PAD_IDX, device).to(device)
    else:
        dec = Decoder(OUTPUT_DIM, 
                      HID_DIM, 
                      DEC_LAYERS, 
                      DEC_HEADS, 
                      DEC_PF_DIM, 
                      DEC_DROPOUT, 
                      device)
        model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)
    

In [119]:
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

print(f'The model has {count_parameters(model):,} trainable parameters')

The model has 7,154,414 trainable parameters


In [120]:
def initialize_weights(m):
    if hasattr(m, 'weight') and m.weight.dim() > 1:
        nn.init.xavier_uniform_(m.weight.data)
        
model.apply(initialize_weights);

## Model Training

In [121]:
import math

N_EPOCHS = 50
LEARNING_RATE = 1e-04
CLIP = 1
patience, max_patience = 0, 10

optimizer = torch.optim.Adam(model.parameters(), lr = LEARNING_RATE)
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)

best_valid_loss = float('inf')
best_valid_map = 0

print(f"Number of steps per epoch: {len(train_data)//BATCH_SIZE}")
for epoch in range(N_EPOCHS):
    
    start_time = time.time()
    
    train_loss = train(model, train_iterator, optimizer, criterion, CLIP, device)
    valid_loss, valid_map = evaluate(model, valid_iterator, criterion, device)
    
    end_time = time.time()
    
    epoch_mins, epoch_secs = epoch_time(start_time, end_time)
    
#     if valid_map > best_valid_map:
    if valid_loss < best_valid_loss:
#         best_valid_map = valid_map
        best_valid_loss = valid_loss
        torch.save(model.state_dict(), model_path)
        patience = 0
    else:
        patience += 1

    if patience == max_patience:
        print("Maximum patience reached ... exiting!")
        break

    print(f'Epoch: {epoch+1:02} | Time: {epoch_mins}m {epoch_secs}s | patience {patience}')
    print(f'\tTrain Loss: {train_loss:.3f} | Train PPL: {math.exp(train_loss):7.3f}')
    print(f'\t Val. Loss: {valid_loss:.3f} |  Val. PPL: {math.exp(valid_loss):7.3f} | Val. MAP: {valid_map:7.3f}')

0it [00:00, ?it/s]

Number of steps per epoch: 223


224it [00:38,  5.79it/s]
56it [00:04, 13.75it/s]
1it [00:00,  5.74it/s]

Epoch: 01 | Time: 0m 42s | patience 0
	Train Loss: 8.294 | Train PPL: 3999.802
	 Val. Loss: 6.885 |  Val. PPL: 977.736 | Val. MAP:   0.004


224it [00:39,  5.74it/s]
56it [00:04, 13.92it/s]
1it [00:00,  5.71it/s]

Epoch: 02 | Time: 0m 43s | patience 0
	Train Loss: 6.520 | Train PPL: 678.802
	 Val. Loss: 6.134 |  Val. PPL: 461.280 | Val. MAP:   0.004


224it [00:39,  5.71it/s]
56it [00:04, 13.82it/s]
1it [00:00,  5.48it/s]

Epoch: 03 | Time: 0m 43s | patience 0
	Train Loss: 5.624 | Train PPL: 277.103
	 Val. Loss: 5.547 |  Val. PPL: 256.415 | Val. MAP:   0.001


224it [00:39,  5.70it/s]
56it [00:04, 13.80it/s]
1it [00:00,  5.55it/s]

Epoch: 04 | Time: 0m 43s | patience 0
	Train Loss: 5.400 | Train PPL: 221.302
	 Val. Loss: 5.533 |  Val. PPL: 252.955 | Val. MAP:   0.000


224it [00:39,  5.70it/s]
56it [00:04, 13.38it/s]
1it [00:00,  5.65it/s]

Epoch: 05 | Time: 0m 43s | patience 1
	Train Loss: 5.372 | Train PPL: 215.394
	 Val. Loss: 5.534 |  Val. PPL: 253.114 | Val. MAP:   0.000


224it [00:39,  5.68it/s]
56it [00:04, 13.57it/s]
1it [00:00,  5.59it/s]

Epoch: 06 | Time: 0m 43s | patience 0
	Train Loss: 5.358 | Train PPL: 212.205
	 Val. Loss: 5.532 |  Val. PPL: 252.652 | Val. MAP:   0.005


224it [00:39,  5.68it/s]
56it [00:04, 13.90it/s]
1it [00:00,  5.46it/s]

Epoch: 07 | Time: 0m 43s | patience 0
	Train Loss: 5.345 | Train PPL: 209.588
	 Val. Loss: 5.531 |  Val. PPL: 252.483 | Val. MAP:   0.004


224it [00:39,  5.70it/s]
56it [00:04, 13.65it/s]
1it [00:00,  5.68it/s]

Epoch: 08 | Time: 0m 43s | patience 1
	Train Loss: 5.335 | Train PPL: 207.391
	 Val. Loss: 5.541 |  Val. PPL: 254.970 | Val. MAP:   0.005


224it [00:39,  5.71it/s]
56it [00:04, 13.62it/s]
1it [00:00,  5.46it/s]

Epoch: 09 | Time: 0m 43s | patience 2
	Train Loss: 5.322 | Train PPL: 204.699
	 Val. Loss: 5.546 |  Val. PPL: 256.316 | Val. MAP:   0.005


224it [00:39,  5.68it/s]
56it [00:04, 13.78it/s]
1it [00:00,  5.49it/s]

Epoch: 10 | Time: 0m 43s | patience 3
	Train Loss: 5.310 | Train PPL: 202.412
	 Val. Loss: 5.560 |  Val. PPL: 259.801 | Val. MAP:   0.005


224it [00:39,  5.70it/s]
56it [00:04, 13.53it/s]
1it [00:00,  5.44it/s]

Epoch: 11 | Time: 0m 43s | patience 4
	Train Loss: 5.301 | Train PPL: 200.561
	 Val. Loss: 5.567 |  Val. PPL: 261.681 | Val. MAP:   0.005


224it [00:39,  5.70it/s]
56it [00:04, 13.73it/s]
1it [00:00,  5.46it/s]

Epoch: 12 | Time: 0m 43s | patience 5
	Train Loss: 5.291 | Train PPL: 198.446
	 Val. Loss: 5.567 |  Val. PPL: 261.619 | Val. MAP:   0.005


224it [00:39,  5.70it/s]
56it [00:04, 13.92it/s]
1it [00:00,  5.51it/s]

Epoch: 13 | Time: 0m 43s | patience 6
	Train Loss: 5.279 | Train PPL: 196.223
	 Val. Loss: 5.579 |  Val. PPL: 264.916 | Val. MAP:   0.005


224it [00:39,  5.70it/s]
56it [00:04, 13.43it/s]
1it [00:00,  5.61it/s]

Epoch: 14 | Time: 0m 43s | patience 7
	Train Loss: 5.270 | Train PPL: 194.319
	 Val. Loss: 5.586 |  Val. PPL: 266.765 | Val. MAP:   0.005


224it [00:39,  5.69it/s]
56it [00:04, 13.71it/s]
1it [00:00,  5.64it/s]

Epoch: 15 | Time: 0m 43s | patience 8
	Train Loss: 5.259 | Train PPL: 192.347
	 Val. Loss: 5.594 |  Val. PPL: 268.903 | Val. MAP:   0.005


224it [00:39,  5.69it/s]
56it [00:04, 13.70it/s]
1it [00:00,  5.46it/s]

Epoch: 16 | Time: 0m 43s | patience 9
	Train Loss: 5.250 | Train PPL: 190.536
	 Val. Loss: 5.598 |  Val. PPL: 269.834 | Val. MAP:   0.005


224it [00:39,  5.69it/s]
56it [00:04, 13.70it/s]

Maximum patience reached ... exiting!





In [109]:
model.load_state_dict(torch.load(model_path))

<All keys matched successfully>

In [112]:
from transformer_pyt import map_batch, mAP_k

In [110]:
criterion = nn.CrossEntropyLoss(ignore_index = TRG_PAD_IDX)
valid_loss, valid_map = evaluate(model, valid_iterator, criterion, device)
print(valid_map)

56it [00:04, 13.59it/s]

0.005455325555929919





In [114]:
model.eval()
for _, (src, trg) in enumerate(valid_iterator):
    src, trg = src.to(device), trg.to(device)
    if src.dim() == 3:
        src = src.permute(1, 0, 2)
    else:
        src = src.permute(1, 0)
    trg = trg.permute(1, 0)

    output, _ = model(src, trg[:, :-1])
    prediction = torch.argmax(output, axis=-1)
    mapr = map_batch(trg[:, 1:], prediction)
    print(mapr)

0.0
0.00390625
0.0078125
0.0078125
0.0078125
0.0
0.0078125
0.0078125
0.0078125
0.00390625
0.0
0.00390625
0.0078125
0.0078125
0.0
0.00390625
0.0
0.00390625
0.01171875
0.0
0.00390625
0.01171875
0.0078125
0.01171875
0.00390625
0.0
0.0
0.00390625
0.015625
0.00390625
0.0
0.00390625
0.00390625
0.0078125
0.00390625
0.00390625
0.0078125
0.01171875
0.01171875
0.01953125
0.01171875
0.00390625
0.00390625
0.0
0.0
0.00390625
0.00390625
0.0078125
0.00390625
0.0078125
0.0078125
0.0
0.0078125
0.00390625
0.0078125
0.0


In [115]:
prediction

tensor([[1, 3, 2,  ..., 2, 2, 2],
        [1, 3, 2,  ..., 2, 2, 2],
        [1, 3, 2,  ..., 2, 2, 2],
        ...,
        [1, 3, 2,  ..., 2, 2, 2],
        [1, 3, 2,  ..., 2, 2, 2],
        [1, 3, 2,  ..., 2, 2, 2]], device='cuda:0', grad_fn=<NotImplemented>)

In [116]:
trg[:,1:]

tensor([[1579,    2,    0,  ...,    0,    0,    0],
        [  88,  310,    2,  ...,    0,    0,    0],
        [2595,    2,    0,  ...,    0,    0,    0],
        ...,
        [5322,    2,    0,  ...,    0,    0,    0],
        [ 219, 1061, 1292,  ...,    0,    0,    0],
        [3275,    2,    0,  ...,    0,    0,    0]], device='cuda:0')

In [117]:
label = trg[:, 1:]
pred = prediction.cpu().numpy()
label = label.cpu().numpy()
    
for ii in range(prediction.shape[0]):
    l_ii = [x for x in label[ii, :] if x not in [0, 1, 2]]
    p_ii = [x for x in pred[ii, :] if x not in [0, 1, 2]]
    if mAP_k(l_ii, p_ii) > 0:
        print(l_ii, p_ii, mAP_k(l_ii, p_ii))

### Get the test data - last session

In [75]:
test_seqs = get_session_data_test(inp_file, prod_dict, inp_seq_len=inp_seq_len)
print(len(test_seqs), test_seqs.keys())

364695it [00:02, 169774.17it/s]


Read 48709 user interactions
10 dict_keys([0, 1, 2, 3, 4, 5, 6, 7, 8, 'prod'])


In [77]:
if include_meta:
    def generate_batch_test(data_batch):
        inp_batch = []
        for de_item in data_batch:
            n = de_item.shape[0]
            before = torch.unsqueeze(torch.tensor([BOS_IDX] * n), 1)
            after = torch.unsqueeze(torch.tensor([EOS_IDX] * n), 1)
            total = torch.cat([before, de_item, after], dim=1)
            total = total.permute(1, 0)
            inp_batch.append(total)

        inp_batch = pad_sequence(inp_batch, padding_value=PAD_IDX)
        return inp_batch
    
    test_data = data_process_meta(test_seqs, tokenizer, src_vocab, test_flag=True)
else:
    def generate_batch_test(data_batch):
        de_batch = []
        for de_item in data_batch:
            de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0))
        de_batch = pad_sequence(de_batch, padding_value=PAD_IDX)
        return de_batch
    test_data = data_process_no_meta(test_seqs, src_vocab, test_flag=True)
len(test_data)

48709

In [78]:
test_iterator = DataLoader(test_data, batch_size=BATCH_SIZE, shuffle=False, collate_fn=generate_batch_test)

In [79]:
model.eval()

epoch_loss = 0
all_maps = []
with torch.no_grad():

    for _, src in tqdm(enumerate(test_iterator)):

        src = src.to(device)
        # make the batch dimension first
        if src.dim() == 3:
            src = src.permute(1, 0, 2)
        else:
            src = src.permute(1, 0)

        output, _ = model(src, trg[:, :-1])
        prediction = torch.argmax(output, axis=-1)
        print(prediction.shape)
        sys.exit()


0it [00:00, ?it/s]


NameError: name 'trg' is not defined

In [81]:
src.shape

torch.Size([256, 14])

In [82]:
def translate_sentence(src, model, vocab, device, max_len = 50):
    """
    inp: (b, s, f), batch_size X sequence length X feature dimension
    """
    src = src.unsqueeze(0).to(device)  # add the batch dimension
    
    model.eval()
    if src.dim() == 3:
        src_mask = model.make_src_mask(src[:, :, 0])
    else:
        src_mask = model.make_src_mask(src)
    
    with torch.no_grad():
        if src.dim() == 3:
            enc_src = []
            for ii in range(model.num_features):
                enc_src.append(model.encoders[ii](src_tensor[:, :, ii], src_mask))
            enc_src = torch.cat(enc_src, dim=-1)
            enc_src = model.linear(enc_src)
        else:
            enc_src = model.encoder(src, src_mask)

    trg_indexes = [vocab.stoi['<bos>']]

    for i in range(max_len):
        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
        trg_mask = model.make_trg_mask(trg_tensor)
        with torch.no_grad():
            output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
        pred_token = output.argmax(2)[:,-1].item()
        trg_indexes.append(pred_token)
        if pred_token == vocab.stoi['<eos>']:
            break
    
    trg_tokens = [vocab.itos[i] for i in trg_indexes]
    return trg_tokens[1:], attention

In [87]:
for ii in range(src.shape[0]):
    pred, _ = translate_sentence(src[ii, :], model, src_vocab, device, max_len = tgt_seq_len)
    print(' '.join(src[ii, :]), ' '.join(pred[:-1]))

TypeError: sequence item 0: expected str instance, Tensor found

In [88]:
src[0,:]

tensor([2, 0, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], device='cuda:0')

In [94]:
pred, _ = translate_sentence(src[4, :, :], model, src_vocab, device, max_len = tgt_seq_len)
pred

['220', '<eos>']

# Only Product Sequence

In [7]:
src_vocab = build_vocab_from_file(file_path, tokenizer)
all_data = data_process(file_path, tokenizer, src_vocab)
train_data, val_data = train_test_split(all_data, test_size=0.2)
test_data = data_process(test_file_path, tokenizer, src_vocab, test_flag=True)
len(all_data), len(train_data), len(val_data), len(test_data)

(71460, 57168, 14292, 48709)

In [23]:
INPUT_DIM = len(src_vocab)
OUTPUT_DIM = len(src_vocab)
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

enc = Encoder(INPUT_DIM, 
              HID_DIM, 
              ENC_LAYERS, 
              ENC_HEADS, 
              ENC_PF_DIM, 
              ENC_DROPOUT, 
              device)

dec = Decoder(OUTPUT_DIM, 
              HID_DIM, 
              DEC_LAYERS, 
              DEC_HEADS, 
              DEC_PF_DIM, 
              DEC_DROPOUT, 
              device)

In [10]:
PAD_IDX = src_vocab['<pad>']
BOS_IDX = src_vocab['<bos>']
EOS_IDX = src_vocab['<eos>']

def generate_batch(data_batch):
    de_batch, en_batch = [], []
    for (de_item, en_item) in data_batch:
        de_batch.append(torch.cat([torch.tensor([BOS_IDX]), de_item, torch.tensor([EOS_IDX])], dim=0))
        en_batch.append(torch.cat([torch.tensor([BOS_IDX]), en_item, torch.tensor([EOS_IDX])], dim=0))
    de_batch = pad_sequence(de_batch, padding_value=PAD_IDX)
    en_batch = pad_sequence(en_batch, padding_value=PAD_IDX)
    return de_batch, en_batch

train_iterator = DataLoader(train_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
valid_iterator = DataLoader(val_data, batch_size=BATCH_SIZE,
                        shuffle=True, collate_fn=generate_batch)
test_iterator = DataLoader(test_data, batch_size=BATCH_SIZE,
                       shuffle=False, collate_fn=generate_batch)

In [24]:
SRC_PAD_IDX = 0
TRG_PAD_IDX = 0

model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device).to(device)

In [34]:
def predict(sentence, src_vocab, trg_vocab, model, device, max_len = tgt_seq_len):
    
    model.eval()

    tokens = tokenizer(sentence)
    tokens = [src_vocab['<bos>']] + tokens + [src_vocab['<eos>']]
    src_indexes = [src_vocab.stoi[token] for token in tokens]
    src_tensor = torch.LongTensor(src_indexes).unsqueeze(0).to(device)
    src_mask = model.make_src_mask(src_tensor)
    
    with torch.no_grad():
        enc_src = model.encoder(src_tensor, src_mask)

    trg_indexes = [trg_vocab.stoi['<bos>']]
    for i in range(max_len):

        trg_tensor = torch.LongTensor(trg_indexes).unsqueeze(0).to(device)
        trg_mask = model.make_trg_mask(trg_tensor)
        
        with torch.no_grad():
            output, attention = model.decoder(trg_tensor, enc_src, trg_mask, src_mask)
        
        pred_token = output.argmax(2)[:,-1].item()
        trg_indexes.append(pred_token)
        if pred_token == trg_vocab.stoi['<eos>']:
            break
    
    trg_tokens = [trg_vocab.itos[i] for i in trg_indexes]
    
    return trg_tokens[1:], attention

In [37]:
pred, _ = predict('13112 16042 3871 35', src_vocab, src_vocab, model, device)
pred

['1566', '<eos>']