# Coding Attention Mechanism


![Alt text](../../assests/mental-model-attn.png)
A mental model of the three main stages of coding an LLM, pretraining the LLM on a general text dataset, and finetuning it on a lableded dataset. 


![Alt text](../../assests/four-variants.png)
This figure depicts different attention mechanisms we will code in this section, starting with a simplified version of self-attention before adding the trainable weights. The casual attention mechanism adds a mask to self-attention that allows the LLM to generate one word at a time. Multi-head attention organizes the attention mechanism into multiple heads, allowing the model to capture various aspects of the input data in parallel.

In [1]:
from importlib.metadata import version

print("torch version:", version("torch"))

torch version: 2.0.1


## The problem with modeling long sequences

- Translating a text word by word isn't feasible due to the differences in grammatical structures between the source and target languages:

![Alt text](../../assests/before-llms.png)



- Prior to the introduction of transformer models, encoder-decoder RNNs were commonly used for machine translation tasks

- In this setup, the encoder processes a sequence of tokens from the source language, using a hidden state—a kind of intermediate layer within the neural network—to generate a condensed representation of the entire input sequence:

![Alt text](../../assests/rnn-seq-seq.png)

The big issue and limitation of `encoder-decoder RNNs` is that the RNN can't directly access earlier hidden states from the encoder during the decoding phase. Consequently, it relies solely on the current hidden state, which encapsulates all relevant information. This can lead to a loss of context, especially in complex sentences where dependencies might
span long distances.


## Capturing data dependencies with attention mechanisms

Before transformer LLMs, it was common to use RNNs for language modeling tasks such as language translation, as mentioned previously. RNNs work fine for translating short sentences but don't work well for longer texts as they don't have direct access to previous words in the input.

One major shortcoming in this approach is that the RNN must remember the entire encoded input in a single hidden state before passing it to the decoder. 


Through an attention mechanism, the text-generating decoder segment of the network is capable of selectively accessing all input tokens, implying that certain input tokens hold more significance than others in the generation of a specific output token:

![Alt text](../../assests/attention.png)


Self-attention in transformers is a technique designed to enhance input representations by enabling each position in a sequence to engage with and determine the relevance of every other position within the same sequence.

Self-attention is a mechanism that allows each position in the input sequence to attend to all positions in the same sequence when computing the representation of a sequence. Self-attention is a key component of contemporary LLMs based on the transformer architecture, such as the GPT series.

![Alt text](../../assests/self-attention.png)


## Attending to different parts of the input with self-attention

Self-attention serves as the cornerstone of every LLM based on the transformer architecture.

**THE "SELF" IN SELF-ATTENTION**
In self-attention, the `self` refers to the mechanism's ability to compute attention weights by relating different positions within a single input sequence. It assesses and learns the relationships and dependencies between various parts of the input itself, such as words in a sentence or pixels in an image. This is in contrast to traditional attention mechanisms, where the focus is on the relationships between elements of two different sequences, such as in sequence-to-sequence models where the attention might be between an input sequence and an output sequence.


**A simple self-attention mechanism without trainable weights**

![Alt text](../../assests/goal-self-attention.png)


- Suppose we are given an input sequence $x^{(1)}$ to $x^{(T)}$
  - The input is a text (for example, a sentence like "Your journey starts with one step") that has already been converted into token embeddings.
  - For instance, $x^{(1)}$ is a `d-dimensional` vector representing the word "Your", and so forth.
- `GOAL:` compute context vectors $z^{(i)}$ for each input sequence element $x^{(i)}$ in $x^{(1)}$ to $x^{(T)}$ (where $z$ and $x$ have the same dimension).
  - A context vector $z^{(i)}$ is a weighted sum over the inputs $x^{(1)}$ to $x^{(T)}$
  - The context vector is "context"-specific to a certain input.
    - Instead of $x^{(i)}$ as a placeholder for an arbitrary input token, let's consider the second input, $x^{(2)}$
    - And to continue with a concrete example, instead of the placeholder $z^{(i)}$, we consider the second output context vector, $z^{(2)}$
    - The second context vector, $z^{(2)}$, is a weighted sum over all inputs $x^{(1)}$ to $x^{(T)}$ weighted with respect to the second input element, $x^{(2)}$.
    - The attention weights are the weights that determine how much each of the input elements contributes to the weighted sum when computing $z^{(2)}$.
    - In short, think of $z^{(2)}$ as a modified version of $x^{(2)}$ that also incorporates information about all other input elements that are relevant to a given task at hand.


