# Attention

This notebook will implement chapter three of the book [Build a Large Language Model (From Scratch) by Sebastian Raschka](https://www.amazon.com/Build-Large-Language-Model-Scratch/dp/1633437167) which deals with one of the most important architectural aspects of Large Language Models (LLMs)-[**Attention**](https://arxiv.org/abs/1706.03762). Attention is the main architectural feature of [transformers](https://en.wikipedia.org/wiki/Transformer_(deep_learning) that give LLMs superior performance compared to other [AutoRegressive methods](https://www.geeksforgeeks.org/nlp/autoregressive-models-in-natural-language-processing/) especially when it comes to the task of text generation. Which will be the primary focus of this notebook.

The chapter covers **self-attention**, **causal-attention** and **multi-head attention** the three main attention mechanisms used in today's transformer architectures. In the following sections, we will ultimatley be calculating the **context vectors** (i.e., an enriched embedding vector that incorporates information from all of the other elements in the sentence/sequence) for the following sentence: **Your journey starts with one step**

In [1]:
import torch

In [2]:
# input sentence: 2d martrix of (6, 3) shape
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x1)
   [0.55, 0.87, 0.66],   # journey  (x2)
   [0.57, 0.85, 0.64],   # starts   (x3)
   [0.22, 0.58, 0.33],   # with     (x4)
   [0.77, 0.25, 0.10],   # one      (x5)
   [0.05, 0.80, 0.55]]   # step     (x6)
)

## Self-Attention

### Self-Attention No Weights

Self-attention refers to computing attention for different positions within a single input sequence. Generally this refers to the **attention weights** which weigh the importance of different parts/elements of the input sequence i.e. which words play a more *important role* than other words in the sequence. As part of this attention process the goal is to calculate **context vectors** for each part/element of the sequence. These context vectors create an enriched embedding vector that incorporates information from all other parts/elements of the sequence. This is how LLMs are able to understand the relationships and relevance of different words to each other in the sentence without being explicitly programmed on how to do so.  

The method of calculating self-attention starts as follows:

1.  Choose a current word/token as the **query vector** i.e., $x_1=q_1$
2.  Calculate **attention scores **between the query vector and all other input words/tokens i.e., $ω_j=\forall j, token_j\cdot q_1$
3.  Normalize the current attention score to obtain the current **attention weight** i.e., $α_1=softmax(\omega_j)=\frac{e^{\omega_1}}{\sum_{j=1}^{n}e^{\omega_j}}$
4.  Repeat step 3 for all other words/tokens i.e. calculate all attention weights $\alpha_{ij}$
5. Once all attention weights have been calculated i.e., $α_{ij}$ the **context vector** is calcualted which incorporates all of the information from all of the other words/tokens in the sequence. This is done by multipling each input word/token by its corresponding attention weight and summing the results i.e. $z_{1} = \sum_{j=1}^{n}\alpha_{1j} \cdot token_j$
6. Repeat step 1 with a different word/token as the query vector and exit when all context vectors for the input sequence have been computed.



---


### Self-Attention with Weights

To implement self-attention with neural networks, you need trainable weights that can be updated and changed based on the training task.

*Importantly the trainable weights of the neural network are not to be confused with the attention weights of the context vector since the trainable weights define the connection of the neural network, while the attention weights of the context vector define the connections/relationships of the sequence*.

The trainable weights of the network are represented by the following three matrices:
*   $W_q$
*   $W_k$
*   $W_v$

Generally these three weight matrices are generated by three different fully connected neural networks and are used to project/transform the words/tokens into a lower dimensional vector space defined by the following vectors:
*   Query
*   Key
*   Value

