### 3. Hiding future words with causal attention

We usually only want the self-attention mechanism to consider only the tokens that appear prior to the current position when predicting the next token in a sequence. Causal/masked attention restricts a model to only consider previous and current inputs in a sequence when computing attention scores. This is in contrast to what we saw previously, where the entire input was considered. 

So, we need to implement a causal attention mask.

In [1]:
import warnings
warnings.filterwarnings('ignore')

import torch
import torch.nn as nn

In [25]:
from importlib.metadata import version

print("torch version:", version("torch"))

torch version: 2.5.1


In [7]:
# Input sequence
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your      (x^1)
     [0.55, 0.87, 0.66], # journey   (x^2)
     [0.57, 0.85, 0.64], # starts    (x^3)
     [0.22, 0.58, 0.33], # with      (x^4)
     [0.77, 0.25, 0.10], # one       (x^5)
     [0.05, 0.80, 0.55]  # step      (x^6)
    ]
)

d_in = 3
d_out = 2

In [9]:
# Same class as before
class SelfAttention_v2(nn.Module):
    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)
        attn_scores = queries @ keys.T 
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        context_vec = attn_weights @ values
        return context_vec

In [10]:
# Different outputs due to different initial weights
torch.manual_seed(789)
sa_v2 = SelfAttention_v2(d_in, d_out)
print(sa_v2(inputs))

tensor([[-0.0739,  0.0713],
        [-0.0748,  0.0703],
        [-0.0749,  0.0702],
        [-0.0760,  0.0685],
        [-0.0763,  0.0679],
        [-0.0754,  0.0693]], grad_fn=<MmBackward0>)


In [7]:
# Re-using the same weight matrices as previously for convenience
queries = sa_v2.W_query(inputs)
keys = sa_v2.W_key(inputs)
attn_scores = queries @ keys.T
attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
print(attn_weights)

tensor([[0.1921, 0.1646, 0.1652, 0.1550, 0.1721, 0.1510],
        [0.2041, 0.1659, 0.1662, 0.1496, 0.1665, 0.1477],
        [0.2036, 0.1659, 0.1662, 0.1498, 0.1664, 0.1480],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.1661, 0.1564],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.1585],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


In [8]:
# We use PyTorch's tril function to create a mask above the diagonal
context_length = attn_scores.shape[0]
mask_simple = torch.tril(torch.ones(context_length, context_length))
print(mask_simple)

tensor([[1., 0., 0., 0., 0., 0.],
        [1., 1., 0., 0., 0., 0.],
        [1., 1., 1., 0., 0., 0.],
        [1., 1., 1., 1., 0., 0.],
        [1., 1., 1., 1., 1., 0.],
        [1., 1., 1., 1., 1., 1.]])


In [9]:
# Now multiply the mask with the attention weights 
mask_simple = attn_weights * mask_simple
print(mask_simple)

