IMPLEMENTING SELF ATTENTION WITH TRAINABLE WEIGHTS

In [4]:
pip install torch



In [5]:
import torch
inputs = torch.tensor([
    [0.43,0.15,0.89], #Your
    [0.55,0.87,0.66], #Journey
    [0.57,0.85,0.64], #starts
    [0.22,0.58,0.33], #with
    [0.77,0.25,0.10], #one
    [0.05,0.80,0.55]]) #step

In [6]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2
torch.manual_seed(123)

#Weight matrices
w_query = torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)
w_key = torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)
w_value = torch.nn.Parameter(torch.rand(d_in,d_out),requires_grad=False)

print(w_query)
print(w_key)
print(w_value)

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])
Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])
Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])


Note that we are setting requires_grad=False to reduce clutter in the outputs for illustration purposes.
If we were to use the weight matrices for model training, we would set require_grad = True to update these matrices during training.

Next , we compute the query, key and value vectors as shown earlier.

In [7]:
query_2 = x_2 @ w_query
key_2 = x_2 @ w_key
value_2 = x_2 @ w_value

print(query_2)
print(key_2)
print(value_2)

tensor([0.4306, 1.4551])
tensor([0.4433, 1.1419])
tensor([0.3951, 1.0037])


We can obtain all keys and values via matrix multiplication



In [8]:
keys = inputs @ w_key
values = inputs @ w_value
queries = inputs @ w_query

print("keys.shape",keys.shape)
print("values.shape",values.shape)
print("queries.shape",queries.shape)


keys.shape torch.Size([6, 2])
values.shape torch.Size([6, 2])
queries.shape torch.Size([6, 2])


In [9]:
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(1.8524)


Again, we can generalize this computation to all attention scores via matrix multiplicaiton

In [10]:
attn_scores_2 = query_2 @ keys.T #All attention scores for given query
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


In [11]:
attn_scores = queries @ keys.T
print(attn_scores)

tensor([[0.9231, 1.3545, 1.3241, 0.7910, 0.4032, 1.1330],
        [1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
        [1.2544, 1.8284, 1.7877, 1.0654, 0.5508, 1.5238],
        [0.6973, 1.0167, 0.9941, 0.5925, 0.3061, 0.8475],
        [0.6114, 0.8819, 0.8626, 0.5121, 0.2707, 0.7307],
        [0.8995, 1.3165, 1.2871, 0.7682, 0.3937, 1.0996]])


Scaling by square root of keys dimension before applying softmax.

In [12]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2/d_k**0.5,dim=-1)
print(attn_weights_2)
print(d_k)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])
2


WHY DIVIDE BY SQRT (DIMENSION)


Reason 1: For stability in learning

The softmax function is sensitive to the magnitudes of its inputs. When the inputs are large, the differences between the exponential values of each input become much more pronounced. This causes the softmax output to become "peaky," where the highest value receives almost all the probability mass, and the rest receive very little.

In attention mechanisms, particularly in transformers, if the dot products between query and key vectors become too large (like multiplying by 8 in this example), the attention scores can become very large. This results in a very sharp softmax distribution, making the model overly confident in one particular "key." Such sharp distributions can make learning unstable,

In [13]:
import torch

# Define the tensor
tensor = torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])

# Apply softmax without scaling
softmax_result = torch.softmax(tensor, dim=-1)
print("Softmax without scaling:", softmax_result)

# Multiply the tensor by 8 and then apply softmax
scaled_tensor = tensor * 8
softmax_scaled_result = torch.softmax(scaled_tensor, dim=-1)
print("Softmax after scaling (tensor * 8):", softmax_scaled_result)

Softmax without scaling: tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
Softmax after scaling (tensor * 8): tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])


WHY SQRT ?
Reason 2: To make the variance of the dot product stable

The dot product of  Q and K increases the variance because multiplying two random numbers increases the variance.

The increase in variance grows with the dimension.

Dividing by sqrt (dimension) keeps the variance close to 1

In [14]:
import numpy as np
# Function to compute variance before and after scaling
def compute_variance(dim, num_trials=1000):
    dot_products = []
    scaled_dot_products = []

    # Generate multiple random vectors and compute dot products
    for _ in range(num_trials):
        q = np.random.randn(dim)
        k = np.random.randn(dim)

        # Compute dot product
        dot_product = np.dot(q, k)
        dot_products.append(dot_product)

        # Scale the dot product by sqrt(dim)
        scaled_dot_product = dot_product / np.sqrt(dim)
        scaled_dot_products.append(scaled_dot_product)

    # Calculate variance of the dot products
    variance_before_scaling = np.var(dot_products)
    variance_after_scaling = np.var(scaled_dot_products)

    return variance_before_scaling, variance_after_scaling

# For dimension 5
variance_before_5, variance_after_5 = compute_variance(5)
print(f"Variance before scaling (dim=5): {variance_before_5}")
print(f"Variance after scaling (dim=5): {variance_after_5}")

# For dimension 20
variance_before_100, variance_after_100 = compute_variance(100)
print(f"Variance before scaling (dim=100): {variance_before_100}")
print(f"Variance after scaling (dim=100): {variance_after_100}")



Variance before scaling (dim=5): 5.265040115409276
Variance after scaling (dim=5): 1.0530080230818553
Variance before scaling (dim=100): 92.70069475714222
Variance after scaling (dim=100): 0.9270069475714222


In [15]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


IMPLEMENTING A COMPACT SELF ATTENTION PYTHONC LASS

In [16]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
  def __init__(self,d_in,d_out):
    super().__init__()
    self.w_query = nn.Parameter(torch.rand(d_in,d_out))
    self.w_key = nn.Parameter(torch.rand(d_in,d_out))
    self.w_value = nn.Parameter(torch.rand(d_in,d_out))

  def forward(self,x):
    keys = x @ self.w_key
    values = x @ self.w_value
    queries = x @ self.w_query

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
    context_vec = attn_weights @ values
    return context_vec

torch.manual_seed(123)
self_attn = SelfAttention_v1(d_in,d_out)
print(self_attn(inputs))


tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


Using nn.Linear instead of random vector for key ,, query value as nn.Linear has optimized weight initialization scheme, contributing to more stable and effective model training.

In [17]:
class SelfAttention_v2(nn.Module):
  def __init__(self, d_in, d_out, qkv_bias = False):
    super().__init__()
    self.w_query = nn.Linear(d_in,d_out,bias=qkv_bias)
    self.w_key = nn.Linear(d_in,d_out,bias=qkv_bias)
    self.w_value = nn.Linear(d_in,d_out,bias=qkv_bias)

  def forward(self,x):
    keys = self.w_key(x)
    queries = self.w_query(x)
    values = self.w_value(x)

    attn_scores = queries @ keys.T
    attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
    context_vec = attn_weights @ values
    return context_vec

torch.manual_seed(789)
self_attn_v2 = SelfAttention_v2(d_in,d_out)
print(self_attn_v2(inputs))



tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


HIDING FUTURE WORDS WITH CAUSAL ATTENTION

In [18]:
inputs = torch.tensor([
    [0.43,0.15,0.89], #Your
    [0.55,0.87,0.66], #Journey
    [0.57,0.85,0.64], #starts
    [0.22,0.58,0.33], #with
    [0.77,0.25,0.10], #one
    [0.05,0.80,0.55]]) #step

queries = self_attn_v2.w_query(inputs)
keys = self_attn_v2.w_key(inputs)
values = self_attn_v2.w_value(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


Using pytorch tril function to create a mask where the values above the diagonal are zero

In [19]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length,context_length))
print(mask_simple)


tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [20]:
masked_simple = mask_simple * attn_weights
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


In [21]:
row_sums = masked_simple.sum(dim=1,keepdim=True)
masked_simple_norm = masked_simple/row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


Due to softmax applying above, the above approach future tokens have already influenced all the inputs . This leads to data leakage problem. To avoid this , there is smarter way to do renormalization.

In [22]:
print(attn_scores)
torch.triu(torch.ones(context_length, context_length))

tensor([[ 0.2899,  0.0716,  0.0760, -0.0138,  0.1344, -0.0511],
        [ 0.4656,  0.1723,  0.1751,  0.0259,  0.1771,  0.0085],
        [ 0.4594,  0.1703,  0.1731,  0.0259,  0.1745,  0.0090],
        [ 0.2642,  0.1024,  0.1036,  0.0186,  0.0973,  0.0122],
        [ 0.2183,  0.0874,  0.0882,  0.0177,  0.0786,  0.0144],
        [ 0.3408,  0.1270,  0.1290,  0.0198,  0.1290,  0.0078]],
       grad_fn=<MmBackward0>)


