# Coding Attention Mechanisms

This chapter covers
- Exploring the reasons for using attention mechanisms in neural networks
- Introducing a basic self-attention framework and progressing to an enhanced self-attention mechanism
- Implementing a causal attention module that allows LLMs to generate one token at a time
- Masking randomly selected attention weights with dropout to reduce overfitting
- Stacking multiple causal attention modules into a multi-head attention module

## 3.3 Attending to different parts of the input with self-attention

In this section, we implement a simplified variant of self-attention, free from any trainable weights.

In self-attention, our goal is to calculate context vectors $z^{(i)}$ for each element $x^{(i)}$ in the input sequence. A context vector can be interpreted as an enriched embedding vector.

Firsty, let us calculate the attention scores for a single input token. This is done by taking dot product of the key vectors of each token with the query vector of the token whose attention scores we are finding.

In [17]:
import torch

inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

query = inputs[1]  # Taking the 2nd input token as the query vector

attn_scores_2 = query @ inputs.T
print(attn_scores_2)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


In [18]:
print(query @ inputs.T)

tensor([0.9544, 1.4950, 1.4754, 0.8434, 0.7070, 1.0865])


Normalize the attention scores using softmax to get the attention weights.

In [19]:
attn_weights_2 = torch.softmax(attn_scores_2, dim=0)
print(attn_weights_2)

tensor([0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581])


Now, we compute the context vector, $z^{(2)}$ by multiplying attention weights to the input token embeddings and summing them.

In [20]:
context_vector_2 = torch.zeros(inputs[1].shape)

for i, x_i in enumerate(inputs):
    context_vector_2 += attn_weights_2[i] * x_i

print(context_vector_2)

tensor([0.4419, 0.6515, 0.5683])


Now, we are extending this computation to calculate attention weights and context vectors for all inputs.

In [24]:
attn_scores = inputs @ inputs.T
attn_weights = torch.softmax(attn_scores, dim=1)
print(attn_weights)

tensor([[0.2098, 0.2006, 0.1981, 0.1242, 0.1220, 0.1452],
        [0.1385, 0.2379, 0.2333, 0.1240, 0.1082, 0.1581],
        [0.1390, 0.2369, 0.2326, 0.1242, 0.1108, 0.1565],
        [0.1435, 0.2074, 0.2046, 0.1462, 0.1263, 0.1720],
        [0.1526, 0.1958, 0.1975, 0.1367, 0.1879, 0.1295],
        [0.1385, 0.2184, 0.2128, 0.1420, 0.0988, 0.1896]])


In [23]:
context_vecs = attn_weights @ inputs
print(context_vecs)

tensor([[0.4421, 0.5931, 0.5790],
        [0.4419, 0.6515, 0.5683],
        [0.4431, 0.6496, 0.5671],
        [0.4304, 0.6298, 0.5510],
        [0.4671, 0.5910, 0.5266],
        [0.4177, 0.6503, 0.5645]])


## 3.4: Attention Mechanism with trainable weights

We will implement the self-attention mechanism step by step by introducing the three trainable weight matrices $W_q$, $W_k$, and $W_v$. These three matrices are used to project the embedded input tokens, $x^{(i)}$, into query, key, and value
vectors. The query vector $q^{(2)}$ is obtained via matrix multiplication between the input $x^{(2)}$ and the weight matrix $W_q$. Similarly, we obtain the key and value vectors via matrix multiplication involving the weight matrices $W_k$ and $W_v$.

In [26]:
d_in = d_out = inputs.shape[1]

torch.manual_seed(123)
W_query = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_key   = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)
W_value = torch.nn.Parameter(torch.rand(d_in, d_out), requires_grad=False)

# Let us calculate the context vector for the 2nd input token 
query_2 = inputs[1] @ W_query
keys = inputs @ W_key # Key vectors for all input tokens
values = inputs @ W_value # Value vectors for all input tokens

attn_scores_2 = query_2 @ keys.T
attn_weights_2 = torch.softmax(attn_scores_2 / (d_out**0.5), dim=0)
context_vector_2 = attn_weights_2 @ values
print(context_vector_2)

