## Coding attention mechanisms

At this point, you know how to prepare the input text for training LLMs by splitting
text into individual word and subword tokens, which can be encoded into vector rep-
resentations, embeddings, for the LLM.
Now, we will look at an integral part of the LLM architecture itself, attention
mechanisms, as illustrated in figure 3.1. We will largely look at attention mechanisms
in isolation and focus on them at a mechanistic level. Then we will code the remaining parts of the LLM surrounding the self-attention mechanism to see it in action and to
create a model to generate text.

In [55]:
from __future__ import annotations

import torch
import torch.nn as nn

### A simple self-attention mechanism without trainable weights

Consider the following input sentence, which has already been embedded into
three-dimensional vectors 

In [2]:
inputs = torch.tensor(
    [
        [0.43, 0.15, 0.89], # Your(x^1)
        [0.55, 0.87, 0.66], # journey(x^2)
        [0.57, 0.85, 0.64], # starts(x^3)
        [0.22, 0.58, 0.33], # with(x^4)
        [0.77, 0.25, 0.10], # one(x^5)
        [0.05, 0.80, 0.55]# step(x^6)
    ]
)

The first step of implementing self-attention is to compute the intermediate values ω,
referred to as attention scores.

- attention scores for des second token of the input : journey(x^2)

In [6]:
query2 = inputs[1]
attn_scores_2 = torch.empty(inputs.shape[0])
for i, key in enumerate(inputs):
    attn_scores_2[i] = torch.dot(query2, key)

attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()

print("Attention weights:", attn_weights_2_tmp)

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])


In practice, it’s more common and advisable to use the softmax function for normalization. This approach is better at managing extreme values and offers more favorable gradient properties during training. The following is a basic implementation of the softmax function for normalizing the attention scores:

In [7]:
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

In [8]:
attn_weights_2_naive = softmax_naive(attn_scores_2)
print("Attention weights:", attn_weights_2_naive)

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


Note that this naive softmax implementation (softmax_naive) may encounter
numerical instability problems, such as overflow and underflow, when dealing with
large or small input values. Therefore, in practice, it’s advisable to use the PyTorch
implementation of softmax, which has been extensively optimized for performance:

In [9]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print("Attention weights:", attn_weights_2)

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


Now that we have computed the normalized attention weights, we are ready for the
final step : calculating the context vector z(2) by multiplying the
embedded input tokens, x(i), with the corresponding attention weights and then sum-
ming the resulting vectors.

In [24]:
query = inputs[1]
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i]*x_i
print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


- Computing attention weights for all input tokens

In [28]:
attn_scores = torch.zeros(inputs.shape[0], inputs.shape[0])
for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


When computing the preceding attention score tensor, we used for loops in
Python. However, for loops are generally slow, and we can achieve the same results
using matrix multiplication:

In [38]:
attn_scores = inputs@inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


we can now normalize each row with softmax so that the values in each row sum to 1

In [39]:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In the third and final step, we use these attention weights to compute all
context vectors via matrix multiplication:

In [42]:
context_vecs = attn_weights@inputs
print(context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


### Implementing self-attention with trainable weights

Our next step will be to implement the self-attention mechanism used in the original transformer architecture, the GPT models, and most other popular LLMs. This self-attention mechanism is also called scaled dot-product attention.

The most notable difference is the introduction of weight matrices that are
updated during model training. These trainable weight matrices are crucial so that
the model (specifically, the attention module inside the model) can learn to produce
“good” context vectors

- Computing the attention weights step by step
  
  We will implement the self-attention mechanism step by step by introducing the
three trainable weight matrices W_q, W_k , and W_v. These three matrices are used to
project the embedded input tokens, x(i), into query, key, and value vectors, respec-
tively.

In [None]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

In [45]:
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [48]:
query_2 = x_2@W_query
key_2 = x_2@W_key
value_2 = x_2@W_value
print(query_2)

tensor([0.4306, 1.4551])


The output for the query results in a two-dimensional vector since we set the number
of columns of the corresponding weight matrix, via d_out, to 2

Even though our temporary goal is only to compute the one context vector, z(2), we still
require the key and value vectors for all input elements as they are involved in com-
puting the attention weights with respect to the query q (2).
We can obtain all keys and values via matrix multiplication:

In [None]:
keys = inputs@W_key
values = inputs@W_value
print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


The second step is to compute the attention scores.

In [51]:
attn_scores_2 = query_2@keys.T
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


Now, we want to go from the attention scores to the attention weights. We compute the attention weights by scaling the attention scores and
using the softmax function. However, now we scale the attention scores by dividing
them by the square root of the embedding dimension of the keys 

In [53]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2/d_k**2, dim=-1)
print(attn_weights_2)