tensor([[0.1921, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2041, 0.1659, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.2036, 0.1659, 0.1662, 0.0000, 0.0000, 0.0000],
        [0.1869, 0.1667, 0.1668, 0.1571, 0.0000, 0.0000],
        [0.1830, 0.1669, 0.1670, 0.1588, 0.1658, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<MulBackward0>)


In [10]:
# Renormalise the attention weights to sum up to 1 again
row_sums = mask_simple.sum(dim=-1, keepdim=True)
mask_simple_norm = mask_simple / row_sums
print(mask_simple_norm)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<DivBackward0>)


We can make the above more efficient. The softmax function converts its inputs into a probability distribution; when negative infinity values are present in a row, they are treated as 0 probability. So, we can replace the 0s with -infinity. 

In [11]:
mask = torch.triu(torch.ones(context_length, context_length), diagonal=1)
masked = attn_scores.masked_fill(mask.bool(), -torch.inf)
print(masked)

tensor([[0.2899,   -inf,   -inf,   -inf,   -inf,   -inf],
        [0.4656, 0.1723,   -inf,   -inf,   -inf,   -inf],
        [0.4594, 0.1703, 0.1731,   -inf,   -inf,   -inf],
        [0.2642, 0.1024, 0.1036, 0.0186,   -inf,   -inf],
        [0.2183, 0.0874, 0.0882, 0.0177, 0.0786,   -inf],
        [0.3408, 0.1270, 0.1290, 0.0198, 0.1290, 0.0078]],
       grad_fn=<MaskedFillBackward0>)


In [12]:
# Again renormalise
attn_weights = torch.softmax(masked / keys.shape[-1]**0.5, dim=1)
print(attn_weights)

tensor([[1.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4483, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.3800, 0.3097, 0.3103, 0.0000, 0.0000, 0.0000],
        [0.2758, 0.2460, 0.2462, 0.2319, 0.0000, 0.0000],
        [0.2175, 0.1983, 0.1984, 0.1888, 0.1971, 0.0000],
        [0.1935, 0.1663, 0.1666, 0.1542, 0.1666, 0.1529]],
       grad_fn=<SoftmaxBackward0>)


### Masking additional attention weights with dropout

Dropout is a technique where randomly selected hidden layer units are ignored during training. This helps prevent overfitting by ensuring that a model does not become overly reliant on any specific set of hidden layer units. It is only used during training and then disabled after. 

In Transformer architectures, dropout is typically applied at two specific times: after calculating the attention weights or after applying the attention weights to the value vectors. We will do the former as it is more common. We will use a dropout rate of 50% (though when training the model we will use 0.1 - 0.2).

In [None]:
# As an example - half the values are zeroed out
torch.manual_seed(123)
dropout = nn.Dropout(0.5)
example = torch.ones(6, 6)
print(dropout(example))

tensor([[2., 2., 2., 2., 2., 2.],
        [0., 2., 0., 0., 0., 0.],
        [0., 0., 2., 0., 2., 0.],
        [2., 2., 0., 0., 0., 2.],
        [2., 0., 0., 0., 0., 2.],
        [0., 2., 0., 0., 0., 0.]])


To compensate for the reduction in active elements, the remaining elements in the matrix are scaled up by 1 / dropout rate (in this case 1 / 0.5 = 2). This is done to maintain the overall balance of the attention weights, ensuring that the average influence of the attention mechanism remains consistent during both training and inference. 

In [15]:
torch.manual_seed(123)
print(dropout(attn_weights))

tensor([[2.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.8966, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.0000, 0.6206, 0.0000, 0.0000, 0.0000],
        [0.5517, 0.4921, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.4350, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
        [0.0000, 0.3327, 0.0000, 0.0000, 0.0000, 0.0000]],
       grad_fn=<MulBackward0>)


We also want to ensure we can handle batches of more than one input. To simulate such inputs, we duplicate the input text example.

In [None]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) 
# (2, 6, 3) -> (batch_size, num_tokens, dimension)
# 3-d tensor containing 2 input texts with 6 tokens each & each token is a 3-d embedding vector

torch.Size([2, 6, 3])


We can modify the self attention class to reflect the changes above.

In [None]:
class CausalAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout) # new dropout layer
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length),
            diagonal=1)
        ) 

    def forward(self, x):
        # batch, # tokens, dimension size
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf
        )
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec

The <i>register_buffer</i> in the init method is not strictly necessary but offers a few advantages. For instance, when we use the class in our LLM, buffers are automatically moved to the appropriate device (CPU/GPU), meaning we don't have to manually ensure these tensors are on the same device as the model parameters.

In PyTorch, operations with a trailing underscore are performing in-place, avoidng unnecessary memory copies (<i>masked_fill_</i>).

In [None]:
torch.manual_seed(123)
context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)
context_vecs = ca(batch)
print("context_vecs.shape:", context_vecs.shape)
# Each token is now represented by a 2-d embedding

context_vecs.shape: torch.Size([2, 6, 2])


### 4. Extending single-head attention to multi-head attention

Multi-head attention refers to dividing the attention mechanism into multiple "heads", each operating independently. This involves creating multiple instances of self-attention, each with its own weights, and then combining their outputs. This is computationally intensive but it's crucial for complex pattern recognition. The attention mechanism will run multiple times in parallel, with different, learned linear projections - the results of multiplying the input data (query, key, and value vectors) by a weight matrix. 

First we will intuitively build a multi-head attention module by stacking multiple of the above class. Then we will implement the same module in a more complicated but more computationally efficient way. Below, we implement a class that stacks multiple instances of the CausalAttention module.

In [20]:
class MultiHeadAttentionWrapper(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                 dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(
                d_in, d_out, context_length, dropout, qkv_bias
            )
            for _ in range(num_heads)]
        )

    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

In [21]:
# An example with 2 attention heads
torch.manual_seed(123)
context_length = batch.shape[1] # number of tokens
d_in, d_out = 3, 2

mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)
context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)
context_vecs.shape: torch.Size([2, 6, 4])


Above, the first dimension is 2 since we have two input texts (which are duplicated, so the context vectors are the exact same for both). The second dimension refers to the 6 tokens in each input. The third refers to the four-dimensional embedding of each token. 

