# Chapter 2: Building a Production-Ready Attention  Module

<div class="alert alert-block alert-success">

In Chapter 1, we built a simplified self-attention mechanism from first principles. While excellent for building intuition, that version was not flexible because each input vector had to act as its own query, key, and value.

In this chapter, we will upgrade our mechanism to a "production-ready" version by introducing **trainable weight matrices**. This is the key step that allows the attention mechanism to *learn* the complex relationships in data, making it an incredibly powerful component of modern LLMs.
</div>

<div class="alert alert-block alert-success">
Set up our environment with the necessary imports.
</div>

In [1]:
import torch
import numpy as np
import torch.nn as nn

## 2.1 Introducing Trainable Weights (Wq, Wk, Wv)

<div class="alert alert-block alert-success">
For consistency, we will continue to use the same 6-token sample sentence and its 3-dimensional embedding vectors from the previous chapter.
</div>

In [2]:
# Our sample input sentence as embedding vectors
inputs = torch.tensor(
    [[ 0.8938,  0.9003,  0.8978], # Your
     [ 0.7165,  0.3428,  0.2553], # journey
     [ 0.1042,  0.5163,  0.3753], # starts
     [ 0.0445,  0.3091,  0.9763], # with
     [ 0.1554,  0.1614,  0.2700], # one
     [ 0.8089,  0.9435,  0.5480]] # step
)

# Corresponding words
words = ['Your', 'journey', 'starts', 'with', 'one', 'step']

<div class="alert alert-block alert-success">

To make our attention mechanism more powerful and production-ready, we now introduce three dedicated, trainable **weight matrices**:

* **`W_query` (Wq)**
* **`W_key` (Wk)**
* **`W_value` (Wv)**

The purpose of these matrices is to **project** our input embeddings into three separate, specialized vectors. For each input token `x`, we will now calculate:

1.  A **query vector `q`** (calculated as `x @ W_query`): This vector is optimized for asking the right "question" to find relevant keys.
2.  A **key vector `k`** (calculated as `x @ W_key`): This vector is optimized to be effectively "found" by relevant queries.
3.  A **value vector `v`** (calculated as `x @ W_value`): This vector contains the rich information that the token will contribute to the final output.

Crucially, these matrices are **trainable parameters**. The model will learn the optimal values for these matrices during the training process, allowing it to master the complex art of understanding context in language.
</div>

<div class="alert alert-block alert-info">
    
To see how this projection works in practice, let's focus on a single input token and define the dimensions for our weight matrices. For this hands-on example, we will:

1.  Select the second input token ("journey") to be the **query** we analyze.
2.  Get its embedding dimension from the input tensor (`d_in`).
3.  Define a smaller output dimension (`d_out`) for the resulting query, key, and value vectors.
</div>

In [3]:
x_2 = inputs[1]
d_in = inputs.shape[1]
d_out = 2

<div class="alert alert-block alert-info">
    
Note that in GPT-like models, the input and output dimensions are usually the same. 

But for illustration purposes,  we are using a smaller output dimension here simply to make the matrix operations easier to track visually.
</div>

<div class="alert alert-block alert-success">
Next, we initialize the three weight matrices Wq, Wk and Wv
</div>

In [4]:
torch.manual_seed(100)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

In [5]:
print(W_query)

Parameter containing:
tensor([[0.1117, 0.8158],
        [0.2626, 0.4839],
        [0.6765, 0.7539]])


In [6]:
print(W_key)

Parameter containing:
tensor([[0.2627, 0.0428],
        [0.2080, 0.1180],
        [0.1217, 0.7356]])


In [7]:
print(W_value)

Parameter containing:
tensor([[0.7118, 0.7876],
        [0.4183, 0.9014],
        [0.9969, 0.7565]])


<div class="alert alert-block alert-info">
    
Note that we are setting requires_grad=False to reduce clutter in the outputs for illustration purposes. 

If we were to use the weight matrices for model training, we would set requires_grad=True to update these matrices during model training.
</div>

<div class="alert alert-block alert-success">
Next, we compute the query, key, and value vectors as shown earlier
</div>

In [8]:
query_2 = x_2 @ W_query
key_2 = x_2 @ W_key
value_2 = x_2 @ W_value

print(query_2)

tensor([0.3427, 0.9429])


<div class="alert alert-block alert-info">
    
As we can see based on the output for the query, this results in a 2-dimensional vector. 

This is because: we set the number of columns of the corresponding weight matrix, via d_out, to 2:
</div>

<div class="alert alert-block alert-success">

Even though our temporary goal is to only compute the one context vector z(2),  we still require the key and value vectors for all input elements. 

This is because they are involved in computing the attention weights with respect to the query q(2)
</div>

