In [None]:
import math
import torch
from torch import nn

### Dot Product Attention

The setting in NW Regression has queries, keys and values of dimension one.

More generally, we have multiple dimensional data, represented by row vectors (assume $\mathbf{q}$, $\mathbf{k}_i$ are of same dimension):

$$\mathbf{q}, \mathbf{k}_i \in \mathbf{R}^{1\times d}, \mathbf{v}_j \in \mathbf{R}^{1\times v}$$

When scaled dot product is used as attention algorithm, we have scalar attention weight (before normalization):

$$a(\mathbf{q}, \mathbf{k}_i) = \frac{\mathbf{q} \cdot \mathbf{k}_i}{\sqrt{d}} = \frac{\mathbf{q} \mathbf{k}^T_i}{\sqrt{d}}$$


For a database with $m$ keys (and values), we can place each key (and value) in a row vector, stacking up into a matrix:

$$K = \begin{bmatrix} \mathbf{k}_1 \\ \mathbf{k}_2 \\ \dots \\ \mathbf{k}_m \end{bmatrix} \in \mathbf{R}^{m \times d}, V = \begin{bmatrix} \mathbf{v}_1 \\ \mathbf{v}_2 \\ \dots \\ \mathbf{v}_m \end{bmatrix} \in \mathbf{R}^{m \times v}$$

Now, to perform a single query (represented by row vector $\mathbf{q}$), we first calculate the attention scores on all keys:

$$\mathbf{a} = \begin{bmatrix}
a_1 & a_2 & \dots & a_m
\end{bmatrix} = \begin{bmatrix}
\frac{\mathbf{q} \mathbf{k}_1^T}{\sqrt{d}} & \frac{\mathbf{q} \mathbf{k}_2^T}{\sqrt{d}} & \dots & \frac{\mathbf{q} \mathbf{k}_m^T}{\sqrt{d}}
\end{bmatrix} = \frac{\mathbf{q} K^T}{\sqrt{d}}$$

Then we apply softmax to get the attention weights:

$$\mathbf{w} = \begin{bmatrix}
w_1 & w_2 & \dots & w_m
\end{bmatrix} = softmax( \begin{bmatrix}
\frac{\mathbf{q} \mathbf{k}_1^T}{\sqrt{d}} & \frac{\mathbf{q} \mathbf{k}_2^T}{\sqrt{d}} & \dots & \frac{\mathbf{q} \mathbf{k}_m^T}{\sqrt{d}}
\end{bmatrix}) = softmax(\frac{\mathbf{q} K^T}{\sqrt{d}})$$

To get the output, we take the weighted sum of all values $\mathbf{v}_j$, which is just:

$$\hat{\mathbf{v}} = w_1 \mathbf{v}_1 + w_2 \mathbf{v}_2 + \dots + w_m \mathbf{v}_m = \mathbf{w} V = softmax(\frac{\mathbf{q} K^T}{\sqrt{d}}) V$$

When trying to perform $n$ queries at the same time, we have a matrix of queries:

$$Q = \begin{bmatrix} \mathbf{q}_1 \\ \mathbf{q}_2 \\ \dots \\ \mathbf{q}_n \end{bmatrix} \in \mathbf{R}^{n \times d}$$

For each $\mathbf{q}_i$, attention weights are the same as as above:

$$\mathbf{a}_i = \begin{bmatrix}
a_{i1} & a_{i2} & \dots & a_{im}
\end{bmatrix} = \begin{bmatrix}
\frac{\mathbf{q}_i \mathbf{k}_1^T}{\sqrt{d}} & \frac{\mathbf{q}_i \mathbf{k}_2^T}{\sqrt{d}} & \dots & \frac{\mathbf{q}_i \mathbf{k}_m^T}{\sqrt{d}}
\end{bmatrix} = \frac{\mathbf{q}_i K^T}{\sqrt{d}}$$

$$\mathbf{w}_i = \begin{bmatrix}
w_{i1} & w_{i2} & \dots & w_{im}
\end{bmatrix} = softmax( \begin{bmatrix}
\frac{\mathbf{q}_i \mathbf{k}_1^T}{\sqrt{d}} & \frac{\mathbf{q}_i \mathbf{k}_2^T}{\sqrt{d}} & \dots & \frac{\mathbf{q}_i \mathbf{k}_m^T}{\sqrt{d}}
\end{bmatrix}) = softmax(\frac{\mathbf{q}_i K^T}{\sqrt{d}})$$


Stacking them up, the weight matrix for these queries are actually calculated as:

$$A = \begin{bmatrix}
\mathbf{a}_1 \\
\mathbf{a}_2 \\
\dots \\
\mathbf{a}_n
\end{bmatrix} = \frac{QK^T}{\sqrt{d}} \in \mathbf{R}^{n \times m}$$

