<a href="https://colab.research.google.com/github/RCortez25/PhD/blob/main/LLM/4.%20Attention%20mechanism/1_Self_attention_with_trainable_weights.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This kind of self attention mechanism is also called **scaled dot-product attention**.

Now, we introduce trainable weight matrices that are updated during training. The are used so that the model learns to produce good context vectors. These matrices are:

* $W_q$ for queries
* $W_k$ for keys
* $W_v$ for values

These project the input vectors $x^{(i)}$ into yet another new query, key, and value vectors for each input vector.

In [1]:
import torch

# Create the input tensors
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
     [0.55, 0.87, 0.66], # journey  (x^2)
     [0.57, 0.85, 0.64], # starts   (x^3)
     [0.22, 0.58, 0.33], # with     (x^4)
     [0.77, 0.25, 0.10], # one      (x^5)
     [0.05, 0.80, 0.55]] # step     (x^6)
)

In [2]:
# Select the second input vector for demonstration purposes
x_2 = inputs[1]

# Select the dimension of the inputs. 3D in this case
dimension_inputs = inputs.shape[1]

# Select the output dimension for the projected vectors. 2D in this case
dimension_outputs = 2

In [3]:
# Define the weight matrices, initialized in a random manner
# requires_grad=False because we're not training yet
torch.manual_seed(1)
W_q = torch.nn.Parameter(torch.randn(dimension_inputs, dimension_outputs), requires_grad=False)
W_k = torch.nn.Parameter(torch.randn(dimension_inputs, dimension_outputs), requires_grad=False)
W_v = torch.nn.Parameter(torch.randn(dimension_inputs, dimension_outputs), requires_grad=False)

In [4]:
print(W_q)

Parameter containing:
tensor([[ 0.6614,  0.2669],
        [ 0.0617,  0.6213],
        [-0.4519, -0.1661]])


In [5]:
print(W_k)

Parameter containing:
tensor([[-1.5228,  0.3817],
        [-1.0276, -0.5631],
        [-0.8923, -0.0583]])


In [6]:
print(W_v)

Parameter containing:
tensor([[-0.1955, -0.9656],
        [ 0.4224,  0.2673],
        [-0.4212, -0.5107]])


Now, in this case, let's obtain the query, key, and value vectors for the input vector $x^{(2)}$ corresponding to "journey".

In [7]:
query_2 = x_2 @ W_q
key_2 = x_2 @ W_k
value_2 = x_2 @ W_v

print(query_2)
print(key_2)
print(value_2)

tensor([0.1191, 0.5777])
tensor([-2.3205, -0.3184])
tensor([-0.0180, -0.6356])


Now, obtain the key, query, and value vectors for all input vectors at once.

In [8]:
keys = inputs @ W_k
queries = inputs @ W_q
values = inputs @ W_v

print("Keys:")
print(keys)
print("Queries:")
print(queries)
print("Values:")
print(values)

Keys:
tensor([[-1.6031,  0.0278],
        [-2.3205, -0.3184],
        [-2.3125, -0.2983],
        [-1.2255, -0.2618],
        [-1.5187,  0.1473],
        [-1.3890, -0.4634]])
Queries:
tensor([[-0.1086,  0.0601],
        [ 0.1191,  0.5777],
        [ 0.1402,  0.5739],
        [ 0.0321,  0.3643],
        [ 0.4795,  0.3442],
        [-0.1661,  0.4190]])
Values:
tensor([[-0.3956, -0.8296],
        [-0.0180, -0.6356],
        [-0.0220, -0.6500],
        [ 0.0630, -0.2259],
        [-0.0871, -0.7278],
        [ 0.0965, -0.1153]])


As can be seen, one has 3 matrices containing 6 vectors (rows) each, all of dimension 2 (columns).

## Attention scores

Now, recalling the previous topic, one calculated the attention scores between the vector of interest $x^{(2)}$ and the other input vectors. Now, in this case one calculates those attention scores using the query and the keys for each input.

Let's then calculate the attention scores using only $x^{(2)}$, this will give us a matrix of 6 attention scores containing the importance of $x^{(2)}$ with respect to the other input vectors.

In [9]:
# Calculate the attention scores matrix for x^2 only

query_2 = queries[1] # x^2 as a query

attention_scores_2 = query_2 @ keys.T
print(attention_scores_2)