tensor([[1., 1., 1., 1., 1., 1.],
        [0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.]])

In [23]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [24]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [25]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


DROPOUT 0.5 Example

In [26]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)
example = torch.ones(6,6)
print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


In [27]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)


IMPLEMENTING A COMPACT CAUSAL ATTENTION CLASS

In [28]:
batch = torch.stack((inputs, inputs),dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [29]:
class CausalAttention(nn.Module):
  def __init__(self,d_in,d_out,context_length,dropout,qkv_bias=False):
    super().__init__()
    self.d_out = d_out
    self.w_query = nn.Linear(d_in,d_out,bias=qkv_bias)
    self.w_key = nn.Linear(d_in,d_out,bias=qkv_bias)
    self.w_value = nn.Linear(d_in,d_out,bias=qkv_bias)
    self.dropout = nn.Dropout(dropout)
    self.register_buffer('mask',torch.triu(torch.ones(context_length,context_length),diagonal=1))

  def forward(self,x):
    b, num_tokens, d_in = x.shape
    keys = self.w_key(x)
    queries = self.w_query(x)
    values = self.w_value(x)

    attn_scores = queries @ keys.transpose(1,2) # Batches will be processed sequentially
    attn_scores.masked_fill(self.mask.bool()[:num_tokens,:num_tokens],-torch.inf) #[:num tokens to handle if token length is less than context size]
    attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5,dim=-1)
    attn_weights = self.dropout(attn_weights)
    context_vec = attn_weights @ values
    return context_vec

In [30]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in,d_out,context_length,0.0)
context_vecs = ca(batch)
print("context_vecs.shape:",context_vecs.shape)
print(context_vecs)

