## Imports

In [88]:
import torch
import torch.nn as nn
import tqdm
from nltk.translate.bleu_score import corpus_bleu
from torchtext.legacy.data import BucketIterator
from config import read_training_pipeline_params
from load_data import get_dataset, split_data, _len_sort_key
import my_network
from train_model import evaluate
from utils import generate_translation, get_text
import random
import numpy as np
from helpers import get_bleu
import torch.nn.functional as F
from torch import optim
from loguru import logger

In [57]:
import network_transformer

In [58]:
from nltk.translate.bleu_score import corpus_bleu

## load data

In [74]:
config = read_training_pipeline_params("train_config_pretrained_emb_transformer.yaml")

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

SRC, TRG, dataset = get_dataset(config.dataset_path, config.net_params.transformer)
train_data, valid_data, test_data = split_data(dataset, **config.split_ration.__dict__)
SRC.vocab = torch.load("small_transformer/src_vocab_transformer")
TRG.vocab = torch.load("small_transformer/trg_vocab_transformer")
train_iterator, valid_iterator, test_iterator = BucketIterator.splits(
    (train_data, valid_data, test_data),
    batch_size=64,
    device=device,
    sort_key=_len_sort_key
)

In [75]:
for batch in train_iterator:
    check_batch = batch

In [61]:
src = check_batch.src.permute(1, 0)
trg = check_batch.trg.permute(1, 0)

In [76]:
INPUT_DIM = len(SRC.vocab)
OUTPUT_DIM = len(TRG.vocab)

In [77]:
class Seq2Seq(nn.Module):
    def __init__(self,
                 encoder,
                 decoder,
                 src_pad_idx,
                 trg_pad_idx,
                 device):
        super().__init__()

        self.encoder = encoder
        self.decoder = decoder
        self.src_pad_idx = src_pad_idx
        self.trg_pad_idx = trg_pad_idx
        self.device = device

    def make_src_mask(self, src):

        src_mask = (src != self.src_pad_idx).unsqueeze(1).unsqueeze(2)

        return src_mask

    def make_trg_mask(self, trg):

        trg_pad_mask = (trg != self.trg_pad_idx).unsqueeze(1).unsqueeze(2)

        trg_len = trg.shape[1]

        trg_sub_mask = torch.tril(torch.ones((trg_len, trg_len), device=self.device)).bool()

        trg_mask = trg_pad_mask & trg_sub_mask

        return trg_mask

    def forward(self, src, trg):
        src_mask = self.make_src_mask(src)
        trg_mask = self.make_trg_mask(trg)
        enc_src = self.encoder(src, src_mask)
        output, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        return output, attention

    def translate(self, src, greedy=False, max_len=None, eps=1e-30, **flags):
        src_mask = self.make_src_mask(src)
        enc_src = self.encoder(src, src_mask)
        trg_mask = self.make_trg_mask(trg)
        logits, attention = self.decoder(trg, enc_src, trg_mask, src_mask)
        if greedy:
            output = logits.argmax(dim=-1)
        else:
            out_h = F.softmax(logits, dim=-1)
            output = torch.LongTensor([torch.multinomial(t, 1)[:, 0].detach().numpy() for t in out_h])
        return output, logits

In [78]:
Encoder = network_transformer.Encoder
Decoder = network_transformer.Decoder
# Seq2Seq = network_transformer.Seq2Seq
SRC_PAD_IDX = SRC.vocab.stoi[SRC.pad_token]
TRG_PAD_IDX = TRG.vocab.stoi[TRG.pad_token]
HID_DIM = 256
ENC_LAYERS = 3
DEC_LAYERS = 3
ENC_HEADS = 8
DEC_HEADS = 8
ENC_PF_DIM = 512
DEC_PF_DIM = 512
ENC_DROPOUT = 0.1
DEC_DROPOUT = 0.1

enc = Encoder(INPUT_DIM,
              HID_DIM,
              ENC_LAYERS,
              ENC_HEADS,
              ENC_PF_DIM,
              ENC_DROPOUT,
              device)

dec = Decoder(OUTPUT_DIM,
              HID_DIM,
              DEC_LAYERS,
              DEC_HEADS,
              DEC_PF_DIM,
              DEC_DROPOUT,
              device)
model = Seq2Seq(enc, dec, SRC_PAD_IDX, TRG_PAD_IDX, device)

In [79]:
checkpoint = torch.load("models/transformer_model.pt", map_location='cpu')
model.load_state_dict(checkpoint, strict=True)

