In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.nn.utils.rnn import pack_padded_sequence, pad_sequence

In [3]:
term_idx_sample = [ 1, 2, 3, 4, 5, 3, 4, 6, 1, 4, 5, 7 ]

## attributes

## input

In [4]:
term_idx_sample = torch.LongTensor(term_idx_sample)
docs_offsets_sample = torch.LongTensor([ 0, 5, 9 ])
docs_offsets_sample, term_idx_sample

(tensor([0, 5, 9]), tensor([1, 2, 3, 4, 5, 3, 4, 6, 1, 4, 5, 7]))

## forward

In [2]:
class Mask(nn.Module):
    def __init__(self, negative_slope=1000, kappa=0.):
        super(Mask, self).__init__()
        self.negative_slope = negative_slope
        self.kappa = kappa
        self.sig = nn.Sigmoid()
    def forward(self, h):
        w = F.leaky_relu( h, negative_slope=self.negative_slope)
        w = self.sig(w-self.kappa)
        return w

In [5]:
class AttentionBag(nn.Module):
    def __init__(self, vocab_size, hiddens):
        super(AttentionBag, self).__init__()
        self.hiddens    = hiddens
        self.mask       = Mask()
        self.dt_emb     = nn.Embedding(vocab_size, hiddens)
        self.dt_dir_map = nn.Linear(hiddens, hiddens)
        self.ma_term    = nn.MultiheadAttention(hiddens, 1)
    def forward(self, terms_idx, docs_offsets, return_mask=False):
        n = terms_idx.shape[0]
        batch_size = docs_offsets.shape[0]
        
        k         = [ terms_idx[ docs_offsets[i-1]:docs_offsets[i] ] for i in range(1, batch_size) ]
        k.append( terms_idx[ docs_offsets[-1]: ] )
        x_packed  = pad_sequence(k, batch_first=True, padding_value=0)

        bx_packed = x_packed == 0
        pad_mask  = bx_packed.logical_not()
        pad_mask  = pad_mask.view(*bx_packed.shape, 1)
        pad_mask  = pad_mask.logical_and(pad_mask.transpose(1, 2))
        
        dt_h      = self.dt_emb( x_packed )
        dir_dt_h  = self.dt_dir_map( dt_h )

        weights = torch.bmm( dt_h, dir_dt_h.transpose( 1, 2 ) )
        weights = self.mask(weights)
        
        weights_disc = (weights * pad_mask)
        weights_disc = weights_disc.sum(axis=1)
        weights_disc = F.softmax(weights_disc, dim=1)
        weights_disc = weights_disc.view( *weights_disc.shape, 1 )
        
        attn_mask = weights != 0
        attn_mask = attn_mask.logical_and( pad_mask ).logical_not()
        
        dt_h     = dt_h.transpose(0,1)
        dir_dt_h = dir_dt_h.transpose(0,1)
        docs_att, weigths_att = self.ma_term( dt_h, dir_dt_h, dt_h,
                                  key_padding_mask=bx_packed, 
                                  attn_mask=attn_mask )

        weigths_att = torch.where(torch.isnan(weigths_att), torch.zeros_like(weigths_att), weigths_att)
        weigths_att = (weigths_att * pad_mask)
        weigths_att = F.softmax(weigths_att.sum(axis=1), dim=1)
        weigths_att = weigths_att.view( *weigths_att.shape, 1 )
        
        weigths = weigths_att + weights_disc

        docs_att = docs_att.transpose(0,1)
        docs_att = torch.where(torch.isnan(docs_att), torch.zeros_like(docs_att), docs_att)
        
        docs_h = docs_att * weigths
        docs_h = docs_h.sum(axis=1)
        docs_h = docs_h / bx_packed.logical_not().sum(dim=1).view(batch_size, 1)
        docs_h = torch.where(torch.isnan(docs_h), torch.zeros_like(docs_h), docs_h)
        if return_mask:
            return docs_h, bx_packed, pad_mask, attn_mask
        return docs_h