context_vecs.shape: torch.Size([2, 6, 2])
tensor([[[-0.5337, -0.1051],
         [-0.5323, -0.1080],
         [-0.5323, -0.1079],
         [-0.5297, -0.1076],
         [-0.5311, -0.1066],
         [-0.5299, -0.1081]],

        [[-0.5337, -0.1051],
         [-0.5323, -0.1080],
         [-0.5323, -0.1079],
         [-0.5297, -0.1076],
         [-0.5311, -0.1066],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)


EXTENDING SINGLE HEAD ATTENTION TO MULTI HEAD ATTENTION

In [31]:
inputs = torch.tensor([
    [0.43,0.15,0.89], #Your
    [0.55,0.87,0.66], #Journey
    [0.57,0.85,0.64], #starts
    [0.22,0.58,0.33], #with
    [0.77,0.25,0.10], #one
    [0.05,0.80,0.55]]) #step

batch = torch.stack((inputs, inputs),dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [32]:
class MultiHeadAttention(nn.Module):
  def __init__(self,d_in,d_out,context_length,dropout,num_heads,qkv_bias=False):
    super().__init__()
    self.heads = nn.ModuleList([CausalAttention(d_in,d_out,context_length,dropout,qkv_bias) for _ in range(num_heads)])

  def forward(self,x):
    return torch.cat([head(x) for head in self.heads],dim=-1)

In [33]:
torch.manual_seed(123)
context_length = batch.shape[1]
d_in, d_out = 3,2
mha = MultiHeadAttention(d_in,d_out,context_length,0.0,num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape",context_vecs.shape)

tensor([[[-0.5337, -0.1051,  0.5085,  0.3508],
         [-0.5323, -0.1080,  0.5084,  0.3508],
         [-0.5323, -0.1079,  0.5084,  0.3506],
         [-0.5297, -0.1076,  0.5074,  0.3471],
         [-0.5311, -0.1066,  0.5076,  0.3446],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.5337, -0.1051,  0.5085,  0.3508],
         [-0.5323, -0.1080,  0.5084,  0.3508],
         [-0.5323, -0.1079,  0.5084,  0.3506],
         [-0.5297, -0.1076,  0.5074,  0.3471],
         [-0.5311, -0.1066,  0.5076,  0.3446],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape torch.Size([2, 6, 4])


IMPLEMENTING MULTI-HEAD ATTENTION WITH WEIGHT SPLITS - Toy example

STEP1 : START WITH THE INPUT
b, num_tokens, d_in = (1,3,6)

In [34]:
x = torch.tensor([[[1.0,2.0,3.0,4.0,5.0,6.0],
                   [6.0,5.0,4.0,3.0,2.0,1.0],
                   [1.0,1.0,1.0,1.0,1.0,1.0]]])
print(x)

tensor([[[1., 2., 3., 4., 5., 6.],
         [6., 5., 4., 3., 2., 1.],
         [1., 1., 1., 1., 1., 1.]]])


STEP2 : DECIDE d_out, num_heads which is 6, 2
d_out = 2nd dimension of context embedding vector

head_dim = d_out/no_of_heads = 6/2 = 3

STEP3: INITIALIZE TRAINABLE WEIGHT MATRICES FOR key (6 *6 ), query (6 *6 ) and value (6 *6) value


In [35]:
wq = torch.tensor([[0.6323,-0.2366,1.2455,0.3465,1.2458,0.3229],
             [0.6571,-0.2378,-0.5311,-0.2610,-1.4819,-1.6418],
             [-0.2990,0.4216,0.2114,-0.0271,-0.5682,0.6937],
             [-1.1291,-1.0102,0.6946,0.1094,0.5130,-0.8669],
             [0.3480,0.2593,0.4412,1.0017,-0.3913,-0.2878],
             [0.2484,0.2846,-0.3386,-0.6164,1.2722,0.5754]])
wk = torch.tensor([[ -0.3703,  0.5431, -0.0372,  -0.4406,  0.4103, -0.1773],
             [ 1.5993, -0.2777, -1.1909, -0.4301, 0.6927, -1.3304],
             [ 1.2470, -0.1872, -0.1670,  1.4302,  1.2927,  0.4822],
             [-0.0984, -0.8983,  0.3334, -0.6312,  0.1022,  -1.0715],
             [-0.7647, -0.1734,  0.6305,  1.0155,  0.8474,  0.1454],
             [-1.5085, -0.4529,  0.0997, -0.1084,  0.8046,  0.3459]])

wv = torch.tensor([[ 1.6395,  1.1234, -0.1001,  0.5021,  -1.0590,  0.1412],
             [-0.4271,  0.5681, 0.4164, -1.2534,  1.3061,  0.3610],
             [-0.2824, -0.4314,  1.2358,  0.1181,  -1.2467,  0.1893],
             [ 1.3440,  0.1487, -0.6174,  0.8890, -0.3282, 1.4662],
             [ 0.1814, -0.4761, -0.0402,  0.7326,  0.7654,  -0.1080],
             [-0.8974,  0.6786, 0.5602, -0.2443, -0.4883,  1.3996]])

print(wq)
print(wk)
print(wv)

tensor([[ 0.6323, -0.2366,  1.2455,  0.3465,  1.2458,  0.3229],
        [ 0.6571, -0.2378, -0.5311, -0.2610, -1.4819, -1.6418],
        [-0.2990,  0.4216,  0.2114, -0.0271, -0.5682,  0.6937],
        [-1.1291, -1.0102,  0.6946,  0.1094,  0.5130, -0.8669],
        [ 0.3480,  0.2593,  0.4412,  1.0017, -0.3913, -0.2878],
        [ 0.2484,  0.2846, -0.3386, -0.6164,  1.2722,  0.5754]])
tensor([[-0.3703,  0.5431, -0.0372, -0.4406,  0.4103, -0.1773],
        [ 1.5993, -0.2777, -1.1909, -0.4301,  0.6927, -1.3304],
        [ 1.2470, -0.1872, -0.1670,  1.4302,  1.2927,  0.4822],
        [-0.0984, -0.8983,  0.3334, -0.6312,  0.1022, -1.0715],
        [-0.7647, -0.1734,  0.6305,  1.0155,  0.8474,  0.1454],
        [-1.5085, -0.4529,  0.0997, -0.1084,  0.8046,  0.3459]])
tensor([[ 1.6395,  1.1234, -0.1001,  0.5021, -1.0590,  0.1412],
        [-0.4271,  0.5681,  0.4164, -1.2534,  1.3061,  0.3610],
        [-0.2824, -0.4314,  1.2358,  0.1181, -1.2467,  0.1893],
        [ 1.3440,  0.1487, -0.6174,  0

STEP 4: Calculate keys, queries and value matrices

calc_keys = x * wk

calc_queries = x * wq

calc_values = x * wv

print(calc_keys)

print(calc_queries)

print(calc_values)

In [36]:
calc_keys = torch.matmul(x,wk)
calc_queries = torch.matmul(x,wq)
calc_values = torch.matmul(x,wv)

print(calc_keys)
print(calc_queries)
print(calc_values)

tensor([[[-6.6988, -7.7515,  2.1643,  4.8921, 15.1472, -2.8751],
         [ 7.4296, -2.3733, -4.4848,  0.9557, 13.9021, -8.3648],
         [ 0.1044, -1.4464, -0.3315,  0.8354,  4.1499, -1.6057]]])
tensor([[[-0.2365, -0.4841,  3.7703,  1.4909,  4.3061, -2.3338],
         [ 3.4404, -3.1496,  8.2907,  2.3808, -0.1789, -6.0977],
         [ 0.4577, -0.5191,  1.7230,  0.5531,  0.5896, -1.2045]]])
tensor([[[ 0.8367,  3.2513,  5.1307,  4.1028, -2.6025, 15.1535],
         [10.0693,  8.0278,  5.0522,  1.1059, -4.7524,  8.9916],
         [ 1.5580,  1.6113,  1.4547,  0.7441, -1.0507,  3.4493]]])


STEP5 :Unroll last dimension of keys, queries and value to invlude num_heads and head_dim

head_dim = d_out/num_heads = 6/2 = 3

b, num_tokens, d_out --> b, num_tokens , num_heads, head_dim
(1, 3, 6) -> (1, 3, 2, 3)

In [37]:
keys_unrolled = calc_keys.view(1,3,2,3) # b , num_tokens, num_heads, head_dim
values_unrolled = calc_values.view(1,3,2,3)
queries_unrolled = calc_queries.view(1,3,2,3)

print(keys_unrolled)
print(values_unrolled)
print(queries_unrolled)

tensor([[[[-6.6988, -7.7515,  2.1643],
          [ 4.8921, 15.1472, -2.8751]],

         [[ 7.4296, -2.3733, -4.4848],
          [ 0.9557, 13.9021, -8.3648]],

         [[ 0.1044, -1.4464, -0.3315],
          [ 0.8354,  4.1499, -1.6057]]]])
tensor([[[[ 0.8367,  3.2513,  5.1307],
          [ 4.1028, -2.6025, 15.1535]],

         [[10.0693,  8.0278,  5.0522],
          [ 1.1059, -4.7524,  8.9916]],

         [[ 1.5580,  1.6113,  1.4547],
          [ 0.7441, -1.0507,  3.4493]]]])
tensor([[[[-0.2365, -0.4841,  3.7703],
          [ 1.4909,  4.3061, -2.3338]],

         [[ 3.4404, -3.1496,  8.2907],
          [ 2.3808, -0.1789, -6.0977]],

         [[ 0.4577, -0.5191,  1.7230],
          [ 0.5531,  0.5896, -1.2045]]]])


          first token
          [[-6.6988, -7.7515,  2.1643], head 1
          [ 4.8921, 15.1472, -2.8751]], head 2
          head_dim1, head_dim2, head_dim3

          second token
         [[ 7.4296, -2.3733, -4.4848],
          [ 0.9557, 13.9021, -8.3648]],

          third token
         [[ 0.1044, -1.4464, -0.3315],
          [ 0.8354,  4.1499, -1.6057]]]])

STEP6 : GROUP MATRICES BY "Number of heads"

(b, num_tokens, num_heads , head_dim ) --> ( b, num_heads. num_tokens, head_dim)

(1,3,2,3)                              --> (1,2,3,3)

In [38]:
keys_grp = keys_unrolled.transpose(1,2)
values_grp = values_unrolled.transpose(1,2)
queries_grp = queries_unrolled.transpose(1,2)

print(keys_grp)
print(values_grp)
print(queries_grp)

tensor([[[[-6.6988, -7.7515,  2.1643],
          [ 7.4296, -2.3733, -4.4848],
          [ 0.1044, -1.4464, -0.3315]],

         [[ 4.8921, 15.1472, -2.8751],
          [ 0.9557, 13.9021, -8.3648],
          [ 0.8354,  4.1499, -1.6057]]]])
tensor([[[[ 0.8367,  3.2513,  5.1307],
          [10.0693,  8.0278,  5.0522],
          [ 1.5580,  1.6113,  1.4547]],

         [[ 4.1028, -2.6025, 15.1535],
          [ 1.1059, -4.7524,  8.9916],
          [ 0.7441, -1.0507,  3.4493]]]])
tensor([[[[-0.2365, -0.4841,  3.7703],
          [ 3.4404, -3.1496,  8.2907],
          [ 0.4577, -0.5191,  1.7230]],

         [[ 1.4909,  4.3061, -2.3338],
          [ 2.3808, -0.1789, -6.0977],
          [ 0.5531,  0.5896, -1.2045]]]])



          Head1
          Token1 [[-6.6988, -7.7515,  2.1643],
          Token2 [ 7.4296, -2.3733, -4.4848],
          Token3 [ 0.1044, -1.4464, -0.3315]],
                  headdim1 headdim2  headdim3

          Head2
         [[ 4.8921, 15.1472, -2.8751],
          [ 0.9557, 13.9021, -8.3648],
          [ 0.8354,  4.1499, -1.6057]]]])

STEP7 : Find the attention scores

queries     matmul    keys.transpose

(1, 2, 3, 3)   matmul    (1, 2, 3, 3)

(b, num_heads, num_tokens , head_dim )     matmul   (b, num_heads, head_dim, num_tokens) -----> (b, num_heads, num_tokens , num_tokens)





In [42]:

print(keys_grp)
print(keys_grp.transpose(2,3))
print(queries_grp)

tensor([[[[-6.6988, -7.7515,  2.1643],
          [ 7.4296, -2.3733, -4.4848],
          [ 0.1044, -1.4464, -0.3315]],

         [[ 4.8921, 15.1472, -2.8751],
          [ 0.9557, 13.9021, -8.3648],
          [ 0.8354,  4.1499, -1.6057]]]])
tensor([[[[-6.6988,  7.4296,  0.1044],
          [-7.7515, -2.3733, -1.4464],
          [ 2.1643, -4.4848, -0.3315]],

         [[ 4.8921,  0.9557,  0.8354],
          [15.1472, 13.9021,  4.1499],
          [-2.8751, -8.3648, -1.6057]]]])
tensor([[[[-0.2365, -0.4841,  3.7703],
          [ 3.4404, -3.1496,  8.2907],
          [ 0.4577, -0.5191,  1.7230]],

         [[ 1.4909,  4.3061, -2.3338],
          [ 2.3808, -0.1789, -6.0977],
          [ 0.5531,  0.5896, -1.2045]]]])


In [43]:
attn_score = queries_grp @ keys_grp.transpose(2,3)
print(attn_score)

tensor([[[[ 13.4968, -17.5172,  -0.5743],
          [ 19.3111,  -4.1464,   2.1664],
          [  4.6869,  -3.0948,   0.2274]],

         [[ 79.2289,  80.8105,  22.8628],
          [ 26.4688,  50.7943,  11.0376],
          [ 15.0997,  18.8007,   4.8429]]]])


Step8 : Finding the attention weight

Mask attentions scores to implement causal attention

Divide by sq.rt of head_dimension = sq.rt ( d_out/num_heads)

Apply softmax

We can apply dropout (optional)


In [50]:
mask = torch.triu(torch.ones(3,3),diagonal =1)
mask_bool = mask.bool()[:3,:3]
print(mask)
print(mask_bool)

attn_score = attn_score.masked_fill(mask_bool,-torch.inf)
print(attn_score)
print(keys_grp.shape[-1])
attn_weights = torch.softmax(attn_score/keys_grp.shape[-1]**0.5,dim=-1)
attn_weights = dropout(attn_weights)
print(attn_weights)


tensor([[0., 1., 1.],
        [0., 0., 1.],
        [0., 0., 0.]])
tensor([[False,  True,  True],
        [False, False,  True],
        [False, False, False]])
tensor([[[[13.4968,    -inf,    -inf],
          [19.3111, -4.1464,    -inf],
          [ 4.6869, -3.0948,  0.2274]],

         [[79.2289,    -inf,    -inf],
          [26.4688, 50.7943,    -inf],
          [15.0997, 18.8007,  4.8429]]]])
3
tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 1.4012e-01]],

         [[2.0000e+00, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 2.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 5.6582e-04]]]])


