<a href="https://colab.research.google.com/github/athahibatullah/llm-from-scratch/blob/main/03%20-%20Coding%20Attention%20Mechanisms/ch03.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# **Coding Attention Mechanisms**

Now, we will lok at an integral part of the LLM architecture itself, attention mechanisms.

Figure below depicts different attention mechanisms we will code in this chapter. These different attention variants build on each other, and the goal is to arrive at a compact and efficient implementation of multi-head attention that we can then plug into the LLM architecture we will code in the next chapter.

<img src="https://drive.google.com/uc?id=1eY2OWPT3QfDvouOPbSP0YfV4_PzfF7EG">

# **3.1 The Problem with Modeling Long Sequences**

Before we dive into self-attention mechanism, we should know the problem with pre-LLM architectures that do not include attention mechanisms.

Suppose we want to build a language translation model, we can't translate a text from a language to another language in the same sequence because grammar exist

<img src="https://drive.google.com/uc?id=1WxOhVuKOqN-nR_RWcI8Z6E-5CscsQt0p">

To address this problem, it is common to use a deep neural network with two submodules, an encoder and a decoder. The job of the encoder is to first read in and process the entire text, and the decoder then produces the translated text.

Before the advent of transformers, recurrent neural networks (RNN) were the most popular encoder-decoder architecture for language translation. An RNN is a type of neural network where outputs from previous steps are fed as inputs to the current step, making them well-suited for sequential data like text.

In encoder-decoder RNN, the input text is fed into the encoder, which processes it sequentially. The encoder updates its hidden state (the internal values at the hidden layers) at each step, tryning to capture the entire meaning of the input sentence in the final hidden state. The decoder then start to take this final hidden state to start translating text one word at a time. It also updates its hidden state at each step, which is supposed to carry the context necessary for the next-word prediction.

<img src="https://drive.google.com/uc?id=1ouD6Ma4e4ScqF3Sr5UW7cJWsKDWnkXWj">

The key idea here is that encoder part processes the entire input text into a hidden state (memory cell). The decoder then takes in the hidden state to produce the output. Hidden state can be analogize like embedding vector.

The big limitation of encoder-decoder RNNs is that the RNN can't directly access earlier hidden states from the encoder during the decoding phase. It relies solely on the current hidden state, which encapsulates all relevant information. This can lead to a loss of context, especially in complex sentences where dependencies might span long distances. It is not essential to understand RNNs to build an LLM, but the limitation of the RNNs is the motivation behind the design of attention mechanisms.

# **3.2 Capturing Data Dependencies with Attention Mechanisms**

Although RNNs work fine for translating short sentences, they don't work well for longer texts as they don't have direct access to previous words in the input. One major shortcoming in this approach is that the RNN must remember the entire encoded input in a single hidden state before passing it to the decoder.

Hence, researchers developed the Bahdanau attention mechanism for RNNs in 2014, which modifies the encoder-decoder RNN such that the encoder can selectively access different parts of the input sequence at each decoding step:

<img src="https://drive.google.com/uc?id=1XB4eRzWnkYccOMgoiuUIcTeXfHu49gWI">

From the image, the text-generating decoder part of the network can access all input tokens selectively. This means that some input tokens are more important than others for generating a given output token. The importance is determined by the attention weights

3 years later, researchers found that RNN architectures are not required for building deep neural networks for natural language processing and proposed the original transformer architecture including a self-attention mechanism inspired by the Bahdanau attention mechanism.

Self-attention is a mechanism that allows each position in the input sequence to consider the relevancy of, or "attend to," all other positions in the same sequence when computing the representation of a sequence. Self-attention is a key component to LLMs based on the transformer architecture.

<img src="https://drive.google.com/uc?id=1eOc8sBHNafdxp-rzQpGttDBvqXDaB2kl">

# **3.3 Attending to Different Parts of the Input with Self-Attention**