In [6]:
att = AttentionBag(len(set(term_idx_sample.tolist()))+1, 10)
att

AttentionBag(
  (mask): Mask(
    (sig): Sigmoid()
  )
  (dt_emb): Embedding(8, 10)
  (dt_dir_map): Linear(in_features=10, out_features=10, bias=True)
  (ma_term): MultiheadAttention(
    (out_proj): _LinearWithBias(in_features=10, out_features=10, bias=True)
  )
)

In [7]:
att(term_idx_sample, docs_offsets_sample, return_mask=True)

(tensor([[ 0.0115, -0.1640,  0.0684, -0.0876, -0.1101,  0.0051,  0.1153, -0.1482,
           0.0171, -0.0103],
         [ 0.0094, -0.1155, -0.0052, -0.0195, -0.0078, -0.0547,  0.0899, -0.0682,
           0.0090,  0.0051],
         [ 0.0467, -0.0305,  0.0470, -0.1557, -0.0030,  0.1063, -0.0692, -0.1062,
           0.0078,  0.0341]], grad_fn=<SWhereBackward>),
 tensor([[False, False, False, False, False],
         [False, False, False, False,  True],
         [False, False, False,  True,  True]]),
 tensor([[[ True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True]],
 
         [[ True,  True,  True,  True, False],
          [ True,  True,  True,  True, False],
          [ True,  True,  True,  True, False],
          [ True,  True,  True,  True, False],
          [False, False, False, False, False]],
 
         [[ True,  True

In [49]:


k = [ term_idx_sample[ docs_offsets_sample[i-1]:docs_offsets_sample[i] ] for i in range(1, batch_size) ]
k.append( term_idx_sample[ docs_offsets_sample[-1]: ] )
x_packed   = pad_sequence(k, batch_first=True, padding_value=0)

dt_h     = dt_emb( x_packed )
dir_dt_h = dt_dir_map( dt_h )

bx_packed = x_packed == 0
x_packed

tensor([[1, 2, 3, 4, 5],
        [3, 4, 6, 1, 0],
        [4, 5, 7, 0, 0]])

In [51]:
batched_dt     = dt_h.transpose(0,1)
#batched_dir_dt = dt_h.transpose(0,1)
batched_dir_dt = dir_dt_h.transpose(0,1)

weights = torch.bmm( dt_h, dir_dt_h.transpose( 2, 1 ) )
weights = mask(weights)
attn_mask = weights != 0

pad_mask = bx_packed.logical_not()
pad_mask = pad_mask.view(*bx_packed.shape, 1)
pad_mask = pad_mask.logical_and(pad_mask.transpose(2,1))

attn_mask_old = attn_mask
attn_mask = attn_mask.logical_and( pad_mask ).logical_not()

In [52]:
bx_packed.logical_not(), attn_mask_old, pad_mask, attn_mask.logical_not()

(tensor([[ True,  True,  True,  True,  True],
         [ True,  True,  True,  True, False],
         [ True,  True,  True, False, False]]),
 tensor([[[ True, False, False, False,  True],
          [ True, False, False,  True,  True],
          [ True,  True,  True,  True,  True],
          [ True, False,  True,  True,  True],
          [ True, False,  True, False,  True]],
 
         [[ True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True],
          [False, False, False,  True, False],
          [False, False, False,  True, False],
          [ True, False,  True,  True,  True]],
 
         [[ True,  True,  True,  True,  True],
          [False,  True, False,  True,  True],
          [ True,  True, False,  True,  True],
          [False,  True,  True,  True,  True],
          [False,  True,  True,  True,  True]]]),
 tensor([[[ True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True],
          [ True,  True,  True,  True,  True],
     

## Matriz de pesos de co-ocorrências
$(N, L, S)$

In [53]:
weights_disc = (weights * pad_mask)
weights_disc = weights_disc.sum(axis=1)
weights_disc = F.softmax(weights_disc, dim=1)
weights, weights_disc

(tensor([[[0.9142, 0.0000, 0.0000, 0.0000, 0.9766],
          [0.7897, 0.0000, 0.0000, 0.5455, 0.7037],
          [0.9130, 0.5700, 0.8687, 0.7904, 0.7713],
          [0.9163, 0.0000, 0.9870, 0.9730, 0.7838],
          [0.8055, 0.0000, 0.5633, 0.0000, 0.9841]],
 
         [[0.8687, 0.7904, 0.9405, 0.9130, 0.7198],
          [0.9870, 0.9730, 0.9108, 0.9163, 0.7433],
          [0.0000, 0.0000, 0.0000, 0.9840, 0.0000],
          [0.0000, 0.0000, 0.0000, 0.9142, 0.0000],
          [0.5805, 0.0000, 0.9363, 0.9426, 0.5807]],
 
         [[0.9730, 0.7838, 0.6593, 0.7433, 0.7433],
          [0.0000, 0.9841, 0.0000, 0.6082, 0.6082],
          [0.8349, 0.5034, 0.0000, 0.6175, 0.6175],
          [0.0000, 0.8724, 0.5710, 0.5807, 0.5807],
          [0.0000, 0.8724, 0.5710, 0.5807, 0.5807]]], grad_fn=<SigmoidBackward>),
 tensor([[0.4569, 0.0105, 0.0670, 0.0600, 0.4055],
         [0.1046, 0.0953, 0.1041, 0.6797, 0.0163],
         [0.3092, 0.4914, 0.0980, 0.0507, 0.0507]], grad_fn=<SoftmaxBackward>))

tensor([[1.6162, 0.6537, 0.0000, 3.7412, 1.2287],
        [0.0000, 3.2439, 0.6063, 1.6162, 0.0000],
        [1.2826, 0.5478, 1.4665, 0.0000, 0.0000]], grad_fn=<SumBackward1>)

In [20]:
attn_mask.shape

torch.Size([3, 5, 5])

key_padding_mask: $(N,S)$ where $N$ is the batch size, $S$ is the source sequence length. If a $BoolTensor$ is provided, the positions with the value of $True$ will be ignored while the position with the value of $False$ will be unchanged.

attn_mask: 3D mask $(N*num\_heads, L, S)$ where $N$ is the batch size, $L$ is the target sequence length, $S$ is the source sequence length. If a $BoolTensor$ is provided, positions with $True$ is not allowed to attend while $False$ values will be unchanged.

In [54]:
bx_packed.shape, attn_mask.shape

(torch.Size([3, 5]), torch.Size([3, 5, 5]))

attn_output: $(L,N,E)$ where $L$ is the target sequence length, $N$ is the batch size, $E$ is the embedding dimension.

attn_output_weights: $(N,L,S)$ where $N$ is the batch size, $L$ is the target sequence length, $S$ is the source sequence length.

In [55]:
docs_att, att_weights = ma_term( batched_dt, batched_dir_dt, batched_dt,
                                  key_padding_mask=bx_packed, 
                                  attn_mask=attn_mask )
docs_att.shape, att_weights.shape

(torch.Size([5, 3, 10]), torch.Size([3, 5, 5]))

In [56]:
docs_att, att_weights = ma_term( batched_dt, batched_dir_dt, batched_dt,
                                  key_padding_mask=bx_packed, 
                                  attn_mask=attn_mask )
docs_att.shape, att_weights.shape
docs_att = docs_att.transpose(1,0)
docs_att.shape

torch.Size([3, 5, 10])

In [57]:
docs_att = torch.where(torch.isnan(docs_att), torch.zeros_like(docs_att), docs_att)
(docs_att*1000).round()

tensor([[[-727.,   52.,  166.,  421.,  472.,   61.,  200., -246.,  205.,  -34.],
         [-483.,   95.,  128.,  131.,   38., -163.,  188., -161.,   56., -112.],
         [-443.,  123.,  107.,   55.,   33., -131.,  176.,  -90.,   48., -158.],
         [-518.,   25.,  181.,  303.,  143.,  -34.,   66., -195.,   54., -161.],
         [-613.,   65.,  152.,  406.,  331.,   55.,  125., -247.,  147., -125.]],

        [[-322.,  -74.,  224.,  120.,  -81.,  -94.,  -93.,  -36.,  -54., -138.],
         [-412., -159.,  292.,  248.,   57.,   -8., -160.,  -47.,  -45., -122.],
         [-641., -287.,  426.,  499.,  258.,   72., -213., -135.,  -60., -151.],
         [-641., -287.,  426.,  499.,  258.,   72., -213., -135.,  -60., -151.],
         [   0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.,    0.]],

        [[-484., -131.,  244., -106.,  -81., -196.,  443.,  140.,  175., -265.],
         [-789.,  296.,  -20.,  365.,  625.,   53.,  496., -326.,  394.,   50.],
         [-272.,  406., 

In [58]:
att_weights = torch.where(torch.isnan(att_weights), torch.zeros_like(att_weights), att_weights)
weigths_att = F.softmax(att_weights.sum(axis=1), dim=1)

weigths_att  = weigths_att.view( *weigths_att.shape, 1 )
weights_disc = weights_disc.view( *weights_disc.shape, 1 )

docs_att.shape, weigths_att.shape, weights_disc.shape

(torch.Size([3, 5, 10]), torch.Size([3, 5, 1]), torch.Size([3, 5, 1]))

In [59]:
att_weights = torch.where(torch.isnan(att_weights), torch.zeros_like(att_weights), att_weights)
weigths_att = F.softmax(att_weights.sum(axis=1), dim=1)

weigths_att  = weigths_att.view( *weigths_att.shape, 1 )
weights_disc = weights_disc.view( *weights_disc.shape, 1 )
weigths = weigths_att + weights_disc
weigths.shape

torch.Size([3, 5, 1])

In [61]:
att_weights = torch.where(torch.isnan(att_weights), torch.zeros_like(att_weights), att_weights)
weigths_att = F.softmax(att_weights.sum(axis=1), dim=1)

weigths_att  = weigths_att.view( *weigths_att.shape, 1 )
weights_disc = weights_disc.view( *weights_disc.shape, 1 )
weigths = weigths_att + weights_disc

docs_h =  docs_att * weigths
docs_h =  docs_h.sum(axis=1)
docs_h /= bx_packed.logical_not().sum(dim=1).view(batch_size, 1)
docs_h = torch.where(torch.isnan(docs_h), torch.zeros_like(docs_h), docs_h)
(docs_h*100).round()

tensor([[-25.,   3.,   6.,  15.,  14.,   1.,   6.,  -9.,   6.,  -4.],
        [-29., -12.,  19.,  22.,  10.,   2., -10.,  -6.,  -3.,  -7.],
        [-37.,  11.,   2.,   9.,  18.,  -5.,  28., -10.,  17.,  -3.]],
       grad_fn=<RoundBackward>)

In [62]:
bx_packed.logical_not().sum(dim=1).view(batch_size, 1)

tensor([[5],
        [4],
        [3]])

In [63]:
weigths_att, weights_disc

(tensor([[[0.3931],
          [0.0657],
          [0.1015],
          [0.0928],
          [0.3469]],
 
         [[0.0680],
          [0.0607],
          [0.0747],
          [0.7513],
          [0.0454]],
 
         [[0.1625],
          [0.5375],
          [0.1332],
          [0.0834],
          [0.0834]]], grad_fn=<ViewBackward>),
 tensor([[[0.4569],
          [0.0105],
          [0.0670],
          [0.0600],
          [0.4055]],
 
         [[0.1046],
          [0.0953],
          [0.1041],
          [0.6797],
          [0.0163]],
 
         [[0.3092],
          [0.4914],
          [0.0980],
          [0.0507],
          [0.0507]]], grad_fn=<ViewBackward>))