**A simple self-attention mechanism without trainable weights**

In [2]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [3]:
input_query = inputs[1]
input_query 

tensor([0.5500, 0.8700, 0.6600])

In [4]:
input_1 = inputs[0]
input_1

tensor([0.4300, 0.1500, 0.8900])

In [5]:
torch.dot(input_query, input_1) #dot product calculates the similarity between the query and the key

tensor(0.9544)

In [6]:
query = inputs[1] #2nd input is taken as example and this is called as query

attn_scores_2 = torch.empty(inputs.shape[0]) #empty tensor is created to store the attention score of the query token

for i, x_i in enumerate(inputs): #enumerate function is used to get the index and the value of the input tensor
    attn_scores_2[i] = torch.dot(query, x_i) #dot product is calculated between the query and the key (key refers to other tokens including the query token)

In [7]:
attn_scores_2 #this is the attention score of the query token attention scores refers to the similarity between the query and the key if you see the query token is most similar to the query token itself that is why it is the highest

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])

In [8]:
attn_weights_2 = torch.nn.functional.softmax(attn_scores_2, dim=0) #softmax is applied to the attention scores to get the attention weights
attn_weights_2 #this is the attention weight of the query token

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])

In [9]:
#now we will calculate the context vector of the query token which is the weighted sum of the input tokens with this attention weights

query = inputs[1]

context_vec_2 = torch.zeros(query.shape) #empty tensor is created to store the context vector of the query token and the shape of the context vector is the same as the query token

for i, x_i in enumerate(inputs): #enumerate function is used to get the index and the value of the input tensor
    context_vec_2 += attn_weights_2[i] * x_i #weighted sum is calculated between the attention weights and the input tokens

In [10]:
context_vec_2 #this is the context vector of the query token

tensor([0.4419, 0.6515, 0.5683])

In [11]:
attn_scores = torch.zeros(6,6) #empty tensor is created to store the attention scores of the all the token with respect to all the other tokens so it's a 6x6 matrix

for i, x_i in enumerate(inputs): #enumerate function is used to get the index and the value of the input tensor
    for j, x_j in enumerate(inputs): #enumerate function is used to get the index and the value of the input tensor
        attn_scores[i, j] = torch.dot(x_i, x_j) #dot product is calculated between the input tokens

In [12]:
attn_scores #if you look at this you will see that attn_scores_2 is the second row of this matrix so each row is the attention score of that token with respect to all the other tokens and the diagonal elements are the attention scores of the query token with respect to itself which is the highest

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [13]:
attn_scores = inputs @ inputs.T #this is the same as the above code but it is more efficient and it is a matrix multiplication
attn_scores

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])

In [14]:
attn_weights = torch.softmax(attn_scores, dim=1) #softmax is applied to the attention scores to get the attention weights dim=1 means that the softmax is applied to each row of the matrix so each row adds up to 1

attn_weights

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])

In [15]:
context_vec = attn_weights @ inputs #context vector is the weighted sum of the input tokens with this attention weights
context_vec #each row of the context vector of that token 

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])

**Implementing self attention with trainable weights**

In [16]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [17]:
x_2

tensor([0.5500, 0.8700, 0.6600])

In [18]:
torch.manual_seed(123)

W_q = torch.nn.Parameter(torch.rand(d_in, d_out))
W_k = torch.nn.Parameter(torch.rand(d_in, d_out))
W_v = torch.nn.Parameter(torch.rand(d_in, d_out))

# these are trainable weights for the query, key and value matrices we can choose d_out to be any value we want that is the dimension of the output space for simplicity we will choose it to be 2

In [19]:
W_q

Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]], requires_grad=True)

In [20]:
W_k

Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]], requires_grad=True)

In [21]:
W_v

Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]], requires_grad=True)

In [22]:
query_2 = x_2 @ W_q #this is the query vector of the 2nd token
query_2

tensor([0.4306, 1.4551], grad_fn=<SqueezeBackward4>)