<div class="alert alert-block alert-success">
We can obtain all keys and values via matrix multiplication:
</div>

In [9]:
keys = inputs @ W_key
values = inputs @ W_value
queries = inputs @ W_query

print("keys.shape:", keys.shape)

print("values.shape:", values.shape)

print("queries.shape:", queries.shape)

keys.shape: torch.Size([6, 2])
values.shape: torch.Size([6, 2])
queries.shape: torch.Size([6, 2])


<div class="alert alert-block alert-info">
As we can tell from the outputs, we successfully projected the 6 input tokens from a 3D onto a 2D embedding space:
</div>

## 2.2 Scaling Attention Scores to create Attention Weights and Context Vectors

<div class="alert alert-block alert-success">
First, let's compute the attention score ω22
</div>

In [10]:
keys_2 = keys[1]
attn_score_22 = query_2.dot(keys_2)
print(attn_score_22)

tensor(0.3438)


<div class="alert alert-block alert-success">
Again, we can generalize this computation to all attention scores via matrix multiplication:
</div>

In [11]:
attn_scores_2 = query_2 @ keys.T # All attention scores for given query
print(attn_scores_2)

tensor([0.9411, 0.3438, 0.3838, 0.7801, 0.2483, 0.6807])


<div class="alert alert-block alert-success">
    
We compute the attention weights by scaling the attention scores and using the softmax function we used earlier. 

The difference to earlier is that we now scale the attention scores by dividing them by the square root of the embedding dimension of the keys. 
</div>

In [12]:
d_k = keys.shape[-1]
attn_weights_2 = torch.softmax(attn_scores_2 / d_k**0.5, dim=-1)
print("Attention weights for the second input:", attn_weights_2)
print("Embedding dimension for the keys:", d_k)

Attention weights for the second input: tensor([0.2143, 0.1405, 0.1445, 0.1912, 0.1313, 0.1782])
Embedding dimension for the keys: 2


### Why divide by the square root of the embedding dimension?

<div class="alert alert-block alert-warning">

<b>Reason 1: For stability in learning</b>

The softmax function is sensitive to the magnitudes of its inputs. When the inputs are large, the differences between the exponential values of each input become much more pronounced. This causes the softmax output to become "peaky," where the highest value receives almost all the probability mass, and the rest receive very little.ery sharp softmax distribution, making the model overly confident in one particular "key." Such sharp distributions can make learning unstable,
</div>

In [13]:
# Define the tensor
tensor = torch.tensor([0.1, -0.2, 0.3, -0.2, 0.5])

# Apply softmax without scaling
softmax_result = torch.softmax(tensor, dim=-1)
print("Softmax without scaling:", softmax_result)

# Multiply the tensor by 8 and then apply softmax
scaled_tensor = 8 * tensor
softmax_scaled_result = torch.softmax(scaled_tensor, dim=-1)
print("Softmax after scaling by 8:", softmax_scaled_result)

Softmax without scaling: tensor([0.1925, 0.1426, 0.2351, 0.1426, 0.2872])
Softmax after scaling by 8: tensor([0.0326, 0.0030, 0.1615, 0.0030, 0.8000])


<div class="alert alert-block alert-warning">
In attention mechanisms, particularly in transformers, if the dot products between query and key vectors become too large (like multiplying by 8 in this example), the attention scores can become very large. This results in a very sharp softmax distribution, making the model overly confident in one particular "key." Such sharp distributions can make learning unstable,
</div>

### But, why by the square root?

<div class="alert alert-block alert-warning">
    
<b>Reason 2: To make the variance of the dot product stable</b>

The dot product of  Q and K increases the variance because multiplying two random numbers increases the variance.

The increase in variance grows with the dimension. 

Dividing by sqrt (dimension) keeps the variance close to 1
    
</div>

In [14]:
# Function to compute variance before and after scaling
def compute_variances(dim, num_trials=1000):
    dot_products = []
    scaled_dot_products = []

    # Generate multiple random vectors and compute the products
    for _ in range(num_trials):
        q = np.random.randn(dim)
        k = np.random.randn(dim)

        # Compuute dot product
        dot_product = np.dot(q, k)
        dot_products.append(dot_product)

        # Scale the dot product by sqrt(dim)
        scaled_dot_product = dot_product / np.sqrt(dim)
        scaled_dot_products.append(scaled_dot_product)

    # Calculate the variance of the dot produucts
    variance_before_scaling = np.var(dot_products)
    variance_after_scaling = np.var(scaled_dot_products)

    return variance_before_scaling, variance_after_scaling

torch.manual_seed(100)

