<a href="https://www.kaggle.com/code/aisuko/causal-self-attention?scriptVersionId=163963344" target="_blank"><img align="left" alt="Kaggle" title="Open in Kaggle" src="https://kaggle.com/static/images/open-in-kaggle.svg"></a>

# Overview

In this notebook, we are adapting the previously discussed self-attention mechanism into a causal self-attention mechanism, specifically for GPT-like(decoder-style) LLMs that are used to generate text. The previously notebooks are below:
* [Encoder in Transformers Architecture](https://www.kaggle.com/code/aisuko/encoder-in-transformers-architecture)
* [Decoder in Transformers architecture](https://www.kaggle.com/code/aisuko/decoder-in-transformers-architecture)
* [Coding the Self-Attention Mechanism](https://www.kaggle.com/code/aisuko/coding-the-self-attention-mechanism)
* [Coding the Multi-Head Attention](https://www.kaggle.com/code/aisuko/coding-the-multi-head-attention)
* [Coding Cross-Attention](https://www.kaggle.com/code/aisuko/coding-cross-attention)


And the causal self-attention mechanism is also often referred to as **masked self-attention**. In the original transformer architecture, it corresponds to the "masked multi-head attention" module - for simplicity, we will look at a single attention head in this section, but the same concept generalizes to multiple heads.

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/820/570/310/327/483/original/9c619019f8f9a286.webp" width="80%" heigh="80%" alt="Causal self-attetion/masked self-attention"></div>

**It ensures that the outputs for a certain position in a sequence is based only on the known outputs at previous positions and not on future positions**. In simper terms, it ensures that the prediction for each next word should only depend on the preceding words. To achieve this in GPT-like LLMs, for each token processed, we mask out the feature tokens, which come after the current token in the input text.

The application of a causal mask to the attention weights for hiding future input tokens in the inputs is illustrated in the figure below. Our input is "Life is short, eat desert first."

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/820/654/573/649/346/original/ee97aa86264ac573.webp" width="80%" heigh="80%" alt="causal mask to the attention weights for hiding future input tokens"></div>

# Self-Attention Mechanism(Scaled dot-product attention)

To illustrate and implement causal self-attention, let's work with the unweighted attention scores and attention weights.


## Tokenization

In [1]:
import torch
import torch.nn as nn

torch.manual_seed(123)

#Tokenization
sentence='Lieft is short, eat dessert first'
sentence_ids={s:i for i,s in enumerate(sorted(sentence.replace(',','').split()))}
sentence_tokens=torch.tensor([sentence_ids[s] for s in sentence.replace(',','').split()])
print(sentence_tokens)

tensor([0, 4, 5, 2, 1, 3])


## Creating Embedding

We get the examples with input ids `[0,4,5,2,1,3]`. Let's create the token embeddings with `vocab_size=`$6$ and the embeddings size is $3$. This would result in a $6*3$ weight matrix.

In [2]:
# Embedding
vocab_size=len(sentence_ids)

embed=torch.nn.Embedding(vocab_size, 3)
print(embed.weight)

Parameter containing:
tensor([[ 0.3374, -0.1778, -0.1690],
        [ 0.9178,  1.5810,  1.3010],
        [ 1.2753, -0.2010, -0.1606],
        [-0.4015,  0.9666, -1.1481],
        [-1.1589,  0.3255, -0.6315],
        [-2.8400, -0.7849, -1.4096]], requires_grad=True)


## Embedding the Inputs

In [3]:
embedded_sentence=embed(sentence_tokens).detach()
print(embedded_sentence)
print(embedded_sentence.shape)

tensor([[ 0.3374, -0.1778, -0.1690],
        [-1.1589,  0.3255, -0.6315],
        [-2.8400, -0.7849, -1.4096],
        [ 1.2753, -0.2010, -0.1606],
        [ 0.9178,  1.5810,  1.3010],
        [-0.4015,  0.9666, -1.1481]])
torch.Size([6, 3])


## Computing Q,K and V Vectors

**Note: the image below illustrate with one input which is show as $x^{(1)}$. However, here we do the calculation for all the inputs tokens, not only for one input token**

Here, we initialize the three weight matrices $[3,2 or 4]$. And the input is $[6,3]$. We projected the input tokens from a 3D onto a 2D/4D embedding space which is means we project the embedded input into query, key and value vectors via matrix multiplication:

* Query vector: $q=W_{q}.x$
* Key vector: $k=W_{k}.x$
* Value vector: $v=W_{v}.x$

![](https://cdn.masto.host/sigmoidsocial/media_attachments/files/111/979/478/418/490/248/original/d160c70d1d85b897.png)

The embedding dimensions of the input $x$ and the query vector `q` can be the same or different, depending on the model's design and specific implementation.

In [4]:
d_in, d_out_kq, d_out_v=3,2,4

W_query=nn.Parameter(torch.rand(d_in, d_out_kq))
W_key=nn.Parameter(torch.rand(d_in, d_out_kq))
W_value=nn.Parameter(torch.rand(d_in, d_out_v))

x=embedded_sentence
keys=x.matmul(W_key)
queries=x.matmul(W_query)
values=x.matmul(W_value)

print(queries)
print(values)

tensor([[ 0.2702, -0.0070],
        [-1.1294, -0.7235],
        [-2.8694, -2.2637],
        [ 1.1285,  0.3962],
        [ 1.1547,  1.6555],
        [-0.4628, -0.4417]], grad_fn=<MmBackward0>)
tensor([[-5.1625e-02,  8.1035e-02,  1.7690e-01, -7.9819e-03],
        [-5.4351e-01, -4.9471e-01, -1.0151e+00, -1.2934e+00],
        [-2.3440e+00, -2.1707e+00, -3.4391e+00, -4.1387e+00],
        [ 3.3215e-01,  6.0992e-01,  1.0510e+00,  7.3947e-01],
        [ 2.0248e+00,  1.5392e+00,  2.0929e+00,  3.0412e+00],
        [-1.4491e-02,  3.1548e-01, -1.4274e-03, -7.1759e-01]],
       grad_fn=<MmBackward0>)


## Computing the Unnomalized Attention Scores

![](https://cdn.masto.host/sigmoidsocial/media_attachments/files/111/979/484/600/775/671/original/9b79607f37fcf96f.png)

In [5]:
# computing the unnomalized attention weights, atten_scores
attn_scores=queries.matmul(keys.T)

print(attn_scores)
print(attn_scores.shape)

tensor([[-4.5107e-03, -1.6400e-01, -5.6877e-01,  9.5348e-02,  4.3990e-01,
         -7.0085e-02],
        [ 1.5593e-01,  8.9439e-01,  3.9957e+00, -4.1141e-01, -3.6631e+00,
          1.0034e-01],
        [ 4.7362e-01,  2.3904e+00,  1.1067e+01, -1.0525e+00, -1.0338e+01,
          1.4611e-01],
        [-9.6320e-02, -8.0302e-01, -3.2902e+00,  4.0549e-01,  2.8685e+00,
         -1.8383e-01],
        [-3.2614e-01, -1.1686e+00, -6.0539e+00,  4.3631e-01,  5.9644e+00,
          1.3164e-01],
        [ 9.0325e-02,  4.0680e-01,  1.9495e+00, -1.7108e-01, -1.8529e+00,
          3.9907e-03]], grad_fn=<MmBackward0>)
torch.Size([6, 6])


The output above is 6x6 tensor containing these pairwide unnomalized attention weights(also called attention scores) for the 6 input tokens.


## Computing the Attention Weights(Normalized Attention Scores)

We can computed the attention weights(**normalized attention scores** that sum up to 1) via the softmax function as follows. Here we scale the `attention scores` by dividing them by the square root of the embedding dimension $\sqrt{doutkq}$.

![](https://cdn.masto.host/sigmoidsocial/media_attachments/files/111/979/489/636/461/465/original/4b73761fe18799f7.png)

In [6]:
attn_weights=torch.softmax(attn_scores/d_out_kq**0.5, dim=-1)
attn_weights

tensor([[1.6775e-01, 1.4985e-01, 1.1256e-01, 1.8002e-01, 2.2968e-01, 1.6014e-01],
        [5.1306e-02, 8.6487e-02, 7.7508e-01, 3.4352e-02, 3.4465e-03, 4.9329e-02],
        [5.5658e-04, 2.1586e-03, 9.9665e-01, 1.8917e-04, 2.6631e-07, 4.4152e-04],
        [8.1872e-02, 4.9672e-02, 8.5569e-03, 1.1674e-01, 6.6619e-01, 7.6960e-02],
        [1.1096e-02, 6.1158e-03, 1.9329e-04, 1.9024e-02, 9.4823e-01, 1.5337e-02],
        [1.2501e-01, 1.5636e-01, 4.6546e-01, 1.0391e-01, 3.1637e-02, 1.1761e-01]],
       grad_fn=<SoftmaxBackward0>)

## Computing the Context Vector for Input

![](https://cdn.masto.host/sigmoidsocial/media_attachments/files/111/979/523/121/939/552/original/591cc4fccdfb656c.png)

In [7]:
context_vec=attn_weights@values
print(context_vec)

tensor([[ 0.1686,  0.2090,  0.1601,  0.0557],
        [-1.8487, -1.6793, -2.7011, -3.3196],
        [-2.3373, -2.1642, -3.4295, -4.1279],
        [ 1.3353,  1.0844,  1.4515,  1.9568],
        [ 1.9218,  1.4734,  1.9996,  2.8780],
        [-1.0856, -0.9284, -1.5621, -2.0410]], grad_fn=<MmBackward0>)


# Applying a Mask to the Attention Weight Matrix

In GPT-like LLMs, we train the model to read and generate one token (or word) at a time, from left to right. If we have a training text sample like "Life is short eat desert first" we have the following setup, where the context vectors for the word to the right side of the arrow should only incorportate itself and the previous words:
* "Life" -> "is"
* "Life is" -> "short"
* "Life is short" -> "eat"
* "Life is short eat" -> "desert"
* "Life is short eat desert" -> "first"

The simplest way to achieve this setup above is **to mask out all future tokens by applying a mask to the attention weight matrix above the diagonal**, as illustrated in the figure below. This way, "future" words will not be included when creating the context vectors, which are craeted as a attention-weighted sum over the inputs.

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/820/820/858/969/543/original/2e0ec840a282071d.webp" width="80%" heigh="80%" alt="mask to the attention weight matrix above the diagonal"></div>


We can achieve this via PyTorch's [tril](https://pytorch.org/docs/stable/generated/torch.tril.html#) funciton, which we first use to create a mask of 1's and 0's.

In [8]:
block_size=attn_scores.shape[0]
mask_simple=torch.tril(torch.ones(block_size, block_size))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

Next, we multiply the attention weights with this mask to zero out all the attention weights above the diagonal:

In [9]:
masked_simple=attn_weights*mask_simple
masked_simple

tensor([[1.6775e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [5.1306e-02, 8.6487e-02, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [5.5658e-04, 2.1586e-03, 9.9665e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [8.1872e-02, 4.9672e-02, 8.5569e-03, 1.1674e-01, 0.0000e+00, 0.0000e+00],
        [1.1096e-02, 6.1158e-03, 1.9329e-04, 1.9024e-02, 9.4823e-01, 0.0000e+00],
        [1.2501e-01, 1.5636e-01, 4.6546e-01, 1.0391e-01, 3.1637e-02, 1.1761e-01]],
       grad_fn=<MulBackward0>)

## Normalize the attention weights

While the above is one wat to masked out future words, notice that the attention weights in each row don't sum to one anymore. To mitigate that, we can normalize the rows such that they sum up to 1 again, which is a standard convention for attention weights:

In [10]:
row_sums=masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm=masked_simple/row_sums
masked_simple_norm

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.7234e-01, 6.2766e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [5.5693e-04, 2.1600e-03, 9.9728e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.1876e-01, 1.9339e-01, 3.3315e-02, 4.5453e-01, 0.0000e+00, 0.0000e+00],
        [1.1268e-02, 6.2110e-03, 1.9630e-04, 1.9320e-02, 9.6300e-01, 0.0000e+00],
        [1.2501e-01, 1.5636e-01, 4.6546e-01, 1.0391e-01, 3.1637e-02, 1.1761e-01]],
       grad_fn=<DivBackward0>)

As we can see, the attention weights in each row now sum up to 1.

Normalizing attention weights in neural networks, such as intransformer models, is advantageous over unnormalized weights for two main reasons. 

First, normalized attention weights that sum to 1 resemble a probability distribution. This makes it easier to interpret the model's attention to various parts of the input in terms of proportions.

Second, by constraining the attention weights to sum to 1, this normalization helps control the scale of the weights and gradients to improve the training dynamics.

**More efficient masking without renormalization**

In the causal self-attention procesure we codede above, we first compute the attention scores, then compute the attention weights, mask out attention weights above the diagonal, and lastly renormalize the attention weights. This is summarized in the figure below:

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/820/915/897/930/128/original/d318a712263ef680.webp" width="80%" heigh="80%" alt="implementing causal self-attention procedure"></div>


Alternatively, there is a more efficient way to achieve the same results. In this approach, we take the attention scores and replace the values above the diagonal with negative infinity before the value are input into the softmax function to compute the attention weights. This is summatized in the figure below:

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/820/964/024/803/792/original/ca94af87a2bfe55b.webp" width="80%" heigh="80%" alt="more efficient approach to implementing causal self-attention"></div>

In [11]:
mask=torch.triu(torch.ones(block_size, block_size), diagonal=1)
masked=attn_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[-4.5107e-03,        -inf,        -inf,        -inf,        -inf,
                -inf],
        [ 1.5593e-01,  8.9439e-01,        -inf,        -inf,        -inf,
                -inf],
        [ 4.7362e-01,  2.3904e+00,  1.1067e+01,        -inf,        -inf,
                -inf],
        [-9.6320e-02, -8.0302e-01, -3.2902e+00,  4.0549e-01,        -inf,
                -inf],
        [-3.2614e-01, -1.1686e+00, -6.0539e+00,  4.3631e-01,  5.9644e+00,
                -inf],
        [ 9.0325e-02,  4.0680e-01,  1.9495e+00, -1.7108e-01, -1.8529e+00,
          3.9907e-03]], grad_fn=<MaskedFillBackward0>)

Then, all we have to do is to apply the softmax function as usual to obtain the normalized and masked attention weights.

In [12]:
attn_weights=torch.softmax(masked/d_out_kq**0.5, dim=-1)
attn_weights

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.7234e-01, 6.2766e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [5.5693e-04, 2.1600e-03, 9.9728e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [3.1876e-01, 1.9339e-01, 3.3315e-02, 4.5453e-01, 0.0000e+00, 0.0000e+00],
        [1.1268e-02, 6.2110e-03, 1.9630e-04, 1.9320e-02, 9.6300e-01, 0.0000e+00],
        [1.2501e-01, 1.5636e-01, 4.6546e-01, 1.0391e-01, 3.1637e-02, 1.1761e-01]],
       grad_fn=<SoftmaxBackward0>)

Why does this work? The softmax function, applied in the last step, converts the input values into a probability distribution. When -inf is present in the inputs, softmax effectively treats them as zero probability. This is because e^(inf) approaches 0, and thus these positions contribute nothing to the output probabilites.

**Note: For the real word case, please consider optimize the self-attention by using Flash Attention, it helps to reducing memory footprint and computational load.**


# Credit

* https://magazine.sebastianraschka.com?utm_source=navbar&utm_medium=web&r=fbe14
* https://github.com/rasbt/LLMs-from-scratch/blob/main/ch03/01_main-chapter-code/ch03.ipynb