tensor([0.1623, 0.1877, 0.1858, 0.1547, 0.1358, 0.1738])


Now we can use matrix multiplication between `attn_weights_2` and `values` to obtain the output context vecteur for the token `x^2`

In [54]:
context_vec_2 = attn_weights_2@values
print(context_vec_2)

tensor([0.2896, 0.7811])


At this point, we have gone through a lot of steps to compute the self-attention out-
puts. We did so mainly for illustration purposes so we could go through one step at a
time. In practice, with the LLM implementation in the next chapter in mind, it is
helpful to organize this code into a Python class, as shown in the following listing

In [60]:
class SelfAttentionV1(nn.Module):
    
    def __init__(self, d_in:int, d_out:int):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
        
    def forward(self, x):
        keys = x@self.W_key
        queries = x@self.W_key
        values = x@self.W_value
        attn_scores = queries@keys.T
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        context_vecs = attn_weights@values
        return context_vecs

We can use this class as follows:

In [75]:
torch.manual_seed(123)
sa_v1 = SelfAttentionV1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2948, 0.7944],
        [0.3013, 0.8099],
        [0.3009, 0.8089],
        [0.2927, 0.7888],
        [0.2866, 0.7737],
        [0.2979, 0.8016]], grad_fn=<MmBackward0>)


We can improve the SelfAttentionV1 implementation further by utilizing
PyTorch’s nn.Linear layers, which effectively perform matrix multiplication when
the bias units are disabled. Additionally, a significant advantage of using nn.Linear instead of manually implementing nn.Parameter(torch.rand(...)) is that nn.Linear
has an optimized weight initialization scheme, contributing to more stable and
effective model training

In [65]:
class SelfAttentionV2(nn.Module):
    
    def __init__(self, d_in:int, d_out:int, qkv_bias:bool=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        
    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries@keys.T
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        context_vecs = attn_weights@values
        return context_vecs
        

You can use the SelfAttentionV2 similar to SelfAttentionV1:

In [84]:
torch.manual_seed(789)
sa_v2 = SelfAttentionV2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0793,  0.0640],
        [-0.0829,  0.0591],
        [-0.0825,  0.0596],
        [-0.0815,  0.0608],
        [-0.0746,  0.0700],
        [-0.0849,  0.0562]], grad_fn=<MmBackward0>)


To check that both
implementations, SelfAttention_v1 and SelfAttention_v2 , are otherwise simi-
lar, we can transfer the weight matrices from a SelfAttention_v2 object to a Self-
Attention_v1, such that both objects then produce the same results

In [91]:
sa_v1.W_query = torch.nn.Parameter(sa_v2.W_query.weight.T)
sa_v1.W_key = torch.nn.Parameter(sa_v2.W_key.weight.T)
sa_v1.W_value = torch.nn.Parameter(sa_v2.W_value.weight.T)
sa_v1(inputs)

tensor([[-0.0793,  0.0640],
        [-0.0829,  0.0591],
        [-0.0825,  0.0596],
        [-0.0815,  0.0608],
        [-0.0746,  0.0700],
        [-0.0849,  0.0562]], grad_fn=<MmBackward0>)

### Hiding future words with causal attention

 Causal attention, also known as masked attention, is a specialized form of self-
attention. It restricts a model to only consider previous and current inputs in a sequence
when processing any given token when computing attention scores

#### Applying a causal attention mask

One way to obtain the masked attention weight matrix in causal attention is to apply the
softmax function to the attention scores, zeroing out the elements above the diagonal and normalizing
the resulting matrix

In [93]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries@keys.T
attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
attn_weights

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)

We can implement the second step using PyTorch’s tril function to create a mask
where the values above the diagonal are zero

In [97]:
context_lengt = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_lengt, context_lengt))
masked_simple = attn_weights*mask_simple
masked_simple

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)