# For dimension 5:
variance_before_scaling_5, variance_after_scaling_5 = compute_variances(dim=5)
print(f"Variance before scaling (dim=5): {variance_before_scaling_5}")
print(f"Variance after scaling (dim=5): {variance_after_scaling_5}")

# For dimension 20:
variance_before_scaling_20, variance_after_scaling_20 = compute_variances(dim=20)
print(f"Variance before scaling (dim=20): {variance_before_scaling_20}")
print(f"Variance after scaling (dim=20): {variance_after_scaling_20}")

Variance before scaling (dim=5): 4.4993367243501625
Variance after scaling (dim=5): 0.8998673448700324
Variance before scaling (dim=20): 21.491231475308616
Variance after scaling (dim=20): 1.0745615737654306


<div class="alert alert-block alert-success">
    
We now compute the context vector as a weighted sum over the value vectors. 

Here, the attention weights serve as a weighting factor that weighs the respective importance of each value vector. 

We can use matrix multiplication to obtain the output in one step:
</div>

In [15]:
context_vec_2 = attn_weights_2 @ values
print("Context vector for the second input:", context_vec_2)

Context vector for the second input: tensor([1.1783, 1.3425])


<div class="alert alert-block alert-success">
    
So far, we only computed a single context vector, z(2). 

In the next section, we will generalize the code to compute all context vectors in the input sequence, z(1)to z (T)
</div>

## 2.3 Implementing a Compact Self Attention Python Class

<div class="alert alert-block alert-success">
    
In the previous sections, we have gone through a lot of steps to compute the self-attention outputs. 

This was mainly done for illustration purposes so we could go through one step at a time. 

In practice, with the LLM implementation in the next chapter in mind, it is helpful to organize this code into a Python class as follows:
</div>

In [16]:
class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        queries = x @ self.W_query
        keys = x @ self.W_key
        values = x @ self.W_value

        attn_scores = queries @ keys.T #omega
        attn_weights = torch.softmax(
            attn_scores / d_out**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec
    

<div class="alert alert-block alert-warning">
In this PyTorch code, SelfAttention_v1 is a class derived from nn.Module, which is a fundamental building block of PyTorch models, which provides necessary functionalities for model layer creation and management.    
</div>

<div class="alert alert-block alert-warning">

The __init__ method initializes trainable weight matrices (W_query, W_key, and W_value) for queries, keys, and values, each transforming the input dimension d_in to an output dimension d_out.

During the forward pass, using the forward method, we compute the attention scores (attn_scores) by multiplying queries and keys, normalizing these scores using softmax.

Finally, we create a context vector by weighting the values with these normalized attention scores.
</div>

In [17]:
torch.manual_seed(100)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[1.2705, 1.4457],
        [1.1783, 1.3425],
        [1.1593, 1.3236],
        [1.1985, 1.3688],
        [1.1366, 1.2980],
        [1.2373, 1.4083]], grad_fn=<MmBackward0>)


<div class="alert alert-block alert-info">
Since inputs contains six embedding vectors, we get a matrix storing the six context vectors, as shown in the above result. 
</div>

<div class="alert alert-block alert-info">
As a quick check, notice how the second row ([1.1783, 1.3425]) matches the contents of context_vec_2 in the previous section.
</div>

<div class="alert alert-block alert-warning">
We can improve the SelfAttention_v1 implementation further by utilizing PyTorch's nn.Linear layers, which effectively perform matrix multiplication when the bias units are disabled. 
</div>

<div class="alert alert-block alert-warning">
Additionally, a significant advantage of using nn.Linear instead of manually implementing nn.Parameter(torch.rand(...)) is that nn.Linear has an optimized weight initialization scheme, contributing to more stable and effective model training.
</div>

In [18]:
class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.T #omega
        attn_weights = torch.softmax(
            attn_scores / d_out**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec
    

<div class="alert alert-block alert-success">
You can use the SelfAttention_v2 similar to SelfAttention_v1:
</div>

In [19]:
torch.manual_seed(100)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[0.2303, 0.6761],
        [0.2381, 0.6888],
        [0.2268, 0.6639],
        [0.2287, 0.6709],
        [0.2335, 0.6777],
        [0.2288, 0.6711]], grad_fn=<MmBackward0>)


<div class="alert alert-block alert-info">
Note that SelfAttention_v1 and SelfAttention_v2 give different outputs because they use different initial weights for the weight matrices since nn.Linear uses a more sophisticated weight initialization scheme.
</div>

## 2.4 Causal Attention: Masking for Autoregressive Models

<div class="alert alert-block alert-success">

In a language model that generates text one token at a time (an <b>autoregressive</b> model), there's a strict rule: to predict the next token, the model is only allowed to see the tokens that came before it. It cannot "see into the future."

