In [15]:
import torch
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
import math
import copy
from torch.autograd import Variable

In [18]:
class PositionalEncodings(nn.Module):
    def __init__(self, d, drop_prob=.1, max_len=5000):
        """
        :param d: Dimension of embedding
        :param drop_prob: Dropout Rate
        :param max_len: Maximum length of a sequence
        """
        super(PositionalEncodings, self).__init__()
#         self.dropout = nn.Dropout(p=drop_prob)

        PE = torch.zeros((max_len, d))  # (L, d)
        pos = torch.arange(max_len, dtype=torch.float32).unsqueeze(1)  # (L, 1)
        div = torch.exp(torch.arange(0., d, 2)/d*math.log(1e4))  # (d/2)
        PE[:, ::2] = torch.sin(pos/div)  # (L, d/2)
        PE[:, 1::2] = torch.cos(pos/div)  # (L, d/2)
        self.register_buffer('PE', PE)  # (L, d)

    def forward(self, x):
        """
        :param x: Input (batch, seq_len, d)
        :return: x + PE
        """
        return x + self.PE[:x.shape[1]]  # You added the same PE sinusoid to all positions

In [16]:
class PositionalEncoding(nn.Module):
    "Implement the PE function."
    def __init__(self, d_model, max_len=5000):
        super(PositionalEncoding, self).__init__()
        
        # Compute the positional encodings once in log space.
        pe = torch.zeros(max_len, d_model)

        position = torch.arange(0., max_len).unsqueeze(1)
        div_term = torch.exp(torch.arange(0., d_model, 2) *
                             -(math.log(1e4) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0)
        self.register_buffer('pe', pe)
        
    def forward(self, x):
        x = x + Variable(self.pe[:, :x.size(1)], 
                         requires_grad=False)
        return x

In [14]:
PE1 = PositionalEncoding(10)

In [17]:
O1=PE1(torch.zeros((2,3,10)))

In [19]:
PE2 = PositionalEncodings(10)

In [20]:
O2=PE2(torch.zeros((2,3,10)))

In [21]:
np.allclose(O1,O2)

True

In [80]:
src = torch.cat((torch.ones(2,3), torch.zeros(2,3)), dim=1)

In [81]:
scores = 5*torch.ones((2,3,6,6))

In [82]:
src_mask = (src != 0).unsqueeze(-2)  # B x 1 x L

In [83]:
src_mask = src_mask.unsqueeze(1)  # Heads

In [84]:
s1=scores.masked_fill(src_mask == 0, -1e9)  # Source mask

In [85]:
src_mask = (src != 0).unsqueeze(-2)  # B x 1 x L

In [86]:
tgt_mask = torch.from_numpy(np.tril(np.ones(6))).byte()

In [87]:
mask = (src_mask * tgt_mask).unsqueeze(1)

In [74]:
s2=scores.masked_fill(mask == 0, -1e9)  # Target mask

In [76]:
np.tril(np.ones(6))

array([[1., 0., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0., 0.],
       [1., 1., 1., 0., 0., 0.],
       [1., 1., 1., 1., 0., 0.],
       [1., 1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 1., 1.]])

In [79]:
np.tril(np.ones((6,6)))

array([[1., 0., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0., 0.],
       [1., 1., 1., 0., 0., 0.],
       [1., 1., 1., 1., 0., 0.],
       [1., 1., 1., 1., 1., 0.],
       [1., 1., 1., 1., 1., 1.]])