# Self-attention with trainable  Query, Key, Value weights
- Todo:
- create query, key, value weight matrices to transofrm the input embeddings.
- We need these weight metrices because without these the relationship between tokens like "journey" and "starts" would indeed be fixed across all contexts. We'd be limited to using only the original embedding space to determine attention
- We multiply the input embeddings with the query, key, value weight matrices
- we derive the attention scores by multiplying the query and the key matrix.
- These attention scores are normalised by applying softmax to derive attention weights.
- The attention weight matrix is multiplied by the value matrix to derive our context vecots. 

In [22]:
import torch 
inputs = torch.tensor(
    [[0.43, 0.15, 0.89],# Your (x^1)
   [0.55, 0.87, 0.66],  # journey  (x^2)
   [0.57, 0.85, 0.64],  # starts   (x^3)
   [0.22, 0.58, 0.33],  # with     (x^4)
   [0.77, 0.25, 0.10],  # one      (x^5)
   [0.05, 0.80, 0.55]]  # step     (x^6)
)

In [23]:
inputs.shape

torch.Size([6, 3])

In [24]:
# To create the weight matrices we will have to use the shape : inputs.shape[1] x no. of o/p dims we want to have
#Note that in GPT-like models, the input and output dimensions are usually the same.
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [25]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)

- Tehse are the query, key, value weight matrices initialised with random values, these will be trained.
- Note that we are setting requires_grad=False to reduce clutter in the outputs for illustration purposes.
- If we were to use the weight matrices for model training, we would set requires_grad=True to update these matrices during model training.

In [26]:
# computing the query, key, and value matrices for the input embedding 
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(query_2, key_2)

tensor([0.4306, 1.4551]) tensor([0.4433, 1.1419])


- Even though our temporary goal is to only compute the one context vector z(2), we still require the key and value vectors for all input elements.
- This is because they are involved in computing the attention weights with respect to the query.

In [27]:
keys = inputs @ W_key
values = inputs @ W_value
print (f"keys shape :{keys.shape}, values shape :{values.shape}") 

keys shape :torch.Size([6, 2]), values shape :torch.Size([6, 2])


In [28]:
# attention score of query_2 with respect to only keys[1]
keys_2 = keys[1]
attention_score_query_2 = query_2.dot(keys_2)
print(attention_score_query_2)

tensor(1.8524)


In [29]:
# lets find out attention score for  query_2 with respect to all keys
attention_scores_query_2 = query_2 @ keys.T
print(attention_scores_query_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


- We compute the attention weights by scaling the attention scores and using the softmax function we used earlier.
- The difference to earlier is that we now scale the attention scores by dividing them by the square root of the embedding dimension of the keys.

In [30]:
dim_k = keys.shape[-1]
attention_weights_q2 = torch.softmax((attention_scores_query_2 /dim_k ** 0.5) , dim = -1)
print(f"attention weights for q2: {attention_weights_q2}")

attention weights for q2: tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


- We now compute the context vector as a weighted sum over the value vectors.
- Here, the attention weights serve as a weighting factor that weighs the respective importance of each value vector.
- We can use matrix multiplication to obtain the output in one step:

In [31]:
context_vectors_z2 = attention_weights_q2 @ values
print(context_vectors_z2)

tensor([0.3061, 0.8210])


- Now, lets implement a class so that we could do this for all the input embeddings. 

In [32]:
import torch.nn as nn
class SelfAttentionV1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
        self.W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
        self.W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad = False)
    def forward(self, x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value

        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores / keys.shape[-1] ** 0.5, dim = -1)
        context_vectors = attention_weights @ values 
        return context_vectors
        
        
    

In [33]:
sa_v1 = SelfAttentionV1(d_in, d_out)
print(sa_v1(inputs))

tensor([[1.4035, 1.0391],
        [1.4410, 1.0669],
        [1.4391, 1.0655],
        [1.3786, 1.0178],
        [1.3653, 1.0086],
        [1.4025, 1.0361]])


- These are the context vectors corrensponing to our inputs
- Since inputs contains six embedding vectors, we get a matrix storing the six context vectors, as shown in the above result.
- We can improve the SelfAttention_v1 implementation further by utilizing PyTorch's nn.Linear layers, which effectively perform matrix multiplication when the bias units are disabled.
- Additionally, a significant advantage of using nn.Linear instead of manually implementing nn.Parameter(torch.rand(...)) is that nn.Linear has an optimized weight initialization scheme, contributing to more stable and effective model training.

In [40]:
class SelfAttentionV2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

In [41]:
torch.manual_seed(789)
sa_v2 = SelfAttentionV2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)
