In [1]:
import pickle
from preprocess import Lang
import numpy as np


In [2]:
with open('result_0326.pkl', 'rb') as f:
    data = pickle.load(f)
    
print(data['eng_lang'].n_words, data['fra_lang'].n_words)
filter_tokenize = data['token']

fra_token, eng_token = zip(*filter_tokenize)
print(len(fra_token))

1790 2097
135842


In [12]:
data['token'][:3]

[[[2, 4, 5, 3], [2, 4, 5, 3]],
 [[2, 6, 5, 3], [2, 6, 7, 3]],
 [[2, 1, 5, 3], [2, 6, 7, 3]]]

In [13]:
len(data['token'])

135842

In [16]:
len1 = [len(p[0]) for p in data['token']]
len2 = [len(p[1]) for p in data['token']]
np.mean(len1), np.std(len1), np.max(len1), np.mean(len2), np.std(len2), np.max(len2)

(10.312686797897557,
 3.160561893345268,
 65,
 9.58004151882334,
 2.6925437290397505,
 54)

In [4]:
# args = {'n_src_vocab': data['fra_lang'].n_words, 
#         'n_tgt_vocab': data['eng_lang'].n_words, 
#         'len_max_seq': 65, # 65 and 54
#         'd_word_vec': 512,
#         'd_model': 512,
#         'd_inner': 2048,
#         'n_layers': 6, 
#         'n_head': 8, 
#         'd_k': 64, 
#         'd_v': 64, 
#         'dropout': 0.1,
#         'tgt_emb_prj_weight_sharing': True,
#         'emb_src_tgt_weight_sharing': False}

class Args:
    def __init__(self, data):
        self.src_vocab_size = data['fra_lang'].n_words 
        self.tgt_vocab_size = data['eng_lang'].n_words 
        self.max_token_seq_len = 65 # max(65 and 54)
        self.d_word_vec = 512
        self.d_model = 512
        self.d_inner_hid = 2048
        self.n_layers = 6 
        self.n_head = 8 
        self.d_k = 64 
        self.d_v = 64 
        self.dropout = 0.1
        self.proj_share_weight = True # proj_share_weight : tgt_emb_prj_weight_sharing
        self.embs_share_weight = False # emb_src_tgt_weight_sharing : embs_share_weight

args = Args(data)
args.src_vocab_size

2097

In [13]:
import Constants
import torch
from torch.utils.data import Dataset

def collate_fn(insts):
    # if seq_pad in class then all seqs with same length
    maxlen = max([len(x) for x in insts])
    batch_seq = np.array([x + [Constants.PAD] * (maxlen - len(x)) for x in insts])
    #batch_pos = np.array([[i+1 if w != Constants.PAD else 0 for i, w in enumerate(inst)] for inst in batch_seq])
    batch_pos = np.array([[i if w != Constants.PAD else 0 for i, w in enumerate(inst, 1)] for inst in batch_seq])
    # [[i if w != Constants.PAD else 0 for i, w in enumerate(inst, 1)] for inst in batch_seq]
    return torch.LongTensor(batch_seq), torch.LongTensor(batch_pos)

def paired_collate_fn(insts):
    #src_insts, tgt_insts = list(zip(*insts))
    #seq_pairs = sorted(insts, key=lambda p: len(p[0]), reverse=True)
    src_insts, tgt_insts = zip(*insts)
    src_insts = collate_fn(src_insts)
    tgt_insts = collate_fn(tgt_insts)
    return (*src_insts, *tgt_insts)

class Fra2EngDatasets(Dataset):
    def __init__(self, src, tgt):
        self.src = src
        self.tgt = tgt

    def __len__(self):
        return len(self.src)

    def __getitem__(self, idx):
        return self.src[idx], self.tgt[idx]

In [14]:
# filter_tokenize = data['token']
fra_token, eng_token = zip(*filter_tokenize)

train_loader = torch.utils.data.DataLoader(
                    Fra2EngDatasets(fra_token[:100], eng_token[:100]),
                    num_workers = 1,
                    batch_size = 4,
                    collate_fn = paired_collate_fn,
                    shuffle = True,
                    drop_last = True)

for batch in train_loader:
    print(batch)
    break

(tensor([[ 2, 77,  1,  5,  3],
        [ 2, 75, 79, 14,  3],
        [ 2,  1,  5,  3,  0],
        [ 2, 77,  1,  5,  3]]), tensor([[1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 0],
        [1, 2, 3, 4, 5]]), tensor([[ 2, 40, 43,  5,  3],
        [ 2, 40, 44,  5,  3],
        [ 2, 18,  7,  3,  0],
        [ 2, 40, 45,  5,  3]]), tensor([[1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 0],
        [1, 2, 3, 4, 5]]))


In [15]:
print(batch[0])
print(batch[1])

tensor([[ 2, 77,  1,  5,  3],
        [ 2, 75, 79, 14,  3],
        [ 2,  1,  5,  3,  0],
        [ 2, 77,  1,  5,  3]])
tensor([[1, 2, 3, 4, 5],
        [1, 2, 3, 4, 5],
        [1, 2, 3, 4, 0],
        [1, 2, 3, 4, 5]])