The "self" in self-attention refers to the mechanism's ability to compute attention weights by relating different positions within a single input sequence. It assess and learns the relationships and dependencies between various parts of the input itself, such as words in a sentence or pixels in an image.

This is contrast to traditional attention mechanisms, where the focus is on the relationships between elements of two different sequences, such as in sequence-to-sequence models where the attention might be between an input sequence and an output sequence.

## **3.3.1 A Simple Self-Attention Mechanism without Trainable Weights**

Let's begin by implementing a simplified variant of self-attention, free from any trainable weights, as summarized in figure below. The goal is to illustrate a few key concepts in self-attention before adding trainable weights.

<img src="https://drive.google.com/uc?id=1EzLfE00OGTZmrB2N1-aBu-OKVCMTcTkJ">

Figure above shows an input sequence, denoted as x, consisting of T elements represented as x1 to xT. This sequence typically represents text, such as sentence, that has already been transformed into token embeddings.

In self-attention, our goal is to calculate context vectors zi for each element xi in the input sequence. A context vector can be interpreted as an enriched embedding vector.

To illustrate this concept by using above figure, let's focus on the embedding vector of the second input element, x2 ("journey"), and the corresponding context vector, z2. This enhanced context vector, z2, is an embedding that contains information about x2 and all other input elements, x1 to xT.

Context vectors play a crucial role in self-attention. Their purpose is to create enriched representations of each element in an input sequence (like a sentence) by incorporating information from all other elements in the sequence. This is essential for LLM since LLM require relationship and relevance of words to understand context. We will add trainable weights that help an LLM learn to construct these context vectors. For now, let's implement simplified self-attention mechanisms to compute these weights and the resulting context vector one step at a time.

Consider the following input sentence, which has already been embedded into three-dimensional vectors.

In [2]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

The first step of implementing self-attention is to compute the intermediate values ω, referred to as attention scores (figure below). Due to spatial constraints, the figure displayed the values in truncated version. For example 0.87 is truncated to 0.8.

<img src="https://drive.google.com/uc?id=1ZNs_Nw-pfGhBtfqbrsUlSfitAWtZj4-3">

From figure above (with truncated values), we can determine attention scores between the query token and each input token by computing the dot product of the query, x2, with every other input token:

In [2]:
query = inputs[1]  # 2nd input token is the query

attn_scores_2 = torch.empty(inputs.shape[0])
for i, x_i in enumerate(inputs):
    attn_scores_2[i] = torch.dot(x_i, query) # dot product (transpose not necessary here since they are 1-dim vectors)

print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


How to calculate dot product:

dot(a,b) = a1xb1 + a2xb2 + a3xb3 + .... + anxbn

for example (using our earlier input vector):

dot(input_x1,query_x2) = dot([0.43, 0.15, 0.89],[0.55, 0.87, 0.66])

                        = (0.43x0.55)+(0.15x0.87)+(0.89x0.66)
                        = 0.9544


In [3]:
res = 0.

for idx, element in enumerate(inputs[0]):
    res += inputs[0][idx] * query[idx]

print(res)
print(torch.dot(inputs[0], query))

tensor(0.9544)
tensor(0.9544)


Dot product is a measure of similarity because it quantifies how closely two vectors are aligned: a higher dot product indicates a greater degree of alignment or similarity between the vectors. In the context of self-attention mechanisms, the dot product determines the extent to which each element in a sequence focuses on, or "attends to," any other element: the higher the dot product, the higher the similarity and attention score between two elements.

In the next step, we normalize each of the attention scores we computed previously. The main goal behind the normalization is to obtain the attention weights that sum up to 1. This normalization is a convention that is useful for interpretation and maintaning training stability in an LLM.

In [4]:
attn_weights_2_tmp = attn_scores_2 / attn_scores_2.sum()

print("Attention weights:", attn_weights_2_tmp)
print("Sum:", attn_weights_2_tmp.sum())