STEP9 : Compute the context vector.

Context_vector = attn_weights  * Value

                (b,num_heads,num_tokens , num_tokens)     (b, num_heads,num_tokens, head_dim)


In [57]:
print(attn_weights)
print(values_grp)
print(attn_weights.shape)
print(values_grp.shape)

context_vec = (attn_weights @ values_grp)
print(context_vec)

tensor([[[[0.0000e+00, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 1.4012e-01]],

         [[2.0000e+00, 0.0000e+00, 0.0000e+00],
          [0.0000e+00, 2.0000e+00, 0.0000e+00],
          [0.0000e+00, 0.0000e+00, 5.6582e-04]]]])
tensor([[[[ 0.8367,  3.2513,  5.1307],
          [10.0693,  8.0278,  5.0522],
          [ 1.5580,  1.6113,  1.4547]],

         [[ 4.1028, -2.6025, 15.1535],
          [ 1.1059, -4.7524,  8.9916],
          [ 0.7441, -1.0507,  3.4493]]]])
tensor([[-0.0872,  0.0286],
        [-0.1137,  0.0766],
        [-0.1018,  0.0927],
        [-0.0912, -0.0026],
        [ 0.1395,  0.3580],
        [-0.2085, -0.1546]], grad_fn=<MmBackward0>)
