## Attention

Attention is an integral part of the LLM architecture. To get a good understanding of it, we will implement 4 variants of attention mechanisms which build on each other, with the goal of arriving at a compact and efficient implementation of multi-head attention to plug into the LLM architecture.

In [1]:
import torch

  cpu = _conversion_method_template(device=torch.device("cpu"))


### 1. A simple self-attention mechanism without trainable weights

This will highlight a few key concepts in self-attention before adding trainable weights.

An input sequence is denoted as x, consisting of n elements represented as x^1 to x^n. This sequence represents text that has already been transformed into token embeddings. E.g. below our input text is "Your journey starts with one step". Each element of the sequence, such as x^1, corresponds to a d-dimensional embedding vector representing a single token, lke "Your" (in the example below, d = 3).

In self-attention, the goal is to calculate context vectors z^i for each element x^i in the input sequence. A <i>context vector</i> can be interpreted as an enriched embedding vector. Each context vector z^i contains information about x^i and all other input elements, x^1 to x^n. This is essential in an LLM, which needs to understand the relationship and relevance of words in a sentence to each other. In practice, trainable weights help an LLM learn these context vectors to help it generate the next token.

In [None]:
# Input sequence
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your      (x^1)
     [0.55, 0.87, 0.66], # journey   (x^2)
     [0.57, 0.85, 0.64], # starts    (x^3)
     [0.22, 0.58, 0.33], # with      (x^4)
     [0.77, 0.25, 0.10], # one       (x^5)
     [0.05, 0.80, 0.55]  # step      (x^6)
    ]
)

The first step of implementing self-attention is to compute the intermediate values w, known as attention scores. We do this by computing the dot product between the query (as an example, x^2) and every other input token. <u>A dot product is the multiplication of two vectors element-wise and then summing the products</u>. It is a measure of similarity because it quantifies how closely two vectors are aligned: a higher dot product indicates a higher degree of similarity between the vectors. For self-attention, the dot product determines the extent to which each element in a sequence focuses on (attends to) any other element. The higher the dot product, the higher the similarity and attention score between two elements.

In [4]:
query = inputs[1] # x^2
atten_scores_2 = torch.empty(inputs.shape[0])

for i, x_i in enumerate(inputs):
    atten_scores_2[i] = torch.dot(x_i, query)

print(atten_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


We then normalise the scores to obtain attention weights that sum up to 1. This is useful for interpretation and maintaining training stability in an LLM. In practice we use the softmax function, which is better at managing extreme values and offers more favourable gradient properties during training.

In [5]:
atten_scores_2_tmp = atten_scores_2 / atten_scores_2.sum()
print("Attention weights:", atten_scores_2_tmp)
print("Sum:", atten_scores_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


In [6]:
# With softmax
atten_weights_2 = torch.softmax(atten_scores_2, dim=0)
print("Attention weights:", atten_weights_2)
print("Sum:", atten_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


Now that we have normalised attention weights, we calculate the context vector z^2 by multiplying the embedded inputs tokens x^i, with the corresponding attention weights and then summing the resulting vectors.

In [None]:
# For x^2 alone
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)

for i, x_i in enumerate(inputs):
    context_vec_2 += atten_weights_2[i] * x_i

print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


In [9]:
# For all input tokens
attn_scores = torch.empty(6, 6)

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)

print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


Above, each element in the tensor represents an attention score between each pair of inputs. The above can be done more concisely without the double for loop.

In [11]:
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [None]:
# Normalise so each row sums to 1
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights) 

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


Above, by settng dim = -1, we instruct the softmax function to normalise along the last dimension of the tensor. Since it is a 2d-tensor [rows, columns], it will normalise across the columns so that the values in the rows sum to 1. 

In [13]:
# Context vectors via matrix multiplication
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


### 2. Implementing self-attention with trainable weights

This is the mechanism used in the original Transformer architecture - also called <i>scaled dot-product attention</i>.