$$W = \begin{bmatrix}
\mathbf{w}_1 \\
\mathbf{w}_2 \\
\dots \\
\mathbf{w}_n
\end{bmatrix} = Softmax(\frac{QK^T}{\sqrt{d}}) \in \mathbf{R}^{n \times m}$$

where `Softmax` applies softmax to each row in the matrix.

Now we have the complete prediction for the queries:

$$\hat{V} =
\begin{bmatrix}
\hat{\mathbf{v}}_1 \\
\hat{\mathbf{v}}_2 \\
\dots \\
\hat{\mathbf{v}}_n
\end{bmatrix}
=
\begin{bmatrix}
\mathbf{w}_1 V \\
\mathbf{w}_2 V \\
\dots \\
\mathbf{w}_n V
\end{bmatrix}
=
\begin{bmatrix}
w_{11} \mathbf{v}_1 + w_{12} \mathbf{v}_2 + \dots + w_{1m} \mathbf{v}_m \\
w_{21} \mathbf{v}_1 + w_{22} \mathbf{v}_2 + \dots + w_{2m} \mathbf{v}_m \\
\dots \\
w_{n1} \mathbf{v}_1 + w_{n2} \mathbf{v}_2 + \dots + w_{nm} \mathbf{v}_m \\
\end{bmatrix}
= WV = Softmax(\frac{QK^T}{\sqrt{d}})V \in \mathbf{R}^{n \times v}$$

For convenience, we denote $f$ as the attention pooling operation from now on:

$$f(Q, K, V) = Softmax(\frac{QK^T}{\sqrt{d}}) V$$

### Batch Matrix Multiplication

When our task is to perform a batch (size $b$) of attention pooling, each containing $n$ queries, we perform multiple matmuls at the same time:

$$BMM(\mathbf{Q}_{\text{batch}} \in \mathbf{R}^{b \times n \times d}, \mathbf{K}^T_{\text{batch}} \in \mathbf{R}^{b \times d \times m}) \in \mathbf{R}^{b \times n \times m}$$

GPUs come in handy as they can compute multiple matmuls at the same time.

### Masked Softmax

In many tasks, different queries might operate on different databases of different sizes, since sentences might be of different lengths.

This means that for each query (total of n), database length $m$ is different, resulting in attention scores of different lengths, such as:

$$A =
\begin{bmatrix}
a_{11} & a_{12} \\
a_{21} & a_{22} & a_{23} \\
a_{31} 
\end{bmatrix}$$

Even for the same database, you might want different queries to access different entries in it.

Such as in Transformer's decoder, database is the target sentence, and query is the current token.

Each token should only see the generated tokens so far, meaning each query should only access database up to it:

$$A =
\begin{bmatrix}
a_{11} \\
a_{21} & a_{22} \\
a_{31} & a_{32} & a_{33}
\end{bmatrix}$$

In either case, place-holders with no actual meaning are padded to make calculation consistent:

$$A =
\begin{bmatrix}
a_{11} & a_{12} & a_{13} \\
a_{21} & a_{22} & a_{23} \\
a_{31} & a_{32} & a_{33}
\end{bmatrix}$$

To ensure place-holders don't contribute to output or gradient, we need to have the padded weights output zero after taking softmax:

$$w_{13}= \alpha(\mathbf{q}_1, \mathbf{k}_3) = softmax(a(\mathbf{q}_1, \mathbf{k}_3)) = 0$$

which can be done by having $a_{13}$ equal to `-inf`.

In other words, we want to mask an attention score matrix A so that it looks like:

$$A =
\begin{bmatrix}
a_{11} & a_{12} & -\inf \\
a_{21} & a_{22} & a_{23} \\
a_{31} & -\inf & -\inf
\end{bmatrix}$$

In practice, masked attention functions take in an argument `key_padding_mask` specifying the valid length $m$ for each query.

To achieve the above masking, `key_padding_mask` would be $\begin{bmatrix} 2 & 3 & 1 \end{bmatrix}$.

After original attention scores

$$A =
\begin{bmatrix}
a_{11} & a_{12} & a_{13} \\
a_{21} & a_{22} & a_{23} \\
a_{31} & a_{32} & a_{33}
\end{bmatrix}$$

are calculated, we create a masking matrix according to `key_padding_mask`

$$M =
\begin{bmatrix}
0 & 0 & -\inf \\
0 & 0 & 0 \\
0 & -\inf & -\inf
\end{bmatrix}$$

and add it to the original attention score matrix, resulting in the masked attention score matrix A

$$A =
\begin{bmatrix}
a_{11} & a_{12} & -\inf \\
a_{21} & a_{22} & a_{23} \\
a_{31} & -\inf & -\inf
\end{bmatrix}$$

After this, taking softmax will produce desired attention weights

$$W =
\begin{bmatrix}
w_{11} & w_{12} & 0 \\
w_{21} & w_{22} & w_{23} \\
w_{31} & 0 & 0
\end{bmatrix}$$

For batches of queries, `key_padding_mask` would be in 2D form [[],[],...], each list for a batch of queries.