Attention weights: tensor([0.1455, 0.2278, 0.2249, 0.1285, 0.1077, 0.1656])
Sum: tensor(1.0000)


From the output, the attention weights now sum to 1
<img src="https://drive.google.com/uc?id=1jvu08NA_G2vb1r9SHeyKMP59UkZoiNei">

In practice, it's more common and advisable to use the softmax function for normalization. This approach is better at managing extreme values and offers more favorable gradient properties during training. The following is a basic implementation of the softmax function for normalizing the attention scores:

In [5]:
def softmax_naive(x):
    return torch.exp(x) / torch.exp(x).sum(dim=0)

attn_weights_2_naive = softmax_naive(attn_scores_2)

print("Attention weights:", attn_weights_2_naive)
print("Sum:", attn_weights_2_naive.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


As the output shows, the softmax function also meets the objective and normalizes the attention weights such that they sum to 1

In addition, the softmax function ensures that the attention weights are always positive. This make the output interpretable as probabilities or relative importance, where higher weights indicate greater importance.

Note that this naive softmax implementation (softmax_naive) may encounter numerical instability problems, such as overflow and underflow, when dealing with large or small input values. Therefore, in practice, it's advisable to use the PyTorch implementation of softmax, which has been extensively optimized for performance:

In [6]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)

print("Attention weights:", attn_weights_2)
print("Sum:", attn_weights_2.sum())

Attention weights: tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
Sum: tensor(1.)


After computing the normalized attention weights, the last step is to calculate the context vector z2 by multiplying the embedded input tokens, xi, with the corresponding attention weights and then summing the resulting vectors. Thus, the context vector z2 is the weighted sum of all input vectors, obtained by multiplying each input vector by its corresponding attention weight:

In [14]:
query = inputs[1] # 2nd input token is the query
context_vec_2 = torch.zeros(query.shape)
for i,x_i in enumerate(inputs):
    print("attention weights: {} | embedded input token: {}".format(attn_weights_2[i], x_i))
    context_vec_2 += attn_weights_2[i]*x_i

print(context_vec_2)

attention weights: 0.13854756951332092 | embedded input token: tensor([0.4300, 0.1500, 0.8900])
attention weights: 0.2378913015127182 | embedded input token: tensor([0.5500, 0.8700, 0.6600])
attention weights: 0.23327402770519257 | embedded input token: tensor([0.5700, 0.8500, 0.6400])
attention weights: 0.12399158626794815 | embedded input token: tensor([0.2200, 0.5800, 0.3300])
attention weights: 0.10818186402320862 | embedded input token: tensor([0.7700, 0.2500, 0.1000])
attention weights: 0.15811361372470856 | embedded input token: tensor([0.0500, 0.8000, 0.5500])
tensor([0.4419, 0.6515, 0.5683])


<img src="https://drive.google.com/uc?id=1dwjwcsif32LI_glFql5zSkLWYKhJP0D4">

Next, we will generalize the procedure for computing context vectors to calculate all context vectors simultaneously

## **3.3.2 Computing Attention Weights for All Input Tokens**

So far, we have computed attention weights and the context vector for input 2, as shown in the highlighted row in below figure. Now let's extend this computation to calculate attention weights and context vectors for all inputs

<img src="https://drive.google.com/uc?id=1LNedYJwe7f5b3sRMVdkINtCffWb55Msa">

We follow the same three steps as before, except that we make a few modifications in the code to compute all context vectors instead of only the second one, z2

In [15]:
attn_scores = torch.empty(6, 6)

for i, x_i in enumerate(inputs):
    for j, x_j in enumerate(inputs):
        attn_scores[i, j] = torch.dot(x_i, x_j)

print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


<img src="https://drive.google.com/uc?id=1QEHg5PWEmc_bBkojiWHGDY2xKvEJLvho">

Each element in the tensor represents an attention score between each pair of inputs.

When computing the preceding attention score tensor, we used for loops in Python. However, for loops are generally slow, and we can achieve the same results using matrix multiplication:

In [16]:
attn_scores = inputs @ inputs.T
print(attn_scores)

tensor([[0.9995, 0.9544, 0.9422, 0.4753, 0.4576, 0.6310],
        [0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865],
        [0.9422, 1.4754, 1.4570, 0.8296, 0.7154, 1.0605],
        [0.4753, 0.8434, 0.8296, 0.4937, 0.3474, 0.6565],
        [0.4576, 0.7070, 0.7154, 0.3474, 0.6654, 0.2935],
        [0.6310, 1.0865, 1.0605, 0.6565, 0.2935, 0.9450]])


Next, let's normalize these attention scores using softmax function so the value of each row is sum to 1:

In [17]:
attn_weights = torch.softmax(attn_scores, dim=-1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In the context of using PyTorch, the dim parameter in functions like torch.softmax specifies the dimension of the input tensor along which the function will be computed. By setting dim=-1, we are instructing the softmax function to apply the normalization along the last dimension of the attn_scores tensor. If attn_scores is a two-dimensional tensor (for example, with a shape of [rows, columns]), it will normalize accross the columns so that the values in each row (summing over the column dimension) sum up to 1.

We can verify that the rows indeed all sum to 1:

In [18]:
row_2_sum = sum([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])
print("Row 2 sum:", row_2_sum)

print("All row sums:", attn_weights.sum(dim=-1))

Row 2 sum: 1.0
All row sums: tensor([1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000])


Last step, we use these attention weights to calculate the context vector via matrix multiplication:

In [19]:
all_context_vecs = attn_weights @ inputs
print(all_context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


We can double-check that the code is correct by comparing the second row with the context vector z2 that we computed in section 3.3.1:

In [20]:
print("Previous 2nd context vector:", context_vec_2)

Previous 2nd context vector: tensor([0.4419, 0.6515, 0.5683])


Based on the result, we can see the exact same result for context vector z2.

This concludes the code walkthrough of a simple self-attention mechanism. Next, we will add trainable weights, enabling the LLM to learn from data and improve its performance on specific tasks.

# **3.4 Implementing Self-Attention with Trainable Weights**

Our next step will be to implement the self-attention mechanism used in the original transformer architecture, the GPT models, and most other popular LLMs. This self-attention mechanism is also called scaled dot-product attention. Figure below shows how this self-attention mechanism fits into broader context of implementing an LLM.

<img src="https://drive.google.com/uc?id=1Ii-dv8sAL3SEQXFvCG95CZiCXjMBCyhY">


The self-attention mechanism with trainable weights build on the previous concepts: we want to compute context vectors as weighted sums over the input vectors specific to a certain input element. There are only slight difference compared to our simplified self-attention earlier.

The most notable difference is the introduction of weight matrices that are updated during model training. These trainable weight matrices are crucial so that the model (specifically, the attention module inside the model) can learn to produce "good" context vectors.

We will divide the self-attention mechanism into 2 parts.
* First, we will code it step by step as before.
* Second, we will organize the code into a compact Python class that can be imported into the LLM architecture.

## **3.4.1 Computing the Attention Weights Step by Step**

We will implement the self-attention mechanism step by step by introducing the three trainable weight matrices Wq, Wk, and Wv. There matrices are used to project the embedded input tokens, xi, into query, key, and value vectors, respectively, as illustrated in figure below

<img src="https://drive.google.com/uc?id=1KVfx4qwvYCe9mAC6mj3Xp_yK_STyhGwE">

Earlier, we defined the second input element x2 as the query when we computed the simplified attention weights to compute the context vector z2. Then we generalized this to compute all context vectors z1 ... zT for the six-word input sentence. "Your journey starts with one step."

Similarly, we start here by computing only one context vector, z2, for illustration purposes. We will then modify this code to calculate all context vectors.

Let's begin by defining a few variables:

In [3]:
x_2 = inputs[1] # second input element
d_in = inputs.shape[1] # the input embedding size, d=3
d_out = 2 # the output embedding size, d=2

In the GPT-like models, the input and output dimensions are usually the same, but to better follow the computation, we'll use different input (d_in=3) and output (d_out=2) dimensions here.

Next, we intialize the three weight matrices Wq, Wk, and Wv.

In [13]:
torch.manual_seed(123)

W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

print("W_query: {}".format(W_query))
print("W_key: {}".format(W_key))
print("W_value: {}".format(W_value))

W_query: Parameter containing:
tensor([[0.2961, 0.5166],
        [0.2517, 0.6886],
        [0.0740, 0.8665]])
W_key: Parameter containing:
tensor([[0.1366, 0.1025],
        [0.1841, 0.7264],
        [0.3153, 0.6871]])
W_value: Parameter containing:
tensor([[0.0756, 0.1966],
        [0.3164, 0.4017],
        [0.1186, 0.8274]])


We set requires_grad=False to reduce clutter in the outputs, but if we were to use the weight matrices for model training, we would set requires_grad=True to update these matrices during model training.

Next, we compute the query, key, and value vectors:

In [5]:
query_2 = x_2 @ W_query # _2 because it's with respect to the 2nd input element
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2)

tensor([0.4306, 1.4551])


The output for the query results in a two-dimensional vector since we set the number of columns of the corresponding weight matrix, via d_out, to 2.

**Weight parameters vs. attention weights**

In the weight matrices W, the term "weight" is short for "weight parameters," the values of a neural network that are optimized during training. This is not to be confused with the attention weights. Attention weights determine the extent to which a context vector depends on the different parts of the input.

In summary, weight parameters are the fundamental, learned coefficients that define the network's connections, while attention weights are dynamic, context-specific values.

Even though our temporary goal is only compute the one context vector, z2, we still require the key and value vectors for all input elements as they are involved in computing the attention weights with respect to the query q2.

We can obtain all keys and values via matrix multiplication:

In [6]:
keys = inputs @ W_key
values = inputs @ W_value

print("keys.shape:", keys.shape)
print("values.shape:", values.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])