In [23]:
key_2 = x_2 @ W_k #this is the key vector of the 2nd token
key_2

tensor([0.4433, 1.1419], grad_fn=<SqueezeBackward4>)

In [24]:
value_2 = x_2 @ W_v #this is the value vector of the 2nd token
value_2

tensor([0.3951, 1.0037], grad_fn=<SqueezeBackward4>)

In [25]:
# to calculate the context vector for 2nd token we will use the query vector of the 2nd token and the key vectors of all the other tokens and the value vectors of all the other tokens 
# so instead of calculating the key and value vectors seperately we will calculate them in one go by using the matrix multiplication

keys = inputs @ W_k #this is the key matrix of all the tokens
values = inputs @ W_v #this is the value matrix of all the tokens

In [26]:
keys #2nd row of the key matrix is the key vector of the 2nd token

tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]], grad_fn=<MmBackward0>)

In [27]:
values #2nd row of the value matrix is the value vector of the 2nd token

tensor([[0.1855, 0.8812],
        [0.3951, 1.0037],
        [0.3879, 0.9831],
        [0.2393, 0.5493],
        [0.1492, 0.3346],
        [0.3221, 0.7863]], grad_fn=<MmBackward0>)

In [28]:
attn_score_22 = torch.dot(query_2, key_2) #this is the attention score of the 2nd token with respect to itself
attn_score_22

tensor(1.8524, grad_fn=<DotBackward0>)

In [29]:
attn_scores_2 = query_2 @ keys.T #this is the attention score of the 2nd token with respect to all the other tokens look at the 2nd value of the attn_scores_2 tensor it is the same as the attn_score_22 tensor
attn_scores_2

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440],
       grad_fn=<SqueezeBackward4>)

In [30]:
d_k = keys.shape[1] #this is the dimension of the key vector 
d_k

2

In [31]:
attn_weights_2 = torch.softmax(attn_scores_2 / torch.sqrt(torch.tensor(d_k)), dim=0) 
attn_weights_2

#this is the attention weight of the 2nd token we divide the attention scores by the square root of the dimension of the key vector to prevent the attention scores from becoming too large and this was recommended by the authors of the paper attention is all you need


tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820],
       grad_fn=<SoftmaxBackward0>)

In [32]:
context_vec_2 = attn_weights_2 @ values #this is the context vector of the 2nd token
context_vec_2

tensor([0.3061, 0.8210], grad_fn=<SqueezeBackward4>)

**Implementing a Compact Self Attention Class**

In [33]:
class SelfAttention(torch.nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_q = torch.nn.Parameter(torch.rand(d_in, d_out))
        self.W_k = torch.nn.Parameter(torch.rand(d_in, d_out))
        self.W_v = torch.nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):

        queries = x @ self.W_q
        keys = x @ self.W_k
        values = x @ self.W_v

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / torch.sqrt(torch.tensor(keys.shape[1])), dim=1)
        context_vec = attn_weights @ values

        return context_vec


torch.manual_seed(123)
sa_v1 = SelfAttention(d_in=3, d_out=2)
sa_v1

SelfAttention()

In [34]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [35]:
sa_v1(inputs)

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)

In [36]:
class SelfAttention_v2(torch.nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_q = torch.nn.Linear(d_in, d_out, bias=False) # we use Linear because instead of random this method has a better weight initilization scheme
        self.W_k = torch.nn.Linear(d_in, d_out, bias=False)
        self.W_v = torch.nn.Linear(d_in, d_out, bias=False)

    def forward(self, x):
        queries = self.W_q(inputs)
        keys = self.W_k(inputs)
        values = self.W_v(inputs)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / torch.sqrt(torch.tensor(keys.shape[1])), dim=1)
        context_vec = attn_weights @ values

        return context_vec

In [37]:
d_in, d_out

(3, 2)

In [38]:
torch.manual_seed(123)
sa_v2 = SelfAttention_v2(d_in, d_out)
sa_v2 

