## Attention

Attention is an integral part of the LLM architecture. To get a good understanding of it, we will implement 4 variants of attention mechanisms which build on each other, with the goal of arriving at a compact and efficient implementation of multi-head attention to plug into the LLM architecture.

In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn

### 1. A simple self-attention mechanism without trainable weights

This will highlight a few key concepts in self-attention before adding trainable weights.

An input sequence is denoted as x, consisting of n elements represented as x<sup>1</sup> to x<sup>n</sup>. This sequence represents text that has already been transformed into token embeddings. E.g. below our input text is "Your journey starts with one step". Each element of the sequence, such as x<sup>1</sup>, corresponds to a d-dimensional embedding vector representing a single token, lke "Your" (in the example below, d = 3).

In self-attention, the goal is to calculate context vectors z<sup>i</sup> for each element x<sup>i</sup> in the input sequence. A <i>context vector</i> can be interpreted as an enriched embedding vector. Each context vector z<suo>i</sup> contains information about x<sup>i</sup> and all other input elements, x<sup>1</sup> to x<sup>n</sup>. This is essential in an LLM, which needs to understand the relationship and relevance of words in a sentence to each other. In practice, trainable weights help an LLM learn these context vectors to help it generate the next token.

In [2]:
# Input sequence
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your      (x^1)
     [0.55, 0.87, 0.66], # journey   (x^2)
     [0.57, 0.85, 0.64], # starts    (x^3)
     [0.22, 0.58, 0.33], # with      (x^4)
     [0.77, 0.25, 0.10], # one       (x^5)
     [0.05, 0.80, 0.55]  # step      (x^6)
    ]
)

The first step of implementing self-attention is to compute the intermediate values w, known as attention scores. We do this by computing the dot product between the query (as an example, x<sup>2</sup>) and every other input token. <u>A dot product is the multiplication of two vectors element-wise and then summing the products</u>. It is a measure of similarity because it quantifies how closely two vectors are aligned: a higher dot product indicates a higher degree of similarity between the vectors. For self-attention, the dot product determines the extent to which each element in a sequence focuses on (attends to) any other element. The higher the dot product, the higher the similarity and attention score between two elements.

In [3]:
query = inputs[1] # x^2
atten_scores_2 = torch.empty(inputs.shape[0])

for i, x_i in enumerate(inputs):
    atten_scores_2[i] = torch.dot(x_i, query)

