<a href="https://www.kaggle.com/code/aisuko/causal-self-attention?scriptVersionId=164120189" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Overview

In this notebook, we are adapting the previously discussed self-attention mechanism into a causal self-attention mechanism, specifically for GPT-like(decoder-style) LLMs that are used to generate text. The previously notebooks are below:
* [Encoder in Transformers Architecture](https://www.kaggle.com/code/aisuko/encoder-in-transformers-architecture)
* [Decoder in Transformers architecture](https://www.kaggle.com/code/aisuko/decoder-in-transformers-architecture)
* [Coding the Self-Attention Mechanism](https://www.kaggle.com/code/aisuko/coding-the-self-attention-mechanism)
* [Coding the Multi-Head Attention](https://www.kaggle.com/code/aisuko/coding-the-multi-head-attention)
* [Coding Cross-Attention](https://www.kaggle.com/code/aisuko/coding-cross-attention)


And the causal self-attention mechanism is also often referred to as **masked self-attention**. In the original transformer architecture, it corresponds to the "masked multi-head attention" module - for simplicity, we will look at a single attention head in this section, but the same concept generalizes to multiple heads.

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/820/570/310/327/483/original/9c619019f8f9a286.webp" width="80%" heigh="80%" alt="Causal self-attetion/masked self-attention"></div>

**It ensures that the outputs for a certain position in a sequence is based only on the known outputs at previous positions and not on future positions**. In simper terms, it ensures that the prediction for each next word should only depend on the preceding words. To achieve this in GPT-like LLMs, for each token processed, we mask out the feature tokens, which come after the current token in the input text.

The application of a causal mask to the attention weights for hiding future input tokens in the inputs is illustrated in the figure below. Our input is "Life is short, eat desert first."

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/820/654/573/649/346/original/ee97aa86264ac573.webp" width="80%" heigh="80%" alt="causal mask to the attention weights for hiding future input tokens"></div>

# Self-Attention Mechanism(Scaled dot-product attention)

To illustrate and implement causal self-attention, let's work with the unweighted attention scores and attention weights.


## Tokenization

In [1]:
import torch
import torch.nn as nn

torch.manual_seed(123)

#Tokenization
sentence='Lieft is short, eat dessert first'
sentence_ids={s:i for i,s in enumerate(sorted(sentence.replace(',','').split()))}
sentence_tokens=torch.tensor([sentence_ids[s] for s in sentence.replace(',','').split()])
print(sentence_tokens)

tensor([0, 4, 5, 2, 1, 3])


## Creating Embedding

We get the examples with input ids `[0,4,5,2,1,3]`. Let's create the token embeddings with `vocab_size=`$6$ and the embeddings size is $3$. This would result in a $6*3$ weight matrix.

In [2]:
# Embedding
vocab_size=len(sentence_ids)

embed=torch.nn.Embedding(vocab_size, 3)
print(embed.weight)

Parameter containing:
tensor([[ 0.3374, -0.1778, -0.1690],
        [ 0.9178,  1.5810,  1.3010],
        [ 1.2753, -0.2010, -0.1606],
        [-0.4015,  0.9666, -1.1481],
        [-1.1589,  0.3255, -0.6315],
        [-2.8400, -0.7849, -1.4096]], requires_grad=True)


## Embedding the Inputs

In [3]:
embedded_sentence=embed(sentence_tokens).detach()
print(embedded_sentence)
print(embedded_sentence.shape)

tensor([[ 0.3374, -0.1778, -0.1690],
        [-1.1589,  0.3255, -0.6315],
        [-2.8400, -0.7849, -1.4096],
        [ 1.2753, -0.2010, -0.1606],
        [ 0.9178,  1.5810,  1.3010],
        [-0.4015,  0.9666, -1.1481]])
torch.Size([6, 3])


## Computing Q,K and V Vectors

**Note: the image below illustrate with one input which is show as $x^{(1)}$. However, here we do the calculation for all the inputs tokens, not only for one input token**

Here, we initialize the three weight matrices $[3,2 or 4]$. And the input is $[6,3]$. We projected the input tokens from a 3D onto a 2D/4D embedding space which is means we project the embedded input into query, key and value vectors via matrix multiplication:

* Query vector: $q=W_{q}.x$
* Key vector: $k=W_{k}.x$
* Value vector: $v=W_{v}.x$

![](https://cdn.masto.host/sigmoidsocial/media_attachments/files/111/979/478/418/490/248/original/d160c70d1d85b897.png)

The embedding dimensions of the input $x$ and the query vector `q` can be the same or different, depending on the model's design and specific implementation.

In [4]:
d_in, d_out_kq, d_out_v=3,2,4

W_query=nn.Parameter(torch.rand(d_in, d_out_kq))
W_key=nn.Parameter(torch.rand(d_in, d_out_kq))
W_value=nn.Parameter(torch.rand(d_in, d_out_v))

x=embedded_sentence
keys=x.matmul(W_key)
queries=x.matmul(W_query)
values=x.matmul(W_value)

print(keys)
print(queries)
print(values)

tensor([[-0.0214, -0.1821],
        [-0.6142, -0.2775],
        [-2.1608, -2.1497],
        [ 0.3533,  0.0171],
        [ 1.6910,  2.4233],
        [-0.2527,  0.2558]], grad_fn=<MmBackward0>)
tensor([[ 0.2702, -0.0070],
        [-1.1294, -0.7235],
        [-2.8694, -2.2637],
        [ 1.1285,  0.3962],
        [ 1.1547,  1.6555],
        [-0.4628, -0.4417]], grad_fn=<MmBackward0>)
tensor([[-5.1625e-02,  8.1035e-02,  1.7690e-01, -7.9819e-03],
        [-5.4351e-01, -4.9471e-01, -1.0151e+00, -1.2934e+00],
        [-2.3440e+00, -2.1707e+00, -3.4391e+00, -4.1387e+00],
        [ 3.3215e-01,  6.0992e-01,  1.0510e+00,  7.3947e-01],
        [ 2.0248e+00,  1.5392e+00,  2.0929e+00,  3.0412e+00],
        [-1.4491e-02,  3.1548e-01, -1.4274e-03, -7.1759e-01]],
       grad_fn=<MmBackward0>)


## Computing the Unnomalized Attention Scores

![](https://cdn.masto.host/sigmoidsocial/media_attachments/files/111/979/484/600/775/671/original/9b79607f37fcf96f.png)

In [5]:
# computing the unnomalized attention weights, atten_scores
attn_scores=queries.matmul(keys.T)

print(attn_scores)
print(attn_scores.shape)

tensor([[-4.5107e-03, -1.6400e-01, -5.6877e-01,  9.5348e-02,  4.3990e-01,
         -7.0085e-02],
        [ 1.5593e-01,  8.9439e-01,  3.9957e+00, -4.1141e-01, -3.6631e+00,
          1.0034e-01],
        [ 4.7362e-01,  2.3904e+00,  1.1067e+01, -1.0525e+00, -1.0338e+01,
          1.4611e-01],
        [-9.6320e-02, -8.0302e-01, -3.2902e+00,  4.0549e-01,  2.8685e+00,
         -1.8383e-01],
        [-3.2614e-01, -1.1686e+00, -6.0539e+00,  4.3631e-01,  5.9644e+00,
          1.3164e-01],
        [ 9.0325e-02,  4.0680e-01,  1.9495e+00, -1.7108e-01, -1.8529e+00,
          3.9907e-03]], grad_fn=<MmBackward0>)
torch.Size([6, 6])


The output above is 6x6 tensor containing these pairwide unnomalized attention weights(also called attention scores) for the 6 input tokens.


## Computing the Attention Weights(Normalized Attention Scores)

We can computed the attention weights(**normalized attention scores** that sum up to 1) via the softmax function as follows. Here we scale the `attention scores` by dividing them by the square root of the embedding dimension $\sqrt{doutkq}$.

![](https://cdn.masto.host/sigmoidsocial/media_attachments/files/111/979/489/636/461/465/original/4b73761fe18799f7.png)

In [6]:
torch.manual_seed(123)
attn_weights=torch.softmax(attn_scores/d_out_kq**0.5, dim=-1)
attn_weights

tensor([[1.6775e-01, 1.4985e-01, 1.1256e-01, 1.8002e-01, 2.2968e-01, 1.6014e-01],
        [5.1306e-02, 8.6487e-02, 7.7508e-01, 3.4352e-02, 3.4465e-03, 4.9329e-02],
        [5.5658e-04, 2.1586e-03, 9.9665e-01, 1.8917e-04, 2.6631e-07, 4.4152e-04],
        [8.1872e-02, 4.9672e-02, 8.5569e-03, 1.1674e-01, 6.6619e-01, 7.6960e-02],
        [1.1096e-02, 6.1158e-03, 1.9329e-04, 1.9024e-02, 9.4823e-01, 1.5337e-02],
        [1.2501e-01, 1.5636e-01, 4.6546e-01, 1.0391e-01, 3.1637e-02, 1.1761e-01]],
       grad_fn=<SoftmaxBackward0>)

## Computing the Context Vector for Input

![](https://cdn.masto.host/sigmoidsocial/media_attachments/files/111/979/523/121/939/552/original/591cc4fccdfb656c.png)

In [7]:
context_vec=attn_weights@values
print(context_vec)

tensor([[ 0.1686,  0.2090,  0.1601,  0.0557],
        [-1.8487, -1.6793, -2.7011, -3.3196],
        [-2.3373, -2.1642, -3.4295, -4.1279],
        [ 1.3353,  1.0844,  1.4515,  1.9568],
        [ 1.9218,  1.4734,  1.9996,  2.8780],
        [-1.0856, -0.9284, -1.5621, -2.0410]], grad_fn=<MmBackward0>)


# Implementing a Compact SelfAttention class

We can also use `nn.Linear` over the manual `nn.Parameter(torch.rand())` approach is that `nn.Linear` has preferred weight initialization scheme, whcih leads to more stable model training.

In [8]:
class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out=d_out
        self.W_query=nn.Parameter(torch.rand(d_in, d_out))
        self.W_key=nn.Parameter(torch.rand(d_in, d_out))
        self.W_value=nn.Parameter(torch.rand(d_in, d_out))
        
    def forward(self, x):
        keys=x@self.W_key
        queries=x@self.W_query
        values=x@self.W_value
        
        attn_scores=queries@keys.T # omega
        atten_weights=torch.softmax(attn_scores/d_out_kq**0.5, dim=-1)
        
        context_vec=attn_weights@values
        return context_vec

sa_vl=SelfAttention_v1(d_in, d_out_kq)
print(sa_vl(embedded_sentence))


class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.d_out=d_out
        self.W_query=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value=nn.Linear(d_in, d_out, bias=qkv_bias)
    
    def forward(self, x):
        keys=self.W_key(x)
        queries=self.W_query(x)
        values=self.W_value(x)
        
        attn_scores=queries@keys.T
        attn_weights=torch.softmax(attn_scores/d_out_kq**0.5, dim=-1)
        
        context_vec=attn_weights@values
        return context_vec

sa_v2=SelfAttention_v2(d_in, d_out_kq)
print(sa_v2(embedded_sentence))

tensor([[ 0.1027, -0.0081],
        [-0.4864, -1.6660],
        [-0.6283, -2.0349],
        [ 0.4821,  1.1554],
        [ 0.6878,  1.7792],
        [-0.2683, -1.0762]], grad_fn=<MmBackward0>)
tensor([[-0.1454,  0.0954],
        [-0.2835,  0.0154],
        [-0.4602, -0.1037],
        [-0.0793,  0.1291],
        [-0.1149,  0.1911],
        [-0.2499,  0.0251]], grad_fn=<MmBackward0>)


# Diding future words with causal attention

## Applying a Mask to the Attention Weight Matrix

In GPT-like LLMs, we train the model to read and generate one token (or word) at a time, from left to right. If we have a training text sample like "Life is short eat desert first" we have the following setup, where the context vectors for the word to the right side of the arrow should only incorportate itself and the previous words:
* "Life" -> "is"
* "Life is" -> "short"
* "Life is short" -> "eat"
* "Life is short eat" -> "desert"
* "Life is short eat desert" -> "first"

The simplest way to achieve this setup above is **to mask out all future tokens by applying a mask to the attention weight matrix above the diagonal**, as illustrated in the figure below. This way, "future" words will not be included when creating the context vectors, which are craeted as a attention-weighted sum over the inputs.

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/820/820/858/969/543/original/2e0ec840a282071d.webp" width="80%" heigh="80%" alt="mask to the attention weight matrix above the diagonal"></div>


We can achieve this via PyTorch's [tril](https://pytorch.org/docs/stable/generated/torch.tril.html#) funciton, which we first use to create a mask of 1's and 0's.

In [9]:
block_size=attn_scores.shape[0]
mask_simple=torch.tril(torch.ones(block_size, block_size))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

Next, we multiply the attention weights with this mask to zero out all the attention weights above the diagonal:

In [10]:
torch.manual_seed(123)
masked_simple=attn_weights*mask_simple
masked_simple

tensor([[1.6775e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [5.1306e-02, 8.6487e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [5.5658e-04, 2.1586e-03, 9.9665e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [8.1872e-02, 4.9672e-02, 8.5569e-03, 1.1674e-01, 0.0000e+00, 0.0000e+00],
        [1.1096e-02, 6.1158e-03, 1.9329e-04, 1.9024e-02, 9.4823e-01, 0.0000e+00],
        [1.2501e-01, 1.5636e-01, 4.6546e-01, 1.0391e-01, 3.1637e-02, 1.1761e-01]],
       grad_fn=<MulBackward0>)

## Normalize the attention weights

While the above is one wat to masked out future words, notice that the attention weights in each row don't sum to one anymore. To mitigate that, we can normalize the rows such that they sum up to 1 again, which is a standard convention for attention weights:

In [11]:
row_sums=masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm=masked_simple/row_sums
masked_simple_norm

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.7234e-01, 6.2766e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [5.5693e-04, 2.1600e-03, 9.9728e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.1876e-01, 1.9339e-01, 3.3315e-02, 4.5453e-01, 0.0000e+00, 0.0000e+00],
        [1.1268e-02, 6.2110e-03, 1.9630e-04, 1.9320e-02, 9.6300e-01, 0.0000e+00],
        [1.2501e-01, 1.5636e-01, 4.6546e-01, 1.0391e-01, 3.1637e-02, 1.1761e-01]],
       grad_fn=<DivBackward0>)

As we can see, the attention weights in each row now sum up to 1.

Normalizing attention weights in neural networks, such as intransformer models, is advantageous over unnormalized weights for two main reasons. 

First, normalized attention weights that sum to 1 resemble a probability distribution. This makes it easier to interpret the model's attention to various parts of the input in terms of proportions.

Second, by constraining the attention weights to sum to 1, this normalization helps control the scale of the weights and gradients to improve the training dynamics.


## More efficient masking without renormalization

In the causal self-attention procesure we codede above, we first compute the attention scores, then compute the attention weights, mask out attention weights above the diagonal, and lastly renormalize the attention weights. This is summarized in the figure below:

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/820/915/897/930/128/original/d318a712263ef680.webp" width="80%" heigh="80%" alt="implementing causal self-attention procedure"></div>


Alternatively, there is a more efficient way to achieve the same results. In this approach, we take the attention scores and replace the values above the diagonal with negative infinity before the value are input into the softmax function to compute the attention weights. This is summatized in the figure below:

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/820/964/024/803/792/original/ca94af87a2bfe55b.webp" width="80%" heigh="80%" alt="more efficient approach to implementing causal self-attention"></div>

In [12]:
mask=torch.triu(torch.ones(block_size, block_size), diagonal=1)
masked=attn_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[-4.5107e-03,        -inf,        -inf,        -inf,        -inf,
                -inf],
        [ 1.5593e-01,  8.9439e-01,        -inf,        -inf,        -inf,
                -inf],
        [ 4.7362e-01,  2.3904e+00,  1.1067e+01,        -inf,        -inf,
                -inf],
        [-9.6320e-02, -8.0302e-01, -3.2902e+00,  4.0549e-01,        -inf,
                -inf],
        [-3.2614e-01, -1.1686e+00, -6.0539e+00,  4.3631e-01,  5.9644e+00,
                -inf],
        [ 9.0325e-02,  4.0680e-01,  1.9495e+00, -1.7108e-01, -1.8529e+00,
          3.9907e-03]], grad_fn=<MaskedFillBackward0>)

Then, all we have to do is to apply the softmax function as usual to obtain the normalized and masked attention weights.

The softmax function, applied in the last step, converts the input values into a probability distribution. When $-inf$ is present in the inputs, softmax effectively treats them as zero probability. This is because $e^{inf}$ approaches $0$, and thus these positions contribute nothing to the output probabilites.

In [13]:
attn_weights=torch.softmax(masked/d_out_kq**0.5, dim=1)
attn_weights

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.7234e-01, 6.2766e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [5.5693e-04, 2.1600e-03, 9.9728e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.1876e-01, 1.9339e-01, 3.3315e-02, 4.5453e-01, 0.0000e+00, 0.0000e+00],
        [1.1268e-02, 6.2110e-03, 1.9630e-04, 1.9320e-02, 9.6300e-01, 0.0000e+00],
        [1.2501e-01, 1.5636e-01, 4.6546e-01, 1.0391e-01, 3.1637e-02, 1.1761e-01]],
       grad_fn=<SoftmaxBackward0>)

## Masking Additional Attention Weights With Dropout

We also apply dropout to reduce overfitting during training. Dropout can be applied in several places:
* After computing the attention weights
* After multiplying the attention weights with the value vectors

Here we will apply the **dropout mask after computing the attention weights** because it's more common. Furthermore, in this specific example, we used a dropout rate of 50%, which means randomly masking out half of the attention weights.

![](https://cdn.masto.host/sigmoidsocial/media_attachments/files/111/980/465/043/909/999/original/579cbc84220685e2.png)

If we apply a dropout rate of 0.5(50%), the non-dropped values will be scaled accordingly by a factor of $1/0.5=2$

In [14]:
torch.manual_seed(123)
dropout=torch.nn.Dropout(0.5)
example=torch.ones(6,6)
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


In [15]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 1.2553, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 1.9946, 0.0000, 0.0000, 0.0000],
        [0.6375, 0.3868, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0225, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3127, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


## Implementing a Compact Causal Self-Attention Class

Let's implement self-attention, including the causal and dropout masks. One more thing is to implement the code to handle batches consisting of more than one input so taht our `CausualAttention` class supports the batch outputs. For simplicy, we duplicate `embedded_sentence` which is defined on the top.

In [16]:
batch=torch.stack((embedded_sentence,embedded_sentence), dim=0)
print(batch.shape)

torch.Size([2, 6, 3])


In [17]:
class CausalAttention(nn.Module):
    
    def __init__(self, d_in, d_out, block_size, dropout, qkv_bias=False):
        super().__init__()
        self.d_out=d_out
        self.W_query=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout=nn.Dropout(dropout) # new
        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1)) # new
    
    def forward(self, x):
        b, num_tokens, d_in=x.shape # new batch dimension b
        keys=self.W_key(x)
        queries=self.W_query(x)
        values=self.W_value(x)
        
        attn_scores=queries@keys.transpose(1,2)
        attn_scores.masked_fill(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)
        attn_weights=torch.softmax(attn_scores/d_out_kq**0.5, dim=1)
        attn_weights=self.dropout(attn_weights) # new
        
        context_vec=attn_weights @ values
        return context_vec


torch.manual_seed(123)

block_size=batch.shape[1]
ca=CausalAttention(d_in, d_out_kq, block_size, 0.0)

context_vecs=ca(batch)

print(context_vecs)
print('context_vecs.shape:', context_vecs.shape)

tensor([[[-0.1567, -0.2143],
         [ 0.2843, -0.2274],
         [ 0.9825, -0.2057],
         [-0.2977, -0.2224],
         [-0.4175, -0.2595],
         [ 0.3927, -0.2499]],

        [[-0.1567, -0.2143],
         [ 0.2843, -0.2274],
         [ 0.9825, -0.2057],
         [-0.2977, -0.2224],
         [-0.4175, -0.2595],
         [ 0.3927, -0.2499]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


# Extending Single-head Attention to Multi-head Attention

Below is a summary of the self-attention implemented previously (causal and dropout masks not shown for simplicity), this is also called single-head attention:

![](https://cdn.masto.host/sigmoidsocial/media_attachments/files/111/980/619/336/586/573/original/31d2f35596068266.png)


We simply stack multiple single-head attention modules to obtain a multi-head attention module. The main idea behind multi-head attention is to run the attention mechanism multiple times(om parallel) with different, learned lienar projections. This allows the model to jointly attend to information form different representation subspaces at different positions.

![](https://cdn.masto.host/sigmoidsocial/media_attachments/files/111/980/625/578/877/206/original/127f6ca33fd35d7c.png)

In [18]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads=nn.ModuleList(
            [CausalAttention(d_in, d_out, block_size, dropout, qkv_bias) for _ in range(num_heads)]
        )
    
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)


torch.manual_seed(123)

block_size=batch.shape[1] # this is thew number of tokens
mha=MultiHeadAttentionWrapper(d_in, d_out_kq, block_size, 0.0, num_heads=2)

context_vecs=mha(batch)

print(context_vecs)
print('context_vecs.shape:', context_vecs.shape)

tensor([[[-0.1567, -0.2143, -0.1012,  0.1098],
         [ 0.2843, -0.2274, -0.2223,  0.0359],
         [ 0.9825, -0.2057, -0.3437, -0.0636],
         [-0.2977, -0.2224, -0.0402,  0.1407],
         [-0.4175, -0.2595, -0.1181,  0.2790],
         [ 0.3927, -0.2499, -0.1829,  0.0432]],

        [[-0.1567, -0.2143, -0.1012,  0.1098],
         [ 0.2843, -0.2274, -0.2223,  0.0359],
         [ 0.9825, -0.2057, -0.3437, -0.0636],
         [-0.2977, -0.2224, -0.0402,  0.1407],
         [-0.4175, -0.2595, -0.1181,  0.2790],
         [ 0.3927, -0.2499, -0.1829,  0.0432]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


In the implementation above, the embedding dimension is 4, because we $d_out_kq=2$ as the embedding dimension for the key, query, andd value vectors as well as the context vector. And since we have 2 attention heads, we have the output embedding dimension $2*2=4$.

If we want to have an output dimension of 2, as earlier in single-head attention, we can have to change the projection dimension `d_out` to 1:

In [19]:
torch.manual_seed(123)

d_out=1
mha=MultiHeadAttentionWrapper(d_in, d_out, block_size, 0.0, num_heads=2)

context_vecs=mha(batch)

print(context_vecs)
print('context_vecs.shape:', context_vecs.shape)

tensor([[[ 0.1757, -0.2133],
         [ 0.0478, -0.2435],
         [-0.0677, -0.2676],
         [ 0.2460, -0.1999],
         [ 0.3556, -0.2085],
         [ 0.0518, -0.2464]],

        [[ 0.1757, -0.2133],
         [ 0.0478, -0.2435],
         [-0.0677, -0.2676],
         [ 0.2460, -0.1999],
         [ 0.3556, -0.2085],
         [ 0.0518, -0.2464]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


## Implementing Multi-head Attention with Weight Splits

We can write a stand-alone class called `MultiHeadAttention` to achieve the same. We don't concatenate single attention heads for this stand-alone `MultiHeadAttention` class. Instead, we create single W_query, W_key, and W_value weight matrices and then split those into individual matrices for each attention head:

In [20]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, block_size, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert d_out % num_heads==0, "d_out must be divisibel by n_heads"
        
        self.d_out=d_out
        self.num_heads=num_heads
        self.head_dim=d_out//num_heads # reduce the projection dim to match desired output dim
        
        self.W_query=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value=nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj=nn.Linear(d_out, d_out) # linear layer to combine head outputs
        self.dropout=nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(block_size, block_size), diagonal=1))
    
    def forward(self, x):
        b, num_tokens, d_in=x.shape
        
        keys=self.W_key(x) # shape:(b, num_tokens, d_out)
        queries=self.W_query(x)
        values=self.W_value(x)
        
        # we implicitly split the matrix by adding a `num_heads` dimension
        # Untoll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys=keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values=values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries=queries.view(b, num_tokens, self.num_heads, self.head_dim)
        
        # transpose" (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys=keys.transpose(1,2)
        queries=queries.transpose(1,2)
        values=values.transpose(1,2)
        
        # compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores=queries@keys.transpose(2,3) # dot product for each head
        #original mask truncated to the number of tokens and converted to boolean
        mask_bool=self.mask.bool()[:num_tokens, :num_tokens]
        
        # unsqueeze the mask twice to match dimensions
        mask_unsqueezed=mask_bool.unsqueeze(0).unsqueeze(0)
        
        # use the unsqueezed mask to fill attention scores
        attn_scores.masked_fill_(mask_unsqueezed, -torch.inf)
        
        attn_weights=torch.softmax(attn_scores/d_out_kq**0.5, dim=-1)
        attn_weights=self.dropout(attn_weights)
        
        # shape: (b, num_tokens, num_heads, head_dim)
        context_vec=(attn_weights@values).transpose(1,2)
        
        # combine heads, where self.d_out=self.num_leads * self.head_dim
        context_vec=context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec=self.out_proj(context_vec) # optional projection
        
        return context_vec
    
torch.manual_seed(123)

batch_size, block_size, d_in=batch.shape
d_out=2
mha=MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads=2)

context_vecs=mha(batch)

print(context_vecs)
print('context_vec.shape:', context_vecs.shape)

tensor([[[ 0.1992,  0.6458],
         [ 0.1101,  0.8075],
         [-0.0464,  1.3741],
         [ 0.1206,  0.8446],
         [ 0.2225,  0.5039],
         [ 0.0595,  0.8968]],

        [[ 0.1992,  0.6458],
         [ 0.1101,  0.8075],
         [-0.0464,  1.3741],
         [ 0.1206,  0.8446],
         [ 0.2225,  0.5039],
         [ 0.0595,  0.8968]]], grad_fn=<ViewBackward0>)
context_vec.shape: torch.Size([2, 6, 2])


We added a linear projection layer `self.out_proj` to the `MultiHeadAttention` class above. This is simply a linear transformation that doesn't change the dimensions. It's a standard convention to use such a projection layer in LLM implementation, but it's not strictly necessary(recent research has shown that it can be removed without afecting the modeling performance).

Since the above implementation may look a bit complex at first glance, let's look at what happens when executing `attn_scores=queries@keys.transpose(2,3)`:

In [21]:
# (b, num_heads, num_tokens, head_dim) = (1, 2, 3, 4)
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],

                   [[0.0772, 0.3565, 0.1479, 0.5331],
                    [0.4066, 0.2318, 0.4545, 0.9737],
                    [0.4606, 0.5159, 0.4220, 0.5786]]]])

print(a@a.transpose(2,3))

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])


In this case, the matrix multiplication implementation in PyTorch will handle the 4-dimensional input tensor so that the matrix multiplication is carried out between the 2 last dimensions(num_tokens, head_dim) and then repeated for the individual heads. Let's compute the matrix multiplication for each head separately:

In [22]:
first_head=a[0,0,:,:]
first_res=first_head@first_head.T
print('First head:\n', first_res)

second_head=a[0,1,:,:]
second_res=second_head@second_head.T
print('\nSecond head:\n', second_res)

block_size=1024
d_in, d_out=768, 768
num_heads=12

mha=MultiHeadAttention(d_in, d_out, block_size, 0.0, num_heads)

def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)

count_parameters(mha)

First head:
 tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])

Second head:
 tensor([[0.4391, 0.7003, 0.5903],
        [0.7003, 1.3737, 1.0620],
        [0.5903, 1.0620, 0.9912]])


2360064

**Note: For the real word case, please consider optimize the self-attention by using Flash Attention, it helps to reducing memory footprint and computational load.**


# Credit

* https://magazine.sebastianraschka.com?utm_source=navbar&utm_medium=web&r=fbe14
* https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb