In [1]:
import torch

In [2]:
#Step 1 start with the input 
# b ,num_of_tokens, d_in = (1,3,6) -> batch =1, number of token =3, embedding dimension = 6
d_in = 6
x = torch.tensor([[[1.0,2.0,3.0,4.0,5.0,6.0],   #the
                   [6.0,5.0,4.0,3.0,2.0,1.0],   #cat
                   [1.0,1.0,1.0,1.0,1.0,1.0]]]) #sleeps
x.shape

torch.Size([1, 3, 6])

In [3]:
#Step 2: decide d_out, number of heads  
d_out = 6   #for dimension of context vector if dout =6 then dimension of context vector = 3 x 6  
# number of attn heads = 2 
num_heads = 2
# Given this head_dim = 6/2 = 3 -> so each head has a dimension 3  
head_dim = d_out//num_heads

In [4]:
# Initilize the weight matrices for key query and value
# 6 x 6  
import torch.nn as nn
W_query = nn.Linear(d_in,d_out, bias=False)
W_key = nn.Linear(d_in, d_out, bias=False)
W_value = nn.Linear(d_in, d_out, bias=False)

In [5]:
# Print the values of the weight matrices
print("W_query weight matrix:")
print(W_query.weight.data)# Access the underlying data of the tensor

print("\nW_key weight matrix:")
print(W_key.weight.data)

print("\nW_value weight matrix:")
print(W_value.weight.data)

W_query weight matrix:
tensor([[ 0.0994, -0.3706,  0.2431,  0.2046,  0.1399,  0.3392],
        [-0.0107, -0.0927,  0.2060,  0.1684,  0.2498, -0.2904],
        [ 0.3014,  0.1254,  0.1257,  0.1717,  0.0955, -0.2667],
        [-0.3731,  0.3828, -0.3635, -0.3155,  0.1172, -0.0590],
        [-0.2768, -0.0638,  0.1943, -0.2806,  0.2499,  0.1582],
        [-0.2427, -0.1134, -0.0316, -0.0958, -0.2733,  0.1731]])

W_key weight matrix:
tensor([[ 0.1453,  0.2993,  0.3894,  0.0725, -0.3739,  0.1844],
        [-0.1958, -0.2417, -0.2086,  0.0361,  0.2955, -0.2080],
        [-0.3570,  0.3868,  0.1560, -0.3431,  0.3903,  0.0418],
        [ 0.2738,  0.3491, -0.1222, -0.2679,  0.3618, -0.0428],
        [-0.4038, -0.0402, -0.1200,  0.2469, -0.1900, -0.1259],
        [-0.0832, -0.2505,  0.3393, -0.2988,  0.0803, -0.3066]])

W_value weight matrix:
tensor([[ 0.3367, -0.3055,  0.0278,  0.3592,  0.3428, -0.0699],
        [-0.0987, -0.0080,  0.3444,  0.3132,  0.2908, -0.3063],
        [-0.3165,  0.1133, -0.381

In [6]:
# get queries, keys and value matrices
# Apply the linear transformations
queries = W_query(x)
keys = W_key(x)
values = W_value(x)


In [7]:
print(queries)
print(queries.shape)

tensor([[[ 3.6407,  0.6016,  0.4935, -1.7284,  1.2552, -1.2758],
         [ 0.9491,  1.0102,  3.3780, -2.5499, -1.3864, -2.8108],
         [ 0.6557,  0.2303,  0.5531, -0.6112, -0.0187, -0.5838]]],
       grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 3, 6])


In [8]:
print(keys)
print(keys.shape)

tensor([[[ 1.4392, -0.9314,  1.7147,  1.0859, -1.5617, -2.1994],
         [ 3.5801, -2.7269,  0.2090,  2.7768, -2.8686, -1.4366],
         [ 0.7170, -0.5226,  0.2748,  0.5518, -0.6329, -0.5194]]],
       grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 3, 6])


In [9]:
print(values)
print(values.shape)

tensor([[[ 2.5402,  1.7876, -4.5756,  0.3422, -1.5177, -1.8528],
         [ 2.2967,  1.9605, -4.0580, -0.0142, -3.4818, -3.3014],
         [ 0.6910,  0.5354, -1.2334,  0.0469, -0.7142, -0.7363]]],
       grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 3, 6])


In [10]:
#Step 4: We implicitly split the matrix by adding a num_heads dimension. 
# Unroll the last dimension for keys, queries and dimension to include the num of heads and head dimension
#head_dim = d_out/num_heads = 6/2 = 3
# (b=1, num_of_token= 3, dout = 6) -> (b=1, num_tokens= 3, num_head =2, head_dim = 3)
# (1,3,6) -> (1,3,2,3)
# Reshape the tensors to (1, 3, 2, head_dim)
reshaped_Q = queries.reshape(1, 3, 2, head_dim)
reshaped_K = keys.reshape(1, 3, 2, head_dim)
reshaped_V = values.reshape(1, 3, 2, head_dim)