The self-attention mechanism we've built so far violates this rule! Currently, every token can "see" and gather context from all other tokens in the sequence, including those that appear later. This is like giving the model the answers to the test during training.

To fix this, we must implement <b>causal attention</b>. The goal is to "mask" or hide these connections to future tokens.
</div>

<div class="alert alert-block alert-success">
First, let's recalculate the full attention weights matrix to see the problem clearly. Notice in the output below how every token attends to every other token (all weights are non-zero). For example, "journey" (the second row) is currently paying attention to "starts", "with", "one", and "step".
</div>

In [20]:
# Our sample input sentence as embedding vectors
inputs = torch.tensor(
    [[ 0.8938,  0.9003,  0.8978], # Your
     [ 0.7165,  0.3428,  0.2553], # journey
     [ 0.1042,  0.5163,  0.3753], # starts
     [ 0.0445,  0.3091,  0.9763], # with
     [ 0.1554,  0.1614,  0.2700], # one
     [ 0.8089,  0.9435,  0.5480]] # step
)

# Corresponding words
words = ['Your', 'journey', 'starts', 'with', 'one', 'step']

In [21]:
# Reusing the sa_v2 object from the previous section
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

print(attn_weights)

tensor([[0.1694, 0.1562, 0.1660, 0.1839, 0.1634, 0.1611],
        [0.1763, 0.1609, 0.1632, 0.1700, 0.1584, 0.1712],
        [0.1581, 0.1642, 0.1701, 0.1770, 0.1737, 0.1568],
        [0.1646, 0.1593, 0.1677, 0.1814, 0.1677, 0.1592],
        [0.1676, 0.1639, 0.1665, 0.1709, 0.1657, 0.1654],
        [0.1647, 0.1595, 0.1677, 0.1810, 0.1677, 0.1594]],
       grad_fn=<SoftmaxBackward0>)


<div class="alert alert-block alert-success">

As you can see, the attention matrix is fully populated. To enforce causality, we will create a <b>look-ahead mask</b>. This is an upper-triangular matrix that will hide all the positions that correspond to future tokens.

We will apply this mask to the attention scores *before* the softmax step by replacing the scores for future positions with negative infinity (`-inf`). When softmax is applied, $e^{-\infty}$ becomes zero, ensuring no attention is paid to those future tokens.
</div>

In [22]:
context_length = attn_weights.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


<div class="alert alert-block alert-success">
Now, we can multiply this mask with the attention weights to zero out the values above the diagonal:
</div>

In [23]:
masked_simple = attn_weights * mask_simple
print(masked_simple)

tensor([[0.1694, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1763, 0.1609, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.1581, 0.1642, 0.1701, 0.0000, 0.0000, 0.0000],
        [0.1646, 0.1593, 0.1677, 0.1814, 0.0000, 0.0000],
        [0.1676, 0.1639, 0.1665, 0.1709, 0.1657, 0.0000],
        [0.1647, 0.1595, 0.1677, 0.1810, 0.1677, 0.1594]],
       grad_fn=<MulBackward0>)


<div class="alert alert-block alert-info">
As we can see, the elements above the diagonal are successfully zeroed out
</div>

<div class="alert alert-block alert-success">

The third step is to renormalize the attention weights to sum up to 1 again in each row. 

We can achieve this by dividing each element in each row by the sum in each row:
</div>

In [24]:
row_sums = masked_simple.sum(dim=-1, keepdim=True)
masked_simple_norm = masked_simple / row_sums
print(masked_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5229, 0.4771, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3212, 0.3334, 0.3454, 0.0000, 0.0000, 0.0000],
        [0.2446, 0.2367, 0.2492, 0.2695, 0.0000, 0.0000],
        [0.2008, 0.1964, 0.1995, 0.2048, 0.1986, 0.0000],
        [0.1647, 0.1595, 0.1677, 0.1810, 0.1677, 0.1594]],
       grad_fn=<DivBackward0>)


<div class="alert alert-block alert-info"> 
The result is an attention weight matrix where the attention weights above the diagonal are zeroed out and where the rows sum to 1.
</div>

<div class="alert alert-block alert-success">

While we could be technically done with implementing causal attention at this point, we can take advantage of a mathematical property of the softmax function. 

We can implement the computation of the masked attention weights more efficiently in fewer steps.
</div>

<div class="alert alert-block alert-success">

The softmax function converts its inputs into a probability distribution. 

When negative infinity values (-∞) are present in a row, the softmax function treats them as zero probability. 

(Mathematically, this is because e^-∞ approaches 0.)


We can implement this more efficient masking "trick" by creating a mask with 1's above the diagonal and then replacing these 1's with negative infinity (-inf) values:
</div>

In [25]:
print(attn_scores)