- By convention, the unnormalized attention weights are referred to as `"attention scores"` whereas the normalized attention scores, which sum to 1, are referred to as `"attention weights"`.

In self-attention, context vectors play a crucial role. Their purpose is to create enriched representations of each element in an input sequence (like a sentence) by incorporating information from all other elements in the sequence. This is essential in LLMs, which need to understand the relationship and relevance of words in a sentence to each other.


**Step 1:** compute unnormalized attention scores $w$

- Suppose we use the second input token as the query, that is, $q^{(2)}$ = $x^{(2)}$, we compute the unnormalized attention scores via dot products:

  - $w_{21}$ = $x^{(1)}q^{(2)T}$
  - $w_{22}$ = $x^{(2)}q^{(2)T}$
  - $w_{23}$ = $x^{(3)}q^{(2)T}$
  - ...
  - $w_{2T}$ = $x^{(T)}q^{(2)T}$
- Above, $w$ is the Greek letter "omega" used to symbolize the unnormalized attention scores.
  - The subscript "21" in $w_{21}$ means that input sequence element 2 was used as a query aganist input sequence element 1.

- Consider the input sentence, which has already been embedded into 3-dimensional vectors. We choose a small embedding dimension:

In [3]:
import torch

inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
    [0.55, 0.87, 0.66], # journey  (x^2)
    [0.57, 0.85, 0.64], # starts   (x^3)
    [0.22, 0.58, 0.33], # with     (x^4)
    [0.77, 0.25, 0.10], # one      (x^5)
    [0.05, 0.80, 0.55]] # step     (x^6)
)

In [11]:
inputs.ndim, inputs.dtype

(2, torch.float32)

- (In this section, we follow the common machine learning and deep learning convention where training examples are represented as rows and feature values as columns; in the case of the tensor shown above, each row represents a word, and each column represents an embedding dimension)


- The primary objective of this section is to demonstrate how the context vector $z^{(2)}$ is calculated using the second input sequence, $x^{(2)}$, as a query.

![Alt text](../../assests/attention-scores.png)


- The  figure above depicts the intial step in this process, which involves calculating the attention scores $w$ between $x^{(2)}$ and all other input elements through a dot product operation.

- We use input sequence element `2`, $x^{(2)}$, as an example to compute context vector $z^{(2)}$; later in this section, we will generalize this to compute all context vectors.

- The first step is to compute the unnormalized attention scores by computing the dot product between the query $x^{(2)}$ and all other input tokens:

In [20]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [18]:
query = inputs[1]

attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query)

print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


`Side note:` a dot product is essentially a shorthand for multiplying two vectors elements-wise and summing the resulting products:

In [25]:
res = 0.

for idx, element in enumerate(inputs[0]):
    res += inputs[0][idx] * query[idx]

print(res)
print(torch.dot(inputs[0], query))

tensor(0.9544)
tensor(0.9544)


`Beyond viewing the dot product operation as a mathematical tool that combines two vectors to yield a scalar value, the dot product is a measure of similarity because it quantifies how much two vectors are aligned: a higher dot product indicates a greater degree of alignment or similarity between the vectors. In the context of self-attention mechanisms, the dot product determines the extent to which elements in a sequence attend to each other: the higher the dot product, the higher the similarity and attention score between two elements.`

**Step 2:** Normalize the unnormalized attention scores ("omegas", $w$) so that they sum to 1.

- We computed the attention scores in the previous step, in this step we now normalize the attention scores to obtain the attention weights.