<All keys matched successfully>

In [80]:
translate, _ = model.translate(src, greedy=True)

In [81]:
translate_n, logits = model.translate(src, greedy=False)

In [82]:
def to_one_hot(y, n_dims=None):
    """ Take integer y (tensor or variable) with n dims and convert it to 1-hot representation with n+1 dims. """
    y_tensor = y.data
    y_tensor = y_tensor.to(dtype=torch.long).view(-1, 1)
    n_dims = n_dims if n_dims is not None else int(torch.max(y_tensor)) + 1
    y_one_hot = torch.zeros(y_tensor.size()[0], n_dims, device=y.device).scatter_(1, y_tensor, 1)
    y_one_hot = y_one_hot.view(*y.shape, -1)
    return y_one_hot

In [83]:
def scst_objective_on_batch(src, trg):
    translate, _ = model.translate(src, greedy=True)
    translate_n, logits = model.translate(src, greedy=False)
    baseline = []
    reward = []
    mask = []
    i=0
    for g, n, t in zip(translate, translate_n, trg):
        g_list = g.tolist()
        n_list = n.tolist()
        t_list = t.tolist()
        if 3 in g_list:
            g_ind = g_list.index(3)
        else:
            g_ind = len(g_list)
        if 3 in n_list:
            n_ind = n_list.index(3)
        else:
            n_ind = len(n_list)
        if 1 in t_list:
            t_ind = t_list.index(1)
        else:
            t_ind = len(t_list)
    #     greedy_tokens.append([tok for tok in g_list[: g_ind] if tok not in [3, 1, 2, 0]])
    #     not_greedy_tokens.append([tok for tok in n_list[: n_ind] if tok not in [3, 1, 2, 0]])
    #     target_tokens.append([tok for tok in t_list[: t_ind] if tok not in [3, 1, 2, 0]])
        mask.append([True] *n_ind + [False] * (g.shape[0] - n_ind))
        greed_txt = [tok for tok in g_list[: g_ind] if tok not in [3, 1, 2, 0]]
        not_greed_txt = [tok for tok in n_list[: n_ind] if tok not in [3, 1, 2, 0]]
        trg_txt = [tok for tok in t_list[: t_ind] if tok not in [3, 1, 2, 0]]
        baseline.append(corpus_bleu([[trg_txt]], [greed_txt]))
        reward.append(corpus_bleu([[trg_txt]], [not_greed_txt]))
    advantage = (torch.FloatTensor(reward) - torch.FloatTensor(baseline)).to(device)
    logp_sample = torch.sum(to_one_hot(translate_n, n_dims=OUTPUT_DIM) * logits, dim=-1)
    J = logp_sample * torch.FloatTensor(advantage)[:, None]
    mask_t = torch.Tensor(mask)
    loss = torch.sum(J * mask_t) / torch.sum(mask_t)
    entropy = -torch.sum(torch.exp(logits) * logits, dim=-1)
    reg = - 1e-10 * torch.sum(entropy * mask_t) / torch.sum(mask_t)
    return loss + reg, torch.sum(entropy * mask_t) / torch.sum(mask_t)

In [84]:
l,e = scst_objective_on_batch(src, trg)

In [85]:
optimizer = optim.Adam(model.parameters(), config.lr)

In [86]:
def train(model, iterator, optimizer):
    print("here1")
    model.train()
    epoch_loss = 0
    for i, batch in enumerate(iterator):

        src = batch.src.permute(1, 0)
        trg = batch.trg.permute(1, 0)

        optimizer.zero_grad()

#         output, _ = model(src, trg[:, :-1])

        # output = [batch size, trg len - 1, output dim]
        # trg = [batch size, trg len]

#         output_dim = output.shape[-1]

#         output = output.contiguous().view(-1, output_dim)
#         trg = trg[:, 1:].contiguous().view(-1)
        loss, entropy = scst_objective_on_batch(src, trg)
        # output = [batch size * trg len - 1, output dim]
        # trg = [batch size * trg len - 1]

#         loss = criterion(output, trg)

        loss.backward()
        logger.info(loss)
#         torch.nn.utils.clip_grad_norm_(model.parameters(), clip)

        optimizer.step()

        epoch_loss += loss.item()

    return epoch_loss / len(iterator)

In [87]:
train(model, train_iterator, optimizer)

