# Coding Attention Mechanism

## Simple Self Attention

This code implements a simple self attention mechanism without trainable weights. Its just done to illustrate the concept of self attention.

![title](./SelfAttention.png)

The above image summrizes how this self attention works.

Consider the following input embeddings for the tokens


In [2]:
import torch
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]  # step     (x^6)
  ])


Now since we want to find Z<sub>(2)</sub> which is the context vector of the second token X<sub>(2)</sub>, we will need have our query vector as X<sub>(2)</sub>.

We will next need to compute the dot product of all the imput embeddings with this query vector

In [16]:
query = inputs[1]  # X(2)
# This same as individually computing the dot product of each input with the query vector,, alternate way is to do inputs @ query
attn_scores_2 = torch.matmul(inputs, query)
print(f"Attention scores for X(2): {attn_scores_2}")
# Normalize this vector
# attn_weights_2_tmp = attn_scores_2 / torch.sum(attn_scores_2)
# print(f"Normalized attention scores for X(2): {attn_weights_2_tmp}")

#Normalize using softmax
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())


Attention scores for X(2): tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])
Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


The context vector is weighted sum of these attention weights and the input embesddings

In [39]:
context_vec_2 = torch.sum(
    (inputs   # inputs
     *
    # Reshaping ensures we get the right shape for broadcasting, by having the number of columns as 1 we get a matric is size (N, 1)
    # since inputs is of size (N, D), we are able to broadcast each attention weight to D dimensions
     attn_weights_2.reshape(-1, 1) ),
    dim=0  # By doing dim=0 we aggregate rows retaining the tensor with D columns
)
print("Context vector for X(2):", context_vec_2)

Context vector for X(2): tensor([0.4419, 0.6515, 0.5683])


Conceptually what we did above is shown in the following image

![text](./SelfAttention2.png)

### We will next compute the attention weights of all input tokens

Step 1 is to calculate the attention score of each input token with each other input token. Lets say we have N tokens each of D dimension, the attention score will be of size (N, N) where each row is the attention score of the i<sup>th</sup> token with all other tokens.

In [49]:
# This one liner computes the dot product of each input with every other input
attn_scores = inputs @ inputs.T
print(f"Attention scores:\n{attn_scores}")


Attention scores:
tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


Notice how its a symmetric matrix where each row (or even column) is the attention score of the i<sup>th</sup> token with all other tokens.

We will now apply softmax normalization to these attn_scores

In [62]:
attn_weights = torch.softmax(attn_scores, dim = 1)
print(f"Attention weights:\n{attn_weights}")

Attention weights:
tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


Finally we will compute the context vector for each input token by multiplying the attention weights with the input embeddings

In [66]:
all_context_vecs = attn_weights @ inputs
print(f"Context vectors:\n{all_context_vecs}")

Context vectors:
tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


In [67]:
print("Previous 2nd context vector:", context_vec_2)
print("New 2nd context vector:", all_context_vecs[1])

Previous 2nd context vector: tensor([0.4419, 0.6515, 0.5683])
New 2nd context vector: tensor([0.4419, 0.6515, 0.5683])


## Implement Self Attention with Trainable Weights

As we see in the following image, we have three sets of weights W<sub>Q</sub>, W<sub>K</sub> and W<sub>V</sub> which are used to transform the input embeddings into Query, Key and Value vectors.

![test](./TrainableSelfAttention1.png)

We will again use the same input x<sub>2</sub> to x<sub>6</sub> as before. Generally the input dimension and output dimensions are same but we will keep it different to illustrate the concept.

In [72]:
x_2 = inputs[1]
d_in = x_2.shape[0]
d_out = 2

In [78]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [79]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value
print(f"Query vector for x2: {query_2}")

Query vector for x2: tensor([0.4306, 1.4551])


We will need the entire key and value matrices to compute the attention weights and context vector, thus we will transform the entire input matrix

In [82]:
keys = inputs @ W_key
values = inputs @ W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)


keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


