## Imports

In [9]:
import torch
import torch.nn as nn
import tqdm
from nltk.translate.bleu_score import corpus_bleu
from torchtext.legacy.data import BucketIterator
from config import read_training_pipeline_params
from load_data import get_dataset, split_data, _len_sort_key
import my_network
from train_model import evaluate
from utils import generate_translation, get_text
import random
import numpy as np
from helpers import get_bleu

In [4]:
import network_transformer

## load data

In [7]:
config = read_training_pipeline_params("train_config_pretrained_emb_transformer.yaml")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SRC, TRG, dataset = get_dataset(config.dataset_path, config.net_params.transformer)
train_data, valid_data, test_data = split_data(dataset, **config.split_ration.__dict__)
SRC.vocab = torch.load("small_transformer/src_vocab_transformer")
TRG.vocab = torch.load("small_transformer/trg_vocab_transformer")
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size=config.BATCH_SIZE,
    device=device,
    sort_key=_len_sort_key
)

In [11]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)

In [65]:
class Seq2Seq(nn.Module):
    def __init__(self,
                 encoder,
                 decoder,
                 src_pad_idx,
                 trg_pad_idx,
                 device):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):

        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

        return src_mask

    def make_trg_mask(self, trg):

        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)

        trg_len = trg.shape[1]

        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=self.device)).bool()

        trg_mask = trg_pad_mask & trg_sub_mask

        return trg_mask

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        return output, attention

    def translate(self, src, greedy=False, max_len=None, eps=1e-30, **flags):
#         device = next(self.parameters()).device
#         batch_size = inp.shape[0]
#         bos = torch.tensor([self.out_voc.bos_ix] * batch_size, dtype=torch.long, device=device)
#         mask = torch.ones(batch_size, dtype=torch.uint8, device=device)
#         logits_seq = [torch.log(to_one_hot(bos, len(self.out_voc)) + eps)]
#         out_seq = [bos]

#         hid_state = self.encode(inp, **flags)
        src_mask = self.make_src_mask(src)
        enc_src = self.encoder(src, src_mask)
        return enc_src
#         while True:
#             hid_state, logits = self.decode(hid_state, out_seq[-1], **flags)
#             if greedy:
#                 _, y_t = torch.max(logits, dim=-1)
#             else:
#                 probs = F.softmax(logits, dim=-1)
#                 y_t = torch.multinomial(probs, 1)[:, 0]

#             logits_seq.append(logits)
#             out_seq.append(y_t)
#             mask *= y_t != self.out_voc.eos_ix

#             if not mask.any(): break
#             if max_len and len(out_seq) >= max_len: break

#         return torch.stack(out_seq, 1), F.log_softmax(torch.stack(logits_seq, 1), dim=-1)

In [66]:
Encoder = network_transformer.Encoder
Decoder = network_transformer.Decoder
# Seq2Seq = network_transformer.Seq2Seq
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

enc = Encoder(INPUT_DIM,
              HID_DIM,
              ENC_LAYERS,
              ENC_HEADS,
              ENC_PF_DIM,
              ENC_DROPOUT,
              device)

dec = Decoder(OUTPUT_DIM,
              HID_DIM,
              DEC_LAYERS,
              DEC_HEADS,
              DEC_PF_DIM,
              DEC_DROPOUT,
              device)
model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device)

In [67]:
checkpoint = torch.load("models/transformer_model.pt", map_location='cpu')
model.load_state_dict(checkpoint, strict=True)

<All keys matched successfully>

In [68]:
for batch in train_iterator:
    check_batch = batch

In [69]:
src = check_batch.src.permute(1, 0)
trg = check_batch.trg.permute(1, 0)

In [71]:
enc_src = model.translate(src)

In [72]:
enc_src.shape

torch.Size([128, 49, 256])

In [73]:
src.shape

torch.Size([128, 49])

In [75]:
o, _ = model(src, trg[:, :-1])

In [77]:
o[0].shape

torch.Size([42, 6734])

In [80]:
o = o.permute(1, 0, 2)

In [85]:
o.argmax(dim=-1).permute(1, 0).shape

torch.Size([128, 42])

In [91]:
sample = o.argmax(dim=-1).permute(1, 0)

In [93]:
sample[:].shape

torch.Size([128, 42])

In [94]:
sample1 = sample[0]

In [89]:
trg[0]

tensor([   2,    0,  100,   48,    9,  526,   14,   20,    7, 5375,   88,   28,
           9,  354,   97,  453,   43,   39,   20,    4,    3,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,    1,
           1,    1,    1,    1,    1,    1,    1])

In [95]:
get_text(sample1, TRG.vocab)

['the',
 'city',
 'centre',
 'is',
 '17',
 'km',
 'away',
 'and',
 'bogotá',
 'international',
 'airport',
 'is',
 'about',
 '3',
 'hours',
 '’',
 'drive',
 'away',
 '.']

## end