tensor([0.6864, 1.0577, 1.1389])


Now, we generate context vectors for all of the input tokens by implementing a compact self-attention Python class.

In [28]:
import torch.nn as nn

class SelfAttention_v1(nn.Module):

    def __init__(self, d_in, d_out):
        super().__init__()
        self.W_query = nn.Parameter(torch.rand(d_in, d_out))
        self.W_key   = nn.Parameter(torch.rand(d_in, d_out))
        self.W_value = nn.Parameter(torch.rand(d_in, d_out))

    def forward(self, x):
        keys = x @ self.W_key
        queries = x @ self.W_query
        values = x @ self.W_value
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=1)

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)
sa_v1 = SelfAttention_v1(d_in, d_out)
print(sa_v1(inputs))

tensor([[0.6692, 1.0276, 1.1106],
        [0.6864, 1.0577, 1.1389],
        [0.6860, 1.0570, 1.1383],
        [0.6738, 1.0361, 1.1180],
        [0.6711, 1.0307, 1.1139],
        [0.6783, 1.0441, 1.1252]], grad_fn=<MmBackward0>)


We can streamline the implementation above using PyTorch's Linear layers, which are equivalent to a matrix multiplication if we disable the bias units.

Another big advantage of using nn.Linear over our manual nn.Parameter(torch.rand(...) approach is that nn.Linear has a preferred weight initialization scheme, which leads to more stable model training

In [29]:
class SelfAttention_v2(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        
        attn_scores = queries @ keys.T
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[ 0.4162, -0.4953,  0.0470],
        [ 0.4154, -0.4957,  0.0476],
        [ 0.4155, -0.4957,  0.0476],
        [ 0.4173, -0.5006,  0.0483],
        [ 0.4178, -0.4996,  0.0477],
        [ 0.4166, -0.4996,  0.0483]], grad_fn=<MmBackward0>)


## 3.5 Hiding future words with causal attention

The causal aspect involves modifying the attention mechanism to prevent the model from accessing future information in the sequence, which is crucial for tasks like language modeling, where each word prediction should only depend on previous words.

In [34]:
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs) 
attn_scores = queries @ keys.T

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1607, 0.1599, 0.1613, 0.1679, 0.1971, 0.1531],
        [0.1606, 0.1579, 0.1590, 0.1721, 0.1907, 0.1597],
        [0.1606, 0.1580, 0.1591, 0.1720, 0.1907, 0.1595],
        [0.1637, 0.1618, 0.1623, 0.1704, 0.1767, 0.1650],
        [0.1631, 0.1629, 0.1638, 0.1676, 0.1836, 0.1590],
        [0.1631, 0.1602, 0.1607, 0.1722, 0.1778, 0.1660]],
       grad_fn=<SoftmaxBackward0>)


In [None]:
context_length = attn_scores.shape[0]
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[-0.2382,    -inf,    -inf,    -inf,    -inf,    -inf],
        [-0.3270, -0.3568,    -inf,    -inf,    -inf,    -inf],
        [-0.3218, -0.3506, -0.3384,    -inf,    -inf,    -inf],
        [-0.1822, -0.2026, -0.1975, -0.1128,    -inf,    -inf],
        [-0.1385, -0.1401, -0.1312, -0.0915,  0.0668,    -inf],
        [-0.2439, -0.2754, -0.2699, -0.1499, -0.0948, -0.2134]],
       grad_fn=<MaskedFillBackward0>)


In [36]:
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5043, 0.4957, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3362, 0.3307, 0.3330, 0.0000, 0.0000, 0.0000],
        [0.2487, 0.2458, 0.2465, 0.2589, 0.0000, 0.0000],
        [0.1939, 0.1937, 0.1947, 0.1993, 0.2183, 0.0000],
        [0.1631, 0.1602, 0.1607, 0.1722, 0.1778, 0.1660]],
       grad_fn=<SoftmaxBackward0>)