In [5]:
from model import Transformer
transformer = Transformer(
        args.src_vocab_size,
        args.tgt_vocab_size,
        args.max_token_seq_len,
        tgt_emb_prj_weight_sharing = args.proj_share_weight,
        emb_src_tgt_weight_sharing = args.embs_share_weight,
        d_k = args.d_k,
        d_v = args.d_v,
        d_model = args.d_model,
        d_word_vec = args.d_word_vec,
        d_inner = args.d_inner_hid,
        n_layers = args.n_layers,
        n_head = args.n_head,
        dropout = args.dropout)

In [16]:
transformer.train()

for batch in train_loader:
    src_seq, src_pos, tgt_seq, tgt_pos = batch
    pred = transformer(src_seq, src_pos, tgt_seq, tgt_pos)
    print(pred)
    break

tensor([[ 0.0000, -0.7139,  1.5896,  ...,  0.4070,  0.7291, -0.3694],
        [ 0.0000, -0.7342,  2.0671,  ...,  0.8151, -0.1122, -0.7192],
        [ 0.0000, -0.5631,  1.4950,  ...,  0.7610,  0.6359, -1.0212],
        ...,
        [ 0.0000,  0.8169,  1.8479,  ..., -0.2909, -0.4258, -1.3035],
        [ 0.0000,  0.0144,  0.6998,  ..., -0.2520,  0.2831, -0.8460],
        [ 0.0000,  0.2966,  2.0965,  ...,  0.6988, -0.0550, -1.1140]],
       grad_fn=<ViewBackward>)


In [17]:
pred.size()

torch.Size([16, 1790])

In [29]:
pred = pred.max(1)
pred

(tensor([2.9606, 2.7730, 3.1195, 2.9669, 3.3332, 2.7936, 2.8986, 3.2726, 3.0928,
         2.9018, 2.9004, 3.0930, 2.9547, 3.3579, 3.3206, 3.1634],
        grad_fn=<MaxBackward0>),
 tensor([1441,  470,  120,  216,  514,  428, 1040,  514, 1399,  466, 1399,  757,
          994,  713,  495,    3]))

In [30]:
pred = pred[1]
pred

tensor([1441,  470,  120,  216,  514,  428, 1040,  514, 1399,  466, 1399,  757,
         994,  713,  495,    3])

In [31]:
gold = tgt_seq[:, 1:]
gold.size()

torch.Size([4, 4])

In [32]:
gold = gold.contiguous().view(-1)
# tensor([[17, 33,  7,  3],
#         [18,  7,  3,  0],
#         [40, 41,  5,  3],
#         [49, 50,  7,  3]])

In [33]:
non_pad_mask = gold.ne(Constants.PAD)
gold, non_pad_mask