In [11]:
print(reshaped_Q)  #(batch, num of tokens, number of heads, head dimension)

tensor([[[[ 3.6407,  0.6016,  0.4935],
          [-1.7284,  1.2552, -1.2758]],

         [[ 0.9491,  1.0102,  3.3780],
          [-2.5499, -1.3864, -2.8108]],

         [[ 0.6557,  0.2303,  0.5531],
          [-0.6112, -0.0187, -0.5838]]]], grad_fn=<ViewBackward0>)


In [12]:
print(reshaped_K)

tensor([[[[ 1.4392, -0.9314,  1.7147],
          [ 1.0859, -1.5617, -2.1994]],

         [[ 3.5801, -2.7269,  0.2090],
          [ 2.7768, -2.8686, -1.4366]],

         [[ 0.7170, -0.5226,  0.2748],
          [ 0.5518, -0.6329, -0.5194]]]], grad_fn=<ViewBackward0>)


In [13]:
print(reshaped_V)

tensor([[[[ 2.5402,  1.7876, -4.5756],
          [ 0.3422, -1.5177, -1.8528]],

         [[ 2.2967,  1.9605, -4.0580],
          [-0.0142, -3.4818, -3.3014]],

         [[ 0.6910,  0.5354, -1.2334],
          [ 0.0469, -0.7142, -0.7363]]]], grad_fn=<ViewBackward0>)


In [14]:
#Step 5: Transpose from shape (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
#matrices grouped according to token
# we want to group by number of head
# (batch = 1, number of tokens = 3, number of heads = 2, head_dim = 3 ) -> ( batch =1, number of heads =2, number of tokens = 3, head_dim =3)
#(1,3,2,3) -> (1,2,3,3)
transposed_Q = reshaped_Q.transpose(1, 2)
transposed_K = reshaped_K.transpose(1, 2)
transposed_V = reshaped_V.transpose(1, 2)

In [15]:
print(transposed_Q)
print(transposed_Q.shape)

tensor([[[[ 3.6407,  0.6016,  0.4935],
          [ 0.9491,  1.0102,  3.3780],
          [ 0.6557,  0.2303,  0.5531]],

         [[-1.7284,  1.2552, -1.2758],
          [-2.5499, -1.3864, -2.8108],
          [-0.6112, -0.0187, -0.5838]]]], grad_fn=<TransposeBackward0>)
torch.Size([1, 2, 3, 3])


In [16]:
print(transposed_K)
print(transposed_K.shape)

tensor([[[[ 1.4392, -0.9314,  1.7147],
          [ 3.5801, -2.7269,  0.2090],
          [ 0.7170, -0.5226,  0.2748]],

         [[ 1.0859, -1.5617, -2.1994],
          [ 2.7768, -2.8686, -1.4366],
          [ 0.5518, -0.6329, -0.5194]]]], grad_fn=<TransposeBackward0>)
torch.Size([1, 2, 3, 3])


In [17]:
print(transposed_V)
print(transposed_V.shape)

tensor([[[[ 2.5402,  1.7876, -4.5756],
          [ 2.2967,  1.9605, -4.0580],
          [ 0.6910,  0.5354, -1.2334]],

         [[ 0.3422, -1.5177, -1.8528],
          [-0.0142, -3.4818, -3.3014],
          [ 0.0469, -0.7142, -0.7363]]]], grad_fn=<TransposeBackward0>)
torch.Size([1, 2, 3, 3])


In [18]:
# for ease of understanding
queries = transposed_Q
keys = transposed_K
values = transposed_V
print(keys)
print(keys.shape)

tensor([[[[ 1.4392, -0.9314,  1.7147],
          [ 3.5801, -2.7269,  0.2090],
          [ 0.7170, -0.5226,  0.2748]],

         [[ 1.0859, -1.5617, -2.1994],
          [ 2.7768, -2.8686, -1.4366],
          [ 0.5518, -0.6329, -0.5194]]]], grad_fn=<TransposeBackward0>)
torch.Size([1, 2, 3, 3])


In [19]:
print(keys.transpose(2, 3))
print(keys.transpose(2, 3).shape)

tensor([[[[ 1.4392,  3.5801,  0.7170],
          [-0.9314, -2.7269, -0.5226],
          [ 1.7147,  0.2090,  0.2748]],

         [[ 1.0859,  2.7768,  0.5518],
          [-1.5617, -2.8686, -0.6329],
          [-2.1994, -1.4366, -0.5194]]]], grad_fn=<TransposeBackward0>)
torch.Size([1, 2, 3, 3])


In [20]:
print(queries)
print(queries.shape)

tensor([[[[ 3.6407,  0.6016,  0.4935],
          [ 0.9491,  1.0102,  3.3780],
          [ 0.6557,  0.2303,  0.5531]],

         [[-1.7284,  1.2552, -1.2758],
          [-2.5499, -1.3864, -2.8108],
          [-0.6112, -0.0187, -0.5838]]]], grad_fn=<TransposeBackward0>)
