In [None]:
import math
import torch
from torch import nn

### Dot Product Attention

The setting in NW Regression has queries, keys and values of dimension one. More generally (assume $\mathbf{q}$, $\mathbf{k}$ are of same dimension):

$$\mathbf{q}, \mathbf{k} \in \mathbf{R}^d, \mathbf{v} \in \mathbf{R}^v$$

When scaled dot product is used as attention algorithm, we have scalar attention weight (before normalization):

$$a(\mathbf{q}, \mathbf{k}) = \frac{\mathbf{q} \cdot \mathbf{k}}{\sqrt{d}} = \frac{\mathbf{q} \mathbf{k}^T}{\sqrt{d}}$$

To get the actual scalar attention weight, we have to normalize the above weights on all keys $\mathbf{k}_i$:

$$\alpha(\mathbf{q}, \mathbf{k}_i) = softmax(a(\mathbf{q}, \mathbf{k}_i)) = softmax(\frac{\mathbf{q} \mathbf{k}_i^T}{\sqrt{d}}) = \frac{\frac{\mathbf{q} \mathbf{k}_i^T}{\sqrt{d}}}{\sum \frac{\mathbf{q} \mathbf{k}_i^T}{\sqrt{d}}}$$


For a database with $m$ keys (and values), we can place each key (and value) in a row vector, stacking up into a matrix:

$$K = \begin{bmatrix} \mathbf{k}_1 \\ \mathbf{k}_2 \\ \dots \\ \mathbf{k}_m \end{bmatrix} \in \mathbf{R}^{m \times d}, V = \begin{bmatrix} \mathbf{v}_1 \\ \mathbf{v}_2 \\ \dots \\ \mathbf{v}_m \end{bmatrix} \in \mathbf{R}^{m \times v}$$

Now, to perform a single query (represented by row vector $\mathbf{q}$), we first calculate the attention weights on all keys (with normalization):

$$\mathbf{w} = \begin{bmatrix}
w_1 & w_2 & \dots & w_m
\end{bmatrix} = softmax( \begin{bmatrix}
\frac{\mathbf{q} \mathbf{k}_1^T}{\sqrt{d}} & \frac{\mathbf{q} \mathbf{k}_2^T}{\sqrt{d}} & \dots & \frac{\mathbf{q} \mathbf{k}_m^T}{\sqrt{d}}
\end{bmatrix}) = softmax(\frac{\mathbf{q} K^T}{\sqrt{d}})$$

To get the output, we take the weighted sum of all values $\mathbf{v}_j$, which is just:

$$\hat{\mathbf{v}} = w_1 \mathbf{v}_1 + w_2 \mathbf{v}_2 + \dots + w_m \mathbf{v}_m = \mathbf{w} V = softmax(\frac{\mathbf{q} K^T}{\sqrt{d}}) V$$

When trying to perform $n$ queries at the same time, we have a matrix of queries:

$$Q = \begin{bmatrix} \mathbf{q}_1 \\ \mathbf{q}_2 \\ \dots \\ \mathbf{q}_n \end{bmatrix} \in \mathbf{R}^{n \times d}$$

For each $\mathbf{q}_i$, we compute the same weight vector as above:

$$\mathbf{w}_i = \begin{bmatrix}
w_{i1} & w_{i2} & \dots & w_{im}
\end{bmatrix} = softmax( \begin{bmatrix}
\frac{\mathbf{q}_i \mathbf{k}_1^T}{\sqrt{d}} & \frac{\mathbf{q}_i \mathbf{k}_2^T}{\sqrt{d}} & \dots & \frac{\mathbf{q}_i \mathbf{k}_m^T}{\sqrt{d}}
\end{bmatrix}) = softmax(\frac{\mathbf{q}_i K^T}{\sqrt{d}})$$


Then the weight matrix for these queries is:

$$W = \begin{bmatrix}
\mathbf{w}_1 \\
\mathbf{w}_2 \\
\dots \\
\mathbf{w}_n
\end{bmatrix} = Softmax(\frac{QK^T}{\sqrt{d}}) \in \mathbf{R}^{n \times m}$$

