I am not going to do full attention explanation, just sharing the notes that I think are important. Especially for me to remember for my tasks.


Attention works by creating query \(Q\), key \(K\), and value \(V\) matrices from inputs \(X\) via linear layers with learnable weights \(W_Q\), \(W_K\), and \(W_V\).


$$
Q = XW^Q \\
K = XW^K \\
V = XW^V
$$


![Self-Attention Matrix Calculation](https://benjaminwarner.dev/img/2022/tinkering-with-attention-pooling/self-attention-matrix-calculation-queries.png)

where $ W^Q \in \mathbb{R}^{d_{\text{model}} \times d_Q} $. Or less formally, $ Q = XW^Q $ is a set of linear equations:

$
Q = X A^Q + B^Q
$

where $ A^Q $ and $ B^Q $ are learnable parameters for calculating $ Q $ from $ X $.

Attention is then calculated by:

$
\text{Attention}(Q, K, V) = \text{softmax} \left( \frac{QK^T}{\sqrt{d_k}} \right) V
$

where $ \sqrt{d_k} $ is a scaling factor, usually based on the individual head dimension or number of heads.


![Self-Attention Calculation](https://benjaminwarner.dev/img/2022/tinkering-with-attention-pooling/self-attention-matrix-calculation-attention.webp)

The resulting $ \text{Attention}(Q, K, V) $ is usually passed through a linear layer $ W^O $ projection

$$
\text{Output} = \text{Attention}(Q, K, V) W^O
$$

as the final step of the Attention layer.

For all the math, Attention is simply a learned weighted average. Attention learns to generate weights between tokens via queries $ XW^Q $ and keys $ XW^K $. Those per-token weights are created by $ \text{softmax}(QK^T / \sqrt{d_k}) $. The values $ XW^V $ learn to create a token representation which can incorporate the weighted average of all the other tokens in the final dot product in the Attention layer $ \text{softmax}(\dots) V $. When someone says a token attends to a second token, this means it’s increasing the size of the second token’s weight in the weighted average relative to all the other tokens.


### Single headed attention initialization
Attention layers will allow disabling bias terms for linear layers since recent papers and models, such as Cramming, Pythia, and PaLM, have shown that disabling the bias term results in little-to-no downstream performance drop
In a well-trained NLP Transformer, such as Pythia, the bias term ends up being near or at zero, which is why we can disable them without causing performance issues.
while decreasing computational and memory requirements.

```python
class SingleHeadAttention(nn.Module):
    def __init__(self,
        hidden_size: int,
        bias: bool = True,
    ):
```

It's possible to merge $W^Q$ and $W^K$ into a single matrix $W^{QK}$ and $W^V$ into a single matrix $W_qkv$ and then $unbind$ them into $W^Q$ and $W^K$ and $W^V$ matrices.
In Multi-Head Attention, each individual head size is smaller than the input size, so for Single Head we will arbitrarily set the head size to be four times smaller than the input dimension.
```python
# linear layer to project queries, keys, values
Wqkv = nn.Linear(hidden_size, (hidden_size//4)*3, bias=bias)
# linear layer to project final output
proj = nn.Linear(hidden_size//4, hidden_size, bias=bias)
```

And that’s it for the Attention initialization. The Attention mechanism in a Transformer only has two layers of learnable parameters. Everything else in Attention is an operation on the output of the Wqkv linear layer.

#### Single Head Forward

After some input shape housekeeping, the first computational step is to generate our keys, queries, and values. First, we pass the input $x$ through the $Wqkv$.
Then we reshape the $Wqkv$ output to batch size, sequence length, one dimension for $Q K V$, and the head size.
Finally, we split the single tensor into the query, key, and value tensors using unbind, where each are of shape B, S, C//4.
```python
# batch size (B), sequence length (S), input dimension (C)
B, S, C = x.shape

# split into queries, keys, & values of shape
# batch size (B), sequence length (S), head size (HS)
q, k, v = self.Wqkv(x).reshape(B, S, 3, C//4).unbind(dim=2)
```

With the queries, keys, and values generated, we can move to the mathematical operations of the Attention mechanism.

So first, we need to transpose $K$ and take the dot product of $Q$ and $K^T$.

```python
# calculate dot product of queries and keys of shape
# (B, S, S) = (B, S, HS) @ (B, HS, S)
attn = q @ k.transpose(-2, -1)
```

Next, we need to scale the outputs of the $QK^T$ by $\sqrt{d_k}$.

```python
# scale by square root of head dimension
attn = attn / math.sqrt(k.size(-1))
```

it’s time to calculate the token Attention weight using softmax.

```python
# apply softmax to get attention weights
attn = attn.softmax(dim=-1)
```

This Softmax output of $QK^T/ \sqrt{d_k}$ is how the Attention mechanism weights the strength of the relationship between each pair of tokens. Where higher Softmax values means Attention is placing more importance on these pairs of tokens and lower values are deemed less important.

Next we matrix multiply the Attention weights with our value matrix $V$ which applies the Attention weights to our propagating token embeddings

```python
# dot product attention weights to values
# (B, S, HS) = (B, S, S) @ (B, S, HS)
x = attn @ v
```

Finally, we project the output of the Attention mechanism back to the original input dimension using the $proj$ linear layer.

```python
# project back to original dimension
x = self.proj(x)
```

And there you have it. A simple rendition of Single Head Bidirectional Attention in code.

```python
class SingleHeadAttention(nn.Module):
    def __init__(self,
        hidden_size: int,
        bias: bool = True,
    ):
        super().__init__()
        self.Wqkv = nn.Linear(hidden_size, (hidden_size//4)*3, bias=bias)
        self.Wo = nn.Linear(hidden_size//4, hidden_size, bias=bias)

    def forward(self, x:Tensor):
        B, S, C = x.shape

        q, k, v = self.Wqkv(x).reshape(B, S, 3, C//4).unbind(dim=2)

        attn = q @ k.transpose(-2, -1)
        attn = attn / math.sqrt(k.size(-1))

        attn = attn.softmax(dim=-1)

        x = attn @ v

        return self.Wo(x)
```

## Multi-Head Self-Attention

Formally, Multi-Head Attention creates one query $ Q_h $, key $ K_h $, and value $ V_h $ per head $ h $, calculates the scaled dot-product Attention per head $ \text{Attention}(Q_h, K_h, V_h) $, concatenates all the Attention outputs back into one tensor $ \text{MultiHead}(Q, K, V) $, before passing the Multi-Head Attention output through the final linear layer $ W^O $:

$$
Q_h = X W_h^Q \quad K_h = X W_h^K \quad V_h = X W_h^V
$$

$$
\text{Attention}(Q_h, K_h, V_h) = \text{softmax} \left( \frac{Q_h K_h^T}{\sqrt{d_h}} \right) V_h
$$

$$
\text{MultiHead}(Q, K, V) = \text{concat}(\text{Attention}(Q_h, K_h, V_h), \text{ for all } h)
$$

$$
\text{Output} = \text{MultiHead}(Q, K, V) W^O
$$


```python
def __init__(self,
    hidden_size: int,
    num_heads: int,
    bias: bool = True,
):
    # input dimension must be divisible by num_heads
    assert hidden_size % num_heads == 0
    # number of attention heads
    self.nh = num_heads
    super().__init__()
    # linear layer to project queries, keys, values
    self.Wqkv = nn.Linear(hidden_size, hidden_size*3, bias=bias)
    # linear layer to project final output
    self.Wo = nn.Linear(hidden_size, hidden_size, bias=bias)
```


Our Multi-Head forward method is largely the same, with a few changes to account for the multiple heads.

Our input sequence is projected through the linear Wqkv layer as before. Then we need to reshape and transpose the output to batch size, number of heads, 
$Q_h$ $K_h$ $V_h$ sequence length, and the head size, which in most Transformers is the embedding shape divided by the number of heads. Then we unbind our reshaped and transposed output to the separate queries, keys, & values, each of shape B, NH, S, HS.
```python
# batch size (B), sequence length (S), input dimension (C)
B, S, C = x.shape

# split into queries, keys, & values of shape
# batch size (B), num_heads (NH), sequence length (S), head size (HS)
x = self.Wqkv(x).reshape(B, S, 3, self.nh, C//self.nh)
q, k, v = x.transpose(3, 1).unbind(dim=2)
```

The Attention mechanism is exactly the same as the Single Head code, but the difference in tensor shape means we are calculating the Softmax individually per each head

```python
# calculate dot product of queries and keys
# (B, NH, S, S) = (B, NH, S, HS) @ (B, NH, HS, S)
attn = q @ k.transpose(-2, -1)

# scale by square root of head dimension
attn = attn / math.sqrt(k.size(-1))

# apply softmax to get attention weights
attn = attn.softmax(dim=-1)
```

Our remaining steps are to matrix multiply the Attention outputs with $V_h$ , then concatenate the per-head Attention into one output of our input shape.

We perform this by transposing the heads and sequences then reshaping to B, S, C. This is mechanically the same as a concatenation, without the requirement of creating a new tensor.

```python

# dot product attention weights with values
# (B, NH, S, HS) = (B, NH, S, S) @ (B, NH, HS, S)
x = attn @ v

# transpose heads & sequence then reshape back to (B, S, C)
x = x.transpose(1, 2).reshape(B, S, C)

# apply final linear layer to get output
return self.Wo(x)

```

With all the pieces defined, we now have a working, albeit incomplete, implementation of Bidirectional Self-Attention.

```python

class MultiHeadAttention(nn.Module):
    def __init__(self,
        hidden_size: int,
        num_heads: int,
        bias: bool = True,
    ):
        super().__init__()
        assert hidden_size % num_heads == 0
        self.nh = num_heads
        self.Wqkv = nn.Linear(hidden_size, hidden_size * 3, bias=bias)
        self.Wo = nn.Linear(hidden_size, hidden_size, bias=bias)

    def forward(self, x: Tensor):
        B, S, C = x.shape

        x = self.Wqkv(x).reshape(B, S, 3, self.nh, C//self.nh)
        q, k, v = x.transpose(3, 1).unbind(dim=2)

        attn = q @ k.transpose(-2, -1)
        attn = attn / math.sqrt(k.size(-1))

        attn = attn.softmax(dim=-1)

        x = attn @ v

        return self.Wo(x.transpose(1, 2).reshape(B, S, C))

```

# Causal Self-Attention

```python
# causal mask to ensure that attention is not applied to future tokens
# where context_size is the maximum sequence length of the transformer
self.register_buffer('causal_mask',
    torch.triu(torch.ones([context_size, context_size],
               dtype=torch.bool), diagonal=1)
        .view(1, 1, context_size, context_size))
```


Then in our CausalAttention forward method, we use masked_fill again to apply the causal mask to our intermediate Attention results before applying softmax to calculate the Attention weights.

```python
# scale by square root of output dimension
attn = attn / math.sqrt(k.size(-1))

# apply causal mask
attn = attn.masked_fill(self.causal_mask[:, :, :S, :S], float('-inf'))

# apply softmax to get attention weights
attn = attn.softmax(dim=-1)
```

And that’s it! We now have a working implementation of Causal Self-Attention.

```python
class CausalAttention(nn.Module):
    def __init__(self,
        hidden_size: int,
        num_heads: int,
        context_size: int,
        attn_drop: float = 0.1,
        out_drop: float = 0.1,
        bias: bool = True,
    ):
        super().__init__()
        assert hidden_size % num_heads == 0
        self.nh = num_heads
        self.Wqkv = nn.Linear(hidden_size, hidden_size * 3, bias=bias)
        self.Wo = nn.Linear(hidden_size, hidden_size, bias=bias)
        self.attn_drop = nn.Dropout(attn_drop)
        self.out_drop = nn.Dropout(out_drop)
        self.register_buffer('causal_mask',
            torch.triu(torch.ones([context_size, context_size],
                       dtype=torch.bool), diagonal=1)
                .view(1, 1, context_size, context_size))

    def forward(self, x: Tensor, mask: BoolTensor):
        B, S, C = x.shape

        x = self.Wqkv(x).reshape(B, S, 3, self.nh, C//self.nh)
        q, k, v = x.transpose(3, 1).unbind(dim=2)

        attn = q @ k.transpose(-2, -1)
        attn = attn / math.sqrt(k.size(-1))

        combined_mask = self.causal_mask[:, :, :S, :S] + mask.view(B, 1, 1, S)
        attn = attn.masked_fill(combined_mask, float('-inf'))

        attn = attn.softmax(dim=-1)
        attn = self.attn_drop(attn)

        x = attn @ v

        x = x.transpose(1, 2).reshape(B, S, C)
        return self.out_drop(self.Wo(x))
```