tensor([[ 0.0922, -0.0228,  0.0634,  0.2085,  0.0415,  0.0216],
        [ 0.2010,  0.0713,  0.0916,  0.1489,  0.0496,  0.1594],
        [-0.1554, -0.1025, -0.0527,  0.0042, -0.0225, -0.1672],
        [-0.0104, -0.0570,  0.0158,  0.1269,  0.0154, -0.0578],
        [ 0.0275, -0.0039,  0.0178,  0.0549,  0.0114,  0.0091],
        [-0.0108, -0.0557,  0.0150,  0.1227,  0.0148, -0.0567]],
       grad_fn=<MmBackward0>)


In [26]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
attn_scores_masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(attn_scores_masked)

tensor([[ 0.0922,    -inf,    -inf,    -inf,    -inf,    -inf],
        [ 0.2010,  0.0713,    -inf,    -inf,    -inf,    -inf],
        [-0.1554, -0.1025, -0.0527,    -inf,    -inf,    -inf],
        [-0.0104, -0.0570,  0.0158,  0.1269,    -inf,    -inf],
        [ 0.0275, -0.0039,  0.0178,  0.0549,  0.0114,    -inf],
        [-0.0108, -0.0557,  0.0150,  0.1227,  0.0148, -0.0567]],
       grad_fn=<MaskedFillBackward0>)


<div class="alert alert-block alert-success">
Now, all we need to do is apply the softmax function along the last dimension to these masked results to normalize the scores for each query, and we are done.
</div>

In [27]:
attn_weights = torch.softmax(attn_scores_masked / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5229, 0.4771, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3212, 0.3334, 0.3454, 0.0000, 0.0000, 0.0000],
        [0.2446, 0.2367, 0.2492, 0.2695, 0.0000, 0.0000],
        [0.2008, 0.1964, 0.1995, 0.2048, 0.1986, 0.0000],
        [0.1647, 0.1595, 0.1677, 0.1810, 0.1677, 0.1594]],
       grad_fn=<SoftmaxBackward0>)


<div class="alert alert-block alert-info">
As we can see based on the output, the values in each row sum to 1, and no further normalization is necessary.
</div>

<div class="alert alert-block alert-warning">
Masking in Transformers sets scores for future tokens to a large negative value, making their influence in the softmax calculation effectively zero. 

The softmax function then recalculates attention weights only among the unmasked tokens. 

This process ensures no information leakage from masked tokens, focusing the model solely on the intended data.
</div>

<div class="alert alert-block alert-warning">
We could now use the modified attention weights to compute the context vectors via context_vec = attn_weights @ values.

However, in the next section, we first cover another minor tweak to the causal attention mechanism that is useful for reducing overfitting when training LLMs.
</div>

## 2.5 Masking Attention Weights With Dropout

<div class="alert alert-block alert-success">

A common challenge when training large neural networks is <b>overfitting</b>. This happens when a model learns the training data <i>too well</i> —including its noise and specific quirks— and then fails to generalize to new, unseen examples. To combat overfitting, we use techniques called <b>regularization</b>.

<b>Dropout</b> is one of the most effective and widely used regularization methods in deep learning.

#### How Dropout Works
During each training step, dropout randomly sets a fraction of neuron activations or weights to zero. This forces the network to learn more robust and redundant representations because it cannot become too reliant on any single connection or feature, as it might be "dropped out" at any moment.

<div class="alert alert-block alert-info">
  <b>Analogy: A Team of Experts</b><br>
  Imagine a team of experts solving a problem. If the same experts are always available, they might start relying too heavily on each other's specific knowledge. Dropout is like telling some experts to randomly take a break for each task. The remaining team members are forced to learn how to cover for them and develop a wider range of skills, making the whole team more robust.
</div>

By applying dropout to the attention weights, we prevent the attention mechanism from "overfitting" on specific token-to-token relationships it sees in the training data. It encourages the model to learn more diverse ways of gathering context.

To see how it works in practice, let's first apply dropout to a simple matrix of ones. We'll use a high dropout rate of 50% (`p=0.5`) for this demonstration to make its effect obvious. Note that dropout is **only active during training** and is automatically disabled during evaluation (`model.eval()`).

When we train the GPT model in later chapters, we will use a lower dropout rate, such as 0.1 or 0.2.
</div>

<div class="alert alert-block alert-success">
In the following code, we apply PyTorch's dropout implementation first to a 6×6 tensor consisting of ones for illustration purposes:
</div>

In [28]:
example = torch.ones(6, 6)
print(example)

tensor([[1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.],
        [1., 1., 1., 1., 1., 1.]])


In [29]:
torch.manual_seed(100)
dropout = torch.nn.Dropout(0.5) #A
example = torch.ones(6, 6) #B
print(dropout(example))