SelfAttention_v2(
  (W_q): Linear(in_features=3, out_features=2, bias=False)
  (W_k): Linear(in_features=3, out_features=2, bias=False)
  (W_v): Linear(in_features=3, out_features=2, bias=False)
)

In [39]:
sa_v2(inputs)

tensor([[-0.5337, -0.1051],
        [-0.5323, -0.1080],
        [-0.5323, -0.1079],
        [-0.5297, -0.1076],
        [-0.5311, -0.1066],
        [-0.5299, -0.1081]], grad_fn=<MmBackward0>)

**Applying a causal attention mask**

In [40]:
# Let's say the LLM generates one word at a time so in the example sentense "Your Journey Starts with One step"
# when giving the context vector of "Your" the LLM should only have access to that token only and when given the context vector of "Your Journey" the LLM should not have access to other context tokens
# just because it is in our training set

In [41]:
queries = sa_v2.W_q(inputs)
queries

tensor([[-0.3536,  0.3965],
        [-0.3021, -0.0289],
        [-0.3015, -0.0232],
        [-0.1353, -0.0978],
        [-0.2052,  0.0870],
        [-0.1542, -0.1499]], grad_fn=<MmBackward0>)

In [42]:
keys = sa_v2.W_k(inputs)
keys

tensor([[-0.5740,  0.2727],
        [-0.8709,  0.1008],
        [-0.8628,  0.1060],
        [-0.4789,  0.0051],
        [-0.4744,  0.1696],
        [-0.5888, -0.0388]], grad_fn=<MmBackward0>)

In [43]:
values = sa_v2.W_v(inputs)
keys

tensor([[-0.5740,  0.2727],
        [-0.8709,  0.1008],
        [-0.8628,  0.1060],
        [-0.4789,  0.0051],
        [-0.4744,  0.1696],
        [-0.5888, -0.0388]], grad_fn=<MmBackward0>)

In [44]:
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / torch.sqrt(torch.tensor(keys.shape[1])), dim=1)

In [45]:
attn_weights

tensor([[0.1717, 0.1762, 0.1761, 0.1555, 0.1627, 0.1579],
        [0.1636, 0.1749, 0.1746, 0.1612, 0.1605, 0.1652],
        [0.1637, 0.1749, 0.1746, 0.1611, 0.1606, 0.1651],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.1632, 0.1674],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.1639],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)

In [46]:
context_length = attn_scores.shape[0]

In [47]:
mask_simple = torch.tril(torch.ones(context_length, context_length))
mask_simple # creates a mask

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

In [48]:
masked_attn_weights = attn_weights * mask_simple
masked_attn_weights

tensor([[0.1717, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1749, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1637, 0.1749, 0.1746, 0.0000, 0.0000, 0.0000],
        [0.1636, 0.1704, 0.1702, 0.1652, 0.0000, 0.0000],
        [0.1667, 0.1722, 0.1721, 0.1618, 0.1633, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<MulBackward0>)

In [49]:
row_sums = masked_attn_weights.sum(dim=1,keepdim=True)
row_sums

tensor([[0.1717],
        [0.3385],
        [0.5132],
        [0.6693],
        [0.8361],
        [1.0000]], grad_fn=<SumBackward1>)

In [50]:
masked_norm_attn_weights = masked_attn_weights/row_sums
masked_norm_attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<DivBackward0>)

In [51]:
# instead of getting the attention score then normalizing to get attention weights and then getting the masked attention weights and then again normalizing what we can do is get the attention score calculate the masked attention score and then calculate the normalized version

mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
mask

tensor([[0., 1., 1., 1., 1., 1.],
        [0., 0., 1., 1., 1., 1.],
        [0., 0., 0., 1., 1., 1.],
        [0., 0., 0., 0., 1., 1.],
        [0., 0., 0., 0., 0., 1.],
        [0., 0., 0., 0., 0., 0.]])

In [52]:
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[0.3111,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.1655, 0.2602,   -inf,   -inf,   -inf,   -inf],
        [0.1667, 0.2602, 0.2577,   -inf,   -inf,   -inf],
        [0.0510, 0.1080, 0.1064, 0.0643,   -inf,   -inf],
        [0.1415, 0.1875, 0.1863, 0.0987, 0.1121,   -inf],
        [0.0476, 0.1192, 0.1171, 0.0731, 0.0477, 0.0966]],
       grad_fn=<MaskedFillBackward0>)