In [98]:
masked_simple_norm = masked_simple/masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)

When we apply a mask and then renormalize the attention weights, it might initially
appear that information from future tokens (which we intend to mask) could still influ-
ence the current token because their values are part of the softmax calculation. How-
ever, the key insight is that when we renormalize the attention weights after masking what we’re essentially doing is recalculating the softmax over a smaller subset (since
masked positions don’t contribute to the softmax value)

A more efficient way to obtain the masked attention weight matrix in
causal attention is to mask the attention scores with negative infinity values before
applying the softmax function

In [105]:
mask = torch.triu(torch.ones(context_lengt, context_lengt), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)

In [106]:
attn_weights = torch.softmax(masked/keys.shape[-1], dim=-1)
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5366, 0.4634, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3660, 0.3168, 0.3172, 0.0000, 0.0000, 0.0000],
        [0.2681, 0.2473, 0.2474, 0.2371, 0.0000, 0.0000],
        [0.2123, 0.1988, 0.1989, 0.1920, 0.1980, 0.0000],
        [0.1853, 0.1665, 0.1667, 0.1578, 0.1667, 0.1569]],
       grad_fn=<SoftmaxBackward0>)

#### Masking additional attention weights with dropout

Dropout in deep learning is a technique where randomly selected hidden layer units are ignored during training, effectively “dropping” them out. This method helps prevent overfitting by ensuring that a model does not become overly reliant on any specific set of hidden layer units. It’s important to emphasize that dropout is only used during training and is disabled afterward.

In [107]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) #dropout 50% of tokens
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


When applying dropout to an attention weight matrix with a rate of 50%, half of the
elements in the matrix are randomly set to zero. To compensate for the reduction in
active elements, the values of the remaining elements in the matrix are scaled up by a
factor of 1/0.5 = 2

Now let’s apply dropout to the attention weight matrix itself

In [112]:
attn_weights

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5366, 0.4634, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3660, 0.3168, 0.3172, 0.0000, 0.0000, 0.0000],
        [0.2681, 0.2473, 0.2474, 0.2371, 0.0000, 0.0000],
        [0.2123, 0.1988, 0.1989, 0.1920, 0.1980, 0.0000],
        [0.1853, 0.1665, 0.1667, 0.1578, 0.1667, 0.1569]],
       grad_fn=<SoftmaxBackward0>)

In [111]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.9268, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6344, 0.0000, 0.0000, 0.0000],
        [0.5362, 0.4946, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4246, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3331, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


### Implementing a compact causal attention class

In [None]:
# -----------------------------------------------------------------------------
# CausalAttention : implémentation d'une attention "causale" de style GPT
#
# - Cette couche projette les entrées (x) en trois matrices : requêtes (Q),
#   clés (K) et valeurs (V) via trois couches linéaires.
#
# - Le masque causal (matrice triangulaire supérieure) empêche chaque token
#   d'accéder aux tokens futurs : un token t ne peut assister qu'aux positions
#   ≤ t. Cela impose la propriété *auto-régressive*, essentielle pour les GPT.
#
# - Le score d’attention est calculé par Q @ Kᵀ puis normalisé via softmax
#   (après l'application du masque).
#
# - On applique ensuite un dropout sur les poids d’attention pour régulariser.
#
# - Le vecteur de contexte est obtenu par : AttentionWeights @ V,
#   ce qui permet à chaque token d'agréger l'information pertinente des tokens
#   précédents.
#
# Résultat : une couche d’attention correcte pour un modèle de langage
# auto-régressif, où chaque position ne dépend que du passé, jamais du futur.


# register_buffer() permet de stocker un tenseur dans le module
# sans le considérer comme un paramètre entraînable.
# Avantages :
# - déplacé automatiquement sur CPU/GPU avec model.to(device)
# - sauvegardé dans state_dict()
# - pas mis à jour par l’optimizer
#
# Idéal pour les masques causaux, constantes, etc.

# -----------------------------------------------------------------------------

In [118]:
class CausalAttention(nn.Module):
    
    def __init__(self, d_in:int, d_out:int, context_length:int, dropout:float, qkv_bias:bool=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
    
    def forward(self, x:torch.Tensor):
        num_tokens = x.shape[1]
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)
        attn_scores = queries@keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vecs = attn_weights@values
        return context_vecs

In [113]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [145]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("conttext_vecs:", context_vecs)
print("\ncontext_vecs.shape:", context_vecs.shape)

conttext_vecs: tensor([[[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]],

        [[-0.4519,  0.2216],
         [-0.5874,  0.0058],
         [-0.6300, -0.0632],
         [-0.5675, -0.0843],
         [-0.5526, -0.0981],
         [-0.5299, -0.1081]]], grad_fn=<UnsafeViewBackward0>)

context_vecs.shape: torch.Size([2, 6, 2])


## Extending single-head attention to multi-head attention

Our final step will be to extend the previously implemented causal attention class over
multiple heads. This is also called multi-head attention.

The term “multi-head” refers to dividing the attention mechanism into multiple
“heads,” each operating independently. In this context, a single causal attention mod-
ule can be considered single-head attention, where there is only one set of attention
weights processing the input sequentially.

### Stacking multiple single-head attention layers

In [None]:
class MultiHeadAttentionWrapper(nn.Module):
    
    def __init__(self, d_in:int, d_out:int, context_length:int, dropout:int, num_heads:int, qkv_bias:bool=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [
                CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)
            ]
        )
        
    def forward(self, x:torch.Tensor):
        
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [None]:
torch.manual_seed(123)
mhaw = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=3)
context_vecs = mhaw(batch)
print("conttext_vecs:", context_vecs)
print("\ncontext_vecs.shape:", context_vecs.shape)