(tensor([17, 33,  7,  3, 18,  7,  3,  0, 40, 41,  5,  3, 49, 50,  7,  3]),
 tensor([1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.uint8))

In [34]:
n_correct = pred.eq(gold)
n_correct

tensor([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], dtype=torch.uint8)

In [35]:
n_correct = n_correct.masked_select(non_pad_mask).sum().item()
n_correct

1

In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F

In [4]:
n_bm = 3
src_seq = torch.randint(0, 100, (4, 10))
src_seq

tensor([[ 4, 41, 82, 29, 24, 24, 15, 62, 55, 45],
        [30, 84, 35, 60, 54, 94, 45, 66, 84, 55],
        [40, 40, 82, 65, 73, 66, 19, 92, 47, 22],
        [85, 19, 11, 40, 57, 58, 78, 16, 92, 97]])

In [5]:
src_seq = src_seq.repeat(1, n_bm).view(4 *3, 10)
src_seq

tensor([[ 4, 41, 82, 29, 24, 24, 15, 62, 55, 45],
        [ 4, 41, 82, 29, 24, 24, 15, 62, 55, 45],
        [ 4, 41, 82, 29, 24, 24, 15, 62, 55, 45],
        [30, 84, 35, 60, 54, 94, 45, 66, 84, 55],
        [30, 84, 35, 60, 54, 94, 45, 66, 84, 55],
        [30, 84, 35, 60, 54, 94, 45, 66, 84, 55],
        [40, 40, 82, 65, 73, 66, 19, 92, 47, 22],
        [40, 40, 82, 65, 73, 66, 19, 92, 47, 22],
        [40, 40, 82, 65, 73, 66, 19, 92, 47, 22],
        [85, 19, 11, 40, 57, 58, 78, 16, 92, 97],
        [85, 19, 11, 40, 57, 58, 78, 16, 92, 97],
        [85, 19, 11, 40, 57, 58, 78, 16, 92, 97]])

In [8]:
src_enc = torch.randn(4,10,20)
src_enc[:, :, :2]

tensor([[[ 0.7888,  0.5766],
         [-0.5495,  0.1515],
         [-1.4031, -0.8542],
         [ 0.5637,  0.7123],
         [-0.4461,  1.6773],
         [ 0.5510,  0.2058],
         [ 0.2617,  0.9290],
         [-0.1693,  0.3658],
         [ 0.9981, -0.7347],
         [ 0.7736, -2.5646]],

        [[-0.2205,  0.1890],
         [-0.0167, -0.0480],
         [-1.2016,  1.2185],
         [ 0.2475, -0.8693],
         [-1.8263,  1.3872],
         [ 0.8244,  0.9413],
         [-0.2949,  0.8623],
         [ 0.5073, -1.4607],
         [ 1.0888, -0.7030],
         [-0.0861, -0.5368]],

        [[ 1.0727, -0.4317],
         [ 0.2789, -1.3457],
         [-0.4172,  0.9617],
         [ 0.9100, -0.0614],
         [-0.2955,  0.8950],
         [-1.2553,  0.1769],
         [-1.7727,  0.8991],
         [-0.3868,  0.5926],
         [-0.7110,  0.3795],
         [ 1.1802, -0.9833]],

        [[ 0.3980, -0.4244],
         [-0.2083, -0.3469],
         [ 0.0430,  0.3725],
         [ 0.3185,  1.1351],
        

In [9]:
d_h = 20
len_s = 10
n_inst = 4
src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h)
src_enc.size()

torch.Size([12, 10, 20])

In [11]:
src_enc[:6, :3, :3]

tensor([[[ 0.7888,  0.5766,  2.7153],
         [-0.5495,  0.1515,  0.1876],
         [-1.4031, -0.8542, -0.2365]],

        [[ 0.7888,  0.5766,  2.7153],
         [-0.5495,  0.1515,  0.1876],
         [-1.4031, -0.8542, -0.2365]],

        [[ 0.7888,  0.5766,  2.7153],
         [-0.5495,  0.1515,  0.1876],
         [-1.4031, -0.8542, -0.2365]],

        [[-0.2205,  0.1890,  0.0079],
         [-0.0167, -0.0480,  0.1409],
         [-1.2016,  1.2185,  0.1445]],

        [[-0.2205,  0.1890,  0.0079],
         [-0.0167, -0.0480,  0.1409],
         [-1.2016,  1.2185,  0.1445]],

        [[-0.2205,  0.1890,  0.0079],
         [-0.0167, -0.0480,  0.1409],
         [-1.2016,  1.2185,  0.1445]]])

In [15]:
import numpy as np
import Constants 
from beam import Beam

In [16]:
device = torch.device('cuda:1' if torch.cuda.is_available() else 'cpu')
inst_dec_beams = [Beam(n_bm, device=device) for _ in range(n_inst)]

In [26]:
len(inst_dec_beams)

4

In [18]:
inst_dec_beams[0].next_ys

[tensor([2, 0, 0], device='cuda:1')]

In [19]:
active_inst_idx_list = list(range(n_inst))
active_inst_idx_list

[0, 1, 2, 3]

In [20]:
def get_inst_idx_to_tensor_position_map(inst_idx_list):
    ''' Indicate the position of an instance in a tensor. '''
    return {inst_idx: tensor_position for tensor_position, inst_idx in enumerate(inst_idx_list)}

inst_idx_to_position_map = get_inst_idx_to_tensor_position_map(active_inst_idx_list)
inst_idx_to_position_map

{0: 0, 1: 1, 2: 2, 3: 3}

In [22]:
#-- Decode

def beam_decode_step(inst_dec_beams, len_dec_seq, src_seq, enc_output, inst_idx_to_position_map, n_bm):

    ''' Decode and update beam status, and then return active beam idx '''

    def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
        dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done]
        dec_partial_seq = torch.stack(dec_partial_seq).to(self.device)
        dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
        return dec_partial_seq

    def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm):
        dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=self.device)
        dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1)
        return dec_partial_pos

    def predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm):
        dec_output, *_ = self.model.decoder(dec_seq, dec_pos, src_seq, enc_output)
        dec_output = dec_output[:, -1, :]  # Pick the last step: (bh * bm) * d_h
        word_prob = F.log_softmax(self.model.tgt_word_prj(dec_output), dim=1)
        word_prob = word_prob.view(n_active_inst, n_bm, -1)

        return word_prob

    def collect_active_inst_idx_list(inst_beams, word_prob, inst_idx_to_position_map):
        active_inst_idx_list = []
        for inst_idx, inst_position in inst_idx_to_position_map.items():
            is_inst_complete = inst_beams[inst_idx].advance(word_prob[inst_position])
            if not is_inst_complete:
                active_inst_idx_list += [inst_idx]

        return active_inst_idx_list

    n_active_inst = len(inst_idx_to_position_map) # 4

    dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
    dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm)
    word_prob = predict_word(dec_seq, dec_pos, src_seq, enc_output, n_active_inst, n_bm)

    # Update the beam with predicted word prob information and collect incomplete instances
    active_inst_idx_list = collect_active_inst_idx_list(
        inst_dec_beams, word_prob, inst_idx_to_position_map)

    return active_inst_idx_list

max_token_seq_len = 65
# for len_dec_seq in range(1, max_token_seq_len + 1):
#     print(len_dec_seq)
#     active_inst_idx_list = beam_decode_step(
#                     inst_dec_beams, len_dec_seq, src_seq, src_enc, inst_idx_to_position_map, n_bm)
#     break

In [24]:
n_active_inst = len(inst_idx_to_position_map)

def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
    dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done]
    dec_partial_seq = torch.stack(dec_partial_seq).to(device)
    dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
    return dec_partial_seq

dec_seq = prepare_beam_dec_seq(inst_dec_beams, len_dec_seq)
dec_seq