tensor([[0., 2., 2., 2., 0., 0.],
        [0., 2., 0., 0., 2., 2.],
        [0., 0., 0., 2., 2., 2.],
        [0., 2., 2., 2., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [2., 0., 2., 2., 0., 0.]])


<div class="alert alert-block alert-info">

When applying dropout to an attention weight matrix with a rate of 50%, half of the elements in the matrix are randomly set to zero. 

To compensate for the reduction in active elements, the values of the remaining elements in the matrix are scaled up by a factor of
1/(1-0.5) =2. 

This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phases.
</div>

<div class="alert alert-block alert-success">
Now, let's apply dropout to the attention weight matrix itself:
</div>

In [30]:
torch.manual_seed(100)
print(dropout(attn_weights))

tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.9542, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.4733, 0.4984, 0.5391, 0.0000, 0.0000],
        [0.4016, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3293, 0.0000, 0.3354, 0.3620, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


<div class="alert alert-block alert-info">
As we can see above, the resulting attention weight matrix now has additional elements zeroed out and the remaining ones rescaled.
</div>

<div class="alert alert-block alert-warning">

Having gained an understanding of causal attention and dropout masking, we will develop a concise Python class in the following section. 
    
This class is designed to facilitate the efficient application of these two techniques.
</div>

## 2.6 Implementing a Compact Causal Attention Class

<div class="alert alert-block alert-success">

In this section, we will now incorporate the causal attention and dropout modifications into the SelfAttention Python class we developed in section 2.3. 

This class will then serve as a template for developing multi-head attention in the upcoming section.
</div>

<div class="alert alert-block alert-success">

Before we begin, one more thing is to ensure that the code can handle batches consisting of more than one input. 

This will ensure that the CausalAttention class supports the batch outputs produced by the data loader we implemented earlier.
</div>

<div class="alert alert-block alert-success">
For simplicity, to simulate such batch inputs, we duplicate the input text example:
</div>

In [31]:
# Our sample input sentence as embedding vectors
inputs = torch.tensor(
    [[ 0.8938,  0.9003,  0.8978], # Your
     [ 0.7165,  0.3428,  0.2553], # journey
     [ 0.1042,  0.5163,  0.3753], # starts
     [ 0.0445,  0.3091,  0.9763], # with
     [ 0.1554,  0.1614,  0.2700], # one
     [ 0.8089,  0.9435,  0.5480]] # step
)

# Corresponding words
words = ['Your', 'journey', 'starts', 'with', 'one', 'step']

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) 

torch.Size([2, 6, 3])


<div class="alert alert-block alert-info">
This results in a 3D tensor consisting of 2 input texts with 6 tokens each, where each token is a 3-dimensional embedding vector.
</div>

<div class="alert alert-block alert-success">
The following CausalAttention class is similar to the SelfAttention class we implemented earlier, except that we now added the dropout and causal mask components as highlighted in the following code.
</div>

<div class="alert alert-block alert-info">

<li><b>Step 1:</b> Compared to the previous SelfAttention_v1 class, we added a dropout layer.</li>
    
<li><b>Step 2:</b> The register_buffer call is also a new addition (more information is provided in the following text).</li>

<li><b>Step 3:</b>  We transpose dimensions 1 and 2, keeping the batch dimension at the first position (0).</li>

<li><b>Step 4:</b> In PyTorch, operations with a trailing underscore are performed in-place, avoiding unnecessary memory copies</li>
</div>

In [32]:
class CausalAttention(nn.Module):
    
    def __init__(self, d_in, d_out, context_length, dropout_rate, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout_rate)
        self.register_buffer('causal_mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        batch_size, num_tokens, d_im = x.shape
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.causal_mask.bool()[:num_tokens,:num_tokens], -torch.inf
        )
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim = 1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec       
        

<div class="alert alert-block alert-warning">

The use of register_buffer in PyTorch is not strictly necessary for all use cases but offers several advantages here. 

For instance, when we use the CausalAttention class in our LLM, buffers are automatically moved to the appropriate device (CPU or GPU) along with our model, which will be relevant when training the LLM in future chapters. 

This means we don't need to manually ensure these tensors are on the same device as your model parameters, avoiding device mismatch errors.
</div>

<div class="alert alert-block alert-success">
We can use the CausalAttention class as follows, similar to SelfAttention previously:
</div>

In [33]:
torch.manual_seed(100)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)

context_vecs.shape: torch.Size([2, 6, 2])


In [34]:
print(context_vecs)