conttext_vecs: tensor([[[-0.4519,  0.2216,  0.4772,  0.1063,  0.4566,  0.2729],
         [-0.5874,  0.0058,  0.5891,  0.3257,  0.5792,  0.3011],
         [-0.6300, -0.0632,  0.6202,  0.3860,  0.6249,  0.3102],
         [-0.5675, -0.0843,  0.5478,  0.3589,  0.5691,  0.2785],
         [-0.5526, -0.0981,  0.5321,  0.3428,  0.5543,  0.2520],
         [-0.5299, -0.1081,  0.5077,  0.3493,  0.5337,  0.2499]],

        [[-0.4519,  0.2216,  0.4772,  0.1063,  0.4566,  0.2729],
         [-0.5874,  0.0058,  0.5891,  0.3257,  0.5792,  0.3011],
         [-0.6300, -0.0632,  0.6202,  0.3860,  0.6249,  0.3102],
         [-0.5675, -0.0843,  0.5478,  0.3589,  0.5691,  0.2785],
         [-0.5526, -0.0981,  0.5321,  0.3428,  0.5543,  0.2520],
         [-0.5299, -0.1081,  0.5077,  0.3493,  0.5337,  0.2499]]],
       grad_fn=<CatBackward0>)

context_vecs.shape: torch.Size([2, 6, 6])


###  Implementing multi-head attention with weight splits

In the MultiHeadAttentionWrapper, multiple heads are implemented by creating
a list of CausalAttention objects (self.heads), each representing a separate attention head. The CausalAttention class independently performs the attention mechanism, and the results from each head are concatenated. In contrast, the following
MultiHeadAttention class integrates the multi-head functionality within a single class.
It splits the input into multiple heads by reshaping the projected query, key, and value
tensors and then combines the results from these heads after computing attention.