here1
tensor(-0.0024, grad_fn=<AddBackward0>)
tensor(-0.0532, grad_fn=<AddBackward0>)
tensor(0.0284, grad_fn=<AddBackward0>)
tensor(0.0308, grad_fn=<AddBackward0>)
tensor(-0.0125, grad_fn=<AddBackward0>)
tensor(0.0074, grad_fn=<AddBackward0>)
tensor(0.0310, grad_fn=<AddBackward0>)
tensor(-0.0185, grad_fn=<AddBackward0>)
tensor(0.0269, grad_fn=<AddBackward0>)
tensor(-0.0093, grad_fn=<AddBackward0>)
tensor(0.0122, grad_fn=<AddBackward0>)
tensor(-0.0458, grad_fn=<AddBackward0>)
tensor(-0.0751, grad_fn=<AddBackward0>)
tensor(-0.0282, grad_fn=<AddBackward0>)
tensor(-0.0495, grad_fn=<AddBackward0>)
tensor(-0.0734, grad_fn=<AddBackward0>)
tensor(0.0303, grad_fn=<AddBackward0>)
tensor(-0.0085, grad_fn=<AddBackward0>)
tensor(0.0026, grad_fn=<AddBackward0>)
tensor(0.0039, grad_fn=<AddBackward0>)
tensor(-0.0457, grad_fn=<AddBackward0>)
tensor(0.0263, grad_fn=<AddBackward0>)
tensor(-0.0745, grad_fn=<AddBackward0>)
tensor(0.0017, grad_fn=<AddBackward0>)
tensor(-0.0584, grad_fn=<AddBackward0>)
tenso

KeyboardInterrupt: 

In [78]:
# baseline = corpus_bleu([[text] for text in target_tokens], greedy_tokens)
# rewards = corpus_bleu([[text] for text in target_tokens], not_greedy_tokens)

In [103]:
advantage = (torch.FloatTensor(rewards) - torch.FloatTensor(baseline)).to(device)

In [104]:
advantage

tensor([ 0.0000,  0.1897, -0.1767, -0.0450])

In [86]:
len(mask[0]), trg.shape

(22, torch.Size([4, 22]))

In [69]:
logp_sample = torch.sum(to_one_hot(translate_n, n_dims=OUTPUT_DIM) * logits, dim=-1)

In [72]:
J = logp_sample * torch.FloatTensor(advantage)[:, None]

In [73]:
J.shape

torch.Size([4, 22])

In [87]:
mask_t = torch.Tensor(mask)

In [90]:
torch.sum(J * mask_t) / torch.sum(mask_t)

tensor(-0.5534, grad_fn=<DivBackward0>)

In [96]:
entropy = -torch.sum(torch.exp(logits) * logits, dim=-1)

In [97]:
entropy