![Alt text](../../assests/attn-weights.png)

- The figure above depicts a simple way to normalize the unnormalized attention scores to sum up to 1 (a convention, useful for interpretation, and important for training stability).

In [28]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()

print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


- However, in practice, using the `softmax` function for normalization, which is better at handling extreme values and has more desirable gradient properties during training, is common and recommended.

- Here is a naive implementation of a softmax function for scaling, which also normalizes the vector elements such that they sum up to 1:

In [29]:
def softmax_naive(logits):
    return torch.exp(logits) / torch.exp(logits).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)

print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


- As the output shows, the softmax function also meets the objective and normalizes the attention weights such that they sum to `1`.

- The softmax function ensures that the attention weights are always positive. This makes the output interpretable as probabilities or relative importance, where higher weights indicate greater importance.

- The naive implementation above can suffer from numerical instability issues for large or small input values due to overflow and underflow issues.


- Hence, in practice, it's recommended to use the PyTorch implementation of softmax instead, which has been highly optimized for performance:

In [30]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


**Step 3:** compute the context vector $z^{(2)}$ by multiplying the embedded input tokens, $x^{(i)}$  with the attention weights and sum the resulting vectors:

![Alt text](../../assests/sum-attn-weights.png)


In [31]:
query = inputs[1] # 2nd input token is the query

context_vec_2 = torch.zeros(query.shape)
for i, x_i in enumerate(inputs):
    context_vec_2 += attn_weights_2[i] * x_i

print(context_vec_2)

tensor([0.4419, 0.6515, 0.5683])


In [34]:
context_vec_2.shape, context_vec_2.ndim, context_vec_2.dtype

(torch.Size([3]), 1, torch.float32)

### Computing attention weights for all input tokens

**Generalize to all input sequence tokens:**

- Above, we computed the attention weights and context vector for input 2 (as illustrated in the highlighted row in the figure below).
- Next, we are generalizing this computation to compute all attention weights and context vectors.

![Alt text](../../assests/weights-all.png)


- In self-attention, the process starts with the calculation of attention scores, which are subsequently normalized to derive attention weights that total 1.

- These attention weights are then utilized to generate the context vectors through a weighted summation of the inputs.

![Alt text](../../assests/weighted-sum.png)


- Apply previous **step 1** to all pairwise elements to compute the unnormalized attention score matrix:

In [47]:
inputs

tensor([[0.4300, 0.1500, 0.8900],
        [0.5500, 0.8700, 0.6600],
        [0.5700, 0.8500, 0.6400],
        [0.2200, 0.5800, 0.3300],
        [0.7700, 0.2500, 0.1000],
        [0.0500, 0.8000, 0.5500]])

In [40]:
attn_scores = torch.empty(6, 6)

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)
        
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [45]:
attn_scores.shape, attn_scores.ndim

(torch.Size([6, 6]), 2)

In [46]:
attn_scores[1, 2]

tensor(1.4754)

- We can achieve the same as above more efficiently via matrix multiplication:

In [48]:
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


In [50]:
attn_scores = torch.matmul(inputs, inputs.T)
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


- Similar to **step 2** previously, we normalize each row so that the values in each row sum to 1:

In [54]:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [55]:
attn_weights[0].sum()

tensor(1.0000)

In the context of using PyTorch, the dim parameter in fucntions like `torch.softmax` specifies the dimension of the input tensor along with the function will be computed. By setting `dim=-1`, we are instructing the softmax function to apply the normalization along the last dimension of the `attn_scores` tensor. If `attn_scores` is a `2D` tensor (for example, with a shape of [rows, columns]), `dim = -1` will normalize across the columns so that the values in each row (summing over the column dimension) sum up to 1.

- Quick verification that the values in each row indeed sum to 1:

In [65]:
row_2_sum = sum( [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("Row 2 sum:", row_2_sum)

print("All row sums:", attn_weights.sum(dim=-1))

Row 2 sum: 1.0
All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


- Apply previous `step 3` to compute all context vectors:

In [66]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


- As a sanity check, the previously computed context vector $z^{(2)}$ = $[0.4419, 0.6515, 0.5683]$ can be found in the 2nd row in above:

In [67]:
print("Previous 2nd context vector:", context_vec_2)

Previous 2nd context vector: tensor([0.4419, 0.6515, 0.5683])


This concludes the code walkthrough of a simple self-attention mechanism. In the next section, we will add trainable weights, enabling the LLM to learn from data and improve its performance on specific tasks.

## Implementing self-attention with trainable weights

In this section, we are implementing the self-attention mechanism that is used in the original transformer architecture, the GPT models, and most other popular LLMs. 

- This self-attention mechanism is also called `scaled dot-product attention`.
- The figure below provides a mental model illustrating how this self-attention mechanism fits into the broader context of implementing an LLM.


![Alt text](../../assests/self-attn-trainable-weights.png)


- The overall idea is similar to before:
  - We want to compute context vectors as weighted sums over the input vectors specific to a certain input elements.
  - For the above, we need attention weights.
- As you will notice, there are only slight differences compared to the basic attention mechanism introduced earlier:
  - The most notable difference is the introduction of weights matrices that are updated during model training.
  - These trainable weight matrices are crucial so that the model (specifically, the attention module inside the model) can learn to produce "good" context vectors.







### Computing the attention weights step by step

We will implement the self-attention mechanism step by step by introducing the three trainable weight matrices $W_{q}$, $W_{k}$, and $W_{v}$. These three matrices are used to project the embedded input tokens, $x^{(i)}$, into 
- `query`, 
- `key`, and 
- `value` vectors.


![Alt text](../../assests/qkv-1.png)


- These three matrices are used to project the embedded input tokens, $x^{(i)}$, into query, key, and value vectors via matrix multiplication:
  - Query vector: $q^{(i)} = x^{(i)}W_q$
  - Key vector: $k^{(i)} = x^{(i)}W_k$
  - Value vector: $v^{(i)} = x^{(i)}W_v$

- The embedding dimensions of the input $x$ and the query vector $q$ can be the same or different, depending on the model's design and specific implementation.

- In GPT models, the input and output dimensions are usually the same, but for illustration purposes, to better follow the computations, we choose different input and output dimensions here:


In [68]:
x_2 = inputs[1] # second input element
d_in = inputs.shape[1] # the input embedding size, d=3
d_out = 2 # the output embedding size, d=2

In [69]:
x_2

tensor([0.5500, 0.8700, 0.6600])

- Below, we initialize the three weight matrices; note that we are setting `requires_grad=False` to reduce clutter in the outputs for illustration purposes, but if we were to use the weight matrices for model training, we would set `requires_grad=True` to update these matrices during model training.

In [76]:
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [77]:
W_query.shape

torch.Size([3, 2])

- Next we compute the query, key, and value vectors:

In [80]:
query_2 = x_2 @ W_query # _2 because it's with respect to the 2nd input element
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2)

tensor([0.4306, 1.4551])


**WEIGHT PARAMETERS VS ATTENTION WEIGHTS**

Note that in the weight matrices `W`, the term `"weight"` is short for `"weight parameters,"` the values of a neural network that are optimized during training. This is not to be confused with the attention weights. Attention weights determine the extent to which a context vector depends on the different parts of the input, i.e., to what extent the network focuses on different parts of the input.

In summary, weight parameters are the fundamental, learned coefficients that define the network's connections, while attention weights are dynamic, context-specific values.

- We can obtain all keys and values via matrix multiplication:

In [81]:
keys = inputs @ W_key
values =  inputs @ W_value

print("Keys.shape:", keys.shape)
print("values.shape:", values.shape)

Keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


In [83]:
values.ndim, keys.ndim

(2, 2)

- As we can see from the output above, we successfully projected the `6` input tokens from a `3D` onto a `2D` embedding space:

- In the next step, **step 2**, we compute the unnormalized attention scores by computing the dot product between the query and each key vector:


![Alt text](../../assests/qkv-2.png)


In [84]:
keys

tensor([[0.3669, 0.7646],
        [0.4433, 1.1419],
        [0.4361, 1.1156],
        [0.2408, 0.6706],
        [0.1827, 0.3292],
        [0.3275, 0.9642]])

- First, let's compute the attention score $w_{22}$

In [85]:
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(1.8524)


- We can generalize this computation to all attention scores via matrix multiplication. Since we have `6` inputs, we have `6` attention scores for the given query vector:

In [86]:
attn_scores_2 = query_2 @ keys.T # All attention scores for given query
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


The **3rd step** is now going from attention scores to the attention weights, as illustrated in the figure below;

![Alt text](../../assests/qkv-3.png)


After computing the attention scores $w$, the next step is to normalize these scores using the softmax function to obtain the attention weights $α$.

We compute the attention weights by scaling the attention scores and using the softmax function we used earlier. The difference to earlier is that we now scale the attention scores by dividing them by the square root of the embedding dimension of the keys. 

In [89]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


**THE RATIONALE BEHIND SCALED-DOT PRODUCT ATTENTION**

`The reason for the normalization by the embedding dimension size is to improve the training performance by avoiding small gradients. For instance, when scaling up the embedding dimension, which is typically greater than thousand for GPT-like LLMs, large dot products can result in very small gradients during backpropagation due to the softmax function applied to them. As dot products increase, the softmax function behaves more like a step function, resulting in gradients nearing zero. These small gradients can drastically slow down learning or cause training to stagnate. The scaling by the square root of the embedding dimension is the reason why this self-attention mechanism is also called scaled-dot product attention.`

- Now, the final step is to compute the context vectors, for input query vector 2:

![Alt text](../../assests/qkv-4.png)

In this final step of the self-attention computation, we compute the context vector by combining all value vectors via the attention weights.

Here, the attention weights serve as a weighing factor that weighs the respective importance of each value vector. We can use matrix multiplication to obtain the output in one step:

In [90]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


**WHY QUERY, KEY, AND VALUE?**

The terms `"key,"` `"query,"` and `"value"` in the context of attention mechanisms are borrowed from the domain of information retrieval and databases, where similar concepts are used to `store`, `search`, and `retrieve` information.

A `"query"` is analogous to a search query in a database. It represents the current item (e.g., a word or token in a sentence) the model focuses on or tries to understand. The query is used to probe the other parts of the input sequence to determine how much attention to pay to them.

The `"key"` is like a database key used for indexing and searching. In the attention mechanism, each item in the input sequence (e.g., each word in a sentence) has an associated key. These keys are used to match with the query.

The `"value"` in this context is similar to the value in a key-value pair in a database. It represents the actual content or representation of the input items. Once the model determines which keys (and thus which parts of the input) are most relevant to the query (the current focus item), it retrieves the corresponding values.

### Implementing a compact Self Attention class

- Putting it all together, we can implement the self-attention mechanism as follows:

In [101]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):
    def __init__(self, d_in, d_out):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))
        
    
    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        
        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec
    

torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


In [98]:
d_in, d_out, inputs.shape

(3, 2, torch.Size([6, 3]))

- Durring the forward pass, using the forward method, we compute the attention scores `(attn_scores)` by multiplying queries and keys, normalizing these scores using `softmax`. Finally, we create a context vector by weighting the values with these normalized attentions scores.

- Since `inputs` contains six embedding vectors, the output above is a matrix storing the six context vectors.

- As a quick check, notice how the second row `([0.3061, 0.8210])` matches the contents of `context_vec_2` in the previous section.

In [102]:
context_vec_2

tensor([0.3061, 0.8210])

In self-attention, we transform the Input vectors in the Input matrix `X` with the three weight matrices, $W_{q}$, $W_{k}$, and $W_{v}$. Then, we compute the attention weight matrix based on the resulting queries $(Q)$ and keys $(K)$. Using the attention weights and values $(V)$, we then compute the context vectors $(Z)$.


![Alt text](../../assests/qkv-5.png)

As shown in the figure above, self-attention involves the trainable weight matrices $W_{q}$, $W_{k}$, and $W_{v}$. These matrices transform input data into `queries`, `keys`, and `values`, which are crucial components of the attention mechanism. As the model is exposed to more data during training, it adjusts these trainable weights.

- We can improve the `SelfAttention_v1` implementation further by utilizing PyTorch's `nn.Linear` layers, which effectively perform matrix multiplication when the bias units are disabled.

