In [1]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import math
import copy
import time
from torch.utils.data import DataLoader
import torchtext.datasets as datasets
from torch.utils.data.distributed import DistributedSampler

import numpy as np


In [24]:
def get_attn_subsequent_mask(seq):
    assert seq.dim() == 2
    attn_shape = [seq.size(0), seq.size(1), seq.size(1)]
    subsequent_mask = torch.triu(torch.ones(attn_shape), diagonal=1)

    if seq.is_cuda:
        subsequent_mask = subsequent_mask.cuda()

    return subsequent_mask

In [3]:
def get_attn_pad_mask(inputs, input_lengths, expand_length):
    """mask position is set to 1"""

    non_pad_mask = get_transformer_non_pad_mask(inputs, input_lengths)
    pad_mask = non_pad_mask.lt(1)
    attn_pad_mask = pad_mask.unsqueeze(1).expand(-1, expand_length, -1)
    return attn_pad_mask

In [4]:
def get_transformer_non_pad_mask(inputs, input_lengths):
        """Padding position is set to 0, either use input_lengths or pad_id"""
        batch_size = inputs.size(0)

        if len(inputs.size()) == 2:
            non_pad_mask = inputs.new_ones(inputs.size())  # B x T
        elif len(inputs.size()) == 3:
            non_pad_mask = inputs.new_ones(inputs.size()[:-1])  # B x T
        else:
            raise ValueError(f"Unsupported input shape {inputs.size()}")

        for i in range(batch_size):
            non_pad_mask[i, input_lengths[i]:] = 0

        return non_pad_mask

In [5]:
get_attn_pad_mask(torch.tensor([[1, 2, 3, 0, 0], [1, 2, 3, 4, 0]]), torch.tensor([3, 4]),5)

tensor([[[False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True]],

        [[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True]]])

In [6]:
get_transformer_non_pad_mask(torch.tensor([[1, 2, 3, 0, 0], [1, 2, 3, 4, 0]]), torch.tensor([3, 4]))

tensor([[1, 1, 1, 0, 0],
        [1, 1, 1, 1, 0]])

In [11]:
x = torch.tensor([[i for i in range(1,10)],[i for i in range(1,10)],[i for i in range(1,10)]])
x

tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9],
        [1, 2, 3, 4, 5, 6, 7, 8, 9],
        [1, 2, 3, 4, 5, 6, 7, 8, 9]])

In [13]:
x[x != 1].view(3, -1)

tensor([[2, 3, 4, 5, 6, 7, 8, 9],
        [2, 3, 4, 5, 6, 7, 8, 9],
        [2, 3, 4, 5, 6, 7, 8, 9]])

In [18]:
targets = torch.tensor([[1, 2, 3, 0, 0], [1, 2, 3, 4, 0]])
target_lengths = torch.tensor([3, 4])
targets

tensor([[1, 2, 3, 0, 0],
        [1, 2, 3, 4, 0]])

In [19]:
target_length = targets.size(1)
target_length

5

In [21]:
decoder_inputs=targets
decoder_input_lengths=target_lengths
positional_encoding_length=target_length,

In [27]:
dec_self_attn_pad_mask = get_attn_pad_mask(decoder_inputs, decoder_input_lengths, decoder_inputs.size(1))
dec_self_attn_pad_mask

tensor([[[False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True]],

        [[False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True]]])

In [26]:
dec_self_attn_subsequent_mask = get_attn_subsequent_mask(decoder_inputs)
dec_self_attn_subsequent_mask

tensor([[[0., 1., 1., 1., 1.],
         [0., 0., 1., 1., 1.],
         [0., 0., 0., 1., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0.]],

        [[0., 1., 1., 1., 1.],
         [0., 0., 1., 1., 1.],
         [0., 0., 0., 1., 1.],
         [0., 0., 0., 0., 1.],
         [0., 0., 0., 0., 0.]]])

In [28]:
self_attn_mask = torch.gt((dec_self_attn_pad_mask + dec_self_attn_subsequent_mask), 0)
self_attn_mask

tensor([[[False,  True,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False,  True,  True]],

        [[False,  True,  True,  True,  True],
         [False, False,  True,  True,  True],
         [False, False, False,  True,  True],
         [False, False, False, False,  True],
         [False, False, False, False,  True]]])