tensor([[2],
        [0],
        [0],
        [2],
        [0],
        [0],
        [2],
        [0],
        [0],
        [2],
        [0],
        [0]], device='cuda:1')

In [25]:
def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm):
    dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=device)
    dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1)
    return dec_partial_pos
dec_pos = prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm)
dec_pos

tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1],
        [1]], device='cuda:1')

# loss function

In [None]:
eps = 0.1
n_class = pred.size(1)

one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
one_hot = one_hot * (1 - eps) + (1 - one_hot) * eps / (n_class - 1)
log_prb = F.log_softmax(pred, dim=1)

non_pad_mask = gold.ne(Constants.PAD)
loss = -(one_hot * log_prb).sum(dim=1)
loss = loss.masked_select(non_pad_mask).sum()  # average later

In [85]:
pred = torch.randn(80, 20)
gold = torch.randint(0, 19, (8, 10))
gold = gold.contiguous().view(-1)
gold

tensor([10,  2,  8,  4,  9,  5,  7, 15, 18,  5,  5,  6,  2,  7, 16, 17,  2, 14,
        16,  2,  3,  4,  8, 18,  5, 13,  5, 17, 12,  9,  1, 18,  8,  1,  4, 14,
        12,  6,  8,  7, 18, 18, 15, 12,  9,  3,  3,  4, 10,  8,  5,  4, 17,  4,
         6,  8,  9,  6,  3,  2, 14,  2,  9,  4, 18, 18, 16,  2,  2, 11, 14,  4,
        18,  9,  1, 16, 16, 12, 15,  4])

In [95]:
one_hot = torch.zeros_like(pred).scatter(1, gold.view(-1, 1), 1)
print(one_hot.size())
one_hot[:15, :15]

torch.Size([80, 20])


