<a href="https://colab.research.google.com/github/Nitinyad/LLM-from-scratch/blob/main/chapter_02_self_attention_with_trainable_weights.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Implementing self attention with trainable weights

In [21]:
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
    [0.55, 0.87, 0.66], # journey  (x^2)
    [0.57, 0.85, 0.64], # starts    (x^3)
    [0.22, 0.58, 0.33], # with      (x^4)
    [0.77, 0.25, 0.10], # one       (x^5)
    [0.05, 0.80, 0.55]] # step      (x^6)
)
print(inputs.shape)

torch.Size([6, 3])


In [22]:
x_2 = inputs[1] # second input element
d_in = inputs.shape[1] # the input embedding size , d = 3
d_out = 2 # the ouput embedding size , d_out= 2

in the gpt , normally the input and output dimensions are usually same .
to just illustrate we are using different dimensions.

initalize three weight matrices , Wq , Wk and Ww

In [23]:
torch.manual_seed(123) # by doing this the same random numbers will be generated every time you run your code.
Wq = torch.nn.Parameter(torch.rand(d_in, d_out) , requires_grad=False)
Wk = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
Wv = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [24]:
print(Wq)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])


for now we are setting the requires_grad = False to reduce clutter in the ouputs for illustrations purpose.
if we were to use the weight matrices for model training , we would set required_grad = True to update these matrices during model training.

Now , compute query , key and value vector

In [25]:
query_2 = x_2 @ Wq
key_2 = x_2 @ Wk
value_2 = x_2 @ Wv
print(query_2)

tensor([0.4306, 1.4551])


As we can see based on the output for the query , this result in 2-dim vector.
because we are setting the number of columns of corresponding weights matrix , via d_out to 2

even though our temporary goal is to only compute the one context vector z(2) we still require the key and value vectors for all inputs elements ,
this is because they are involved in computing the attention weights with respect to the query q(2)  

In [26]:
#we can obtain all keys and values via matrix multiplications
keys = inputs @ Wk
values = inputs @ Wv
querys = inputs @ Wq
print("key.shape" , keys.shape)
print("value.shape" , values.shape)
print("query.shape" , querys.shape)

key.shape torch.Size([6, 2])
value.shape torch.Size([6, 2])
query.shape torch.Size([6, 2])


As we can tell from the output , we successfully projected the 6 input tokens from a 3d to 2d embedding inputs;


In [27]:
key_2 = keys[1]
attn_score_22 = query_2.dot(key_2)
print(attn_score_22)

tensor(1.8524)


find for all attention scores via matrix multiplication

In [28]:
attn_score_22 = query_2 @ keys.T
print(attn_score_22)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [29]:
attn_scores = querys @ keys.T
print(attn_scores)

tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])


we compute the attention weights by scaling the attention scores and using the softmax function we used earlier.
the difference to eariler is that we now scale the attention scores by dividing them by square root of the embedding dimension of the keys.

In [30]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_score_22 / d_k ** 0.5 , dim = -1)
print(attn_weights_2)
print(d_k)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
2


## why divide by sqrt (dimension)

Reason 1 : for the stability in learning
The softmax function is sensitive to the magnitudes of its inputs. when the inputs are large , the difference between the exponential values of each input become much more pronounced. This causes the softmax output to become 'peaky',where the highest value receives almost all the probability mass , and the rest receive a very little.

In [31]:
import torch

tensor = torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])

softmax_result = torch.softmax(tensor , dim = -1)#softmax without scaling
print("softmax without scaling , " , softmax_result)

# softmax with scaling , mulitple the tensor with 8
scaled_tensor= tensor * 8
softmax_scaled_result = torch.softmax(scaled_tensor , dim = -1)
print("softmax with scaling , " , softmax_scaled_result)

softmax without scaling ,  tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
softmax with scaling ,  tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])


In the attention mechanisms , particularly in transformers, if the dot products between query and key vector become too large (like multipling with 8 ) , the attention scores can become very large. This results in a very sharp softmax distribution , making the model overly confident in one particular "key" . such sharp distribution can make learning unstable.

## Why sqrt??