With the query vector and the keys and values metrices, the unscaled attention scores is computed using the dot product of the query vector with each of the key vectors

![test](./TrainableSelfAttention2.png)

In [98]:
attn_scores_2 = query_2 @ keys.T
print(f"Attention scores for x2: {attn_scores_2}")

# Notice how the attention scores are scaled by sqrt(d_k) where d_k is the dimension of the key vectors
# This is done to avoid very small gradients when d_k is large
# See https://arxiv.org/abs/1706.03762 for more details
d_k = keys.shape[-1]
attn_scores_2 = torch.softmax(attn_scores_2/ d_k**0.5, dim = -1)
print(f"Attention scores for x2: {attn_scores_2}")


Attention scores for x2: tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])
Attention scores for x2: tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


![test](./AttentionScalingReason.png)

Finally the context vector is computed as the weighted sum of the value vectors

In [106]:
context_vec_2 = attn_scores_2 @ values
print(f"Context vector for x2: {context_vec_2}")

Context vector for x2: tensor([0.3061, 0.8210])


![test](./WhyQueryKeyValue.png)

In [114]:
import torch.nn as nn
class SelfAttentionV1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        return attn_weights @ values

In [115]:
torch.manual_seed(123)
sa_v1 = SelfAttentionV1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


What we see above are the context vectors of each input token after applying self attention. If we notice the second row, it is same as the context vector we computed manually above.

In [121]:
import torch.nn as nn
class SelfAttentionV2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        return attn_weights @ values


In [122]:
torch.manual_seed(789)
sa_v2 = SelfAttentionV2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


V2 used ``nn.Linear`` insteas of ``nn.Parameter`` to define the weight matrices. This is more standard way of defining the weights in PyTorch.

Apart from the weight initialization, the two implementations are identical. Lets verify that by using the weights of the linear layers of V2 to initialize nn.Parameter weights

In [128]:
# Initialize nn.Parameter with weights same as V2 weights
W_query_test = nn.Parameter(sa_v2.W_query.weight.T)
W_key_test = nn.Parameter(sa_v2.W_key.weight.T)
W_value_test = nn.Parameter(sa_v2.W_value.weight.T)

# perform the attention calculation manually
queries_test = inputs @ W_query_test
keys_test = inputs @ W_key_test
values_test = inputs @ W_value_test
attn_scores_test = queries_test @ keys_test.T
attn_weights_test = torch.softmax(attn_scores_test / keys_test.shape[-1]**0.5, dim=-1)
print(f"Attention scores calculates using V2 weights manually using nn.Parameter are\n {attn_weights_test @ values_test}")

Attention scores calculates using V2 weights manually using nn.Parameter are
 tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


### Hiding future words with causal attention

For many LLM tasks like text genertion we want to ensure that the model does not attend to future words. This is done using causal attention where we mask the attention scores of future words by setting them to ``-Inf`` before applying softmax.

In [138]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
print("Attention scores before masking:\n", attn_scores)

Attention scores before masking:
 tensor([[ 0.2899,  0.0716,  0.0760, -0.0138,  0.1344, -0.0511],
        [ 0.4656,  0.1723,  0.1751,  0.0259,  0.1771,  0.0085],
        [ 0.4594,  0.1703,  0.1731,  0.0259,  0.1745,  0.0090],
        [ 0.2642,  0.1024,  0.1036,  0.0186,  0.0973,  0.0122],
        [ 0.2183,  0.0874,  0.0882,  0.0177,  0.0786,  0.0144],
        [ 0.3408,  0.1270,  0.1290,  0.0198,  0.1290,  0.0078]],
       grad_fn=<MmBackward0>)


We will now mask the upper triangular part of the attention scores by setting them to -Inf. Remember softmax is
$$
\frac{\exp(x_i)}{\sum_j \exp(x_j)}
$$

Setting x to ``-Inf`` to will make $$\exp(x_i)$$ 0 thus ensuring that the attention weights for those positions are 0


