<a href="https://colab.research.google.com/github/RCortez25/PhD/blob/main/LLM/4.%20Attention%20mechanism/1_Self_attention_with_trainable_weights.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This kind of self attention mechanism is also called **scaled dot-product attention**.

Now, we introduce trainable weight matrices that are updated during training. The are used so that the model learns to produce good context vectors. These matrices are:

* $W_q$ for queries
* $W_k$ for keys
* $W_v$ for values

These project the input vectors $x^{(i)}$ into yet another new query, key, and value vectors for each input vector.

In [1]:
import torch

# Create the inpur tensors
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
     [0.55, 0.87, 0.66], # journey  (x^2)
     [0.57, 0.85, 0.64], # starts   (x^3)
     [0.22, 0.58, 0.33], # with     (x^4)
     [0.77, 0.25, 0.10], # one      (x^5)
     [0.05, 0.80, 0.55]] # step     (x^6)
)

In [3]:
# Select the second input vector for demonstration purposes
x_2 = inputs[1]

# Select the dimension of the inputs. 3D in this case
dimension_inputs = inputs.shape[1]

# Select the output dimension for the projected vectors. 2D in this case
dimension_outputs = 2

In [9]:
# Define the weight matrices, initialized in a random manner
# requires_grad=False because we're not training yet
torch.manual_seed(42)
W_q = torch.nn.Parameter(torch.randn(dimension_inputs, dimension_outputs), requires_grad=False)
W_k = torch.nn.Parameter(torch.randn(dimension_inputs, dimension_outputs), requires_grad=False)
W_v = torch.nn.Parameter(torch.randn(dimension_inputs, dimension_outputs), requires_grad=False)

In [10]:
print(W_q)

Parameter containing:
tensor([[ 0.3367,  0.1288],
        [ 0.2345,  0.2303],
        [-1.1229, -0.1863]])


In [11]:
print(W_k)

Parameter containing:
tensor([[ 2.2082, -0.6380],
        [ 0.4617,  0.2674],
        [ 0.5349,  0.8094]])


In [12]:
print(W_v)

Parameter containing:
tensor([[ 1.1103, -1.6898],
        [-0.9890,  0.9580],
        [ 1.3221,  0.8172]])


Now, in this case, let's obtain the query, key, and value vectors for the input vector $x^{(2)}$ corresponding to "journey".

In [13]:
query_2 = x_2 @ W_q
key_2 = x_2 @ W_k
value_2 = x_2 @ W_v

print(query_2)
print(key_2)
print(value_2)

tensor([-0.3519,  0.1483])
tensor([1.9692, 0.4159])
tensor([0.6229, 0.4434])


Now, obtain the key, query, and value vectors for all input vectors at once.

In [16]:
keys = inputs @ W_k
queries = inputs @ W_q
values = inputs @ W_v

print("Keys:")
print(keys)
print("Queries:")
print(queries)
print("Values:")
print(values)

Keys:
tensor([[ 1.4948,  0.4861],
        [ 1.9692,  0.4159],
        [ 1.9934,  0.3816],
        [ 0.9301,  0.2818],
        [ 1.8692, -0.3435],
        [ 0.7739,  0.6271]])
Queries:
tensor([[-0.8194, -0.0759],
        [-0.3519,  0.1483],
        [-0.3274,  0.1500],
        [-0.1605,  0.1004],
        [ 0.2056,  0.1381],
        [-0.4132,  0.0882]])
Values:
tensor([[ 1.5058,  0.1444],
        [ 0.6229,  0.4434],
        [ 0.6384,  0.3741],
        [ 0.1070,  0.4535],
        [ 0.7399, -0.9799],
        [-0.0085,  1.1313]])


As can be seen, one has 3 matrices containing 6 vectors (rows) each, all of dimension 2 (columns).