After doing so, these vectors are used to calculate self-attention using the same steps as above except with a few additional steps that take into account the [linear transformation](https://en.wikipedia.org/wiki/Transformation_matrix) from the orginial embedding space to the new embedding space. This is done by taking the square root of the key vector's dimension and using it as a scaler during normalization.

Accordingly, the method of calculating self-attention with trainable weights starts as follows:
1. Select current word/token i.e., $x_1$ and calculate the **Query** vector i.e., $q_1=W_qx_1$
2. Calculate the **Key** and **Value** matrices for all of the other words/tokens in the sequence i.e., $K=W_kX, \;$ $V=W_vX$
3. Calculate the **attention scores** for current word/token i.e. $\omega_1=q_1K$
4. Normalize the attention scores for current word/token by taking the square root of the key vector's dimension within the softmax function to obtain the **attention weights** i.e., $α_1=softmax(\omega_1/\sqrt{d_{k}})$
5. Calculate the **context vector** for the current word/token which incorporates all of the information from all of the other words/tokens in the sentence. This is done by multipling the **Value** vector by the **attention weights** i.e. $z_{1} = \alpha_{1}V$
6. Repeat step 1 with a different word/token as the query vector and exit when all context vectors for the input sequence have been computed.  





In [16]:
# calculating one context vector


d_in = inputs.shape[1] #input/starting embedding dim is 3
d_out = 2 # set output embedding dim size as 2

# intializing the three weight matrices W_Q, W_K, W_V
torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

# 1. compute query vector for current token
x_2 = inputs[1, :] # select current token i.e. 'journey'
query_2 = x_2 @ W_query

# 2. compute all key and value matrices for all tokens
keys = inputs @ W_key
values = inputs @ W_value

# 3. compute attention scores for current token
attn_scores_2 = query_2 @ keys.T

# 4. transform attention scores into scaled attention weights for current token
d_k = keys.shape[-1] # embedding dimension for key vector
attn_weights_2 = torch.softmax(attn_scores_2/d_k**0.5, dim=-1)

# 5. calculate context vector for current token
context_vec_2 = attn_weights_2 @ values
print(f"Context Vector for journey: {context_vec_2}")

Context Vector for journey: tensor([0.3061, 0.8210])


## Self-Attention for Entire Input Sequence

In [17]:
# class for calculating self-attention for entire input sequence
class SelfAttention(torch.nn.Module):
  def __init__(self, d_in, d_out, qkv_bias=False):
    super().__init__()
    self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)

  def forward(self, X):
    queries = self.W_query(X)
    keys = self.W_key(X)
    values = self.W_value(X)
    attn_scores = queries @ keys.T
    attn_weight = torch.nn.functional.softmax(
        input=attn_scores/keys.shape[-1]**0.5,
        dim=-1
    )
    context_vec = attn_weight @ values
    return context_vec

In [27]:
torch.manual_seed(789)
self_attention = SelfAttention(d_in=3, d_out=2)
print(f"Context Matrix for all words:\n {self_attention(inputs).detach().numpy()}")

Context Matrix for all words:
 [[-0.07389025  0.07128991]
 [-0.07481073  0.0703093 ]
 [-0.07485619  0.07024166]
 [-0.07600163  0.06845011]
 [-0.07632761  0.06794281]
 [-0.07544428  0.06930492]]


## Causal-Attention

**Causal-Attention** (aka Masked-Attention) is about hiding future words in the input sequence.

*Importantly: While self-attention considers all of the tokens/words when predicting the next token/word in the sequence. Causal-attention only considers previous tokens/words when predicting the next token/word in the sequence.*