tensor([[[0.0755, 0.2087],
         [0.1384, 0.3551],
         [0.1526, 0.3997],
         [0.1702, 0.5931],
         [0.2069, 0.7170],
         [0.6627, 1.7936]],

        [[0.0755, 0.2087],
         [0.1384, 0.3551],
         [0.1526, 0.3997],
         [0.1702, 0.5931],
         [0.2069, 0.7170],
         [0.6627, 1.7936]]], grad_fn=<UnsafeViewBackward0>)


<div class="alert alert-block alert-info">
As we can see, the output is a `(batch_size, num_tokens, d_out)` tensor. Each of the original input tokens has now been transformed into a new context-aware embedding vector of size `d_out`.
</div>

<div class="alert alert-block alert-info">
    
  <b>An Interesting Observation: Deterministic Outputs</b><br>
  You might notice that the output context vectors for both items in our batch are identical. This is the expected behavior if the two input sentences in the <code>batch</code> tensor were also identical.
  <br><br>
  This serves as a good sanity check, confirming that our attention module is <b>deterministic</b>. With a fixed set of weights (which we have, since the model is untrained), the same input will always produce the exact same output. The transformation is complex, but it is not random.
</div>

<div class="alert alert-block alert-warning">
Now that we have a complete single-head causal attention module, the next logical step is to scale this concept up. We will now build a `MultiHeadAttention` module, which runs several of these attention mechanisms in parallel to capture an even richer variety of relationships within the text.
</div>

## 2.7 Extending Single-Head Attention to Multi-Head Attention

<div class="alert alert-block alert-success">
    
So far, we have built a single self-attention mechanism, also known as an "attention head." While powerful, a single head can be a bottleneck; it might learn to focus on one type of relationship (for example, the relationship between a subject and its verb), but struggle to capture other kinds of linguistic patterns simultaneously.

To make the model more powerful, the original transformer architecture introduced **Multi-Head Attention**.

<div class="alert alert-block alert-info">
  <b>Analogy: A Committee of Experts</b><br>
  Instead of having one attention mechanism (one 'expert') trying to understand all the different relationships in a sentence, Multi-Head Attention creates a <b>committee of experts</b> (multiple attention 'heads'). 

  Each head works completely independently and in parallel. During training, each head can learn to specialize in detecting a different type of feature. For example:
  <ul>
    <li>One head might learn to track syntactic dependencies.</li>
    <li>Another might learn to track semantic similarity.</li>
    <li>A third might learn to track the proximity of words.</li>
  </ul>
  The outputs of all these specialist heads are then concatenated to form a final, rich representation of the token's context.
</div>

In our example below, we will create a `MultiHeadAttention` module with **two heads**. Each head will produce a **2-dimensional** context vector (`d_out=2`). After concatenating the outputs of both heads, we will get a final **4-dimensional** context vector for each token (`2 heads * 2 dimensions/head = 4 dimensions`).
</div>

<div class="alert alert-block alert-success">
In code, we can achieve this by implementing a simple MultiHeadAttentionWrapper class that stacks multiple instances of our previously implemented CausalAttention module:    
</div>

In [35]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
             for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

<div class="alert alert-block alert-success">
To illustrate further with a concrete example, we can use the MultiHeadAttentionWrapper class similar to the CausalAttention class before:
</div>

In [36]:
# Our sample input sentence as embedding vectors
inputs = torch.tensor(
    [[ 0.8938,  0.9003,  0.8978], # Your
     [ 0.7165,  0.3428,  0.2553], # journey
     [ 0.1042,  0.5163,  0.3753], # starts
     [ 0.0445,  0.3091,  0.9763], # with
     [ 0.1554,  0.1614,  0.2700], # one
     [ 0.8089,  0.9435,  0.5480]] # step
)

# Corresponding words
words = ['Your', 'journey', 'starts', 'with', 'one', 'step']

batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) 

torch.Size([2, 6, 3])