print(atten_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [None]:
f"{rounded_element:.2f}"

torch.Tensor

In [40]:
# Dot product manually
res = 0
for idx, element in enumerate(inputs[0]): # for each element in the first vector
    print(f"Element {idx+1}: {torch.round(element, decimals=2):.2f} * query position {idx+1}: "
          f"{torch.round(query[idx], decimals=2):.2f} = "
          f"{torch.round(element * query[idx], decimals=2):.3f}") 
    res += inputs[0][idx] * query[idx]
print(f"Sum of dot products: {torch.round(res, decimals=4).item():.4f}")
print("Using torch.dot():", torch.dot(inputs[0], query))

Element 1: 0.43 * query position 1: 0.55 = 0.240
Element 2: 0.15 * query position 2: 0.87 = 0.130
Element 3: 0.89 * query position 3: 0.66 = 0.590
Sum of dot products: 0.9544
Using torch.dot(): tensor(0.9544)


We then normalise the scores to obtain attention weights that sum up to 1. This is useful for interpretation and maintaining training stability in an LLM. In practice we use the softmax function, which is better at managing extreme values and offers more favourable gradient properties during training.

In [7]:
atten_scores_2_tmp = atten_scores_2 / atten_scores_2.sum()
print("Attention weights:", atten_scores_2_tmp)
print("Sum:", atten_scores_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


In [8]:
# With softmax
atten_weights_2 = torch.softmax(atten_scores_2, dim=0)
print("Attention weights:", atten_weights_2)
print("Sum:", atten_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


Now that we have normalised attention weights, we calculate the context vector z<sup>2</sup> by multiplying the embedded inputs tokens x<sup>i</sup>, with the corresponding attention weights and then summing the resulting vectors.

In [9]:
# For x^2 alone
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)

for i, x_i in enumerate(inputs):
    context_vec_2 += atten_weights_2[i] * x_i

print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


Why are there only 3 elements in the vector, rather than 6? Because each of the input tokens is 3-dimensional and the context vector is a weighted sum. Notice below, multiplying each 3-d input vector by the corresponding attention weight, also results in a 3-d vector. The context vector is the sum along those columns.

In [14]:
# One example
print("First input token vector:", inputs[0])
print("First attention weight:", atten_weights_2[0])
print('')
print("First row when calculating context vector:", atten_weights_2[0] * inputs[0])
print("Second row when calculating context vector:", atten_weights_2[1] * inputs[1])
print("Third row when calculating context vector:", atten_weights_2[2] * inputs[2])
print("Fourth row when calculating context vector:", atten_weights_2[3] * inputs[3])
print("Fifth row when calculating context vector:", atten_weights_2[4] * inputs[4])
print("Sixth row when calculating context vector:", atten_weights_2[5] * inputs[5])

First input token vector: tensor([0.4300, 0.1500, 0.8900])
First attention weight: tensor(0.1385)

First row when calculating context vector: tensor([0.0596, 0.0208, 0.1233])
Second row when calculating context vector: tensor([0.1308, 0.2070, 0.1570])
Third row when calculating context vector: tensor([0.1330, 0.1983, 0.1493])
Fourth row when calculating context vector: tensor([0.0273, 0.0719, 0.0409])
Fifth row when calculating context vector: tensor([0.0833, 0.0270, 0.0108])
Sixth row when calculating context vector: tensor([0.0079, 0.1265, 0.0870])


In [9]:
# For all input tokens
attn_scores = torch.empty(6, 6)

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)

print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


Above, each element in the tensor represents an attention score between each pair of inputs. The above can be done more concisely without the double for loop.

In [11]:
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [None]:
# Normalise so each row sums to 1
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights) 

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


Above, by settng dim = -1, we instruct the softmax function to normalise along the last dimension of the tensor. Since it is a 2d-tensor [rows, columns], it will normalise across the columns so that the values in the rows sum to 1. 

In [13]:
# Context vectors via matrix multiplication
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


### 2. Implementing self-attention with trainable weights

This is the mechanism used in the original Transformer architecture - also called <i>scaled dot-product attention</i>. The principle is the same as the previous section; the most notable difference is that we introduce weight matrices that are updated during model training. These are crucial so that the model (the attention module inside the model) can learn to produce "good" context vectors.

The trainable weight matrices are W<sub>q</sub>, W<sub>k</sub>, and W<sub>v</sub>, which project the embedded input tokens x<sup>i</sup> into query, key, and value vectors respectively.

A <i>query</i> is analogous to a search query in a database. It represents the current token the model focuses on or tries to understand. It is used to probe the other parts of the input sequence to determine how much attention to pay to them.

A <i>key</i> is like a database key used for indexing. Each item in the input sequence (each token) has an associated key, which are used to match the query.

A <i>value</i> is similar to the value in a key-value pair in a database. It represents the actual content of the input items. Once the model determines which keys (which parts of the input) are most relevant to the query (the current focus item), it retrieves the corresponding values. 

In [43]:
# Again we calculate one context vector
x_2 = inputs[1] # second input element
d_in = inputs.shape[1] # input embedding size, d = 3
d_out = 2 # output embeddng size, d_out = 2

Note, in GPT-like models, the input and output dimensions are usually the same, but they're different here to better follow the calculation. Next, we initialise the three weight matrices. Requires_grad is set to False to reduce clutter in the outputs, but if we were to use the weight matrices for model training, we would set it to True to update them.

In [44]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [None]:
# What one looks like - random values of the right shape
print(W_query)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])


In [None]:
# Compute the vectors by multiping the input by the weight matrices
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2) # query vector for input vector 2