Reason 2: to make the variance of the dot product stable.
the dot product of q and k increases the variance because multiplying two random numbers increases the variance .
the increase in variance grows with the dimensions.
dividing by sqrt(dimension) keeps the variance close to 1.

the reason why variance should be close to one is that if the variance increase a lot it make the learning very unstable , we don't want that we want to keep the standard deviation and variance close so that learing is stable (don't give random values and it should stay to 1 )

and by doing this we are also avoiding the computational issues.

In [32]:
import numpy as np


def compute_variance(dim , num_trails = 1000):
  dot_products = []
  scaled_dot_products = []

  for _ in range(num_trails):
    q = np.random.randn(dim)
    k = np.random.randn(dim)

    dot_product = np.dot(q, k )
    dot_products.append(dot_product)

    scaled_dot_product = np.dot(q , k) / np.sqrt(dim)

    scaled_dot_products.append(scaled_dot_product)

  variance = np.var(dot_products)
  scaled_variance = np.var(scaled_dot_products)

  return variance , scaled_variance

variance_before_5 , variance_after_5 = compute_variance(5)
print("Variance before scaling (dim = 5 )" , {variance_before_5})
print("Variance after scaling (dim = 5 )" , {variance_after_5})
variance_before_100 , variance_after_100 = compute_variance(100)
print("Variance before scaling (dim = 100 )" , {variance_before_100})
print("Variance after scaling (dim = 100 )" , {variance_after_100})


Variance before scaling (dim = 5 ) {np.float64(5.244715543858811)}
Variance after scaling (dim = 5 ) {np.float64(1.0489431087717622)}
Variance before scaling (dim = 100 ) {np.float64(100.44738499329094)}
Variance after scaling (dim = 100 ) {np.float64(1.0044738499329093)}


## Compute the context vector

as a weighted sum over the value vectors.
here , the attention weights serve as a weighting factor that weights the respective importance of each value vector.

we can use matrix multiplication to obtain the output:

In [33]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


## implementing a compact self attention python class

In [34]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
  def __init__(self, d_in , d_out):
    super().__init__()
    self.W_query = nn.Parameter(torch.rand(d_in , d_out))
    self.W_key = nn.Parameter(torch.rand(d_in , d_out))
    self.W_value = nn.Parameter(torch.rand(d_in , d_out))

  def forward(self , x):
    key = x @ self.W_key
    query = x @ self.W_query
    value = x @ self.W_value
    attn_scores = query @ key.T
    d_k = key.shape[-1]
    attn_weights = torch.softmax(attn_scores / (d_k ** 0.5) , dim = -1)
    context_vec = attn_weights @ value
    return context_vec


in this pytorch code , selfattention v1 is a class derived from nn.Module , which is a fundamental building block of Pytorch model , which provides neccessary functionalities for model layer creation and management.

The init method initializes trainable weight matrices (W_query , W_key , W_value) for queries , keys , values , and each transforming the input dimension d_in to an output dimension d_out.

using the forward pass method , we compute the attention scores by multipling queries and keys , normalizing these scores using softmax.



In [36]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in , d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


we can improve the selfAttention_V1 implementation further by utilizing pytorch's nn.Linear layers , which effectively perform matrix multiplication when the bias units are disabled.

advantage of using nn.Linear instead of manually implementing nn.Parameter(torch.randn()) is that nn.Linear has an optimized weight initialized scheme , contributing to more stable and effective model training.

In [37]:
class SelfAttention_v2(nn.Module):
  def __init__(self , d_in , d_out , qkv_bias = False):
    super().__init__()
    self.W_query = nn.Linear(d_in , d_out , bias = qkv_bias)
    self.W_key = nn.Linear(d_in , d_out , bias = qkv_bias)
    self.W_value = nn.Linear(d_in , d_out , bias = qkv_bias)
  def forward(self , x):
    keys = self.W_key(x)
    queries = self.W_query(x)
    values = self.W_value(x)
    attn_scores = queries @ keys.T

    attn_weights = torch.softmax(attn_scores / (d_k ** 0.5) , dim = -1)
    context_vec = attn_weights @ values
    return context_vec

In [38]:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in , d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


both SelfAttention_v1 and selfattention_v2 give different outputs because they use different initial weights for the weight matrices since nn.Linear uses a more sophisticated weight initialized scheme.