# Self-attention with trainable  Query, Key, Value weights
- Todo:
- create query, key, value weight matrices to transofrm the input embeddings.
- We need these weight metrices because without these the relationship between tokens like "journey" and "starts" would indeed be fixed across all contexts. We'd be limited to using only the original embedding space to determine attention
- We multiply the input embeddings with the query, key, value weight matrices
- we derive the attention scores by multiplying the query and the key matrix.
- These attention scores are normalised by applying softmax to derive attention weights.
- The attention weight matrix is multiplied by the value matrix to derive our context vecots. 

In [1]:
import torch 
inputs = torch.tensor(
    [[0.43, 0.15, 0.89],# Your (x^1)
   [0.55, 0.87, 0.66],  # journey  (x^2)
   [0.57, 0.85, 0.64],  # starts   (x^3)
   [0.22, 0.58, 0.33],  # with     (x^4)
   [0.77, 0.25, 0.10],  # one      (x^5)
   [0.05, 0.80, 0.55]]  # step     (x^6)
)

In [2]:
inputs.shape

torch.Size([6, 3])

In [4]:
# To create the weight matrices we will have to use the shape : inputs.shape[1] x no. of o/p dims we want to have
#Note that in GPT-like models, the input and output dimensions are usually the same.
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [7]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)

- Tehse are the query, key, value weight matrices initialised with random values, these will be trained.
- Note that we are setting requires_grad=False to reduce clutter in the outputs for illustration purposes.
- If we were to use the weight matrices for model training, we would set requires_grad=True to update these matrices during model training.

In [8]:
# computing the query, key, and value matrices for the input embedding 
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2, key_2)

tensor([0.4306, 1.4551]) tensor([0.4433, 1.1419])


- Even though our temporary goal is to only compute the one context vector z(2), we still require the key and value vectors for all input elements.
- This is because they are involved in computing the attention weights with respect to the query.

In [9]:
keys = inputs @ W_key
values = inputs @ W_value
print (f"keys shape :{keys.shape}, values shape :{values.shape}") 

keys shape :torch.Size([6, 2]), values shape :torch.Size([6, 2])


In [11]:
# attention score of query_2 with respect to only keys[1]
keys_2 = keys[1]
attention_score_query_2 = query_2.dot(keys_2)
print(attention_score_query_2)

tensor(1.8524)


In [12]:
# lets find out attention score for  query_2 with respect to all keys
attention_scores_query_2 = query_2 @ keys.T
print(attention_scores_query_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


- We compute the attention weights by scaling the attention scores and using the softmax function we used earlier.
- The difference to earlier is that we now scale the attention scores by dividing them by the square root of the embedding dimension of the keys.

In [14]:
dim_k = keys.shape[-1]
attention_weights_q2 = torch.softmax((attention_scores_query_2 /dim_k ** 0.5) , dim = -1)
print(f"attention weights for q2: {attention_weights_q2}")

attention weights for q2: tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
