In [61]:
from d2l import torch as d2l
import math
import torch
from torch import nn
x = torch.arange(16, dtype = torch.float32).reshape(2,2,4)

In [57]:
# Some text sequences are padded with special tokens that do not carry meaning. To get an attention pooling over only 
# meaningful tokens as values, we can specify a valid sequence length (in number of tokens) to filter out those beyond this 
# specifiedrange when computing softmax.

def masked_softmax(X, valid_lens):
    """Perform softmax operation by masking elements on the last axis."""
    # `X`: 3D tensor, `valid_lens`: 1D or 2D tensor
    if valid_lens is None:
        # Sum along every slice of width will be 1
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        print(shape[1])
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
            print(valid_lens)
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_lens,value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

In [62]:
print(x)
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))

tensor([[[ 0.,  1.,  2.,  3.],
         [ 4.,  5.,  6.,  7.]],

        [[ 8.,  9., 10., 11.],
         [12., 13., 14., 15.]]])
2
tensor([2, 2, 3, 3])


tensor([[[0.3755, 0.6245, 0.0000, 0.0000],
         [0.6540, 0.3460, 0.0000, 0.0000]],

        [[0.3787, 0.3050, 0.3163, 0.0000],
         [0.2806, 0.4673, 0.2522, 0.0000]]])

In [86]:
# Additive Attention
# When queries and keys are vectors of different lengths, we can use additive attention as the scoring function. The query and 
# the key are concatenated and fed into an MLP with a single hidden layer whose number of hidden units is h, a hyperparameter. 
# By using tanh as the activation function and disabling bias terms.

class AdditiveAttention(nn.Module):
    """Additive attention."""
    def __init__(self, key_size, query_size, num_hiddens, dropout, **kwargs):
        super(AdditiveAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, num_hiddens, bias=False)
        self.W_q = nn.Linear(query_size, num_hiddens, bias=False)
        self.w_v = nn.Linear(num_hiddens, 1, bias=False)
        self.dropout = nn.Dropout(dropout)
    def forward(self, queries, keys, values, valid_lens):
        queries, keys = self.W_q(queries), self.W_k(keys)
#         print(queries.shape)
#         print(keys.shape)
#         print('Queries unsqueeze:', queries.unsqueeze(2).shape)
#         print('Keys unsqueeze:', keys.unsqueeze(1).shape)
        features = queries.unsqueeze(2) + keys.unsqueeze(1)
        print(features.shape)
        features = torch.tanh(features)
        # There is only one output of `self.w_v`, so we remove the last
        # one-dimensional entry from the shape. Shape of `scores`:
        # (`batch_size`, no. of queries, no. of key-value pairs)
        scores = self.w_v(features).squeeze(-1)
        print(self.w_v(features).shape)
        print(scores.shape)
        self.attention_weights = masked_softmax(scores, valid_lens)
        # Shape of `values`: (`batch_size`, no. of key-value pairs, value
        # dimension)
        return torch.bmm(self.dropout(self.attention_weights), values)
    
queries, keys = torch.normal(0, 1, (2, 1, 20)), torch.ones((2, 10, 2))
# print(queries)
# print(keys)
# The two value matrices in the `values` minibatch are identical
values = torch.arange(40, dtype=torch.float32).reshape(1, 10, 4).repeat(2, 1, 1)
# print(values)
valid_lens = torch.tensor([2, 6])
attention = AdditiveAttention(key_size=2, query_size=20, num_hiddens=8,dropout=0.1)
# print(attention.eval())
attention(queries, keys, values, valid_lens)

torch.Size([2, 1, 10, 8])
torch.Size([2, 1, 10, 1])
torch.Size([2, 1, 10])
1
tensor([2, 6])


tensor([[[ 2.2222,  3.3333,  4.4444,  5.5556]],

        [[11.1111, 12.2222, 13.3333, 14.4444]]], grad_fn=<BmmBackward0>)

In [90]:
# Scaled Dot Product Attention
#  When queries and keys are vectors of different lengths, we can use the additive attention scoring function. When they are 
# the same, the scaled dot-product attention scoring function is more computationally efficient.

class DotProductAttention(nn.Module):
    """Scaled dot product attention."""
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)
        # Shape of `queries`: (`batch_size`, no. of queries, `d`)
        # Shape of `keys`: (`batch_size`, no. of key-value pairs, `d`)
        # Shape of `values`: (`batch_size`, no. of key-value pairs, value
        # dimension)
        # Shape of `valid_lens`: (`batch_size`,) or (`batch_size`, no. of queries)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # Set `transpose_b=True` to swap the last two dimensions of `keys`
        scores = torch.bmm(queries, keys.transpose(1,2)) / math.sqrt(d)
        print('Scores: ',scores.shape)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)
queries = torch.normal(0, 1, (2, 1, 2))
attention = DotProductAttention(dropout=0.5)
attention.eval()
print(queries.shape)
print(keys.shape)
attention(queries, keys, values, valid_lens)
print

torch.Size([2, 1, 2])
torch.Size([2, 10, 2])
Scores:  torch.Size([2, 1, 10])
1
tensor([2, 6])


<function print>