We successfully projected the six input tokens from a three-dimensional onto a two-dimensional embedding space

The second step is to compute the attention scores:

<img src="https://drive.google.com/uc?id=1pQN0FaxIuyl_KCeR54jx9c-CDi0qksyd">
The attention score computation is a dot-product computation similar to what we used in the simplified self-attention mechanism in section 3.3. The new aspect here is that we are not directly computing the dot-product between the input elements but using the query and key obtained by transforming the inputs via the respective weight matrices.

First, let's compute the attention score ω22:

In [7]:
keys_2 = keys[1] # Python starts index at 0
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(1.8524)


Again, we can generalize this computation to all attention scores via matrix multiplication:

In [8]:
attn_scores_2 = query_2 @ keys.T # All attention scores for given query
print(attn_scores_2)

tensor([1.2705, 1.8524, 1.8111, 1.0795, 0.5577, 1.5440])


Now that we get the attention scores, we need to calculate the attention weight. We compute the attention weights by scaling the attention scores and using the softmax function. However, now we scale the attention scores by dividing them by the square root of the embedding dimension of the keys (taking the square root is mathematically the same as exponentiating by 0.5)

<img src="https://drive.google.com/uc?id=1XeZfgHyyNZbAu9-oWRmN727g8q0Hy5gE">

In [14]:
d_k = keys.shape[1]
print(d_k)
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print(attn_weights_2)

2
tensor([0.1500, 0.2264, 0.2199, 0.1311, 0.0906, 0.1820])


**The rationale behind scaled-dot product attention**

The reason for the normalization by the embedding dimension size is to improve the training performance by avoiding small gradients. For instance, when scaling up the embedding dimension, which is typically greater than 1000 for GPT-like LLMs, large dot products can result in very small gradients during backpropagation due to the softmax function applied to them. As dot products increase, the softmax function behaves more like a step function, resulting in gradients nearing zero. These small gradients can drastically slow down learning or cause training to stagnate.