- We can streamline the implementation above using PyTorch's Linear layers, which are equivalent to a matrix multiplication if we disable the bias units.

- Another big advantage of using `nn.Linear` over our manual `nn.Parameter(torch.rand(...))` approach is that `nn.Linear` has a preferred weight initialization scheme, which leads to more stable model training.

In [106]:
import torch
import torch.nn as nn   


class SelfAttention_v2(nn.Module):
    def __init__(self, input_dimension, output_dimension, qkv_bias=False):
        super().__init__()
        self.output_dimension = output_dimension
        
        self.weight_query = nn.Linear(input_dimension, output_dimension, bias=qkv_bias)
        self.weight_key   = nn.Linear(input_dimension, output_dimension, bias=qkv_bias)
        self.weight_value = nn.Linear(input_dimension, output_dimension, bias=qkv_bias)
        
    def forward(self, x: torch.Tensor):
        queries = self.weight_query(x)
        keys = self.weight_key(x)
        values = self.weight_value(x)
        
        attention_scores = queries @ keys.T
        attention_weights = torch.softmax(attention_scores / keys.shape[-1]**0.5, dim=-1)
        context_vector = attention_weights @ values
        return context_vector


# You can use the SelfAttention_v2 similar to SelfAttention_v1:
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


- Note that `SelfAttention_v1` and `SelfAttention_v2` give different outputs because they use different initial weights for the weight matrices since `nn.Linear` uses a more sophisticated weight initialization scheme.

In [122]:
sa_v2.named_parameters()

<generator object Module.named_parameters at 0x12654cf40>

In [124]:
# Get all parameters as an iterator
for name, param in sa_v2.named_parameters():
    print(f"Parameter: {name}, Shape: {param.shape}")
    print(f"Values: {param.data}")  # Actual weight values

Parameter: weight_query.weight, Shape: torch.Size([2, 3])
Values: tensor([[ 0.3161,  0.4568,  0.5118],
        [-0.1683, -0.3379, -0.0918]])
Parameter: weight_key.weight, Shape: torch.Size([2, 3])
Values: tensor([[ 0.4058, -0.4704,  0.2368],
        [ 0.2134, -0.2601, -0.5105]])
Parameter: weight_value.weight, Shape: torch.Size([2, 3])
Values: tensor([[ 0.2526, -0.1415, -0.1962],
        [ 0.5191, -0.0852, -0.2043]])


In [125]:
# Get all parameters as an iterator
for name, param in sa_v1.named_parameters():
    print(f"Parameter: {name}, Shape: {param.shape}")
    print(f"Values: {param.data}")  # Actual weight values

Parameter: W_query, Shape: torch.Size([3, 2])
Values: tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])
Parameter: W_key, Shape: torch.Size([3, 2])
Values: tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])
Parameter: W_value, Shape: torch.Size([3, 2])
Values: tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])


**EXERCISE 3.1 COMPARING SELFATTENTION_V1 AND SELFATTENTION_V2**

Your task is to correctly assign the weights from an instance of SelfAttention_v2 to an instance of SelfAttention_v1. To do this, you need to understand the relationship between the weights in both versions. (Hint: nn.Linear stores the weight matrix in a transposed form.) After the assignment, you should observe that both instances produce the same outputs.

In [140]:
state_dict = sa_v2.state_dict()
state_dict