where `Softmax` applies softmax to each row in the matrix.

Now we have the complete prediction for the queries:

$$\hat{V} =
\begin{bmatrix}
\hat{\mathbf{v}}_1 \\
\hat{\mathbf{v}}_2 \\
\dots \\
\hat{\mathbf{v}}_n
\end{bmatrix}
=
\begin{bmatrix}
\mathbf{w}_1 V \\
\mathbf{w}_2 V \\
\dots \\
\mathbf{w}_n V
\end{bmatrix}
=
\begin{bmatrix}
w_{11} \mathbf{v}_1 + w_{12} \mathbf{v}_2 + \dots + w_{1m} \mathbf{v}_m \\
w_{21} \mathbf{v}_1 + w_{22} \mathbf{v}_2 + \dots + w_{2m} \mathbf{v}_m \\
\dots \\
w_{n1} \mathbf{v}_1 + w_{n2} \mathbf{v}_2 + \dots + w_{nm} \mathbf{v}_m \\
\end{bmatrix}
= WV = Softmax(\frac{QK^T}{\sqrt{d}})V \in \mathbf{R}^{n \times v}$$

### Masked Softmax

In some tasks, different queries might operate on different databases of different sizes, having a different $m$ for each $n$, such as:

$$\hat{V} =
\begin{bmatrix}
w_{n1} \mathbf{v}_1 + w_{n2} \mathbf{v}_2\\
w_{21} \mathbf{v}_1 + w_{22} \mathbf{v}_2 + w_{23} \mathbf{v}_3 + w_{24} \mathbf{v}_4 \\
w_{11} \mathbf{v}_1 + w_{12} \mathbf{v}_2 + w_{13} \mathbf{v}_3
\end{bmatrix}$$

In this case, place-holders are padded to make calculation consistent:

$$\hat{V} =
\begin{bmatrix}
w_{n1} \mathbf{v}_1 + w_{n2} \mathbf{v}_2 + \tilde{w}_{23} \tilde{\mathbf{v}}_3 + \tilde{w}_{24} \tilde{\mathbf{v}}_4 \\
w_{21} \mathbf{v}_1 + w_{22} \mathbf{v}_2 + w_{23} \mathbf{v}_3 + w_{24} \mathbf{v}_4 \\
w_{11} \mathbf{v}_1 + w_{12} \mathbf{v}_2 + w_{23} \mathbf{v}_3 + \tilde{w}_{24} \tilde{\mathbf{v}}_4
\end{bmatrix}$$

To ensure place-holders don't contribute to output or gradient, we need to have the padded weights output zero:

$$w_{23}= \alpha(\mathbf{q}_2, \mathbf{k}_3) = softmax(a(\mathbf{q}_2, \mathbf{k}_3)) = 0$$

which can be done by having $a(\mathbf{q}_2, \mathbf{k}_3)$ equal to a very negative number, such as $10^{-6}$.

`Masked_Softmax` takes in a second argument specifying the valid length $m$ for each query, and masks $10^{-6}$ before computing softmax:

In [None]:
def masked_softmax(X, valid_lens):
    """Perform softmax operation by masking elements on the last axis."""
    # X: 3D tensor, valid_lens: 1D or 2D tensor
    def _sequence_mask(X, valid_len, value=0):
        maxlen = X.size(1)
        mask = torch.arange((maxlen), dtype=torch.float32,
                            device=X.device)[None, :] < valid_len[:, None]
        X[~mask] = value
        return X

    if valid_lens is None:
        return nn.functional.softmax(X, dim=-1)
    else:
        shape = X.shape
        if valid_lens.dim() == 1:
            valid_lens = torch.repeat_interleave(valid_lens, shape[1])
        else:
            valid_lens = valid_lens.reshape(-1)
        # On the last axis, replace masked elements with a very large negative
        # value, whose exponentiation outputs 0
        X = _sequence_mask(X.reshape(-1, shape[-1]), valid_lens, value=-1e6)
        return nn.functional.softmax(X.reshape(shape), dim=-1)