## Attention

Attention is an integral part of the LLM architecture.

In [1]:
import torch

  cpu = _conversion_method_template(device=torch.device("cpu"))


In [3]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your      (x^1)
     [0.55, 0.87, 0.66], # journey   (x^2)
     [0.57, 0.85, 0.64], # starts    (x^3)
     [0.22, 0.58, 0.33], # with      (x^4)
     [0.77, 0.25, 0.10], # one       (x^5)
     [0.05, 0.80, 0.55]  # step      (x^6)
    ]
)

Consider the above input sentence, which has already been embedded into three-dimensional vectors (as an example). 

The first step of implementing self-attention is to compute the intermediate values w, known as attention scores. We do this by computing the dot product between the query (as an example, x^2) and every other input token. <u>A dot product is the multiplication of two vectors element-wise and then summing the products</u>. It is a measure of similarity because it quantifies how closely two vectors are aligned: a higher dot product indicates a higher degree of similarity between the vectors. For self-attention, the dot product determines the extent to which each element in a sequence focuses on (attends to) any other element. The higher the dot product, the higher the similarity and attention score between two elements.

In [4]:
query = inputs[1] # x^2
atten_scores_2 = torch.empty(inputs.shape[0])

for i, x_i in enumerate(inputs):
    atten_scores_2[i] = torch.dot(x_i, query)

print(atten_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


We then normalise the scores to obtain attention weights that sum up to 1. This is useful for interpretation and maintaining training stability in an LLM. In practice we use the softmax function, which is better at managing extreme values and offers more favourable gradient properties during training.

In [5]:
atten_scores_2_tmp = atten_scores_2 / atten_scores_2.sum()
print("Attention weights:", atten_scores_2_tmp)
print("Sum:", atten_scores_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)