torch.Size([1, 2, 3, 3])
torch.Size([1, 2, 3, 3])
tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 2.1831e-01,  2.2577e-01,  2.0383e-01]],

         [[ 8.2056e+00, -5.2050e+00,  3.0307e+01],
          [ 2.2118e+

          
          Head1
          Headdim1       Headdim2     Headdim3
          [[ 0.0000e+00,  0.0000e+00,  0.0000e+00], Token 1
          [ 0.0000e+00,  0.0000e+00,  0.0000e+00],  Token 2
          [ 2.1831e-01,  2.2577e-01,  2.0383e-01]], Token 3

          Head2
         [[ 8.2056e+00, -5.2050e+00,  3.0307e+01],
          [ 2.2118e+00, -9.5048e+00,  1.7983e+01],
          [ 4.2103e-04, -5.9451e-04,  1.9517e-03]]]])

Step10 : Reformat context vectors

(b, num_heads, num_tokens, head_dim ) ----> (b, num_tokens, num_heads, head_dim)

In [59]:
context_vec_fin = context_vec.transpose(1,2)
print(context_vec_fin)


tensor([[[[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 8.2056e+00, -5.2050e+00,  3.0307e+01]],

         [[ 0.0000e+00,  0.0000e+00,  0.0000e+00],
          [ 2.2118e+00, -9.5048e+00,  1.7983e+01]],

         [[ 2.1831e-01,  2.2577e-01,  2.0383e-01],
          [ 4.2103e-04, -5.9451e-04,  1.9517e-03]]]])