tensor([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
        [0., 0., 0.,

In [96]:
log_prb = F.log_softmax(pred, dim=1)
print(log_prb.size())

torch.Size([80, 20])


In [97]:
loss = -(one_hot * log_prb).sum(dim=1)
loss.size()

torch.Size([80])

In [98]:
loss

tensor([2.9130, 3.6655, 3.4123, 3.6186, 3.3510, 4.3486, 3.9666, 2.4910, 4.6356,
        3.9776, 4.2750, 5.4937, 3.3387, 4.0871, 4.2915, 3.8214, 2.2534, 2.5865,
        1.8448, 4.1983, 2.9567, 4.2920, 2.9304, 2.0025, 3.8469, 3.8488, 3.0192,
        4.0202, 4.5486, 2.5903, 2.8071, 3.0103, 4.1669, 3.0809, 1.9707, 4.0798,
        2.9313, 5.6030, 2.8944, 5.1632, 4.4811, 4.1843, 4.2205, 4.4061, 3.2128,
        3.5347, 3.8706, 0.6721, 3.8442, 2.8956, 2.9133, 2.9370, 1.4976, 4.8199,
        3.9985, 3.0168, 4.4859, 4.4383, 2.2513, 5.1611, 4.3127, 4.6679, 4.1396,
        3.1560, 4.1702, 2.5973, 2.3852, 3.9929, 2.6124, 2.6030, 2.7387, 5.6631,
        2.6383, 3.8226, 3.2649, 3.4971, 4.1436, 4.3607, 1.6295, 3.1739])

In [101]:
non_pad_mask = gold.ne(0)
print(non_pad_mask.size())
non_pad_mask

torch.Size([80])


tensor([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
        1, 1, 1, 1, 1, 1, 1, 1], dtype=torch.uint8)

In [104]:
loss1 = torch.tensor([0.2, 0.3, 0.4, 0.5])
non_pad_mask1 = torch.ByteTensor([0,1,1,0])
loss1= loss1.masked_select(non_pad_mask1)
print(loss1)
loss1.sum()

tensor([0.3000, 0.4000])


tensor(0.7000)

In [106]:
pred.max(1)[1].size()

torch.Size([80])

In [1]:
import torch 
from train import Args
from preprocess import Lang
cp = torch.load('best_model.chkpt')

In [2]:
from model import Transformer
import pickle
import torch.nn as nn


with open('result_0326.pkl', 'rb') as f:
    data = pickle.load(f)

model_opt = Args(data)

model = Transformer(
    model_opt.src_vocab_size,
    model_opt.tgt_vocab_size,
    model_opt.max_token_seq_len,
    tgt_emb_prj_weight_sharing = model_opt.proj_share_weight,
    emb_src_tgt_weight_sharing = model_opt.embs_share_weight,
    d_k = model_opt.d_k,
    d_v = model_opt.d_v,
    d_model = model_opt.d_model,
    d_word_vec = model_opt.d_word_vec,
    d_inner = model_opt.d_inner_hid,
    n_layers = model_opt.n_layers,
    n_head = model_opt.n_head,
    dropout = model_opt.dropout)
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = nn.DataParallel(model)

In [3]:
model.load_state_dict(cp['model'])
print('[Info] Trained model state loaded.')

[Info] Trained model state loaded.


In [1]:
# model = model.module

## Mask Mechanism

In [3]:
import torch
import torch.nn as nn

PAD = 0
def get_non_pad_mask(seq):
    assert seq.dim() == 2
    ## ne if elem = pad then 0 else 1; b x sl x 1
    return seq.ne(PAD).type(torch.float).unsqueeze(-1)

In [4]:
scc_seq = torch.LongTensor([[3,2,4,8,5,6,1], 
                            [4,5,3,2,1,0,0], 
                            [2,3,1,2,0,0,0]])
scc_seq

tensor([[3, 2, 4, 8, 5, 6, 1],
        [4, 5, 3, 2, 1, 0, 0],
        [2, 3, 1, 2, 0, 0, 0]])

In [6]:
non_pad_mask = get_non_pad_mask(scc_seq)
non_pad_mask.size()

torch.Size([3, 7, 1])

In [9]:
print(non_pad_mask)

tensor([[[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [1.]],

        [[1.],
         [1.],
         [1.],
         [1.],
         [1.],
         [0.],
         [0.]],

        [[1.],
         [1.],
         [1.],
         [1.],
         [0.],
         [0.],
         [0.]]])


In [7]:
def get_attn_key_pad_mask(seq_k, seq_q):
    ''' For masking out the padding part of key sequence. '''

    # Expand to fit the shape of key query attention matrix.
    len_q = seq_q.size(1)
    padding_mask = seq_k.eq(PAD) # where elem = pad then 1 else 0
    padding_mask = padding_mask.unsqueeze(1).expand(-1, len_q, -1)  # b x lq x lk
    # here seq_k = seq_q
    return padding_mask
    
slf_attn_mask = get_attn_key_pad_mask(seq_k=scc_seq, seq_q=scc_seq)
slf_attn_mask.size()

torch.Size([3, 7, 7])

In [8]:
slf_attn_mask

tensor([[[0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1, 1]],

        [[0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1]]], dtype=torch.uint8)

In [13]:
import numpy as np
attn = torch.randn(3,7,7)
attn = attn.masked_fill(slf_attn_mask, -np.inf)
attn

tensor([[[ 1.1907, -0.1930, -1.7384, -1.3472, -0.2542, -1.0443, -2.3004],
         [-1.2975, -1.2419,  0.9577, -0.6475,  1.0010, -0.2028,  0.4403],
         [-0.6398,  0.4331, -1.6618, -1.5534,  0.5599, -0.8390, -0.6250],
         [-1.8951,  0.7345, -0.3899, -0.9053,  1.4683,  0.0615, -0.0028],
         [-0.7552,  0.2180, -0.0062, -0.3522, -1.2483,  0.0866, -0.9544],
         [ 0.5425, -0.2690,  0.6157, -1.1469,  0.6495, -1.3331, -2.0436],
         [-0.2885,  0.1965,  1.2770, -0.6580,  0.6872, -0.5895, -1.2989]],

        [[ 0.1168,  0.0962,  0.1755, -0.2229, -0.4230,    -inf,    -inf],
         [ 1.4319,  0.5290, -0.6110,  0.8114,  0.6045,    -inf,    -inf],
         [ 1.0251, -0.4113, -1.3442, -1.6052, -0.5357,    -inf,    -inf],
         [-0.9391,  2.3830,  1.1223, -1.8243,  0.2361,    -inf,    -inf],
         [-0.3730, -0.1633, -0.6894, -0.5727,  0.4126,    -inf,    -inf],
         [ 0.2491,  1.1800,  0.0675,  2.0645,  1.5759,    -inf,    -inf],
         [ 0.8019, -1.0352, -0.4519,

In [11]:
enc_output = torch.randn(3, 7, 5)
enc_output

tensor([[[ 0.1054, -1.0855, -0.2113,  0.8852, -1.1243],
         [ 0.8511,  1.9635, -0.0993,  0.5729,  0.5146],
         [-0.8496, -0.6644,  0.0154,  0.0218,  0.1121],
         [-0.1432,  0.6329, -1.2499, -1.4691,  0.0487],
         [ 0.4552,  1.5328,  0.8571,  0.0168, -0.9150],
         [ 1.1203, -0.5750,  0.7145, -1.5010, -0.5726],
         [ 1.4359, -0.2541,  1.7998,  1.1124, -0.7505]],

        [[ 0.5389, -1.2988,  0.4234, -0.9288, -0.8077],
         [-0.6864,  0.8567, -0.8808,  1.8583, -0.4359],
         [ 0.0725, -0.1528, -0.2866, -0.4291, -1.0751],
         [ 0.5324,  0.7143, -0.3355,  0.3752,  1.0411],
         [ 0.5307, -0.4071,  0.8978,  0.9533,  0.3029],
         [ 1.5017, -1.0790,  0.6965, -0.0104, -1.4753],
         [ 0.5082,  0.1516,  0.1859, -0.8744,  0.9237]],

        [[ 0.4984,  0.2248, -0.2131, -0.9260,  0.3960],
         [ 0.5335, -0.7710,  0.9514, -0.6420, -0.7203],
         [ 0.2399, -1.6857,  0.1937, -0.9144,  0.3296],
         [ 0.8215, -0.7160, -2.5166, -0.0378

In [12]:
enc_output *= non_pad_mask
enc_output

tensor([[[ 0.1054, -1.0855, -0.2113,  0.8852, -1.1243],
         [ 0.8511,  1.9635, -0.0993,  0.5729,  0.5146],
         [-0.8496, -0.6644,  0.0154,  0.0218,  0.1121],
         [-0.1432,  0.6329, -1.2499, -1.4691,  0.0487],
         [ 0.4552,  1.5328,  0.8571,  0.0168, -0.9150],
         [ 1.1203, -0.5750,  0.7145, -1.5010, -0.5726],
         [ 1.4359, -0.2541,  1.7998,  1.1124, -0.7505]],

        [[ 0.5389, -1.2988,  0.4234, -0.9288, -0.8077],
         [-0.6864,  0.8567, -0.8808,  1.8583, -0.4359],
         [ 0.0725, -0.1528, -0.2866, -0.4291, -1.0751],
         [ 0.5324,  0.7143, -0.3355,  0.3752,  1.0411],
         [ 0.5307, -0.4071,  0.8978,  0.9533,  0.3029],
         [ 0.0000, -0.0000,  0.0000, -0.0000, -0.0000],
         [ 0.0000,  0.0000,  0.0000, -0.0000,  0.0000]],

        [[ 0.4984,  0.2248, -0.2131, -0.9260,  0.3960],
         [ 0.5335, -0.7710,  0.9514, -0.6420, -0.7203],
         [ 0.2399, -1.6857,  0.1937, -0.9144,  0.3296],
         [ 0.8215, -0.7160, -2.5166, -0.0378

In [32]:
tgt_seq = torch.LongTensor([[4,8,5,6,1,6], 
                            [3,2,1,0,0,0], 
                            [1,2,0,0,0,0]])

def get_subsequent_mask(seq):
    ''' For masking out the subsequent info. '''

    sz_b, len_s = seq.size()
    subsequent_mask = torch.triu(
        torch.ones((len_s, len_s), device=seq.device, dtype=torch.uint8), diagonal=1)
    subsequent_mask = subsequent_mask.unsqueeze(0).expand(sz_b, -1, -1)  # b x ls x ls

    return subsequent_mask

slf_attn_mask_subseq = get_subsequent_mask(tgt_seq)
print(slf_attn_mask_subseq.size())
slf_attn_mask_subseq

torch.Size([3, 6, 6])


tensor([[[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]],

        [[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]],

        [[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]]], dtype=torch.uint8)

In [33]:
slf_attn_mask_keypad = get_attn_key_pad_mask(seq_k=tgt_seq, seq_q=tgt_seq)
slf_attn_mask_keypad

tensor([[[0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 1, 1, 1]],

        [[0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)

In [34]:
slf_attn_mask = (slf_attn_mask_keypad + slf_attn_mask_subseq).gt(0)
slf_attn_mask

tensor([[[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1],
         [0, 0, 0, 0, 0, 0]],

        [[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 1, 1, 1],
         [0, 0, 0, 1, 1, 1]],

        [[0, 1, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1],
         [0, 0, 1, 1, 1, 1]]], dtype=torch.uint8)

In [35]:
attn = torch.randn(3,6,6)
attn = attn.masked_fill(slf_attn_mask, -np.inf)
print(attn)

tensor([[[-0.0390,    -inf,    -inf,    -inf,    -inf,    -inf],
         [ 0.3725,  0.4847,    -inf,    -inf,    -inf,    -inf],
         [ 0.2521,  0.9616,  0.4055,    -inf,    -inf,    -inf],
         [ 0.5334, -0.3027,  0.1682,  0.4592,    -inf,    -inf],
         [ 1.8542,  0.0811, -0.9784, -0.2155, -0.3622,    -inf],
         [-0.8214,  0.1133,  0.5731, -1.6596,  1.4122,  0.3225]],

        [[-0.4751,    -inf,    -inf,    -inf,    -inf,    -inf],
         [ 0.4183,  0.9962,    -inf,    -inf,    -inf,    -inf],
         [-0.1224,  0.3569,  1.8051,    -inf,    -inf,    -inf],
         [ 0.4057,  0.3073,  0.2014,    -inf,    -inf,    -inf],
         [ 0.1666, -0.4830, -1.9803,    -inf,    -inf,    -inf],
         [-1.2039,  0.1367, -1.6129,    -inf,    -inf,    -inf]],

        [[-1.4043,    -inf,    -inf,    -inf,    -inf,    -inf],
         [ 0.8128,  1.8255,    -inf,    -inf,    -inf,    -inf],
         [-1.3506,  1.1077,    -inf,    -inf,    -inf,    -inf],
         [ 1.4936,  0

In [36]:
dec_enc_attn_mask = get_attn_key_pad_mask(seq_k=scc_seq, seq_q=tgt_seq)
dec_enc_attn_mask

tensor([[[0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0],
         [0, 0, 0, 0, 0, 0, 0]],

        [[0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1, 1],
         [0, 0, 0, 0, 0, 1, 1]],

        [[0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1],
         [0, 0, 0, 0, 1, 1, 1]]], dtype=torch.uint8)

# Translate

In [39]:
n_bm = 3 
n_inst = 2
src_seq = torch.randint(0, 20, (2, 5))
print(src_seq)
src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, 5)
src_seq

tensor([[ 7,  2,  7, 17, 17],
        [19,  3, 18,  8, 10]])


tensor([[ 7,  2,  7, 17, 17],
        [ 7,  2,  7, 17, 17],
        [ 7,  2,  7, 17, 17],
        [19,  3, 18,  8, 10],
        [19,  3, 18,  8, 10],
        [19,  3, 18,  8, 10]])

In [54]:
n_inst = 5
src_seq = torch.randint(0, 20, (n_inst, 5))
print(src_seq)
src_seq = src_seq.repeat(1, n_bm).view(n_inst * n_bm, 5)
src_seq.size()

tensor([[ 1,  7, 13,  7,  4],
        [14,  2, 11, 18,  1],
        [15, 13,  1, 18,  7],
        [13,  0, 10,  4,  0],
        [10, 14, 10,  6,  0]])


torch.Size([15, 5])

In [58]:
src_seq = src_seq.view(n_inst, -1)
ids = torch.LongTensor([0, 2, 4])
src_seq = src_seq.index_select(0, ids)
src_seq.view(-1, 5)

tensor([[ 1,  7, 13,  7,  4],
        [ 1,  7, 13,  7,  4],
        [ 1,  7, 13,  7,  4],
        [15, 13,  1, 18,  7],
        [15, 13,  1, 18,  7],
        [15, 13,  1, 18,  7],
        [10, 14, 10,  6,  0],
        [10, 14, 10,  6,  0],
        [10, 14, 10,  6,  0]])

In [42]:
len_s = 5
d_h = 4
src_enc = torch.randn(2, 5, 4)
src_enc = src_enc.repeat(1, n_bm, 1).view(n_inst * n_bm, len_s, d_h) # repeat : 2 x 15 x 4 ; view : 6 x 5 x 4
src_enc.size()

torch.Size([6, 5, 4])

In [43]:
src_enc

tensor([[[ 1.5695,  1.0529, -0.1525, -0.8965],
         [ 0.9586,  0.7160,  0.6085,  0.2505],
         [-0.8403,  0.7315, -0.9570,  0.2432],
         [-1.5587, -0.5631,  1.3815,  0.7977],
         [ 0.5308, -0.6769,  1.8018, -0.8129]],

        [[ 1.5695,  1.0529, -0.1525, -0.8965],
         [ 0.9586,  0.7160,  0.6085,  0.2505],
         [-0.8403,  0.7315, -0.9570,  0.2432],
         [-1.5587, -0.5631,  1.3815,  0.7977],
         [ 0.5308, -0.6769,  1.8018, -0.8129]],

        [[ 1.5695,  1.0529, -0.1525, -0.8965],
         [ 0.9586,  0.7160,  0.6085,  0.2505],
         [-0.8403,  0.7315, -0.9570,  0.2432],
         [-1.5587, -0.5631,  1.3815,  0.7977],
         [ 0.5308, -0.6769,  1.8018, -0.8129]],

        [[-0.4781,  1.1903, -0.8891, -1.2384],
         [ 0.1979, -0.2469, -0.5289, -1.5441],
         [-1.6505, -0.5915,  0.4779, -0.5341],
         [ 0.9907,  0.1285, -0.2244,  0.1127],
         [-0.0596, -1.9247, -0.4590, -0.2597]],

        [[-0.4781,  1.1903, -0.8891, -1.2384],
     

In [44]:
torch.arange(1,2)

tensor([1])

In [48]:
from beam import Beam

device = torch.device("cpu")
inst_dec_beams = [Beam(n_bm, device = device) for _ in range(n_inst)]

def prepare_beam_dec_seq(inst_dec_beams, len_dec_seq):
    dec_partial_seq = [b.get_current_state() for b in inst_dec_beams if not b.done]
    print(dec_partial_seq)
    dec_partial_seq = torch.stack(dec_partial_seq)
    print(dec_partial_seq)
    dec_partial_seq = dec_partial_seq.view(-1, len_dec_seq)
    return dec_partial_seq

prepare_beam_dec_seq(inst_dec_beams, 1)

[tensor([[2],
        [0],
        [0]]), tensor([[2],
        [0],
        [0]])]
tensor([[[2],
         [0],
         [0]],

        [[2],
         [0],
         [0]]])


tensor([[2],
        [0],
        [0],
        [2],
        [0],
        [0]])

In [49]:
def prepare_beam_dec_pos(len_dec_seq, n_active_inst, n_bm):
    dec_partial_pos = torch.arange(1, len_dec_seq + 1, dtype=torch.long, device=device)
    dec_partial_pos = dec_partial_pos.unsqueeze(0).repeat(n_active_inst * n_bm, 1)
    return dec_partial_pos

dec_pos = prepare_beam_dec_pos(1, 2, n_bm)
dec_pos

tensor([[1],
        [1],
        [1],
        [1],
        [1],
        [1]])

In [50]:
flat_beam_lk = torch.randn(9)
print(flat_beam_lk)
best_scores, best_scores_id = flat_beam_lk.topk(3, 0, True, True) # 1st sort
print(best_scores, best_scores_id)
best_scores, best_scores_id = flat_beam_lk.topk(3, 0, True, True) # 2nd sort
print(best_scores, best_scores_id)

tensor([-1.1397, -0.2724, -0.4226,  1.2623, -0.5551,  1.5296, -0.8335,  0.8877,
        -0.9859])
tensor([1.5296, 1.2623, 0.8877]) tensor([5, 3, 7])
tensor([1.5296, 1.2623, 0.8877]) tensor([5, 3, 7])


In [52]:
best_scores_id / 20

tensor([0, 0, 0])

In [None]:
[[tensor([2, 0, 0], device='cuda:1'), 
  tensor([185, 677, 478], device='cuda:1'), 
  tensor([478, 170, 185], device='cuda:1'), 
  tensor([572, 739, 728], device='cuda:1'), 
  tensor([ 79,  79, 268], device='cuda:1'), 
  tensor([204, 204, 813], device='cuda:1'),
  tensor([180, 180, 180], device='cuda:1'), 
  tensor([287, 287, 633], device='cuda:1'), 
  tensor([23, 23, 23], device='cuda:1'), 
  tensor([3, 3, 3], device='cuda:1')],  [[185, 478, 572, 79, 204, 180, 287, 23, 3]]
 
 [tensor([2, 0, 0], device='cuda:1'), 
  tensor([ 22,  79, 153], device='cuda:1'), 
  tensor([71, 185, 277], device='cuda:1'), 
  tensor([413, 186, 491], device='cuda:1'), 
  tensor([1655,  798,  186], device='cuda:1'), 
  tensor([384, 684, 684], device='cuda:1'), 
  tensor([180, 180, 180], device='cuda:1'), 
  tensor([287, 287, 287], device='cuda:1'), 
  tensor([413,   1,   1], device='cuda:1'), 
  tensor([798,   5,   5], device='cuda:1'), 
  tensor([  1,   3, 395], device='cuda:1'),
  tensor([5, 5, 5], device='cuda:1')]] [22, 71, 413, 1655, 384, 180, 287, 413, 798, 1, 5, 3]

[[tensor([0, 0, 0], device='cuda:1'), 
  tensor([0, 1, 0], device='cuda:1'), 
  tensor([0, 0, 0], device='cuda:1'), 
  tensor([0, 1,0], device='cuda:1'), 
  tensor([0, 1, 0], device='cuda:1'), 
  tensor([0, 1, 2], device='cuda:1'), 
  tensor([0, 1, 0], device='cuda:1'), 
  tensor([0, 1, 2], device='cuda:1'), 
  tensor([0, 1, 2], device='cuda:1')], 
 [tensor([0, 0, 0], device='cuda:1'), tensor([0, 0, 1], device='cuda:1'), tensor([0, 0, 0], device='cuda:1'), tensor([0, 0, 0], device='cuda:1'), tensor([0, 1, 0], device='cuda:1'), tensor([0, 1, 2], device='cuda:1'), tensor([0, 1, 2], device='cuda:1'), tensor([0, 1, 2], device='cuda:1'), tensor([0, 1, 2], device='cuda:1'), tensor([0, 1, 0], device='cuda:1'), tensor([0, 1, 2], device='cuda:1')]]

[[tensor([2, 0, 0], device='cuda:1'), tensor([185, 677, 478], device='cuda:1'), tensor([478, 170, 185], device='cuda:1'), tensor([572, 739, 728], device='cuda:1'), tensor([ 79,  79, 268], device='cuda:1'), tensor([204, 204, 813], device='cuda:1'),tensor([180, 180, 180], device='cuda:1'), tensor([287, 287, 633], device='cuda:1'), tensor([23, 23, 23], device='cuda:1'), tensor([3, 3, 3], device='cuda:1')], [tensor([2, 0, 0], device='cuda:1'), tensor([ 22,  79, 153], device='cuda:1'), tensor([71, 185, 277], device='cuda:1'), tensor([413, 186, 491], device='cuda:1'), tensor([1655,  798,  186], device='cuda:1'), tensor([384, 684, 684], device='cuda:1'), tensor([180, 180, 180], device='cuda:1'), tensor([287, 287, 287], device='cuda:1'), tensor([413,   1,   1], device='cuda:1'), tensor([798,   5,   5], device='cuda:1'), tensor([  1,   3, 395], device='cuda:1'),tensor([5, 5, 5], device='cuda:1'), tensor([3, 3, 3], device='cuda:1')]]
[[tensor([0, 0, 0], device='cuda:1'), tensor([0, 1, 0], device='cuda:1'), tensor([0, 0, 0], device='cuda:1'), tensor([0, 1,0], device='cuda:1'), tensor([0, 1, 0], device='cuda:1'), tensor([0, 1, 2], device='cuda:1'), tensor([0, 1, 0], device='cuda:1'), tensor([0, 1, 2], device='cuda:1'), tensor([0, 1, 2], device='cuda:1')], [tensor([0, 0, 0], device='cuda:1'), tensor([0, 0, 1], device='cuda:1'), tensor([0, 0, 0], device='cuda:1'), tensor([0, 0, 0], device='cuda:1'), tensor([0, 1, 0], device='cuda:1'), tensor([0, 1, 2], device='cuda:1'), tensor([0, 1, 2], device='cuda:1'), tensor([0, 1, 2], device='cuda:1'), tensor([0, 1, 2], device='cuda:1'), tensor([0, 1, 0], device='cuda:1'), tensor([0, 1, 2], device='cuda:1'), tensor([0, 1, 2], device='cuda:1')]]

[[[185, 478, 572, 79, 204, 180, 287, 23, 3]], 
 [[22, 71, 413, 1655, 384, 180, 287, 413, 798, 1, 5, 3]]]