tensor([[-6.0753e+07, -8.1388e+08, -1.2782e+09, -1.0742e+09, -5.7912e+08,
         -1.9203e+10, -1.5264e+09, -1.2574e+10, -5.1099e+10, -1.0268e+09,
         -3.8313e+08, -4.4544e+09, -1.4655e+06, -1.4429e+06, -1.7198e+06,
         -9.4949e+05, -1.0066e+06, -2.4669e+06, -1.0126e+06, -1.7618e+06,
         -1.3701e+06, -1.3021e+06],
        [-9.6211e+06, -1.0642e+09, -1.0032e+09, -8.5795e+08, -5.3567e+07,
         -3.1642e+07, -3.1013e+07, -5.0274e+07, -8.1468e+07, -2.4153e+07,
         -5.5048e+07, -6.9020e+04, -5.2866e+05, -1.9378e+05, -6.4577e+05,
         -2.4502e+05, -5.5868e+05, -2.1592e+05, -2.4842e+05, -3.1824e+05,
         -2.5404e+05, -3.8890e+05],
        [-1.3160e+08, -5.0753e+08, -1.3021e+10, -7.3488e+07, -3.1621e+08,
         -9.1614e+08, -2.5509e+09, -3.5756e+09, -1.8619e+08, -4.5795e+07,
         -1.4923e+09, -3.8585e+09, -3.1395e+08, -2.8090e+08, -1.7708e+08,
         -5.5537e+07, -9.0597e+09, -4.3076e+06, -3.0087e+06, -3.3406e+06,
         -1.7003e+06, -8.2921e+06],
    

In [98]:
reg = - 0.01 * torch.sum(entropy * mask_t) / torch.sum(mask_t)

In [99]:
reg

tensor(36425564., grad_fn=<DivBackward0>)

In [48]:
logits.shape, to_one_hot(translate_n, n_dims=OUTPUT_DIM).shape

(torch.Size([4, 22, 6734]), torch.Size([4, 22, 6734]))

## end

In [16]:
corpus_bleu([[translate[1].tolist()]], [trg[1].tolist()])

0.161692143534558

In [17]:
texts_greedy = [get_text(sentence, TRG.vocab) for sentence in translate]

In [18]:
texts_not_greedy = [get_text(sentence, TRG.vocab) for sentence in translate_n]

In [19]:
texts_original = [get_text(sentence, TRG.vocab) for sentence in trg]

In [24]:
texts_original[1], texts_greedy[1]

(['<sos>',
  'there',
  'is',
  'also',
  'a',
  'buffet',
  'restaurant',
  'and',
  'a',
  'bar',
  '.'],
 ['guests',
  'is',
  'also',
  'a',
  'restaurant',
  'restaurant',
  ',',
  'a',
  'bar',
  '.'])

In [21]:
corpus_bleu([[text] for text in texts_original], texts_greedy)

0.5275388348709146

In [22]:
corpus_bleu([[text] for text in texts_original], texts_not_greedy)

0.4778895598771266

In [40]:
references = [[[1, 2, 3, 4]]]
candidates = [[1, 2, 3, 4]]
score = corpus_bleu(references, candidates)

In [41]:
score

1.0

In [24]:
corpus_bleu([[texts_not_greedy]], [texts_original])

0.4699739598002697

In [25]:
score

1.0

In [52]:
from nltk.translate.bleu_score import corpus_bleu,sentence_bleu

In [36]:
output, _ = model(src, trg[:, :-1])

# output = [batch size, trg len - 1, output dim]
# trg = [batch size, trg len]

output_dim = output.shape[-1]

output = output.contiguous().view(-1, output_dim)
trg = trg[:, 1:].contiguous().view(-1)

In [37]:
F.cross_entropy(output, trg)

tensor(11.8226, grad_fn=<NllLossBackward>)

In [None]:
o, _ = model(src, trg[:, :-1])

In [None]:
o[0].shape

In [30]:
o = o.permute(1, 0, 2)

In [31]:
o.argmax(dim=-1).permute(1, 0).shape

torch.Size([46, 128])

In [32]:
sample = o.argmax(dim=-1).permute(1, 0)

In [33]:
sample[:].shape

torch.Size([46, 128])

In [19]:
sample1 = sample[0]

In [20]:
trg[0]

tensor([  2,  24, 127,  11,   5, 128,  37,   4,   3,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,   1,
          1,   1,   1,   1,   1])

In [21]:
get_text(sample1, TRG.vocab)

['rooms', 'have', 'with', 'a', 'shared', 'bathroom', '.']

In [23]:
model.translate(sample)

tensor([[[-0.4832,  0.3313, -0.3978,  ..., -0.7825,  0.2135, -0.0957],
         [-0.1670,  0.1979, -0.5312,  ..., -0.4924, -0.1184, -0.5667],
         [-0.6390,  0.5016, -0.3676,  ..., -0.1893, -0.2710, -0.3325],
         ...,
         [-0.3731, -0.1327, -0.2330,  ..., -0.8845,  0.2529, -0.1635],
         [ 0.0035,  0.6701,  0.2840,  ..., -0.6974, -0.1369, -0.3067],
         [-0.1932, -0.1651,  0.3286,  ..., -0.4547, -0.2638, -0.1974]],

        [[ 0.4543,  0.3039, -0.0067,  ...,  0.0411, -0.2525, -0.4166],
         [ 0.3526,  0.6596, -0.1921,  ..., -0.4458, -0.2560, -0.1452],
         [ 0.5380,  0.2005,  0.2439,  ..., -0.1337, -0.2762,  0.0474],
         ...,
         [ 0.1708,  0.1671,  0.2324,  ..., -0.2832, -0.3041, -0.6673],
         [ 0.2035, -0.0374,  0.2155,  ..., -0.1811, -0.3452, -0.6700],
         [ 0.1592,  0.3354,  0.0645,  ..., -0.1179, -0.4021, -0.6699]],

        [[-0.0418, -0.1311,  0.0444,  ..., -0.1227,  0.0559, -0.1906],
         [-0.2211,  0.5425, -0.4063,  ...,  0

## end