Dropout is a regularization technique in deep learning where random neurons are temporarily "dropped" (ignored) during training. This prevents the model from becoming too reliant on specific neurons, helping reduce overfitting. Dropout is only applied during training, not during inference. 

This scaling is crucial to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains consistent during both the training and inference phases.

In [40]:
dropout = torch.nn.Dropout(0.5)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [1.0086, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.6725, 0.0000, 0.6661, 0.0000, 0.0000, 0.0000],
        [0.4975, 0.4916, 0.0000, 0.5178, 0.0000, 0.0000],
        [0.3879, 0.3875, 0.3895, 0.0000, 0.4367, 0.0000],
        [0.0000, 0.0000, 0.3214, 0.0000, 0.3556, 0.3320]],
       grad_fn=<MulBackward0>)


Now, let us implement a compact causal attention class

In [41]:
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, qkv_bias=False):
        super().__init__()
        self.d_in = d_in
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        batch_size, num_tokens, self.d_in = x.shape

        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(self.mask.bool()[:num_tokens, :num_tokens], -torch.inf) 
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

torch.manual_seed(123)

batch = torch.stack((inputs, inputs), dim=0) # Imitating the bacthes produced by our Dataloader in Chapter 2.ipynb
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[ 0.3326,  0.5659, -0.3132],
         [ 0.3456,  0.5650, -0.2237],
         [ 0.3440,  0.5604, -0.2000],
         [ 0.3103,  0.4941, -0.1606],
         [ 0.2430,  0.4287, -0.1643],
         [ 0.2648,  0.4316, -0.1375]],

        [[ 0.3326,  0.5659, -0.3132],
         [ 0.3456,  0.5650, -0.2237],
         [ 0.3440,  0.5604, -0.2000],
         [ 0.3103,  0.4941, -0.1606],
         [ 0.2430,  0.4287, -0.1643],
         [ 0.2648,  0.4316, -0.1375]]], grad_fn=<UnsafeViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 3])


## 3.6 Extending single-head attention to multi-head attention

The main idea behind multi-head attention is to run the attention mechanism multiple times (in parallel) with different, learned linear projections. This allows the model to jointly attend to information from different representation subspaces at different positions.

In [43]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList( [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias) for _ in range(num_heads)] )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)


torch.manual_seed(123)

context_length = batch.shape[1]
d_in, d_out = 3, 3
mha = MultiHeadAttentionWrapper(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[ 0.3326,  0.5659, -0.3132,  0.0752,  0.4566,  0.2729],
         [ 0.3456,  0.5650, -0.2237,  0.0313,  0.5977,  0.3053],
         [ 0.3440,  0.5604, -0.2000,  0.0178,  0.6413,  0.3138],
         [ 0.3103,  0.4941, -0.1606,  0.0089,  0.5729,  0.2785],
         [ 0.2430,  0.4287, -0.1643,  0.0071,  0.5566,  0.2514],
         [ 0.2648,  0.4316, -0.1375,  0.0023,  0.5363,  0.2508]],

        [[ 0.3326,  0.5659, -0.3132,  0.0752,  0.4566,  0.2729],
         [ 0.3456,  0.5650, -0.2237,  0.0313,  0.5977,  0.3053],
         [ 0.3440,  0.5604, -0.2000,  0.0178,  0.6413,  0.3138],
         [ 0.3103,  0.4941, -0.1606,  0.0089,  0.5729,  0.2785],
         [ 0.2430,  0.4287, -0.1643,  0.0071,  0.5566,  0.2514],
         [ 0.2648,  0.4316, -0.1375,  0.0023,  0.5363,  0.2508]]],
       grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 6])


While the above is an intuitive and fully functional implementation of multi-head attention (wrapping the single-head attention CausalAttention implementation from earlier), we can write a stand-alone class called MultiHeadAttention to achieve the same

Instead, we create single W_query, W_key, and W_value weight matrices and then split those into individual matrices for each attention head:

In [44]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Optional Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim) 
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)
        
        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2) 
        
        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec

torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])