The issue with the above is that the two single-head attention modules are processed sequentially via <i>[head(x) for head in self.heads]</i> in the forward method. We want to process them in parallel.

### Multi-head attention with weight splits

We combine the two previous classes into a single class, and add other modifications. It will split the input into multiple heads by reshaping the projected query, key, and value tensors, and then combining the results after computing attention.

In [3]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length,
                 dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
        "d_out must be divisble by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout) 
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
            diagonal=1)
        ) 
        
    def forward(self, x):
        # batch, # tokens, dimension size
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # we implicitly split the matrix by adding a num_heads dimension
        # Then we unroll the last time: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # from shape (b, num_tokens, num_heads, head_dim)
        # to shape (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        attn_scores = queries @ keys.transpose(2, 3) # dot product for each head
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens] # masks truncated to num_tokens

        attn_scores.masked_fill_(mask_bool, -torch.inf) # use mask to fill attention scores

        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        context_vec = (attn_weights @ values).transpose(1, 2) # (b, num_tokens, n_heads, head_dim)

        # combines heads where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(
            b, num_tokens, self.d_out
        )
        context_vec = self.out_proj(context_vec) # an optional linear projection

        return context_vec

This looks more complicated, but it's implementing the same concept as before. Due to the additional reshaping and transposition of tensors, it is more efficient (we only need one matrix multiplication to compute the keys, for instance, keys = self.W_keys(x)). It is simply an integrated approach, starting with a multi-head layer and then internally splitting this layer into individual attention heads. The splitting of query, key, and value tensors is achieved through tensor reshaping and transposing (.view and .transpose). The input is first transformed (via linear layers) and then reshaped to represent multiple heads. 

The key operation is to split the d_out dimension into <i>num_heads</i> and <i>head_dim</i>, where head_dim = d_out / num_heads. This splitting is then achieved using the .view method. The tensors are then transposed to bring the num_heads dimension before the num_tokens dimension, which is crucial for correctly aligning the queries, keys, and values across the different heads and performing batched matrix multiplications efficently. 

After computing the attention weights and context vectors, the latter are tranposed back to the shape (b, num_tokens, num_heads, head_dim), combining the outputs from all the heads. Finally, we added an output projection layer (self.out_proj) after combining the heads. This is not necessary, but is commonly used in many LLM architectures. 

In [5]:
# Suppose we have the following tensor 
# shape (b, num_heads, num_tokens, head_dim) = (1,2,3,4)
a = torch.tensor([[[[0.2745, 0.6584, 0.2775, 0.8573],
                    [0.8993, 0.0390, 0.9268, 0.7388],
                    [0.7179, 0.7058, 0.9156, 0.4340]],
                     
                    [[0.0772, 0.3565, 0.1479, 0.5331],
                     [0.4066, 0.2318, 0.4545, 0.9737],
                     [0.4606, 0.5159, 0.4220, 0.5786]]]])

# batched matrix multiplication between the tensor and a view of the tensor
# where we transposed the last two dimensions
print(a @ a.transpose(2, 3))

tensor([[[[1.3208, 1.1631, 1.2879],
          [1.1631, 2.2150, 1.8424],
          [1.2879, 1.8424, 2.0402]],

         [[0.4391, 0.7003, 0.5903],
          [0.7003, 1.3737, 1.0620],
          [0.5903, 1.0620, 0.9912]]]])


In [6]:
# Each head separately
first_head = a[0, 0, :, :]
first_res = first_head @ first_head.T
print("First head:\n", first_res)

second_head = a[0, 1, :, :]
second_res = second_head @ second_head.T
print("Second head:\n", second_res)

First head:
 tensor([[1.3208, 1.1631, 1.2879],
        [1.1631, 2.2150, 1.8424],
        [1.2879, 1.8424, 2.0402]])
Second head:
 tensor([[0.4391, 0.7003, 0.5903],
        [0.7003, 1.3737, 1.0620],
        [0.5903, 1.0620, 0.9912]])


In [27]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)
context_vecs = mha(batch)
print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

tensor([[[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]],

        [[0.3190, 0.4858],
         [0.2943, 0.3897],
         [0.2856, 0.3593],
         [0.2693, 0.3873],
         [0.2639, 0.3928],
         [0.2575, 0.4028]]], grad_fn=<ViewBackward0>)
context_vecs.shape: torch.Size([2, 6, 2])


We implemented multi-head attention with small embedding sizes and numbers of attention heads to keep the output readable. The smallest GPT-2 model (117 million parameters) has 12 attention heads and a context embedding vector size of 768. The largest has 25 and 1,600 respectively (and d_in = d_out).