Generally, both forms of attention are used to train transformer models. To achieve **Causal-Attention** two things are generally used:
1.   [Masks](https://en.wikipedia.org/wiki/Triangular_matrix)
2.   [Dropout](https://docs.pytorch.org/docs/stable/generated/torch.nn.Dropout.html)

**Masks** (in most transformers) are generally lower-triangular matricies meaning that the attention weights above the main diagonal are hidden. The book [Build a Large Language Model (From Scratch) by Sebastian Raschka](https://www.amazon.com/Build-Large-Language-Model-Scratch/dp/1633437167) details two ways of constructing attention masks. The first way is to:
1.  create a mask with 0's above the diagonal
2.  multiply the mask with the attention weights
3.  renormalize the attention weights

This method is generally called Simple Masked Attention (SMA) because it does not take advantage of the mathematical properties of the softmax function when dealing with -$∞$ values.

The more efficient and generally accepted method is creating a mask with -$∞$ values above the diagonal. This is done because the softmax function as defined mathematically is: $\frac{e^{z_{i}}}{\sum_{j=1}^{K}e^{z_{j}}}$ and when raised to -$∞$ amounts to $lim_{z\to -\infty} e^{z}=\frac{1}{e^{\infty }}=\frac{1}{\infty }=0$. So instead of multiplying the mask by the attention weights and then having to renormalize the attention weights, we can just mask the attention scores with -$∞$ values and then apply the softmax function to create the normalized attention weights in one step.This is generally mathematically represented as: $softmax(\frac{QK^{T}}{\sqrt{d}}+M)$

**Dropout** is generally the last piece to creating Causal-Attention. As mentioned before the whole goal with Causal-Attention is to learn [causal relationships](https://en.wikipedia.org/wiki/Causality) i.e., x causes y. Masking does this by hiding future information with respect to the current input; or in other words, making sure that the current prediction is based only on the previous information. Dropout is used to make sure that these *causal relationships are learned* and not only [memorized](https://arxiv.org/html/2406.03880v1) by the model. Dropout does this by randomly removing attention scores and/or weights. In this case dropout is only applied to the attention weights.  





In [30]:
# calculating attention weight matrix for all words
query = self_attention.W_query(inputs)
keys = self_attention.W_key(inputs)
att_score = query @ keys.T
att_weight = torch.nn.functional.softmax(input=(att_score/keys.shape[-1]**0.5), dim=-1)
print(f"Attention Weight Matrix for all words:\n {att_weight.detach().numpy()}")

Attention Weight Matrix for all words:
 [[0.19212602 0.1646463  0.16516064 0.15499417 0.17211477 0.15095802]
 [0.20412546 0.16588287 0.16621484 0.14957766 0.16645327 0.1477459 ]
 [0.20356156 0.16592424 0.16624875 0.14981547 0.16641727 0.14803267]
 [0.18688802 0.1666883  0.16683646 0.15710352 0.16609134 0.15639237]
 [0.18304484 0.16685854 0.16695702 0.1588405  0.16582507 0.15847409]
 [0.19347237 0.16633299 0.16656809 0.15418623 0.16656083 0.15287954]]


### Simple Masked Attention

In [36]:
#1. create a mask with 0's above the diagonal
context_length = 6 # input sequence context length
mask_ones = torch.tril(torch.ones((context_length, context_length))) # create mask
print(f"Simple Mask: \n {mask_ones.detach().numpy()}")

Simple Mask: 
 [[1. 0. 0. 0. 0. 0.]
 [1. 1. 0. 0. 0. 0.]
 [1. 1. 1. 0. 0. 0.]
 [1. 1. 1. 1. 0. 0.]
 [1. 1. 1. 1. 1. 0.]
 [1. 1. 1. 1. 1. 1.]]


In [38]:
#2. multiply the mask with the attention weights
mask_attn = att_weight*mask_ones
print(f"Masked Attention-Weights: \n {mask_attn.detach().numpy()}")

Masked Attention-Weights: 
 [[0.19212602 0.         0.         0.         0.         0.        ]
 [0.20412546 0.16588287 0.         0.         0.         0.        ]
 [0.20356156 0.16592424 0.16624875 0.         0.         0.        ]
 [0.18688802 0.1666883  0.16683646 0.15710352 0.         0.        ]
 [0.18304484 0.16685854 0.16695702 0.1588405  0.16582507 0.        ]
 [0.19347237 0.16633299 0.16656809 0.15418623 0.16656083 0.15287954]]


In [39]:
#3. renormalize the attention weights
row_sums=mask_attn.sum(dim=-1, keepdim=True)
mask_normalized_attn = mask_attn/row_sums
print(f"Normalized Masked Attention-Weights: \n {mask_normalized_attn.detach().numpy()}")

Normalized Masked Attention-Weights: 
 [[1.         0.         0.         0.         0.         0.        ]
 [0.551678   0.44832197 0.         0.         0.         0.        ]
 [0.3799672  0.30971354 0.31031927 0.         0.         0.        ]
 [0.27584285 0.24602847 0.24624716 0.23188154 0.         0.        ]
 [0.21751536 0.1982809  0.19839793 0.18875293 0.1970528  0.        ]
 [0.19347237 0.16633299 0.16656809 0.15418623 0.16656083 0.15287954]]


In [41]:
# context matrix from SMA
values = self_attention.W_value(inputs)
context_vec = mask_normalized_attn@ values
print(f"Context Matrix from SMA: \n {context_vec.detach().numpy()}")

Context Matrix from SMA: 
 [[-0.08721808  0.02858998]
 [-0.09906914  0.05009485]
 [-0.09994501  0.06334987]
 [-0.0982549   0.04894815]
 [-0.05144592  0.10984372]
 [-0.07544428  0.06930492]]


### Negative Infinity Masked Attention

In [43]:
#1. create attention score mask with -infs above the diagonal
mask = torch.triu(torch.ones((context_length, context_length)), diagonal=1)
mask = att_score.masked_fill(mask.bool(), value=-torch.inf) # attention score mask
print(f"Masked Attention Scores: \n {mask.detach().numpy()}")

Masked Attention Scores: 
 [[0.2899089        -inf       -inf       -inf       -inf       -inf]
 [0.4656424  0.17225963       -inf       -inf       -inf       -inf]
 [0.45943564 0.17031771 0.17308104       -inf       -inf       -inf]
 [0.26415503 0.10239156 0.10364803 0.01864095       -inf       -inf]
 [0.21828783 0.08735328 0.08818767 0.01770909 0.0785667        -inf]
 [0.34078205 0.12703359 0.12903105 0.01979299 0.12896936 0.00775672]]


In [44]:
#2.Normalized Masked Attention-Weights
att_weight_2 = torch.nn.functional.softmax(input=(mask/keys.shape[-1]**0.5), dim=-1)
print(f"Normalized Masked Attention-Weights: \n {att_weight_2.detach().numpy()}")

Normalized Masked Attention-Weights: 
 [[1.         0.         0.         0.         0.         0.        ]
 [0.551678   0.44832197 0.         0.         0.         0.        ]
 [0.37996718 0.3097135  0.31031924 0.         0.         0.        ]
 [0.27584285 0.24602845 0.24624714 0.23188154 0.         0.        ]
 [0.2175154  0.19828095 0.19839796 0.18875296 0.19705284 0.        ]
 [0.19347237 0.16633299 0.16656809 0.15418623 0.16656083 0.15287954]]


In [46]:
# context matrix from -inf mask
values = self_attention.W_value(inputs)
context_vec = att_weight_2@ values
print(f"Context Matrix from -inf mask: \n {context_vec}")

Context Matrix from -inf mask: 
 tensor([[-0.0872,  0.0286],
        [-0.0991,  0.0501],
        [-0.0999,  0.0633],
        [-0.0983,  0.0489],
        [-0.0514,  0.1098],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


### Dropout

In [51]:
dropout = torch.nn.Dropout(0.5)
print(f"Dropout applied to -inf Attention-Weights: \n {dropout(att_weight_2).detach().numpy()}")

Dropout applied to -inf Attention-Weights: 
 [[0.         0.         0.         0.         0.         0.        ]
 [0.         0.89664394 0.         0.         0.         0.        ]
 [0.75993437 0.619427   0.6206385  0.         0.         0.        ]
 [0.5516857  0.         0.         0.         0.         0.        ]
 [0.4350308  0.3965619  0.39679593 0.         0.39410567 0.        ]
 [0.         0.         0.33313617 0.         0.         0.30575907]]


### Causal-Attention for Entire Input Sequence

In [52]:
class CausalAttention(torch.nn.Module):
  def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
    super().__init__()
    self.d_out = d_out
    self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.dropout = torch.nn.Dropout(dropout)
    self.register_buffer(
        'mask',
        torch.triu(torch.ones((context_length, context_length)), diagonal=1)
        )
  def forward(self, X):
    batch, num_tokes, d_in = X.shape # batched input
    queries = self.W_query(X)
    keys = self.W_key(X)
    values = self.W_value(X)
    # attention score
    attn_scores = queries @ keys.transpose(1,2)
    # -inf masked attention weights
    mask_attn_scores = attn_scores.masked_fill_(self.mask.bool()[:num_tokes, :num_tokes], value=-torch.inf)
    attn_weights = torch.nn.functional.softmax(
        input=mask_attn_scores/keys.shape[-1]**0.5,
        dim=-1
    )
    # dropout
    attn_weights=self.dropout(attn_weights)
    # context matrix
    context_vec = attn_weights @ values
    return context_vec

In [54]:
# batch input to have shape: (2, 6, 3)
batch = torch.stack((inputs, inputs), dim=0)

In [62]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in=3, d_out=2, context_length=context_length, dropout=0.5)
context_vecs = ca(batch)
print(f"Batched Context Matrices: \n\n {context_vecs.detach().numpy()}")

Batched Context Matrices: 

 [[[-0.90384054  0.44320962]
  [-0.4367989   0.21418986]
  [-0.48492774 -0.13410191]
  [-0.58335876  0.00813284]
  [-0.62186474 -0.05263354]
  [-0.14171308 -0.05048606]]

 [[ 0.          0.        ]
  [-1.1748701   0.0115522 ]
  [-0.7732556   0.00728327]
  [-0.9139531  -0.27685684]
  [-0.76786053 -0.07353682]
  [-0.6748546  -0.09838524]]]


## Multi-Head Attention

Multi-head attention is just dividing the above attention methods into multiple heads for parallelization purposes. Each head calculates its own context matrix and the results are  then combined into a single context matrix. The book [Build a Large Language Model (From Scratch) by Sebastian Raschka](https://www.amazon.com/Build-Large-Language-Model-Scratch/dp/1633437167) implements weight splitting multi-head attention in which the number of heads and their dimensions are factored into the Query, Key and Value matrices.

In [58]:
class MultiHeadAtt(torch.nn.Module):
  def __init__(self, d_in, d_out, dropout, conlen, num_heads, qkv_bias=False):
    super().__init__()
    self.d_out = d_out
    self.num_heads = num_heads
    self.head_dim = d_out//num_heads
    self.W_query = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_key = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.W_value = torch.nn.Linear(d_in, d_out, bias=qkv_bias)
    self.dropout = torch.nn.Dropout(dropout)
    self.out = torch.nn.Linear(d_out, d_out, bias=qkv_bias)
    self.register_buffer(
        'mask',
        torch.triu(torch.ones((conlen, conlen)), diagonal=1)
        )


  def forward(self, X):
    batch, num_tokes, d_in = X.shape
    queries = self.W_query(X)
    keys = self.W_key(X)
    values = self.W_value(X)

    keys=keys.view(batch, num_tokes, self.num_heads, self.head_dim)
    values=values.view(batch, num_tokes, self.num_heads, self.head_dim)
    queries=queries.view(batch, num_tokes, self.num_heads, self.head_dim)

    keys=keys.transpose(1, 2)
    values=values.transpose(1, 2)
    queries=queries.transpose(1, 2)

    attn_scores = queries @ keys.transpose(2,3)
    mask_attn_scores = attn_scores.masked_fill_(self.mask.bool()[:num_tokes, :num_tokes], value=-torch.inf)
    attn_weights = torch.nn.functional.softmax(
        input=mask_attn_scores/keys.shape[-1]**0.5,
        dim=-1
    )
    attn_weights=self.dropout(attn_weights)

    context_vec = (attn_weights @ values).transpose(1,2)
    context_vec = context_vec.contiguous().view(
        batch, num_tokes, self.d_out
    )
    context_vec = self.out(context_vec)
    return context_vec

In [64]:
context_length = 6
multi = MultiHeadAtt(
    d_in=3,
    d_out=2,
    dropout=0.2,
    conlen=context_length,
    num_heads=2
)
context_vec = multi(batch)
print(f"First head Context Matrix: \n{context_vec[0, :, :].detach().numpy()}")
print()
print(f"Second head Context Matrix: \n{context_vec[1, :, :].detach().numpy()}")

First head Context Matrix: 
[[0.28226054 0.04943101]
 [0.27218515 0.11528094]
 [0.27345967 0.13080472]
 [0.20599014 0.02699874]
 [0.0383919  0.12273508]
 [0.2399512  0.07629689]]

Second head Context Matrix: 
[[0.28226054 0.04943101]
 [0.2923506  0.0804213 ]
 [0.02250692 0.1599009 ]
 [0.18400052 0.07392472]
 [0.0769636  0.1050854 ]
 [0.19460025 0.08155501]]