In [213]:
class MultiHeadAttention(nn.Module):
    
    def __init__(self, d_in:int, d_out:int, num_heads:int, context_length:int, dropout:float, qkv_bias:bool=False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"
        
        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out//num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_keys = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )
        self.dropout = nn.Dropout(dropout)
        self.out_proj = nn.Linear(d_out, d_out)
   
    def forward(self, x:torch.Tensor):
        
        num_batchs, num_tokens, d_in = x.shape
        
        keys = self.W_keys(x) # -> (num_batchs, num_tokens, d_out=num_heads*head_dim)
        queries = self.W_query(x) # -> (num_batchs, num_tokens, d_out=num_heads*head_dim)
        values = self.W_value(x) # -> (num_batchs, num_tokens, d_out=num_heads*head_dim)
        print(f"keys : {keys}")
        
        keys = keys.view(num_batchs, num_tokens, self.num_heads, self.head_dim) # -> (num_batchs, num_tokens, num_heads, head_dim)
        queries = queries.view(num_batchs, num_tokens, self.num_heads, self.head_dim)
        values= values.view(num_batchs, num_tokens, self.num_heads, self.head_dim)
        print(f"keys : {keys}")
        
        keys = keys.transpose(1, 2) # -> (num_batchs, num_heads, num_tokens, head_dim)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)
        print(f"keys : {keys}")
        
        attn_scores = queries@keys.transpose(-2, -1) # -> (num_batchs, num_heads, num_tokens, num_tokens)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )
        print(f"attn_scores : {attn_scores}")
        
        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        print(f"Attn_weights : {attn_weights}")
        
        print(f"attn_weights@values : {attn_weights@values}")
        
        context_vecs = (attn_weights@values).transpose(1, 2) # -> (num_batchs,  num_tokens, num_heads, head_dim)
        
        print(f"context_vecs : {context_vecs}")
        
        context_vecs = context_vecs.contiguous().view(num_batchs,  num_tokens, self.d_out) # -> (num_batchs,  num_tokens, d_out = num_heads*head_dim)
        print(f"context_vecs : {context_vecs}")
          
        context_vecs = self.out_proj(context_vecs)
        
        return context_vecs

In [214]:
mha = MultiHeadAttention(d_in=3, d_out=6, num_heads=2, context_length=6, dropout=0.1)
mha(batch)

keys : tensor([[[ 0.0618, -0.0698, -0.1751, -0.0632, -0.6651, -0.5935],
         [-0.2447,  0.3075,  0.3275, -0.3468, -0.9567, -0.9076],
         [-0.2248,  0.2904,  0.3332, -0.3266, -0.9397, -0.9034],
         [-0.2204,  0.2532,  0.2174, -0.2655, -0.5439, -0.4879],
         [ 0.1968, -0.0995,  0.3421,  0.1291, -0.3686, -0.5725],
         [-0.4176,  0.4257,  0.1715, -0.4700, -0.7427, -0.5625]],

        [[ 0.0618, -0.0698, -0.1751, -0.0632, -0.6651, -0.5935],
         [-0.2447,  0.3075,  0.3275, -0.3468, -0.9567, -0.9076],
         [-0.2248,  0.2904,  0.3332, -0.3266, -0.9397, -0.9034],
         [-0.2204,  0.2532,  0.2174, -0.2655, -0.5439, -0.4879],
         [ 0.1968, -0.0995,  0.3421,  0.1291, -0.3686, -0.5725],
         [-0.4176,  0.4257,  0.1715, -0.4700, -0.7427, -0.5625]]],
       grad_fn=<UnsafeViewBackward0>)
keys : tensor([[[[ 0.0618, -0.0698, -0.1751],
          [-0.0632, -0.6651, -0.5935]],

         [[-0.2447,  0.3075,  0.3275],
          [-0.3468, -0.9567, -0.9076]],

    

tensor([[[ 0.4615,  0.0466, -0.0957, -0.0997,  0.4638, -0.1654],
         [ 0.5027,  0.0054, -0.2106, -0.0236,  0.5276, -0.1192],
         [ 0.5174, -0.0086, -0.2539,  0.0018,  0.5512, -0.1015],
         [ 0.4743, -0.0484, -0.1432, -0.0860,  0.4775, -0.0169],
         [ 0.4799, -0.0101, -0.2614, -0.0658,  0.5580, -0.0273],
         [ 0.5062, -0.0405, -0.2377, -0.0285,  0.5370, -0.0331]],

        [[ 0.4615,  0.0466, -0.0957, -0.0997,  0.4638, -0.1654],
         [ 0.4000, -0.0204, -0.1297, -0.1882,  0.4936,  0.0512],
         [ 0.5174, -0.0086, -0.2539,  0.0018,  0.5512, -0.1015],
         [ 0.5103, -0.0508, -0.1956, -0.0288,  0.5020, -0.0481],
         [ 0.4813, -0.0433, -0.2037, -0.0852,  0.5100,  0.0032],
         [ 0.5091, -0.0790, -0.1353, -0.0593,  0.4673,  0.0006]]],
       grad_fn=<ViewBackward0>)