OrderedDict([('weight_query.weight',
              tensor([[ 0.3161,  0.4568,  0.5118],
                      [-0.1683, -0.3379, -0.0918]])),
             ('weight_key.weight',
              tensor([[ 0.4058, -0.4704,  0.2368],
                      [ 0.2134, -0.2601, -0.5105]])),
             ('weight_value.weight',
              tensor([[ 0.2526, -0.1415, -0.1962],
                      [ 0.5191, -0.0852, -0.2043]]))])

In [142]:
query_weights = state_dict['weight_query.weight']
key_weights = state_dict['weight_key.weight']
value_weights = state_dict['weight_value.weight']

In [145]:
query_weights.shape

torch.Size([2, 3])

In [155]:
queries = inputs @ query_weights.T
keys = inputs @ key_weights.T
values = inputs @ value_weights.T

torch.manual_seed(789)
attn_scores = queries @ keys.T # omega
attn_weights = torch.softmax(
    attn_scores / keys.shape[-1]**0.5, dim=-1
)
context_vec = attn_weights @ values
context_vec

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]])

## Hiding future words with causal attention

Casual attention, also known as `masked attention`, is a specialized form of self-attention. It restricts a model to only consider previous and current inputs in a sequence when processing any given token. This is in contrast to the standard self-attention mechanism, which allows access to the entire input sequence at once.

- When computing attention scores, the casual attention mechanism ensures that the model only factors in tokens that occur at or before the current token in the sequence.


- In causal attention, the `attention weights` above the diagonal are masked, ensuring that for any given input, the LLM is unable to utilize future tokens while calculating the context vectors with the attention weight.

![Alt text](../../assests/casual-attn.png)

