In [1]:
import collections

import torch
import torch.nn as nn

torch.manual_seed(239)  # For reproducibility

<torch._C.Generator at 0x13753c430>

## Overview
This educational notebook delves into the mathematics of self-attention and multi-head attention mechanisms that
make up the Transformer architecture. It also provides implementations in PyTorch. 
This architecture is the core technology behind large language models (LLMs), 
such as OpenAI's GPT models.

### Table of Contents
1. [Attention Weights](#Attention-Weights)
2. [Self-Attention Mechanism Without Trainable Parameters](#Self-Attention-Mechanism-Without-Trainable-Parameters)
3. [Adding Trainable Weights to Self-Attention](#Adding-Trainable-Weights-to-Self-Attention)
4. [Positional Encoding](#Positional-Encoding)
5. [Causal Attention](#Causal-aAttention)
6. [Multi-head Attention](#Multi-head-Attention)

## Attention Weights

  
The input data to a transfomer is a set of vectors $x_1, \cdots, x_N$ each with a
dimensionality of $D$. This, of course, can we written as a $D \times N$ matrix $\mathbf{X}$. 
In (Bishop & Bishop, 2023) the $x_1 \cdots x_n$ vectors are referred to as "tokens." 
I will refer to these as token embedding vectors or token embeddings for short. 
In language modeling, through a process known as "tokenization" a sentence or 
sequence of words in split into a sequence of natural number which I will 
refer to as "tokens." Each token is then converted to a $D$ dimenionsal 
vector (i.e. token embedding) to create the input for the transformer.

Now suppose we have a sentence (e.g. "I swam across the river to get to the other bank") and this
sentence has gone through the tokenization process and has been converted into a sequence of 
token embedding vectors, $x_1, \cdots, x_N$, each with dimensionality of $D$. Now goal is to map this
sequence of vectors $x_1, \cdots, x_N$ to a new sequence of vectors $y_1, \cdots, y_N$
in a new space that captures important "semantic" information within the full sentence. 
For example, when reading the sentence "I swam across the river to get to the other bank" 
the words "swam" and "river" give us information about the meaning of "bank" in that sentence.
As in (Raschka, 2024) we will refer to the $y_n$ as a context vector to highlight
that each $y_n$ depends on the other token embeddings in the token embedding sequence or "context."

To this end, we want the context vector $y_n$ to depend not only on $x_n$ but also on
the full token embedding sequence $x_1, \cdots, x_N$. For those that more familar with
tokenizers will ignore masking for now. Returning to our example sentence, 
this would mean the context vector $y_n$ would depend not only on the 
embedding vector $x_n$ corresponding to the word "bank" but on 
every embedding vector in the sequence, each corresponding to a word (i.e. token) 
in the sentence.

A possible approach to constructing the context vectors $y_n$ is to define each $y_n$ 
to be a linear combination of the token embeddings $x_1, \cdots n_N$. That is, $y_n = \sum_{m=1}^N a_{nm}x_m$.
The question then become how do we define the weights $a_{nm}$. To start we constrain the
weights to be non-negative and sum to one. Mathematically, $a_{nm} \geq 0$ for $n, m=1, \cdots N$ and $\sum_{m=1}^N a_{nm}=1$ for $n=1, \cdots N$.

In summary, we have
$$
y_n = \sum_{m=1}^N a_{nm}x_m
$$
where,
$$
a_{nm} \geq 0 \quad n, m=1, \cdots N \quad \text{and}\\
\sum_{m=1}^N a_{nm}=1 \quad n=1, \cdots N.
$$
The $a_{nm}$ parameters are called **attention weights**.

## Self-Attention Mechanism Without Trainable Parameters

Self-Attention is the name given to a process for calculating the attention weights. 
Many of the terms used here are from the field of information retrieval and so we
will began with some definitions. If you are familiar with Python then recall 
a dictionary object. Dictionaries in Python have **keys** and **values** (e.g. {'key': 'value'}).
Now given a dictionary a user provides a query. The query is what they want to find in the dictionary. 
The query is used to look up (or check) for the key in the dictionary and when a match is made
the value is returned. This looks like this in Python: 

In [2]:
example_dictionary = collections.defaultdict(str)
example_dictionary['self'] = 'attention'  # {key1: value1}
example_dictionary['attention'] = 'weights'  # {key2: value2}
example_dictionary['python'] = 'dictionary'  # {key3: value3}
example_dictionary['trainable'] = 'weights'  # {key3: value3}

query = 'self'
value = example_dictionary.get(query, 'not found')
print(f'Query: {query}, Value: {value}')

Query: self, Value: attention


The Python dictionary example is just one simple example of the idea behind key, value, and query triples.
You can imagine extending this so that the query and key do not have to match exactly as they did in our dictionary example.
We could take our query and search for the most "similar" (we have to define what similar means). This is idea will help us understand
the self-attention mechanism.

In the self-attention mechanism the sequences of token embedding vectors, ${x_1, \cdots x_N}$, can be thought of as values and these values will be used to create the context vectors. Furthermore, we will also use sequences of token embedding vectors,  ${x_1, \cdots x_N}$, directly as the keys for each corresponding value. So using the dictionary example before you could imagine our dictionary looks like $\{x_1:x_1, x_2:x_2, \cdots, x_N:x_N\}$. This example is to simply an analogy to help our understanding we will of course not acutally be using these dictionaries in the attention mechanism. Now consider a single token embedding vector, $x_m$, from our sequence of token embedding vectors, ${x_1, \cdots x_N}$. The token embedding vector $x_m$ will be our query. For each $x_m$ will then measure the "degree of match" (or similarity) between the $x_m$ and all the keys (our sequence of token embedding vectors ${x_1, \cdots x_N}$). This measure of "degree of match" will then be used as the weights in a linear combination of the values to produce the context embedding vector $y_m$. We will repeat this produce treating each embedding vector as a query and using it to find the weights to use in a linear combination of the values to produce the corresponding context embedding vector. Of course using matrix multiple we can calculate all the necessary weights at once.

A common method to determine the degree of match (or similarity) between a query vector and a key vector is to take their [dot product](https://en.wikipedia.org/wiki/Dot_product). Let $a^*_{nm}$ be the unconstrained attention weights. We will refer to these as attention scores as in (Raschka, 2024). Then for a given query $x_n$ and key $x_m$ the attention score is defined to be $a^*_{nm} = x^T_n x_m$. To constrain the values to be non-negative and sum to one we will apply the [softmax function](https://en.wikipedia.org/wiki/Softmax_function). This results in the attention weights:
$$
a_{nm} = \frac{\exp(a^*_{nm})}{\sum_{k=1}^N \exp(a^*_{nk})} = \frac{\exp(x^T_n x_m)}{\sum_{k=1}^N \exp(x^T_n x_k)}.
$$

Each attention weight $a_{nm}$ is the contribution of the key $x_m$ to the output context vector $y_n$. Mathematically, 
$$
y_n = \sum_{m=1}^N a_{nm}x_m.
$$

Let $\mathbf{X}$ be a matrix of token embedding vectors where each **row** corresponds to a token embedding, $x_n$ and let $\mathbf{A^*}$ be the matrix of unconstrained attention scores (before the softmax function has been applied). Then we have $\mathbf{A^*} = \mathbf{X}\mathbf{X^T}$ where each row of $A^*$ is a vector of scores that when constrained using the softmax function will determine how much each input token embedding vector contributes to the linear combination that produces the context vector $y_n$. Applying the softmax function across the rows we get $A = \text{Softmax}[\mathbf{XX^T}] = \text{Softmax}[\mathbf{A^*}]$ where the softmax function is apply across the rows independently. Finally we can apply this attention weights to the values (another set of the input token embeddings token embedding) using matrix multiplcation. Let $\mathbf{Y}$ be the matrix of output context vectors where each **row** is a context vector in the new (hopefully) more contextually rich space. Then 
$$
\mathbf{Y} = \text{Softmax}[\mathbf{XX^t}]\mathbf{X}  = \text{Softmax}[\mathbf{A^*}]\mathbf{X} = \mathbf{AX}.
$$

Now lets see how to program this in [PyTorch](https://pytorch.org).

In [3]:
# Suppose we have a sentence of 5 words or "tokens"
# (e.g. "This is an example sentence.")
# Then we can represent each token as a vector of dimension d (e.g. d=3).

X = torch.tensor(
    [
        [0.5341, 0.3316, 0.5995],  # This        x1
        [0.9891, 0.8921, 0.4602],  # is          x2
        [0.1234, 0.5678, 0.9101],  # an          x3
        [0.4567, 0.7890, 0.1234],  # example     x4
        [0.2345, 0.6789, 0.3456],  # sentence    x5
    ]
)
print(f'X = {X}\n')

A_star = X @ X.T  # Unconstrained attention scores
print(f'A_star = {A_star}\n')

# Output:
# A_star = tensor([[0.7546, 1.1000, 0.7998, 0.5795, 0.5576],
#                  [1.1000, 1.9859, 1.0474, 1.2124, 0.9966],
#                  [0.7998, 1.0474, 1.1659, 0.6167, 0.7289],
#                  [0.5795, 1.2124, 0.6167, 0.8463, 0.6854],
#                  [0.5576, 0.9966, 0.7289, 0.6854, 0.6353]])

# Apply softmax across rows to constrain attention scores
A = torch.softmax(A_star, dim=1)
print(f'A = {A}\n')

# Output:
# A = tensor([[0.1953, 0.2759, 0.2044, 0.1640, 0.1604],
#         [0.1564, 0.3793, 0.1484, 0.1750, 0.1410],
#         [0.1822, 0.2334, 0.2628, 0.1517, 0.1698],
#         [0.1578, 0.2971, 0.1637, 0.2060, 0.1754],
#         [0.1679, 0.2605, 0.1993, 0.1908, 0.1815]])
#
# Looking at the first row of A, we see that to compute y1, the context vector
# for the first token embedding x1 ("This"), we will weight x1 (itself) by
# 0.1953, x2 ("is") by 0.2759, x3 ("an") by 0.2044, x4 ("example") by 0.1640,
# and x5 ("sentence") by 0.1604. This means that the context.

# Compute the context vectors using matrix multiplication
Y = A @ X
print(f'Y = {Y}')  # notice that Y is the same shape as X

# Y = tensor([[0.5150, 0.6652, 0.5058], # This      y1
#         [0.5899, 0.7082, 0.4736],     # is        y2
#         [0.4698, 0.6529, 0.5333],     # an        y3
#         [0.5335, 0.6919, 0.4664],     # example   y4
#         [0.5016, 0.6750, 0.4882]])    # sentence  y5

X = tensor([[0.5341, 0.3316, 0.5995],
        [0.9891, 0.8921, 0.4602],
        [0.1234, 0.5678, 0.9101],
        [0.4567, 0.7890, 0.1234],
        [0.2345, 0.6789, 0.3456]])

A_star = tensor([[0.7546, 1.1000, 0.7998, 0.5795, 0.5576],
        [1.1000, 1.9859, 1.0474, 1.2124, 0.9966],
        [0.7998, 1.0474, 1.1659, 0.6167, 0.7289],
        [0.5795, 1.2124, 0.6167, 0.8463, 0.6854],
        [0.5576, 0.9966, 0.7289, 0.6854, 0.6353]])

A = tensor([[0.1953, 0.2759, 0.2044, 0.1640, 0.1604],
        [0.1564, 0.3793, 0.1484, 0.1750, 0.1410],
        [0.1822, 0.2334, 0.2628, 0.1517, 0.1698],
        [0.1578, 0.2971, 0.1637, 0.2060, 0.1754],
        [0.1679, 0.2605, 0.1993, 0.1908, 0.1815]])

Y = tensor([[0.5150, 0.6652, 0.5058],
        [0.5899, 0.7082, 0.4736],
        [0.4698, 0.6529, 0.5333],
        [0.5335, 0.6919, 0.4664],
        [0.5016, 0.6750, 0.4882]])


## Adding Trainable Parameters to Self-Attention

At this point the transformation from an token embedding vector, $x_1$ to a context
vector $y_1$ is fixed with no ability to "good" representations from the data.
To this end we will add weights that through training will allow us to learn
good context context vectors, $y_1$. In this context "good" means the context vectors are useful
for the task we want to accomplish. 

There is one more subtle issue. The token embedding vectors can be considered 
as a set of features that describe the token the embeddings represent. In the 
self-attention mechanism described above each of these features has equal 
weight in contribution to the attention scores (recall $a_n^* = x_n^Tx_n$ ). However, 
it may be benefically to allow some features to contribute more heavily to the weight scores.

We can address both these issues by adding trainable parameters to our self-attention
mechanism. Let $U$ be a $D \times D$ matrix of trainable weights and define
$$
\widetilde{\mathbf{X}} = \mathbf{XU}.
$$

Recall from linear algebra that every [linear transformation can be represented
by a matrix multiplication](https://en.wikipedia.org/wiki/Matrix_multiplication). So this
is simply applying a linear transformation to the matrix $\mathbf{X}$. This is also 
the same thing as adding a "layer" to a artificial neural network.
This will allow the model to "learn" how to weight the features in each of the
token embedding vectors when computing the attention scores that will be used
to compute the context vectors.

The new unconstrianed attention scores using the new linear transformed token embedding vectors
is given by
$$
\widetilde{\mathbf{A}}^* = \widetilde{\mathbf{X}} \widetilde{\mathbf{X}}^T = \mathbf{XU}\mathbf{U}^T\mathbf{X}^T.
$$

Plugging these new unconstrianed attention scores into the formula for calculating the context vcetors gives
$$
\mathbf{Y} = \text{Softmax}[\widetilde{\mathbf{A}}^*]\widetilde{\mathbf{X}} = \text{Softmax}[\widetilde{\mathbf{X}} \widetilde{\mathbf{X}}^T]\widetilde{\mathbf{X}} = \text{Softmax}[\mathbf{XU}\mathbf{U}^T\mathbf{X}^T]\mathbf{XU}.
$$

There are more improvements we can make to our attention scores. As summarized in 
Bishop & Bishop, 2023, the new unconstrianed attention scores, 
$\widetilde{\mathbf{A}}^* = \mathbf{XU}\mathbf{U}^T\mathbf{X}^T$, are symmetric.
However, there are many cases when we will want to scores to be asymmetric. For example,
we **may** want MacBook to have a stronger association with Apple because the Apply company
makes the MacBook than the associate from Apple to MacBook because Apple can be
used in many more contexts. Another limitation is that the parameter matrix
(the learned linear transformation) is used to both compute the attention scores
and to transform the value vectors. We could get more flexibility by defining
seperated transformed matrices, $\mathbf{Q}$ for queries, $\mathbf{K}$ for keys, 
and $\mathbf{V}$ for values each with their own independent set of parameters:
$$
\mathbf{Q} = \mathbf{XW}^{(q)}, \\
\mathbf{K} = \mathbf{XW}^{(k)}, \\
\mathbf{V} = \mathbf{XW}^{(v)}.
$$
Using our independently linearly transformed query and key matrices, we get the following unconstrained attention scores
$$
\mathbf{A} = \mathbf{Q}\mathbf{K}^T.
$$
Finally, using $\mathbf{V}$ for the independently linearly transformed values and plugging in our unconstrained attention scores to the 
formula for the context vectors gives
$$
\mathbf{Y} = \text{Softmax}[\mathbf{Q}\mathbf{K}^T]{\mathbf{V}}.
$$

There is one last adjustment to make. The last thing we will do is scale
the attention scores before applying the softmax. The softmax functions struggles
to handle very large values so scaling the the attentions scores to make
sure no values are too large before applying the softmax function to constrain the
attention scores improves model training. We will normalize by the square root
of the dimension of the value and key vectors, $D_k$. The reasoning behind this choice is
that if all the elements of the query and key vectors were all independent random numbers
with mean $0$ and variance $1$ then their dot product would have variance $D_k$. So
we normalize by the standard deviation to again achieve unit variance. This final step results in the
**scaled dot-produce self-attention** form of self-attention
that is used in most modern language models:

$$
\mathbf{Y} = \text{Softmax}\left[\frac{\mathbf{Q}\mathbf{K}^T}{\sqrt{D_k}}\right]{\mathbf{V}}.
$$

**The amazing and maybe the most important feature of the self-attention mechanism is that
the attention scores, which act like a set of weights in the network (i.e. a learned 
linear transformation of $\mathbf{V}$), depend on the input data.** 
This example may not seem like much but this is a rare feautre of
artifical neural networks. In most artifical neural networks the weights in the 
network are fixed after training.

Now lets implement this in PyTorch.

In [4]:
# Adapted from Section 3.4 from Raschka, S. (2024).
class SelfAttentionV1(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        # Recall linear layers, nn.Linear, are linear transformations which
        # are simply matrices.
        self.D_k = d_out
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, X):
        # Apply linear transformations to the input X to compute the
        # query, key, and value matrices
        Q = self.W_q(X)  # Q = XW_q
        K = self.W_k(X)  # K = XW_k
        V = self.W_v(X)  # V = XW_v

        # Compute attention weights
        # A = Softmax[Q K^t / sqrt(D_k)]
        attention_weights = torch.softmax(Q @ K.T / self.D_k**0.5, dim=-1)
        print(f'Attention Weights:\n{attention_weights}\n')

        # Apply attention weights to the value matrix V to
        # compute the context vectors Y
        Y = attention_weights @ V

        return Y


# Example usage
d_in = 3  # Dimension of input token embedding vectors
d_out = 3  # Dimension of output context vectors

self_attention = SelfAttentionV1(d_in, d_out)
output = self_attention(X)
print(f'Output Context Vectors of Self-Attention:\n{output}')

Attention Weights:
tensor([[0.1981, 0.1991, 0.2002, 0.2013, 0.2013],
        [0.2040, 0.1963, 0.2096, 0.1923, 0.1977],
        [0.2006, 0.2032, 0.2054, 0.1943, 0.1965],
        [0.2070, 0.1984, 0.2124, 0.1877, 0.1944],
        [0.2051, 0.2005, 0.2102, 0.1895, 0.1947]], grad_fn=<SoftmaxBackward0>)

Output Context Vectors of Self-Attention:
tensor([[ 0.3362, -0.4088,  0.2517],
        [ 0.3290, -0.4012,  0.2486],
        [ 0.3344, -0.4069,  0.2510],
        [ 0.3276, -0.3996,  0.2480],
        [ 0.3299, -0.4021,  0.2490]], grad_fn=<MmBackward0>)


## Positional Encoding

Given the current formulation of the self-attention mechanism, the ordering of 
the token embedding vectors has no influence on the attention weights. One approach
to incorporating positional information is to define a position vector for each
position. These vectors could be concatenated with the token embedding vector however
this would dramatically increase the number of parameters in the model and therefore
increase the comutational cost. Another option is to force the postional vectors
to be the same demension as the token embedding vectors and then add them together.
Let $\mathbf{r}_n$ we the position vector for the $n$th position and let $\widetilde{\mathbf{x}}_n$ be
the token embedding vector, $\mathbf{x}_n$ added with the position vector, 
$$
\widetilde{\mathbf{x}}_n = \mathbf{x}_n + \mathbf{r}_n.
$$
There are many different variants of position vectors and they are typically categorized
into two different groups. The first are **absolute** position vectors, where each 
position $n$ is assigned a unique, fixed vector $r_n$. A weakness of absolute position vectors
is that they do not generalize to sequences longer than the ones used during training. 
The second group comprises **relative** position vectors. This group allows the model
to generalize to sequences of previously  unobserved sequence lengths.
Instead of encoding the absolute position, these vectors encode the relationship between two positions. 
For example, the attention mechanism might consider the relative distance between a query token at position i and a key token at position j.

We will follow the same approach as used in OpenAI's series of GPT models while knowledging their weaknesses. The 
GPT models learn absolute position vectors jointly with tge rest of the model weights. Since these are 
absolute position vectors the will not generalize to sequences of unobserved length since the
position vectors will be untrained for any position not observed during training.

Now lets add these position vectors to our self-attention mechanism.
​	
 

In [5]:
class SelfAttention(nn.Module):
    def __init__(self, d_in, d_out, context_len, qkv_bias=False):
        super().__init__()
        # Recall linear layers, nn.Linear, are linear transformations which
        # are simply matrices.
        self.D_k = d_out
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)

        # POSITIONAL ENCODING
        self.positional_encoding = nn.Embedding(context_len, d_out)

    def forward(self, X):
        seq_len, _ = X.shape  # Get the sequence length from the input X

        # Apply positional encoding to the input X
        X = X + self.positional_encoding(
            torch.arange(seq_len, device=X.device)
        )

        # Apply linear transformations to the input X to compute the
        # query, key, and value matrices
        Q = self.W_q(X)  # Q = XW_q
        K = self.W_k(X)  # K = XW_k
        V = self.W_v(X)  # V = XW_v

        # Compute attention weights
        # A = Softmax[Q K^t / sqrt(D_k)]
        attention_weights = torch.softmax(Q @ K.T / self.D_k**0.5, dim=-1)
        print(f'Attention Weights:\n{attention_weights}\n')

        # Apply attention weights to the value matrix V to
        # compute the context vectors Y
        Y = attention_weights @ V

        return Y


# Example usage
d_in = 3  # Dimension of input token embedding vectors
d_out = 3  # Dimension of output context vectors

self_attention = SelfAttention(d_in, d_out, context_len=X.shape[0])
output = self_attention(X)
print(f'Output Context Vectors of Self-Attention:\n{output}')

Attention Weights:
tensor([[0.1458, 0.0471, 0.2460, 0.2505, 0.3105],
        [0.0947, 0.0834, 0.3150, 0.1842, 0.3226],
        [0.2007, 0.4527, 0.0939, 0.1501, 0.1026],
        [0.1880, 0.0915, 0.2341, 0.2317, 0.2547],
        [0.2253, 0.2973, 0.1442, 0.1847, 0.1486]], grad_fn=<SoftmaxBackward0>)

Output Context Vectors of Self-Attention:
tensor([[ 0.0295,  0.0785,  0.1139],
        [ 0.0946,  0.1900,  0.1150],
        [-0.2950,  0.1823, -0.1397],
        [-0.0281,  0.0799,  0.0833],
        [-0.2052,  0.1193, -0.0484]], grad_fn=<MmBackward0>)


## Causal Attention

Foundational large language models are pre-trained on a one-token-ahead prediction task. 
Consider the sentence ``This is an example sentence`` and suppose we have only observed 
a sub-sequence of this sentence (e.g. ``This is an``) and we then want to predict the
next word, ``example``. In order to not cheat at this task we need to make sure we
only consider the words up-to the current word in our sub-sequence, ``This is an``.
This is the idea of causal attention or masked attention. In causal attention we
will restrict the model to only consider the previous and current inputs when
computing the attention weights. 

Luckily, acheiving this in practice is simple. We can simple set all of the attention
weights above the diagonal (these correspond the "future" inputs in the sequence) in our attention weight matrix to zero
and then renormalize. the weights to sum to one. When we code this in PyTorch
we will actually set values above the diagonal to "-inf" before we apply this softmax. The softmax
of "-inf" is of course zero. This is a simplier implementation in PytTorch.

**What about information leakage?** There is a subtle reason why this process does not
result in information leakage. You may notice the attention scores matrix was
orginially computed using the all sequence of inputs then we set the values above
the dialog to -inf so it might appear that there is information about future
inputs in our attention scores. However, since we normalize the attention
scores after setting values above the diagonal to -inf this ensures that information
from future inputs is destroyed (or not used) in the attention scores.

Let now add this masking process to our self-attention class in PyTorch.

In [6]:
# Adapted from Section 3.5 from Raschka, S. (2024).
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_len, qkv_bias=False):
        super().__init__()
        # Recall linear layers, nn.Linear, are linear transformations which
        # are simply matrices.
        self.D_k = d_out
        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.positional_encoding = nn.Embedding(context_len, d_out)

        # Respister buffer for the causal mask to ensure the mask matrix
        # is not considered a trainable parameter
        self.register_buffer(
            'mask',
            torch.triu(
                torch.ones((context_len, context_len), dtype=torch.bool),
                diagonal=1,
            ),
        )

    def forward(self, X):
        seq_len, d_in = X.shape

        # Apply positional encoding to the input X
        X = X + self.positional_encoding(
            torch.arange(seq_len, device=X.device)
        )

        # Apply linear transformations to the input X to compute the
        # query, key, and value matrices
        Q = self.W_q(X)  # Q = XW_q
        K = self.W_k(X)  # K = XW_k
        V = self.W_v(X)  # V = XW_v

        # Compute attention scores
        attention_scores = Q @ K.T

        print(f'Attention Scores before masking:\n{attention_scores}\n')

        # Apply the causal mask to the attention scores
        # This ensures that information from future inputs is not used
        # in the attention scores.
        attention_scores.masked_fill_(
            self.mask[:seq_len, :seq_len],
            -torch.inf,  # type: ignore
        )

        print(f'Attention Scores after masking:\n{attention_scores}\n')

        # A = Softmax[Q K^t / sqrt(D_k)]
        attention_weights = torch.softmax(
            attention_scores / self.D_k**0.5, dim=-1
        )

        print(f'Constrained masked attention weights:\n{attention_scores}\n')

        # Apply attention weights to the value matrix V to
        # compute the context vectors Y
        Y = attention_weights @ V

        return Y


# Example usage
d_in = 3  # Dimension of input token embedding vectors
d_out = 3  # Dimension of output context vectors

causal_attention = CausalAttention(d_in, d_out, context_len=5)
output = causal_attention(X)
print(f'Output Context Vectors of Causal Attention:\n{output}')

Attention Scores before masking:
tensor([[-0.4652, -1.4155, -0.3993, -0.4332, -0.1389],
        [ 0.1301,  0.3980,  0.1119,  0.1782,  0.0268],
        [-0.1884, -0.5737, -0.1618, -0.1677, -0.0579],
        [-0.3804, -0.7427, -0.2699, -0.3657, -0.1499],
        [-0.2941, -0.9808, -0.2641, -0.2768, -0.0792]], grad_fn=<MmBackward0>)

Attention Scores after masking:
tensor([[-0.4652,    -inf,    -inf,    -inf,    -inf],
        [ 0.1301,  0.3980,    -inf,    -inf,    -inf],
        [-0.1884, -0.5737, -0.1618,    -inf,    -inf],
        [-0.3804, -0.7427, -0.2699, -0.3657,    -inf],
        [-0.2941, -0.9808, -0.2641, -0.2768, -0.0792]],
       grad_fn=<MaskedFillBackward0>)

Constrained masked attention weights:
tensor([[-0.4652,    -inf,    -inf,    -inf,    -inf],
        [ 0.1301,  0.3980,    -inf,    -inf,    -inf],
        [-0.1884, -0.5737, -0.1618,    -inf,    -inf],
        [-0.3804, -0.7427, -0.2699, -0.3657,    -inf],
        [-0.2941, -0.9808, -0.2641, -0.2768, -0.0792]],
      

## Multi-head Attention

A single self-attention mechanism in a artificial nueral network is called an "attention head."
In a single attention head, through the single set of learned query, key and value matrices the model
can be limited in the set of combinations of features it can construct. If several 
different combinations of features might be important a single attention 
head will average over these combinations. 

To this end, we can use multiple attention heads in parallel. Where each attention
head will have is own independent set of query, key and value matrices. This will 
allow use to construct different combinations of features in each attention head 
and therefore focus on different aspects of the input. Suppose we have $H$ attention
heads indexed by $h=1, \cdots, H$ each with the same form as dervied above:
$$
\mathbf{H}_h =  \text{Softmax}\left[\frac{\mathbf{Q}_h\mathbf{K}^T_h}{\sqrt{D_k}}\right]{\mathbf{V}_h}
$$
where
$$
\mathbf{Q}_h = \mathbf{XW}^{(q)}_h, \\
\mathbf{K}_h = \mathbf{XW}^{(k)}_h, \\
\mathbf{V}_h = \mathbf{XW}^{(v)}_h. \\
$$
That is, each attention head gets their own triple of learnable weight matrices for queries, keys and values, $\mathbf{W}^{(i)}_h$ for $i \in \{q, k, v\}$ and $h=1, \cdots, H$. The $H$ attention heads each with dimension $N \times D_v$ are concatenated into a single matrix then linearly transformed with another matrix $\mathbf{W}^{(o)}$ of learnable parameters. This gives the matrix of context vectors
$$
\mathbf{Y} = \text{Concat}\left[\mathbf{H}_1 \cdots  \mathbf{H}_H\right]\mathbf{W}^{(o)}.
$$
SInce each attention head, $\mathbf{H}_h$ has  dimension $N \times D_v$ the concatenated matrix had dimension $N \times HD_v$ and therefore the matrix  $\mathbf{W}^{(o)}$ has dimension $HD_v \times D$ to allow for matrix multiplication an produce the context vector matrix of desired size (i.e. $N \times D$). In multi0head attention $D_v$ is typcially chosen to equal $D/H$ so that the resulting concatenated matrix has dimension $N \times D$.

Now lets code this up in PyTorch.

In [7]:
# Adapted from Section 3.6 from Raschka, S. (2024).
# Note: For educational purposes, this implementation is not optimized.
# Simply concatenating the outputs of multiple heads is not an
# efficient way to implement multi-head attention. See section
# 3.6.2 in Raschka, S. (2024) for a more efficient implementation where
# the heads are processed in parallel.


class MultiHeadAttentionV1(nn.Module):
    def __init__(self, d_in, d_out, context_len, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                CausalAttention(d_in, d_out, context_len, qkv_bias)
                for _ in range(num_heads)
            ]
        )
        self.W_o = nn.Linear(num_heads * d_in, d_out, bias=qkv_bias)

    def forward(self, X):
        # Apply each head to the input X and concatenate the results
        head_outputs = torch.cat([head(X) for head in self.heads], dim=-1)
        Y = self.W_o(head_outputs)

        return Y


# Example usage
d_in = 3  # Dimension of input token embedding vectors
d_out = 3  # Dimension of output context vectors

# You will see the printed statements from the CausalAttention class
# twice because MultiHeadAttention uses CausalAttention for each head.
mha = MultiHeadAttentionV1(d_in, d_out, context_len=5, num_heads=2)
output = mha(X)
print(f'Output Context Vectors of Multihead Attention:\n{output}')

Attention Scores before masking:
tensor([[ 0.0691, -0.1044, -0.0681, -0.1564,  0.0174],
        [ 0.5019,  0.0876,  1.3337, -0.4958,  0.2746],
        [ 0.1776, -0.0414,  0.2506, -0.2281,  0.0801],
        [ 0.3385,  0.1831,  1.1695, -0.2406,  0.2070],
        [ 0.0517, -0.0381,  0.0312, -0.0867,  0.0198]], grad_fn=<MmBackward0>)

Attention Scores after masking:
tensor([[ 0.0691,    -inf,    -inf,    -inf,    -inf],
        [ 0.5019,  0.0876,    -inf,    -inf,    -inf],
        [ 0.1776, -0.0414,  0.2506,    -inf,    -inf],
        [ 0.3385,  0.1831,  1.1695, -0.2406,    -inf],
        [ 0.0517, -0.0381,  0.0312, -0.0867,  0.0198]],
       grad_fn=<MaskedFillBackward0>)

Constrained masked attention weights:
tensor([[ 0.0691,    -inf,    -inf,    -inf,    -inf],
        [ 0.5019,  0.0876,    -inf,    -inf,    -inf],
        [ 0.1776, -0.0414,  0.2506,    -inf,    -inf],
        [ 0.3385,  0.1831,  1.1695, -0.2406,    -inf],
        [ 0.0517, -0.0381,  0.0312, -0.0867,  0.0198]],
      

We can make our implementation more efficient. A single query, key, and value matrix triple.
Consider the case of the query weight matrix (but the logic is the same for all three matrices).
You can conceptually think of the large, $\mathbf{W}_Q$, matrix as a stack of 
the individual head weight matrices, $\mathbf{W}_{Q1}, \mathbf{W}_{Q2}, \cdots, \mathbf{W}_{QK}$ . That is,

$$
\mathbf{W}_Q = \left[\mathbf{W}_{Q1} \ \mathbf{W}_{Q2} \ \cdots \ \mathbf{W}_{QH} \right], \\
\mathbf{W}_K = \left[\mathbf{W}_{K1} \ \mathbf{W}_{K2} \ \cdots \ \mathbf{W}_{KH} \right], \\
\mathbf{W}_V = \left[\mathbf{W}_{V1} \ \mathbf{W}_{V2} \ \cdots \ \mathbf{W}_{VH} \right]. \\
$$

Then its easy to see that

$$
\mathbf{XW}_Q = \left[\mathbf{XW}_{Q1} \ \mathbf{XW}_{Q2} \ \cdots \ \mathbf{XW}_{QH} \right], \\
\mathbf{XW}_K = \left[\mathbf{XW}_{K1} \ \mathbf{XW}_{K2} \ \cdots \ \mathbf{XW}_{KH} \right], \\
\mathbf{XW}_V = \left[\mathbf{XW}_{V1} \ \mathbf{XW}_{V2} \ \cdots \ \mathbf{XW}_{VH} \right]. \\
$$

Using the ``.view`` method we can essentially unstack these matrices. From there 
we can use PyTorch's [broadcasting](https://docs.pytorch.org/docs/stable/notes/broadcasting.html)
semantics to calculate the attention weights and apply the causal mask.

With this we can implement the Multi-head attention mechanism as it is commonly
used in the Transformer architecture that is used in Large Language Model.
Though these classes do not work for batched inputs but that fix is easy to
implement and as your math teacher might say is left as an exercise to the reader.

In [8]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_len, num_heads, qkv_bias=False):
        super().__init__()
        super().__init__()
        assert d_out % num_heads == 0, 'd_out must be divisible by num_heads'

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_q = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_k = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_v = nn.Linear(d_in, d_out, bias=qkv_bias)

        self.W_o = nn.Linear(d_out, d_out)

        self.register_buffer(
            'mask',
            torch.triu(
                torch.ones((context_len, context_len), dtype=torch.bool),
                diagonal=1,
            ),
        )

    def forward(self, X):
        num_tokens, d_in = X.shape

        # stack the query, key, and value matrices
        keys = self.W_k(X)
        queries = self.W_q(X)
        values = self.W_v(X)

        # Unstack the query, key, and value matrices
        # into multiple heads.
        # keys, queries, and values will now have shape
        # (num_tokens, num_heads, head_dim). Recall that
        # head_dim = D // num_heads
        keys = keys.view(num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(num_tokens, self.num_heads, self.head_dim)
        values = values.view(num_tokens, self.num_heads, self.head_dim)

        # swap the first two dimensions to have shape
        # (num_heads, num_tokens, head_dim)
        keys = keys.transpose(-3, -2)
        queries = queries.transpose(-3, -2)
        values = values.transpose(-3, -2)

        # Use PyTorch's broadcasting semantics to compute attention scores
        # for each head in parallel.
        attn_scores = queries @ keys.transpose(-2, -1)
        attn_scores.masked_fill_(
            self.mask[:num_tokens, :num_tokens], -torch.inf
        )

        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1] ** 0.5, dim=-1
        )

        concat_heads = (attn_weights @ values).transpose(1, 2)
        concat_heads = concat_heads.contiguous().view(num_tokens, self.d_out)
        Y = self.W_o(concat_heads)

        return Y

In [9]:
# Example usage
d_in = 6  # Dimension of input token embedding vectors
d_out = 6  # Dimension of output context vectors
context_len = 5  # Length of the context (number of tokens)

inputs = torch.rand(context_len, d_in)

# You will see the printed statements from the CausalAttention class
# twice because MultiHeadAttention uses CausalAttention for each head.
mha = MultiHeadAttention(d_in, d_out, context_len, num_heads=2)
output = mha(inputs)
print(f'Output Context Vectors of Multihead Attention:\n{output}')

Output Context Vectors of Multihead Attention:
tensor([[-0.2345,  0.1695, -0.2975, -0.0606,  0.4637, -0.3039],
        [-0.1418,  0.1676, -0.2328, -0.1002,  0.3881, -0.2657],
        [-0.2019,  0.3206, -0.1069, -0.3274,  0.4580, -0.1621],
        [-0.0782,  0.2902, -0.2329, -0.1236,  0.1868, -0.3234],
        [-0.1032,  0.4003, -0.1043, -0.4110,  0.2032, -0.0542]],
       grad_fn=<AddmmBackward0>)


## References
1. Bishop, C. M., & Bishop, H. (2023). Deep Learning. Springer.  
2. Raschka, S. (2024). Build a Large Language Model (From Scratch). Manning.