tensor([-0.1749, -0.4604, -0.4479, -0.2973, -0.0958, -0.4332])


These are the 6 attention scores of $x^{(2)}$ with respecto to all other input vectors.

Now, let's calculate the big attention scores matrix for all input vectors at once.

In [10]:
attention_scores = queries @ keys.T # These are called omega
print(attention_scores)

tensor([[ 0.1757,  0.2328,  0.2331,  0.1173,  0.1737,  0.1229],
        [-0.1749, -0.4604, -0.4479, -0.2973, -0.0958, -0.4332],
        [-0.2087, -0.5080, -0.4954, -0.3221, -0.1283, -0.4607],
        [-0.0414, -0.1906, -0.1830, -0.1348,  0.0048, -0.2134],
        [-0.7590, -1.2222, -1.2115, -0.6777, -0.6774, -0.8255],
        [ 0.2780,  0.2521,  0.2592,  0.0939,  0.3140,  0.0366]])


As can be seen, each row corresponds to the attention scores of each input vector with respect to all the others.

## Attention weights

Now, as before, we'll calculate attention weights by normalizing the attention scores, but in this case, we scale the result by the square root of the dimension of the keys. This has to do with the sensitivity of the magnitude of the inputs. This then gives stability when learning. The square root has to do with the variance, as variance grows with dimensions, so dividing by the square root of the dimension keeps the variance close to 1.

In [13]:
dimension_keys = keys.shape[-1]
attention_weights = torch.softmax(attention_scores / (dimension_keys ** 0.5), dim=-1)
print(attention_weights)

tensor([[0.1666, 0.1734, 0.1735, 0.1598, 0.1663, 0.1605],
        [0.1835, 0.1500, 0.1513, 0.1683, 0.1941, 0.1529],
        [0.1837, 0.1487, 0.1500, 0.1695, 0.1944, 0.1537],
        [0.1767, 0.1590, 0.1599, 0.1654, 0.1826, 0.1565],
        [0.1812, 0.1306, 0.1316, 0.1919, 0.1919, 0.1729],
        [0.1750, 0.1718, 0.1727, 0.1536, 0.1795, 0.1475]])


In [14]:
# Check that the results are normalized
print(attention_weights.sum(dim=-1))

tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


## Context vectors

Finally, we are ready to calculate the context vectors just like we did with the simplified self-attention mechanism. In this case, once we have the attention weights, we multiply each attention weights corresponding to each input vector with the corresponding values, that is, the corresponding row in the values matrix. We sum them up and that gives the context vector.

Recall that every row in the attention weights matrix and every row in the values matrix correspond to a particular input token.

Also, recall that a particular context vector is obtained by scaling the original input vectors by the attention weights, then adding them up so that it is a regular vector addition.

In [15]:
# Calculate the context vector matrix for all tokens at once
context_vectors = attention_weights @ values
print(context_vectors)

tensor([[-0.0617, -0.5368],
        [-0.0702, -0.5428],
        [-0.0700, -0.5419],
        [-0.0666, -0.5399],
        [-0.0648, -0.5218],
        [-0.0678, -0.5489]])


## Self-attention class

We will now implement a Python class to calculate self-attention.

In [16]:
import torch.nn as nn

class SelfAttentionV1(nn.Module):
    # Initialize the object with the number of input and output dimensions
    def __init__(self, dimension_inputs, dimension_outputs):
        super().__init__()
        # Initialize the matrices
        self.W_q = nn.Parameter(torch.randn(dimension_inputs, dimension_outputs), requires_grad=False)
        self.W_k = nn.Parameter(torch.randn(dimension_inputs, dimension_outputs), requires_grad=False)
        self.W_v = nn.Parameter(torch.randn(dimension_inputs, dimension_outputs), requires_grad=False)

    # Method to calculate the context vector
    def forward(self, input_vectors):
        # Calculate the query, key, and value vectors
        queries = input_vectors @ self.W_q
        keys = input_vectors @ self.W_k
        values = input_vectors @ self.W_v

        # Calculate attention scores, that is, omega
        attention_scores = queries @ keys.T

        # Calculate attention weights
        dimension_keys = keys.shape[-1]
        attention_weights = torch.softmax(attention_scores / (dimension_keys ** 0.5), dim=-1)

        # Calculate the context vectors
        context_vectors = attention_weights @ values

        # return context vectors
        return context_vectors