torch.Size([1, 2, 3, 3])


In [21]:
#Now get attention scores
# queries*keys.transpose(2,3) --> why 2,3 (b,numhead, num_tokens, head_dim)
attn_scores = queries @ keys.transpose(2, 3)

In [22]:
print(attn_scores)    # attention 3x3 (the cat sleeps) for each head
print(attn_scores.shape)

tensor([[[[ 5.5256, 11.4966,  2.4317],
          [ 6.2173,  1.3493,  1.0809],
          [ 1.6776,  1.8351,  0.5018]],

         [[-1.0312, -6.5674, -1.0855],
          [ 5.5780,  0.9344,  0.9303],
          [ 0.6496, -0.8047, -0.0222]]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 2, 3, 3])


In [23]:
#Step 7: Mask truncated to the number of tokens
context_length = 3  # Or whatever your context length is
mask = torch.triu(torch.ones(context_length, context_length),
                  diagonal=1)
num_tokens = 3
mask_bool = mask.bool()[:num_tokens, :num_tokens]

In [24]:
# Step 8: Use the mask to fill attention scores
attn_scores.masked_fill_(mask_bool, -torch.inf)

tensor([[[[ 5.5256,    -inf,    -inf],
          [ 6.2173,  1.3493,    -inf],
          [ 1.6776,  1.8351,  0.5018]],

         [[-1.0312,    -inf,    -inf],
          [ 5.5780,  0.9344,    -inf],
          [ 0.6496, -0.8047, -0.0222]]]], grad_fn=<MaskedFillBackward0>)

In [25]:
# Softmax with variance
print(keys.shape[-1])   #head dimension 3
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

3
tensor([[[[1.0000, 0.0000, 0.0000],
          [0.9432, 0.0568, 0.0000],
          [0.3843, 0.4208, 0.1949]],

         [[1.0000, 0.0000, 0.0000],
          [0.9359, 0.0641, 0.0000],
          [0.4738, 0.2046, 0.3215]]]], grad_fn=<SoftmaxBackward0>)


In [26]:
# u can also apply dropout after this - not shown here

In [27]:
print(attn_weights.shape)

torch.Size([1, 2, 3, 3])


In [28]:
print(values)
print(values.shape)

tensor([[[[ 2.5402,  1.7876, -4.5756],
          [ 2.2967,  1.9605, -4.0580],
          [ 0.6910,  0.5354, -1.2334]],

         [[ 0.3422, -1.5177, -1.8528],
          [-0.0142, -3.4818, -3.3014],
          [ 0.0469, -0.7142, -0.7363]]]], grad_fn=<TransposeBackward0>)
torch.Size([1, 2, 3, 3])


In [29]:
# Context vector ---> attn * values
# attn -> (b, heads, num_token, num_token) - (1,2,3,3)
# value -> (b, heads, num_token, head_dim) - (1,2,3,3)
# after multiplication -> (b,heads, num_token, head_dim)  (3 x 3) x (3 x 3) -> 3 x 3
# see below
context_vec = (attn_weights @ values)
print(context_vec)
print(context_vec.shape)

tensor([[[[ 2.5402,  1.7876, -4.5756],
          [ 2.5264,  1.7975, -4.5462],
          [ 2.0773,  1.6163, -3.7064]],

         [[ 0.3422, -1.5177, -1.8528],
          [ 0.3194, -1.6436, -1.9457],
          [ 0.1743, -1.6613, -1.7903]]]], grad_fn=<UnsafeViewBackward0>)
torch.Size([1, 2, 3, 3])


In [30]:
# but we wanted to combine the heads in the output concatinated context vector output
# basically convert (b, head, num_token, head_dim) -> (b, num_token, head, head_dim)
# how do we do this - transpose(1,2) -> transpose(head, num_token)
context_vec = (attn_weights @ values).transpose(1, 2) 
print(context_vec)
print(context_vec.shape)

tensor([[[[ 2.5402,  1.7876, -4.5756],
          [ 0.3422, -1.5177, -1.8528]],

         [[ 2.5264,  1.7975, -4.5462],
          [ 0.3194, -1.6436, -1.9457]],

         [[ 2.0773,  1.6163, -3.7064],
          [ 0.1743, -1.6613, -1.7903]]]], grad_fn=<TransposeBackward0>)
torch.Size([1, 3, 2, 3])


In [31]:
# now flatten each set of token into 1 row
# Combine heads, where self.d_out = self.num_heads * self.head_dim
b = 1,
num_tokens = 3
context_vec = context_vec.contiguous().view(1, 3, d_out)
print(context_vec)
print(context_vec.shape)


tensor([[[ 2.5402,  1.7876, -4.5756,  0.3422, -1.5177, -1.8528],
         [ 2.5264,  1.7975, -4.5462,  0.3194, -1.6436, -1.9457],
         [ 2.0773,  1.6163, -3.7064,  0.1743, -1.6613, -1.7903]]],
       grad_fn=<ViewBackward0>)
torch.Size([1, 3, 6])
