In [1]:
!pip install -U git+https://github.com/d2l-ai/d2l-en.git@master

Collecting git+https://github.com/d2l-ai/d2l-en.git@master
  Cloning https://github.com/d2l-ai/d2l-en.git (to revision master) to /tmp/pip-req-build-g32xovoe
  Running command git clone -q https://github.com/d2l-ai/d2l-en.git /tmp/pip-req-build-g32xovoe
Building wheels for collected packages: d2l
  Building wheel for d2l (setup.py) ... [?25l[?25hdone
  Created wheel for d2l: filename=d2l-0.15.1-cp36-none-any.whl size=63507 sha256=2278423ef13dc83aba84c88058bff8bcb60a6e793cc16c0439caa7bb8db56bd0
  Stored in directory: /tmp/pip-ephem-wheel-cache-qdusrjfa/wheels/0f/41/8f/72ece70ede8a0e37eec72c03087eb4604925ba212b804f8cad
Successfully built d2l
Installing collected packages: d2l
Successfully installed d2l-0.15.1


In [2]:
from d2l import torch as d2l
import math
import torch
from torch import nn

In [3]:
#@save
def masked_softmax(X, valid_len):
    """Perform softmax by filtering out some elements."""
    # X: 3-D tensor, valid_len: 1-D or 2-D tensor
    if valid_len is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_len.dim() == 1:
            valid_len = torch.repeat_interleave(valid_len, repeats=shape[1],
                                                dim=0)
        else:
            valid_len = valid_len.reshape(-1)
        # Fill masked elements with a large negative, whose exp is 0
        X = d2l.sequence_mask(X.reshape(-1, shape[-1]), valid_len, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)

In [4]:
masked_softmax(torch.rand(2, 2, 4), torch.tensor([2, 3]))

tensor([[[0.6182, 0.3818, 0.0000, 0.0000],
         [0.4314, 0.5686, 0.0000, 0.0000]],

        [[0.3172, 0.4922, 0.1906, 0.0000],
         [0.2735, 0.4507, 0.2758, 0.0000]]])

In [5]:
torch.bmm(torch.ones(2,1,3), torch.ones(2,3,2))

tensor([[[3., 3.]],

        [[3., 3.]]])

In [6]:
#@save
class DotProductAttention(nn.Module):
    def __init__(self, dropout, **kwargs):
        super(DotProductAttention, self).__init__(**kwargs)
        self.dropout = nn.Dropout(dropout)

    # `query`: (`batch_size`, #queries, `d`)
    # `key`: (`batch_size`, #kv_pairs, `d`)
    # `value`: (`batch_size`, #kv_pairs, `dim_v`)
    # `valid_len`: either (`batch_size`, ) or (`batch_size`, xx)
    def forward(self, query, key, value, valid_len=None):
        d = query.shape[-1]
        # Set transpose_b=True to swap the last two dimensions of key
        scores = torch.bmm(query, key.transpose(1,2)) / math.sqrt(d)
        attention_weights = self.dropout(masked_softmax(scores, valid_len))
        return torch.bmm(attention_weights, value)

In [7]:
atten = DotProductAttention(dropout=0.5)
atten.eval()
keys = torch.ones(2,10,2)
values = torch.arange(40, dtype=torch.float32).reshape(1,10,4).repeat(2,1,1)
atten(torch.ones(2,1,2), keys, values, torch.tensor([2, 6]))

tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]])

In [8]:
#@save
class MLPAttention(nn.Module):
    def __init__(self, key_size, query_size, units, dropout, **kwargs):
        super(MLPAttention, self).__init__(**kwargs)
        self.W_k = nn.Linear(key_size, units, bias=False)
        self.W_q = nn.Linear(query_size, units, bias=False)
        self.v = nn.Linear(units, 1, bias=False)
        self.dropout = nn.Dropout(dropout)

    def forward(self, query, key, value, valid_len):
        query, key = self.W_q(query), self.W_k(key)
        # Expand query to (`batch_size`, #queries, 1, units), and key to
        # (`batch_size`, 1, #kv_pairs, units). Then plus them with broadcast
        features = query.unsqueeze(2) + key.unsqueeze(1)
        features = torch.tanh(features)
        scores = self.v(features).squeeze(-1)
        attention_weights = self.dropout(masked_softmax(scores, valid_len))
        return torch.bmm(attention_weights, value)

In [9]:
atten = MLPAttention(key_size=2, query_size=2, units=8, dropout=0.1)
atten.eval()
atten(torch.ones(2, 1, 2), keys, values, torch.tensor([2, 6]))

tensor([[[ 2.0000,  3.0000,  4.0000,  5.0000]],

        [[10.0000, 11.0000, 12.0000, 13.0000]]], grad_fn=<BmmBackward0>)