In [53]:
attn_weights = torch.softmax(masked/torch.sqrt(torch.tensor(keys.shape[1])), dim=1)
attn_weights # this is same as what we got before

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4833, 0.5167, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3190, 0.3408, 0.3402, 0.0000, 0.0000, 0.0000],
        [0.2445, 0.2545, 0.2542, 0.2468, 0.0000, 0.0000],
        [0.1994, 0.2060, 0.2058, 0.1935, 0.1953, 0.0000],
        [0.1624, 0.1709, 0.1706, 0.1654, 0.1625, 0.1682]],
       grad_fn=<SoftmaxBackward0>)

**Masking additional attention weights with dropout**

In [54]:
#masking will drop random position from our attn_weight so model will reduce overfitting
drop_layer = torch.nn.Dropout(0.5)
drop_layer

Dropout(p=0.5, inplace=False)

In [55]:
example = torch.ones(6,6)
example

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])

In [56]:
drop_layer(example) # Other values become bigger because PyTorch rescales them to keep the average unchanged.

tensor([[2., 0., 2., 2., 2., 2.],
        [2., 0., 2., 0., 2., 2.],
        [0., 0., 2., 0., 2., 0.],
        [2., 0., 2., 0., 2., 2.],
        [2., 2., 0., 2., 0., 0.],
        [0., 0., 0., 0., 0., 2.]])

**create a final class of everything we learned causal self attention plus dropout**

In [65]:
class CausalAttention(torch.nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout):
        super().__init__()
        self.W_q = torch.nn.Linear(d_in, d_out, bias=False)
        self.W_k = torch.nn.Linear(d_in, d_out, bias=False)
        self.W_v = torch.nn.Linear(d_in, d_out, bias=False)
        self.dropout = torch.nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1)) #this is a buffer so it is not a parameter and it is not updated during the training process and when we train with GPU it will be moved to the GPU

    def forward(self, x):
        b, num_tokens, d_in = x.shape #b is the batch size, num_tokens is the number of tokens in the input, d_in is the dimension of the input
        queries = self.W_q(x)
        keys = self.W_k(x)
        values = self.W_v(x)
        attn_scores = queries @ keys.transpose(1, 2) #transpose is used to make the keys matrix a 3d tensor and why we do that because we want to make the dot product of the queries and the keys and we want to do it for all the tokens in the input so we need to make the keys matrix a 3d tensor
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) #this is the mask that we created earlier the causal mask
        attn_weights = torch.softmax(attn_scores / torch.sqrt(torch.tensor(keys.shape[1])), dim=1)
        attn_weights = self.dropout(attn_weights)
        context_vec = attn_weights @ values
        return context_vec


torch.manual_seed(123)
ca = CausalAttention(d_in=3, d_out=2, context_length=6, dropout=0.5)
ca

CausalAttention(
  (W_q): Linear(in_features=3, out_features=2, bias=False)
  (W_k): Linear(in_features=3, out_features=2, bias=False)
  (W_v): Linear(in_features=3, out_features=2, bias=False)
  (dropout): Dropout(p=0.5, inplace=False)
)

In [66]:
batch = torch.stack((inputs, inputs), dim=0)
batch.shape

torch.Size([2, 6, 3])

In [69]:
batch

tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])

In [67]:
ca(batch)