- For example, for the word `"journey"` in the second row, we only keep the attention weights for the words before ("Your) and in the current position `("journey)`.
  
- As seen in the figure above, we mask out the `attention weights` above the diagonal, and we normalize the `non-masked attention weights`, such that the `attention weights` sum to 1 in each row.

### Applying a causal attention mask

- In this section, we are converting the previous self-attention mechanism into a causal self-attention mechanism.
  
- Causal self-attention ensures that the model's prediction for a certain position in a sequence is only dependent on the known outputs at previous positions, not on future positions.
  
- In simpler words, this ensures that each next word prediction should only depend on the preceding words.
  
- To achieve this, for each given token, we mask out the future tokens (the ones that come after the current token in the input text):


![Alt text](../../assests/casual-attn-mask.png)


- From the figure above, one way to obtain the masked attention weight matrix in casual attention is to apply the `softmax function` to the attention scores, zeroing out the elements above the diagonal and normalizing the resulting matrix.


- To illustrate and implement casual self-attention, let's work with the attention scores and weights from the previous section:

In [157]:
queries = sa_v2.weight_query(inputs)
keys = sa_v2.weight_key(inputs)
# values = sa_v2.weight_value(inputs)

attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


The simplest way to mask out future attention weights is by creating a mask via PyTorch's `tril` function with elements below the main diagonal (including the diagonal itself) set to `1` and above the main diagonal set to `0`:

In [158]:
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [161]:
mask_simple.shape

torch.Size([6, 6])

Now, we can multiply this mask with the attention weights to zero out the values above the diagonal:

In [168]:
masked_simple = attn_weights*mask_simple
print(masked_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


- However, if the mask were applied after softmax, like above, it would disrupt the probability distribution created by softmax.

- Softmax ensures that all output values sum to 1.

- Masking after softmax would require re-normalizing the outputs to sum to 1 again, which complicates the process and might lead to unintended effects.

- To make sure that the rows sum to 1, we can normalize the attention weights as follows:

In [169]:
row_sums = masked_simple.sum(dim=-1, keepdims=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


- While we are technically done with coding the causal attention mechanism now, let's briefly look at a more efficient approach to achieve the same as above.

- So, instead of zeroing out attention weights above the diagonal and renormalizing the results, we can mask the unnormalized attention scores above the diagonal with negative infinity before they enter the softmax function:

![Alt text](../../assests/masked-attn-scores.png)


A more efficient way to obtain the masked attention weight matrix in casual attention is to mask the attention scores with negative infinity values before applying the softmax function.

- The softmax function converts its inputs into a probability distribution. When negative infinity values $(-∞)$ are present in a row, the softmax function treats them as zero probability. (Mathematically, this is because $e^{-∞}$ approaches 0.).

- We can implement this more efficient masking "trick" by creating a mask with `1's` above the diagonal and then replacing these `1's` with negative infinity `(-inf)` values:

In [170]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


- Now, all we need to do is apply the `softmax function` to these masked results, and we are done:

In [171]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


- As we can see from the output above, the values in each row sum to `1`, and no further normalization is necessary.

- We could now use the modified attention weights to compute the context vectors via `context_vec = attn_weights @ values`

### Masking additional attention weights with dropout

`Dropout` in deep learning is a technique where randomly selected hidden layer units are ignored during training, effectively "dropping" them out. This method helps prevent overfitting by ensuring that a model does not become overly reliant on any specific set of hidden layer units. 

- It's important to emphasize that dropout is only used during training and is disabled afterward.

- In the transformer architecture, including models like `GPT`, dropout in the attention mechanism is typically applied in two specific areas: 
  - after calculating the attention scores or 
  - after applying the attention weights to the value vectors.

- Here, we will apply the dropout mask after computing the attention weights because it's more common.

- Futhermore, in this specific example, we use a dropout rate of `50%`, which means randomly masking out half of the attention weights. (When we train the GPT model later, we will use a lower dropout rate, such as `0.1` or `0.2`)

![Alt text](../../assests/dropout-masked.png)


- Using casual attention mask `(uuper left)`, we apply an additional dropout mask `(upper right)` to zero out additional weights to reduce `overfitting` during training.

- If we apply a dropout rate of `0.5 (50%)`, the non-dropped values will be scaled accordingly by a factor of `1/0.5 = 2`

- The scaling is calculated by the formula `1 / (1 - dropout_rate)`

In [172]:
torch.manual_seed(123)
dropout = torch.nn.Dropout(0.5) # dropout rate of 50%
example = torch.ones(6, 6) # create a matrix of ones

print(dropout(example))

tensor([[2., 2., 0., 2., 2., 0.],
        [0., 0., 0., 2., 0., 2.],
        [2., 2., 2., 2., 0., 2.],
        [0., 2., 2., 0., 0., 2.],
        [0., 2., 0., 2., 0., 2.],
        [0., 2., 2., 2., 2., 0.]])


When applying dropout to an attention weight matrix with a rate of `50%`, half of the elements in the matrix are randomly set to zero. To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of `1/0.5 =2`. This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phases.

- Now, let's apply dropout to the attention weight matrix itself:

In [173]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.7599, 0.6194, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4921, 0.4925, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3966, 0.0000, 0.3775, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.3331, 0.3084, 0.3331, 0.0000]],
       grad_fn=<MulBackward0>)


Having gained an understanding of causal attention and dropout masking, we will develop a concise Python class in the following section. This class is designed to facilitate the efficient application of these two techniques.

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

class GPTDatasetV1(Dataset):
    def __init__(self, txt, tokenizer, max_length, stride):
        self.input_ids = []
        self.target_ids = []
        
        # Tokenize the entire text
        token_ids = tokenizer.encode(txt)
        
        # Use a sliding window to chunk the book into overlapping sequences of max_length
        for i in range(0, len(token_ids) - max_length, stride):
            input_chunk = token_ids[i:i + max_length]
            target_chunk = token_ids[i + 1: i + max_length + 1]
            self.input_ids.append(torch.tensor(input_chunk))
            self.target_ids.append(torch.tensor(target_chunk))
    
    # Return the total number of rows in the dataset
    def __len__(self):
        return len(self.input_ids)

    # Return a single row from the dataset
    def __getitem__(self, idx):
        return self.input_ids[idx], self.target_ids[idx]