### Single-Head Attention

Given $n$ queries $Q \in \mathbf{R}^{n \times d}$, $m$ keys and values $K \in \mathbf{R}^{m \times d}, V \in \mathbf{R}^{m \times v}$, apply attention pooling to get a "head":

$$H = 
\begin{bmatrix} \hat{\mathbf{v}}_1 \\ \hat{\mathbf{v}}_2 \\ \dots \\ \hat{\mathbf{v}}_n \end{bmatrix}
= Softmax(\frac{QK^T}{\sqrt{d}}) V \in \mathbf{R}^{n \times v}$$

Basically, attention pooling constructs a combined value, based on each value's corresponding key's "similarity" with the query.

Currently there are no parameters we can tune. We always measure "similarity" with a fixed criterion -- dot product of the raw query and keys.

In other words, our attention has a fixed "perspective", or "behavior".

### Multi-Head Attention

To adjust how we measure similarity, one possible option is adding a learnable transformation before measuring similarity and constructing value:

$$Q = QW_{Q} \in \mathbf{R}^{n \times d}, K = KW_{K} \in \mathbf{R}^{m \times d}, V = VW_{V} \in \mathbf{R}^{m \times v}$$

(note that we assume transformations don't change dimensions $d$ and $v$, when in practice they often do)

Similarity then becomes $QW_{Q}W_{K}^TK^T$, not just dot product of raw query and keys anymore -- we are customizing our attention's "perspective".

Let $f$ denote the attention pooling operation:

$$f(Q, K, V) = Softmax(\frac{QK^T}{\sqrt{d}}) V$$

Then our customized attention pooling can be written as:

$$H = f(QW_{Q}, KW_{K}, VW_{V}) = Softmax(\frac{QW_{Q}W_{K}^TK^T}{\sqrt{d}}) VW_{V}$$

To combine attention from different "perspectives", get $h$ "heads" from $Q, K, V$ with $h$ different transformations $W_{Q}^{(i)}, W_{K}^{(i)}, W_{V}^{(i)}$:

$$H^{(i)} = f^{(i)}(QW_{Q}^{(i)}, KW_{K}^{(i)}, VW_{V}^{(i)}), \text{ } i = 1, 2, ..., h$$

Then, concatnate them along axis 1:

$$\mathbf{H} =
\begin{bmatrix} H^{(1)} & H^{(2)} & \dots & H^{(h)} \end{bmatrix}
=
\begin{bmatrix}
\hat{\mathbf{v}}_1^{(1)} & \hat{\mathbf{v}}_1^{(2)} & \dots & \hat{\mathbf{v}}_1^{(h)} \\
\hat{\mathbf{v}}_2^{(1)} & \hat{\mathbf{v}}_2^{(2)} & \dots & \hat{\mathbf{v}}_2^{(h)} \\
& & \dots & \\
\hat{\mathbf{v}}_n^{(1)} & \hat{\mathbf{v}}_n^{(2)} & \dots & \hat{\mathbf{v}}_n^{(h)}
\end{bmatrix}
\in \mathbf{R}^{n \times hv}
$$

Note that $\mathbf{H}$ is not 3-dimensional. It's still 2-dimensional, just wider.

A column in $\mathbf{H}$ contains all values in the batch generated from the same "perspective".

A row in $\mathbf{H}$ contains how different "perspectives" produced different values from the same query.

Finally, we apply a transformation to merge the results from different "perspectives":

$$\text{result}= \mathbf{H}W \in \mathbf{R}^{n \times v}$$

(again, this final transformation may produce some dimension other than $v$)

In practice, we didn't need to use different transformations $W_{Q}^{(i)}, W_{K}^{(i)}, W_{V}^{(i)}$, calculate heads $H^{(i)}$ one by one, and then concatnate them.

Instead, we could have combined transformations $W_{Q}^{(i)}, W_{K}^{(i)}, W_{V}^{(i)}$ from the beginning:

$$\mathbf{W}_Q = \begin{bmatrix} W_{Q}^{(1)} & W_{Q}^{(2)} & \dots & W_{Q}^{(h)} \end{bmatrix} \in \mathbf{R}^{d \times hd}$$
$$\mathbf{W}_K = \begin{bmatrix} W_{K}^{(1)} & W_{K}^{(2)} & \dots & W_{K}^{(h)} \end{bmatrix} \in \mathbf{R}^{d \times hd}$$
$$\mathbf{W}_V = \begin{bmatrix} W_{V}^{(1)} & W_{V}^{(2)} & \dots & W_{V}^{(h)} \end{bmatrix} \in \mathbf{R}^{v \times hv}$$

Apply attention pooling and get the combined heads:

$$Q\mathbf{W}_{Q} \in \mathbf{R}^{n \times hd}, K\mathbf{W}_{K} \in \mathbf{R}^{m \times hd}, V\mathbf{W}_{V} \in \mathbf{R}^{m \times hv}$$
$$\mathbf{H} = f(Q\mathbf{W}_{Q}, K\mathbf{W}_{K}, V\mathbf{W}_{V}) \in \mathbf{R}^{n \times hv}$$

### Self-Attention

In seq2seq jobs, we are given a sequence of embedded tokens $X$ as input, and a sequence of token $Y$ is expected as output.

Let the sequence length be $n$, and the vocabulary for embedding has size $v$, then:

$$X = \begin{bmatrix} \mathbf{x}_1 \\ \mathbf{x}_2 \\ \dots \\ \mathbf{x}_n \end{bmatrix} \in \mathbf{R}^{n \times v}$$

Self-attention mechanism makes the entire input sequence the "database", with each token mapping to itself, acting as both key and value:

$$\text{Database} = \{ (\mathbf{x}_1, \mathbf{x}_1), (\mathbf{x}_2, \mathbf{x}_2), \dots (\mathbf{x}_n, \mathbf{x}_n) \}$$

Which is just having $K = V = X \in \mathbf{R}^{n \times v}$.

With this database, self-attention produces output token by having the corresponding input token act as query:

$$\mathbf{y}_i = softmax(\frac{\mathbf{q} K^T}{\sqrt{d}}) V = softmax(\frac{\mathbf{x}_i X^T}{\sqrt{d}}) X \in \mathbf{R}^v$$

Entire sequence wriiten in matrix form:

$$Y = Softmax(\frac{XX^T}{\sqrt{d}}) X = f(X, X, X) \in \mathbf{R}^{n \times v}$$



### Positional Encoding

Unlike RNN, self-attention treats attends simultaneously to all tokens, and does not explicitly take into account a token's position in sequence.

If a token's position truly matters (which is often the case in seq2seq jobs), additional encoding will be needed.

For an input sequence $X \in \mathbf{R}^{n \times v}$, add positional embedding matrix $P \in \mathbf{R}^{n \times v}$ to it, whose element satisfies:

$$p_{i, 2j} = \sin(i\omega_j), \text{    }p_{i, 2j+1} = \cos(i\omega_j), \text{    where  }\omega_j = \frac{1}{10000^{2j/v}}$$

Note that in $X+P$, a row contains the embedded representation (length $v$) of a token, while columns are the embedding dimensions.

#### Absolute Positional Information

When $j$ increases, the trigonometric function's frequencies decreases, so higher embedding dimensions shift less frequently.

Think of the embedding dimensions as digits of an increasing number. Higher digits changes less frequently than lower digits.

In this way, each embedding dimension gets assigned a fixed frequency, so its positional information is "absolute".

#### Relative Positional Information

When considering a token's position in a sequence, though, we are more interested in its relative position to other tokens.

In the 2 sentences

```"I love you."```

```"I really love you."```

the phrase "love you" appears in different absolute positions, but their relative position stays the same, expressing the same meaning.

Most human languages preserve meaning through relative positions of words. The absolute position really doesn't matter.

The key here, is that the way each token treats its neighboring tokens should remain the same across all tokens.

In ```"12345"```, ```2``` should treat ```1``` in just the same way as ```4``` treats ```3```, and ```1``` should treat ```3``` in just the same way as ```3``` treats ```5```.

$P$ achieves this by satisfiying the following projection property:

For any $p_{i, 2j}$ and any $\delta$, $A \begin{bmatrix} p_{i, 2j} \\ p_{i, 2j+1} \end{bmatrix} = \begin{bmatrix} p_{i+\delta, 2j} \\ p_{i+\delta, 2j+1} \end{bmatrix}$, where $A = \begin{bmatrix} \cos(\delta\omega_j) & \sin(\delta\omega_j) \\ -\sin(\delta\omega_j) & \cos(\delta\omega_j) \end{bmatrix}$ is independent of $i$.

This tells us that any token can be (in the same way) projected to another token, as long as the two tokens' relative position stays the same.

The projection from ```2``` to ```1```, is the same from ```4``` to ```3```!