Step11 : Combine heads

In [60]:
context_vec_fin = context_vec_fin.contiguous().view(1,3,6)
print(context_vec_fin)

tensor([[[ 0.0000e+00,  0.0000e+00,  0.0000e+00,  8.2056e+00, -5.2050e+00,
           3.0307e+01],
         [ 0.0000e+00,  0.0000e+00,  0.0000e+00,  2.2118e+00, -9.5048e+00,
           1.7983e+01],
         [ 2.1831e-01,  2.2577e-01,  2.0383e-01,  4.2103e-04, -5.9451e-04,
           1.9517e-03]]])


MULTIHEAD ATTENTION CLASS

In [61]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

In [62]:
torch.manual_seed(123)

inputs = torch.tensor(
    [[0.43,0.15,0.89,0.55,0.87,0.66],
     [0.57,0.85,0.64,0.22,0.58,0.33],
     [0.77,0.25,0.10,0.05,0.80,0.55]])

batch = torch.stack((inputs,inputs),dim=0)
print(batch.shape)

batch_size , context_length , d_in = batch.shape
d_out = 6
mha = MultiHeadAttention(d_in,d_out,context_length,0.0,num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print(context_vecs.shape)

torch.Size([2, 3, 6])
tensor([[[ 0.1569, -0.0873,  0.0210,  0.0215, -0.3243, -0.2518],
         [ 0.1117, -0.0547,  0.0406, -0.0213, -0.3251, -0.2993],
         [ 0.1196, -0.0491,  0.0318, -0.0635, -0.2788, -0.2578]],

        [[ 0.1569, -0.0873,  0.0210,  0.0215, -0.3243, -0.2518],
         [ 0.1117, -0.0547,  0.0406, -0.0213, -0.3251, -0.2993],
         [ 0.1196, -0.0491,  0.0318, -0.0635, -0.2788, -0.2578]]],
       grad_fn=<ViewBackward0>)
torch.Size([2, 3, 6])
