# Overview

In this notebook, we are adapting the previously discussed self-attention mechanism into a causal self-attention mechanism, specifically for GPT-like(decoder-style) LLMs that are used to generate text. The previously notebooks are below:
* [Encoder in Transformers Architecture](https://www.kaggle.com/code/aisuko/encoder-in-transformers-architecture)
* [Decoder in Transformers architecture](https://www.kaggle.com/code/aisuko/decoder-in-transformers-architecture)
* [Coding the Self-Attention Mechanism](https://www.kaggle.com/code/aisuko/coding-the-self-attention-mechanism)
* [Coding the Multi-Head Attention](https://www.kaggle.com/code/aisuko/coding-the-multi-head-attention)
* [Coding Cross-Attention](https://www.kaggle.com/code/aisuko/coding-cross-attention)


And the causal self-attention mechanism is also often referred to as **masked self-attention**. In the original transformer architecture, it corresponds to the "masked multi-head attention" module - for simplicity, we will look at a single attention head in this section, but the same concept generalizes to multiple heads.

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/820/570/310/327/483/original/9c619019f8f9a286.webp" width="80%" heigh="80%" alt="Causal self-attetion/masked self-attention"></div>

**It ensures that the outputs for a certain position in a sequence is based only on the known outputs at previous positions and not on future positions**. In simper terms, it ensures that the prediction for each next word should only depend on the preceding words. To achieve this in GPT-like LLMs, for each token processed, we mask out the feature tokens, which come after the current token in the input text.

The application of a causal mask to the attention weights for hiding future input tokens in the inputs is illustrated in the figure below. Our input is "Life is short, eat desert first."

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/820/654/573/649/346/original/ee97aa86264ac573.webp" width="80%" heigh="80%" alt="causal mask to the attention weights for hiding future input tokens"></div>

To illustrate and implement causal self-attention, let's work with the unweighted attention scores and attention weights from the previous section. First, we quickly recap the computation of the attention scores from the previous Self-Attention section:

In [1]:
import torch
import torch.nn as nn

torch.manual_seed(123)

#Tokenization
sentence='Lieft is short, eat dessert first'
sentence_ids={s:i for i,s in enumerate(sorted(sentence.replace(',','').split()))}
sentence_tokens=torch.tensor([sentence_ids[s] for s in sentence.replace(',','').split()])
print(sentence_tokens)


# Embedding
vocab_size=len(sentence_ids)
embed=torch.nn.Embedding(vocab_size, 3)
embedded_sentence=embed(sentence_tokens).detach()
print(embedded_sentence)
print(embedded_sentence.shape)

torch.manual_seed(123)
d_in, d_out_kq, d_out_v=3,2,4

W_query=nn.Parameter(torch.rand(d_in, d_out_kq))
W_key=nn.Parameter(torch.rand(d_in, d_out_kq))
W_value=nn.Parameter(torch.rand(d_in, d_out_v))


x=embedded_sentence

keys=x.matmul(W_key)
queries=x.matmul(W_query)
values=x.matmul(W_value)

# computing the unnomalized attention weights, atten_scores
attn_scores=queries.matmul(keys.T)

print(attn_scores)
print(attn_scores.shape)

tensor([0, 4, 5, 2, 1, 3])
tensor([[ 0.3374, -0.1778, -0.1690],
        [-1.1589,  0.3255, -0.6315],
        [-2.8400, -0.7849, -1.4096],
        [ 1.2753, -0.2010, -0.1606],
        [ 0.9178,  1.5810,  1.3010],
        [-0.4015,  0.9666, -1.1481]])
torch.Size([6, 3])
tensor([[ 1.8214e-02,  1.7208e-02,  1.3132e-01,  1.5569e-02, -1.6673e-01,
          1.8989e-03],
        [ 2.0647e-01,  3.8309e-01,  1.9873e+00,  8.9122e-02, -2.2238e+00,
          1.9144e-01],
        [ 7.2581e-01,  1.3610e+00,  7.0241e+00,  3.0666e-01, -7.8430e+00,
          6.8589e-01],
        [-9.2898e-02, -2.1432e-01, -1.0055e+00, -2.0607e-02,  1.0751e+00,
         -1.2405e-01],
        [-5.9726e-01, -1.0785e+00, -5.6701e+00, -2.7160e-01,  6.3803e+00,
         -5.2696e-01],
        [ 1.1150e-01,  1.5801e-01,  9.4357e-01,  7.0838e-02, -1.1142e+00,
          5.9216e-02]], grad_fn=<MmBackward0>)
torch.Size([6, 6])


The output above is 6x6 tensor containing these pairwide unnomalized attention weights(also called attention scores) for the 6 input tokens.


# Computing the attention weights

We can computed the scaled dot-product attention via the softmax function as follows:

In [2]:
attn_weights=torch.softmax(attn_scores/d_out_kq**0.5, dim=1)
attn_weights

tensor([[1.6816e-01, 1.6804e-01, 1.8216e-01, 1.6785e-01, 1.4755e-01, 1.6623e-01],
        [1.2912e-01, 1.4629e-01, 4.5485e-01, 1.1884e-01, 2.3156e-02, 1.2775e-01],
        [1.1084e-02, 1.7369e-02, 9.5250e-01, 8.2413e-03, 2.5901e-05, 1.0776e-02],
        [1.4800e-01, 1.3582e-01, 7.7628e-02, 1.5576e-01, 3.3802e-01, 1.4477e-01],
        [6.9948e-03, 4.9772e-03, 1.9362e-04, 8.8060e-03, 9.7168e-01, 7.3512e-03],
        [1.6155e-01, 1.6695e-01, 2.9095e-01, 1.5697e-01, 6.7906e-02, 1.5568e-01]],
       grad_fn=<SoftmaxBackward0>)

# Applying a mask to the attention weight matrix

In GPT-like LLMs, we train the model to read and generate one token (or word) at a time, from left to right. If we have a training text sample like "Life is short eat desert first" we have the following setup, where the context vectors for the word to the right side of the arrow should only incorportate itself and the previous words:
* "Life" -> "is"
* "Life is" -> "short"
* "Life is short" -> "eat"
* "Life is short eat" -> "desert"
* "Life is short eat desert" -> "first"

The simplest way to achieve this setup above is **to mask out all future tokens by applying a mask to the attention weight matrix above the diagonal**, as illustrated in the figure below. This way, "future" words will not be included when creating the context vectors, which are craeted as a attention-weighted sum over the inputs.

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/820/820/858/969/543/original/2e0ec840a282071d.webp" width="80%" heigh="80%" alt="mask to the attention weight matrix above the diagonal"></div>


We can achieve this via PyTorch's [tril](https://pytorch.org/docs/stable/generated/torch.tril.html#) funciton, which we first use to create a mask of 1's and 0's.

In [3]:
block_size=attn_scores.shape[0]
mask_simple=torch.tril(torch.ones(block_size, block_size))
mask_simple

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])

Next, we multiply the attention weights with this mask to zero out all the attention weights above the diagonal:

In [4]:
masked_simple=attn_weights*mask_simple
masked_simple

tensor([[1.6816e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.2912e-01, 1.4629e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.1084e-02, 1.7369e-02, 9.5250e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.4800e-01, 1.3582e-01, 7.7628e-02, 1.5576e-01, 0.0000e+00, 0.0000e+00],
        [6.9948e-03, 4.9772e-03, 1.9362e-04, 8.8060e-03, 9.7168e-01, 0.0000e+00],
        [1.6155e-01, 1.6695e-01, 2.9095e-01, 1.5697e-01, 6.7906e-02, 1.5568e-01]],
       grad_fn=<MulBackward0>)

## Normalize the attention weights

While the above is one wat to masked out future words, notice that the attention weights in each row don't sum to one anymore. To mitigate that, we can normalize the rows such that they sum up to 1 again, which is a standard convention for attention weights:

In [5]:
row_sums=masked_simple.sum(dim=1, keepdim=True)
masked_simple_norm=masked_simple/row_sums
masked_simple_norm

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [4.6882e-01, 5.3118e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.1300e-02, 1.7706e-02, 9.7099e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [2.8615e-01, 2.6261e-01, 1.5009e-01, 3.0116e-01, 0.0000e+00, 0.0000e+00],
        [7.0466e-03, 5.0141e-03, 1.9505e-04, 8.8712e-03, 9.7887e-01, 0.0000e+00],
        [1.6155e-01, 1.6695e-01, 2.9095e-01, 1.5697e-01, 6.7906e-02, 1.5568e-01]],
       grad_fn=<DivBackward0>)

As we can see, the attention weights in each row now sum up to 1.

Normalizing attention weights in neural networks, such as intransformer models, is advantageous over unnormalized weights for two main reasons. 

First, normalized attention weights that sum to 1 resemble a probability distribution. This makes it easier to interpret the model's attention to various parts of the input in terms of proportions.

Second, by constraining the attention weights to sum to 1, this normalization helps control the scale of the weights and gradients to improve the training dynamics.

**More efficient masking without renormalization**

In the causal self-attention procesure we codede above, we first compute the attention scores, then compute the attention weights, mask out attention weights above the diagonal, and lastly renormalize the attention weights. This is summarized in the figure below:

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/820/915/897/930/128/original/d318a712263ef680.webp" width="80%" heigh="80%" alt="implementing causal self-attention procedure"></div>


Alternatively, there is a more efficient way to achieve the same results. In this approach, we take the attention scores and replace the values above the diagonal with negative infinity before the value are input into the softmax function to compute the attention weights. This is summatized in the figure below:

<div style="text-align: center"><img src="https://files.mastodon.social/media_attachments/files/111/820/964/024/803/792/original/ca94af87a2bfe55b.webp" width="80%" heigh="80%" alt="more efficient approach to implementing causal self-attention"></div>



In [6]:
mask=torch.triu(torch.ones(block_size, block_size), diagonal=1)
masked=attn_scores.masked_fill(mask.bool(), -torch.inf)
masked

tensor([[ 0.0182,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.2065,  0.3831,    -inf,    -inf,    -inf,    -inf],
        [ 0.7258,  1.3610,  7.0241,    -inf,    -inf,    -inf],
        [-0.0929, -0.2143, -1.0055, -0.0206,    -inf,    -inf],
        [-0.5973, -1.0785, -5.6701, -0.2716,  6.3803,    -inf],
        [ 0.1115,  0.1580,  0.9436,  0.0708, -1.1142,  0.0592]],
       grad_fn=<MaskedFillBackward0>)

Then, all we have to do is to apply the softmax function as usual to obtain the normalized and masked attention weights.

In [7]:
attn_weights=torch.softmax(masked/d_out_kq**0.5, dim=1)
attn_weights

tensor([[1.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [4.6882e-01, 5.3118e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [1.1300e-02, 1.7706e-02, 9.7099e-01, 0.0000e+00, 0.0000e+00, 0.0000e+00],
        [2.8615e-01, 2.6261e-01, 1.5009e-01, 3.0116e-01, 0.0000e+00, 0.0000e+00],
        [7.0466e-03, 5.0141e-03, 1.9505e-04, 8.8712e-03, 9.7887e-01, 0.0000e+00],
        [1.6155e-01, 1.6695e-01, 2.9095e-01, 1.5697e-01, 6.7906e-02, 1.5568e-01]],
       grad_fn=<SoftmaxBackward0>)

Why does this work? The softmax function, applied in the last step, converts the input values into a probability distribution. When -inf is present in the inputs, softmax effectively treats them as zero probability. This is because e^(inf) approaches 0, and thus these positions contribute nothing to the output probabilites.

Note: For the real word case, please consider optimize the self-attention by using Flash Attention, it helps to reducing memory footprint and computational load.


# Credit

* https://magazine.sebastianraschka.com?utm_source=navbar&utm_medium=web&r=fbe14