In [148]:
context_length = attn_scores.shape[0]
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print("Masked attention scores:\n", masked)
attn_weights = torch.softmax(masked / queries.shape[-1]**0.5, dim=-1)
print("Attention weights after masking:\n", attn_weights)
context_vec = attn_weights @ values

Masked attention scores:
 tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)
Attention weights after masking:
 tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


### Masked attention weights with dropout

To prevent overfitting dropout is applied at two possible places
 1. After calculating the attention weights
 2. After calculating the context vectors

Here we will dropout after we calculate the attention weights

Lets also illustrate with an example what dropout exactly does

In [149]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5)  # 50% dropout
example = torch.ones(6, 6)
print(f"Example before dropout:\n{example}")
dropped = dropout(example)
print(f"Example after dropout:\n{dropped}")

Example before dropout:
tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])
Example after dropout:
tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


Notice how the values of the remaining are scaled by 1/(1-p) where p is the dropout probability. This ensures that the expected value of the tensor remains same before and after dropout. Notice the values dropped doesn't have to exactly half of the values since its a random process. In above example we dropped 15/36 values instead of 18/36 for exact 50% dropout. However the scaling is deterministic and is always done by 1/(1-p)

Now lets apply the dropout to the attention weights

In [155]:
torch.manual_seed(123)
print(f"Attention weights with dropout\n{dropout(attn_weights)}")

Attention weights with dropout
tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)


Lets now implement the self attention with causal masking and dropout. Also lets simulate the batching

In [157]:
batch = torch.stack([inputs, inputs], dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [163]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # register_buffer ensures that the mask is not a trainable parameter and is moved to the right device when model.to(device) is called
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in_ca = x.shape
        queries_ca = self.W_query(x)
        keys_ca = self.W_key(x)
        values_ca = self.W_value(x)
        # Transpose last two dimensions so that we don't transpose the batch size
        attn_scores_ca = queries_ca @ keys_ca.transpose(-2, -1)
        # The _ indicates that the operation is done in place, also [:num_tokens, :num_tokens]
        # ensures that we can handle variable length sequences if there are less than context_length tokens
        attn_scores_ca.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights_ca = torch.softmax(attn_scores_ca / d_in_ca ** 0.5, dim=-1)
        attn_weights_ca = self.dropout(attn_weights_ca)
        return attn_weights_ca @ values_ca


torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)


context_vecs.shape: torch.Size([2, 6, 2])


### Extending single-head attention to multi-head attention

 A single causal attention module can be considered single-head attention, where there is only one set of attention weights processing the input sequentially.

We will build the intuition of multi-head attention by simply stacking the ``CausalAttention`` modules

![test](./MultiheadAttentionConcept.png)

The output dimension of the multi-head attention will be ``num_heads * d_out`` since we are concatenating the output of each head

In [179]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                        dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList([CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)])

    def forward(self, x):
        # Exercise 3.2: Make the dim=-2 instead of dim=-1 and we will concat the context vectors vertically such that dimension of
        # of each context vector is preserved however we will double the number of rows in eacxh batch since now we generate
        # contect vectors for each token
        return torch.cat([head(x) for head in self.heads], dim=-1)

torch.manual_seed(123)
context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)
print(f"context_vecs {context_vecs}")
print(f"context_vecs.shape: {context_vecs.shape}")


context_vecs tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5866,  0.0071,  0.5869,  0.3214],
         [-0.6293, -0.0621,  0.6184,  0.3825],
         [-0.5670, -0.0838,  0.5474,  0.3575],
         [-0.5519, -0.0979,  0.5319,  0.3423],
         [-0.5295, -0.1077,  0.5074,  0.3481]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5866,  0.0071,  0.5869,  0.3214],
         [-0.6293, -0.0621,  0.6184,  0.3825],
         [-0.5670, -0.0838,  0.5474,  0.3575],
         [-0.5519, -0.0979,  0.5319,  0.3423],
         [-0.5295, -0.1077,  0.5074,  0.3481]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


#### Implementing multi-head attention with weight splits

What we see above is a naive implementation of multi-head attention where we simply stack multiple single-head attention modules. All these single heads ate processed sequentially which is not efficient.

In [199]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # This is new, we will add an optional Linear layer to project the output.
        self.out_proj = nn.Linear(d_out, d_out)
        self.register_buffer("mask", torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in_ca = x.shape
        keys_mha = self.W_key(x)       # (b, num_tokens, d_out)
        values_mha = self.W_value(x)   # (b, num_tokens, d_out)
        queries_mha = self.W_query(x)  # (b, num_tokens, d_out)

        # d_out is same as num_heads * head_dim
        # view_ reshapes the tensor without changing its data, in this case we project the
        # last d_out dimension to (num_heads, head_dim)
        keys_mha = keys_mha.view(b, num_tokens, self.num_heads, self.head_dim) # (b, num_tokens, num_heads, head_dim)
        values_mha = values_mha.view(b, num_tokens, self.num_heads, self.head_dim) # (b, num_tokens, num_heads, head_dim)
        queries_mha = queries_mha.view(b, num_tokens, self.num_heads, self.head_dim) # (b, num_tokens, num_heads, head_dim)

        # To calculation the attention score, we need the last two dimensions to be num_tokens and head_dim
        # thus we need to transpose the 1st and 2nd dimensions
        queries_mha.transpose_(1, 2)  # (b, num_heads, num_tokens, head_dim)
        keys_mha.transpose_(1, 2)     # (b, num_heads, num_tokens, head_dim)
        values_mha.transpose_(1, 2)   # (b, num_heads, num_tokens, head_dim)

         # Let calculate the attention scores, this is the dot product of queries and keys
        attn_scores_mha = queries_mha @ keys_mha.transpose(-2, -1) # (b, num_heads, num_tokens, num_tokens)

        # Apply the mask, the dimensions of the attn scores are still (_, _, context_length, context_length)
        #  the mask is 2D and is applied to the last two dimensions only
        attn_scores_mha.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) #(b, num_heads, num_tokens, num_tokens)
        attn_weights_mha = torch.softmax(attn_scores_mha / self.head_dim ** 0.5, dim=-1) #(b, num_heads, num_tokens, num_tokens)
        # Apply dropout to the attention weights
        attn_weights_mha = self.dropout(attn_weights_mha) # (b, num_heads, num_tokens, num_tokens)
        # attn_weights_mha @ values_mha gives (b, num_heads, num_tokens, head_dim)
        # We need to transpose the 1st and 2nd dimensions to get (b, num_tokens, num_heads, head_dim)
        context_vecs_mha = (attn_weights_mha @ values_mha).transpose(1,2) # (b, num_tokens, num_heads, head_dim)
        # We will reshape the context vectors back to (b, num_tokens, d_out) where d_out = num_heads * head_dim
        context_vecs_mha = context_vecs_mha.contiguous().view(b, num_tokens, self.d_out) # (b, num_tokens, d_out)
        # Finally we will project the output using the out_proj layer
        context_vecs_mha = self.out_proj(context_vecs_mha)
        return context_vecs_mha



In [201]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


### Exercise 3.3

Using the MultiHeadAttention class, initialize a multi-head attention module that has the same number of attention heads as the smallest GPT-2 model (12 attention heads). Also ensure that you use the respective input and output embedding sizes similar to GPT-2 (768 dimensions). Note that the smallest GPT-2 model supports a context length of 1,024 tokens.

In [203]:
d_in, d_out = 768, 768
context_length = 1024
num_heads = 12
dropout = 0
mha_gpt2 = MultiHeadAttention(d_in, d_out, context_length, dropout, num_heads)
print(mha_gpt2)

MultiHeadAttention(
  (W_query): Linear(in_features=768, out_features=768, bias=False)
  (W_key): Linear(in_features=768, out_features=768, bias=False)
  (W_value): Linear(in_features=768, out_features=768, bias=False)
  (dropout): Dropout(p=0, inplace=False)
  (out_proj): Linear(in_features=768, out_features=768, bias=True)
)