In [37]:
torch.manual_seed(100)
context_length =  batch.shape[1] # This is the number of tokens = 6
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[ 0.0755,  0.2087,  0.0093, -0.1200],
         [ 0.1384,  0.3551,  0.0133, -0.2169],
         [ 0.1526,  0.3997,  0.0085, -0.3357],
         [ 0.1702,  0.5931,  0.0877, -0.4786],
         [ 0.2069,  0.7170,  0.1177, -0.6183],
         [ 0.6627,  1.7936,  0.0402, -1.3122]],

        [[ 0.0755,  0.2087,  0.0093, -0.1200],
         [ 0.1384,  0.3551,  0.0133, -0.2169],
         [ 0.1526,  0.3997,  0.0085, -0.3357],
         [ 0.1702,  0.5931,  0.0877, -0.4786],
         [ 0.2069,  0.7170,  0.1177, -0.6183],
         [ 0.6627,  1.7936,  0.0402, -1.3122]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


<div class="alert alert-block alert-info">

The first dimension of the resulting context_vecs tensor is 2 since we have two input texts (the input texts are duplicated, which is why the context vectors are exactly the same for those). 

The second dimension refers to the 6 tokens in each input. The third dimension refers to the 4-dimensional embedding of each token.
    
</div>

<div class="alert alert-block alert-success">
    
In this section, we implemented a MultiHeadAttentionWrapper that combined multiple single-head attention modules. 

However, note that these are processed sequentially via [head(x) for head in self.heads] in the forward method. 

We can improve this implementation by processing the heads in parallel. 

One way to achieve this is by computing the outputs for all attention heads simultaneously via matrix multiplication, as we will explore in the next section.
</div>

## 2.8 Implementing Multi-Head Attention With Weight Splits

<div class="alert alert-block alert-success">
    
Instead of maintaining two separate classes, MultiHeadAttentionWrapper and CausalAttention, we can combine both of these concepts into a single MultiHeadAttention class. 

Also, in addition to just merging the MultiHeadAttentionWrapper with the CausalAttention code, we will make some other modifications to implement multi-head attention more efficiently.
</div>

<div class="alert alert-block alert-success">
    
In the MultiHeadAttentionWrapper, multiple heads are implemented by creating a list of CausalAttention objects (self.heads), each representing a separate attention head.

The CausalAttention class independently performs the attention mechanism, and the results from each head are concatenated.

In contrast, the following MultiHeadAttention class integrates the multi-head functionality within a single class. 

It splits the input into multiple heads by reshaping the projected query, key, and value tensors and then combines the results from these heads after computing attention.
</div>

<div class="alert alert-block alert-success">
Let's take a look at the MultiHeadAttention class before we discuss it further:
</div>

In [42]:
class MultiHeadAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired ouput dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out) # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('causal_mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        batch_size, num_tokens, d_in = x.shape

        # Project inputs into Q, K, V
        queries = self.W_query(x)
        keys = self.W_key(x)
        values = self.W_value(x)

        # Reshape for multi-head attention by splitting the d_out dimension
        # (batch_size, num_tokens, d_out) -> (batch_size, num_tokens, num_heads, head_dim)
        queries = queries.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        keys = keys.view(batch_size, num_tokens, self.num_heads, self.head_dim)
        values = values.view(batch_size, num_tokens, self.num_heads, self.head_dim)

        # Transpose to bring the 'num_heads' dimension forward for batch matrix multiplication
        # (batch_size, num_tokens, num_heads, head_dim) -> (batch_size, num_heads, num_tokens, head_dim)
        queries = queries.transpose(1, 2)
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention 
        attn_scores = queries @ keys.transpose(-2, -1) # Dot product for each head

        # Apply the causal mask
        causal_mask_bool = self.causal_mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill(causal_mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Apply attention weights to values and reverse the transpose
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads back into a single tensor: self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(batch_size, num_tokens, self.d_out)

        ## Apply the final linear projection
        context_vec = self.out_proj(context_vec) # Optional projection

        return context_vec
        

<div class="alert alert-block alert-info">

Instead of creating separate `CausalAttention` objects for each head (which is inefficient), a more common and optimized approach is to use one large linear layer for all queries, keys, and values, and then **"split"** the resulting matrix into multiple heads for parallel processing. This is achieved through clever tensor reshaping.

The `forward` pass of this new `MultiHeadAttention` class follows these key steps:

1.  **Project into Q, K, V:** The input `x` is first passed through the three large `W_query`, `W_key`, and `W_value` linear layers to get the initial `queries`, `keys`, and `values` tensors of shape `(batch_size, num_tokens, d_out)`.

2.  **Split into Heads:** We then reshape each of these tensors to add the `num_heads` dimension. The `.view()` method unrolls the `d_out` dimension into `num_heads` and `head_dim`. The new shape is `(batch_size, num_tokens, num_heads, head_dim)`.

3.  **Transpose for Calculation:** To perform matrix multiplication across all heads at once, we need the `num_heads` dimension to be second. We use `.transpose(1, 2)` to swap the `num_tokens` and `num_heads` dimensions, resulting in a shape of `(batch_size, num_heads, num_tokens, head_dim)`.

4.  **Compute Attention:** With the tensors in the correct shape, we can now compute the scaled dot-product attention just like before. This calculation happens independently for all heads in parallel.

5.  **Combine Heads:** Finally, we reverse the process. We transpose the dimensions back and then use `.view()` to flatten the `num_heads` and `head_dim` dimensions back into the single `d_out` dimension. This concatenated result is then passed through a final linear projection layer (`out_proj`).
</div>