tensor([[[-0.1610,  0.0789],
         [-0.1517,  0.0744],
         [-0.3697, -0.1022],
         [-0.4923, -0.0251],
         [-0.6918, -0.1094],
         [-0.8427, -0.3002]],

        [[ 0.0000,  0.0000],
         [-0.4459, -0.0064],
         [-0.5214, -0.0278],
         [-0.8765, -0.2752],
         [-0.7949, -0.1041],
         [-1.1554, -0.2665]]], grad_fn=<UnsafeViewBackward0>)

**Stacking Multiple single attention to make multi head attention**

In [86]:
#what we implemented above is a single head attention mechanism but in a multi head attention mechanism we have multiple attention heads and each attention head has its own query, key and value matrices
#so we will stack multiple single head attention mechanisms to make a multi head attention mechanism


class MultiHeadAttention(torch.nn.Module):
    def __init__(self, d_in, d_out, num_heads, dropout, context_length):
        super().__init__()
        self.heads = torch.nn.ModuleList([CausalAttention(d_in=d_in, d_out=d_out, context_length=6, dropout=0.5) for _ in range(num_heads)])

    def forward(self, x):
        out = torch.cat([head(x) for head in self.heads], dim=-1) #dim=-1 means that we are concatenating the output of the heads along the last dimension
        return out


# what we are doing here is we are stacking multiple single head attention mechanisms to make a multi head attention mechanism
# so we are creating a list of causal attention mechanisms and we are stacking them up to make a multi head attention mechanism

context_length = batch.shape[1]
d_in, d_out = batch.shape[-1], 2



In [87]:
d_in, d_out, context_length

(3, 2, 6)

In [88]:
mha = MultiHeadAttention(d_in, d_out, context_length=context_length, dropout=0, num_heads=2)
context_vec = mha(batch)
print(context_vec.shape)

torch.Size([2, 6, 4])


In [89]:
context_vec

tensor([[[ 0.0000,  0.0000, -0.0364,  0.0342],
         [-0.0183,  0.0773, -0.0976,  0.0848],
         [-0.1362,  0.1180, -0.1617,  0.1407],
         [-0.3344,  0.1102, -0.0535,  0.0437],
         [ 0.0000,  0.0000, -0.1070,  0.0905],
         [-0.5707,  0.2531,  0.0542,  0.0884]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000],
         [-0.0945,  0.0350, -0.0976,  0.0848],
         [-0.0182,  0.0773, -0.0972,  0.0844],
         [-0.1174,  0.0404, -0.1050,  0.0757],
         [-0.4025,  0.0223,  0.0488,  0.0912],
         [-0.5957,  0.3066, -0.0634,  0.1779]]], grad_fn=<CatBackward0>)

In [95]:
class MultiHeadAttentionv2(torch.nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`, 
        # this will result in errors in the mask creation further below. 
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs  
        # do not exceed `context_length` before reaching this forward method.

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec


torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 3
mha = MultiHeadAttention(d_in=d_in, d_out=d_out, context_length=context_length, dropout=0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[ 0.1029,  0.1751, -0.0969,  0.0276,  0.1675,  0.1001],
         [ 0.2526,  0.4133, -0.1646,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000, -0.0072,  0.6432,  0.2947],
         [ 0.3149,  0.5051, -0.1411,  0.0090,  0.7815,  0.3809],
         [ 0.4092,  0.6672, -0.2384, -0.0188,  0.8696,  0.4006],
         [ 0.2921,  0.4824, -0.1905, -0.0576,  1.9283,  0.8964]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.2526,  0.4133, -0.1646,  0.0000,  0.0000,  0.0000],
         [ 0.2523,  0.4127, -0.1645,  0.0206,  0.4314,  0.2188],
         [ 0.2595,  0.4251, -0.1714, -0.0075,  0.6650,  0.3046],
         [ 0.3920,  0.8400, -0.4137, -0.0150,  0.5256,  0.2440],
         [ 0.7320,  0.8764, -0.0160, -0.0583,  1.4093,  0.7391]]],
       grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 6])


In [94]:
batch.shape

torch.Size([2, 6, 3])