tensor([0.4306, 1.4551])


In [58]:
# Manually
# x_2 is one token * 3 dimensions
# W_query is 3 rows by 2 dimensions (output dimension)
row_mul = 0
total = 0
for idx, element in enumerate(x_2):
    print(f"Element {idx + 1}: {torch.round(element, decimals=2):.2f} * "
          f"W_query row {idx + 1}: {W_query[idx]}")
    row_mul += x_2[idx] * W_query[idx]
    total += x_2[idx] * W_query[idx]
    print(f"Result of matrix multiplication for row {idx + 1}: {row_mul}")
    row_mul = 0
    print(' ')

print(f"Sum along the columns: {total}")

Element 1: 0.55 * W_query row 1: tensor([0.2961, 0.5166])
Result of matrix multiplication for row 1: tensor([0.1629, 0.2841])
 
Element 2: 0.87 * W_query row 2: tensor([0.2517, 0.6886])
Result of matrix multiplication for row 2: tensor([0.2190, 0.5990])
 
Element 3: 0.66 * W_query row 3: tensor([0.0740, 0.8665])
Result of matrix multiplication for row 3: tensor([0.0488, 0.5719])
 
Sum along the columns: tensor([0.4306, 1.4551])


<u>Note</u>: in the weight matrices, W, "weight" is short for "weight parameters", the values of a neural network that are optimised during training. This is not to be confused with attention weights, which determne the extent to which a context vector depends on the different parts of the input. So, weight parameters are the fundamental, learned coefficients, while attention weights are dynamic, context-specific values.

We still require the key and value vectors for all the input elements as they are involved in computing the weights with respect to the query, q<sup>2</sup>.

In [None]:
# For all the input vectors
keys = inputs @ W_key
values = inputs @ W_value
print("Keys.shape:", keys.shape)
print("Values.shape:", values.shape)

Keys.shape: torch.Size([6, 2])
Values.shape: torch.Size([6, 2])


Six input tokens projected from a 3-d embedding space to a 2-d one. Next, compute the attention scores.

In [None]:
# Attention score w_22 only
keys_2 = keys[1]
attn_score_22 = query_2.dot(key_2) # query vector . key vector to compute attention score
print(attn_score_22)

tensor(1.8524)


In [None]:
# To compute all attention scores for query 2
attn_scores_2 = query_2 @ keys.T
print(attn_scores_2) # Note second element matches the value above

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


To move from attention scores to attention weights, we scale the attention scores using the softmax function, but by diving them using the square root of the embedding dimension of the keys (same as exponentiating by 0.5). This is done to improve training performance by avoiding small gradients. GPT-like LLMs have greater than 1000 dimensions, so large dot products can result in very small gradients when the softmax function is applied. As dot products increase, the softmax function behaves more like a step function, resulting in gradients nearing 0, which slows down learning or causes training to stagnate. 

In [21]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


Lastly, we compute the context vectors. We do so by computing it as a weighted sum over the value vectors. Attention weights serve as a weightng factor that weighs the respective importance of each value vector. 

In [22]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


### Implementing a compact self-attention class

It is helpful to organise all the previous code into a class.

In [3]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

This class is a subclass of nn.Module, which provides necessary functionalities for model layer creation and management. 

The <i>init</i> method initialises trainable weight matrices, each transforming the input dimension d_in to an output dimension d_out.

The forward method computes the attention scores by multipling queries and keys, normalising these scores. Finally, it creates a context vector by weighting the values with these normalised attention scores.

In [7]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


Since inputs contains 6 embedding vectors, this results in a matrix storing the six context vectors, with 2 columns matching d_out. The second row matches what we got manually before.

We can improve the class further by utilising PyTorch's nn.Linear layers, which effectively perform matrix multiplication when the bias units are disabled. It also has an optimised weight initialisation scheme, contributing to more stable and effective model training.

In [8]:
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T 
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

In [None]:
# Different outputs due to different initial weights
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)
