In [None]:
import torch
from torch import nn

In [None]:
def masked_softmax(X, valid_lens):  # @save
    """Perform softmax operation by masking elements on the last axis."""

    # X: 3D tensor, valid_lens: 1D or 2D tensor
    def _sequence_mask(X, valid_len, value=0):  # X: (4, 4), valid_len: [2, 2, 3, 3]
        maxlen = X.size(1)  # sequence max length
        mask = (
            torch.arange((maxlen), dtype=torch.float32, device=X.device)[None, :]
            < valid_len[:, None]
        )  # mask: (4, 4)
        X[~mask] = value
        return X

    if valid_lens is None:
        return nn.functional.softmax(X, dim=0)
    else:
        shape = X.shape  # (2, 2, 4)
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])  # [2, 2, 3, 3]
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = _sequence_mask(
            X.reshape(-1, shape[-1]), valid_lens, value=-1e6
        )  # X: (4, 4)
        return nn.functional.softmax(X.reshape(shape), dim=-1)  # X: (2, 2, 4)


scores = masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))
print(scores, scores.shape)

In [None]:
import sys, os

sys.path.append(os.path.expanduser("~"))
from d2l.pytorch.d2l.torch import show_heatmaps, check_shape


class DotProductAttention(nn.Module):  # @save
    """Scaled dot product attention."""

    def __init__(self, dropout):
        super().__init__()
        self.dropout = nn.Dropout(dropout)

    # Shape of queries: (batch_size, no. of queries, d)
    # Shape of keys: (batch_size, no. of key-value pairs, d)
    # Shape of values: (batch_size, no. of key-value pairs, value dimension)
    # Shape of valid_lens: (batch_size,) or (batch_size, no. of queries)
    def forward(self, queries, keys, values, valid_lens=None):
        d = queries.shape[-1]
        # Swap the last two dimensions of keys with keys.transpose(1, 2)
        scores = torch.bmm(queries, keys.transpose(1, 2)) / (d**0.5)
        self.attention_weights = masked_softmax(scores, valid_lens)
        return torch.bmm(self.dropout(self.attention_weights), values)


queries = torch.normal(0, 1, (2, 1, 2))
keys = torch.normal(0, 1, (2, 10, 2))
values = torch.normal(0, 1, (2, 10, 4))
valid_lens = torch.tensor([2, 6])

attention = DotProductAttention(dropout=0.5)
attention.eval()

check_shape(attention(queries, keys, values, valid_lens), (2, 1, 4))
show_heatmaps(
    attention.attention_weights.reshape((1, 1, 2, 10)), xlabel="Keys", ylabel="Queries"
)