The scaling by the square root of the embedding dimension is the reason why this self-attention mechanism is also called scaled-dot product attention

Now, the final step is to compute the context vectors:

<img src="https://drive.google.com/uc?id=1FlP3au_su_2eX_to_uAyPCFlSi0Y-18_">

Similar to when we computed the context vector as a weighted sum over the input vectors, we now compute the context vector as a weighted sum over the value vectors. Here, the attention weights serve as a weighting factor that weighs the respective importance of each value vector. Also as before, we can use matrix multiplication:

In [10]:
context_vec_2 = attn_weights_2 @ values
print(context_vec_2)

tensor([0.3061, 0.8210])


So far, we've only computed a single context vector, z2. Next, we will generalize the code to compute all context vectors in the input sequence, z1 to zT.

**Why query, key, and value?**

The terms "key", "query", and "value" in the context of the attention mechanisms are borrowed from the domain of information retrieval and databases, where similar concepts are used to store, search, and retrieve information.

A query is analogous to a search query in a database. It represents the current item (e.g. a word or token in a sentence) the model focuses on or tries to understand. The query is used to probe the other parts of the input sequence to determine how much attention to pay to them.

The key is like a database key used for indexing and searching. In the attention mechanism, each item in the input sequence (e.g., each word in a sentence) has an associated key. These keys are used to match the query.

The value in this context is similar to the value in a key-value pair in a database. It represents the actual content or representation of the input items. Once the model determines which keys (and thus which parts of the input) are most relevant to the query (the current focus item), it retrieves the corresponding values.


## **3.4.2 Implementing a Compact Self-Attention Python Class**

Now we need to implement all we learned before into a Python class

In [18]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value

        attn_scores = queries @ keys.T # omega
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )

        context_vec = attn_weights @ values
        return context_vec

In this PyTorch code, SelfAttention_v1 is a class derived from nn.Module, which is a fundamental building block of PyTorch models that provides necessary functionalities for model layer creation and management.

The __init__ method initializes trainable weights matrices (W_query, W_key, and W_value) for queries, keys, and values, each transforming the input dimension d_in to an output dimension d_out.

During the forward pass, using the forward method, we compute attention scores (attn_scores) by multiplying queries and keys, normalizing these scores using softmax. Finally, we create a context vector by weighting the values with these normalized attention scores.

This is how to implement the class:

In [17]:
torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.2996, 0.8053],
        [0.3061, 0.8210],
        [0.3058, 0.8203],
        [0.2948, 0.7939],
        [0.2927, 0.7891],
        [0.2990, 0.8040]], grad_fn=<MmBackward0>)


Since inputs contains six embedding vectors, this results in a matrix storing the six context vectors.

We can improve the SelfAttention_v1 implementation further by utilizing PyTorch's nn.Linear layers, which effectively perform matrix multiplication when the bias units are disabled. Additional advantage is instead of manually implementing nn.Parameter(torch.rand(...)) is that nn.Linear has an optimized weight initialization scheme, contributing to more stable and effective model training

<img src="https://drive.google.com/uc?id=1M2BCfz7LSWmp_TtG4VGU9MZrTDN-pYkF">

In [16]:
class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


We only change the init part where we intialize the weight of query, key, and value.



Next, we will make enhancements to the self-attention mechanism, focusing specifically on incorporating causal and multi-head elements. The causal aspect involves modifying the attention mechanism to prevent the model from accessing the future information in the sequence, which is crucial for tasks like language modeling, where each word prediction should only depend on previous words.

The multi-head component involves splitting the attention mechanism into multiple "heads." Each head learns different aspects of the data, allowing the model to simultaneously attend to information from different representation subspaces at different positions. This improves the model's performance in complex tasks.