# Lab 7: Transformer

The goal of this lab is to develop a class that implements a multihead attention class. To do this, we will follow several steps:
1) Self-attention
2) Causal self-attention
3) Single head attention class
4) Multi-head attention class.

Packages that are being used in this notebook:

In [None]:
from importlib.metadata import version

print("torch version:", version("torch"))

import torch
import torch.nn as nn

Let us generated an arbitrary embedded sentence. Each word is embedded into a 3-dimensional vector representation.

In [None]:
inputs = torch.tensor(
  [[0.43, 0.15, 0.89], # Your     (x^1)
   [0.55, 0.87, 0.66], # journey  (x^2)
   [0.57, 0.85, 0.64], # starts   (x^3)
   [0.22, 0.58, 0.33], # with     (x^4)
   [0.77, 0.25, 0.10], # one      (x^5)
   [0.05, 0.80, 0.55]] # step     (x^6)
)

In [None]:
x_2 = inputs[1] # second input element
n_inputs = inputs.shape[0] # the number of inputs
d_in = inputs.shape[1] # the input embedding size, d=3
d_out = 2 # the output embedding size, d=2

## 1- Implementing self-attention with trainable weights

### Question 1: Implement the forward function in the following class.

You must compute:
- the keys
- the queries
- the values
- the attention scores
- the attention weights
- the context vector.

The forward function must return the context vector.

In [None]:
class SelfAttention(nn.Module):

    def __init__(self, d_in, d_out, qkv_bias=False):
        super().__init__()
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

    def forward(self, x):
        # Complete the code


        return context_vec


Test the SelfAttention class

In [None]:
torch.manual_seed(789)
sa = SelfAttention(d_in, d_out)
print(sa(inputs))

## 2- Hiding future words with causal attention

### 2.1 Applying a causal attention mask

- We are converting the previous self-attention mechanism into a causal self-attention mechanism
- Causal self-attention ensures that the model's prediction for a certain position in a sequence is only dependent on the known outputs at previous positions, not on future positions
- In simpler words, this ensures that each next word prediction should only depend on the preceding words
- To achieve this, for each given token, we mask out the future tokens (the ones that come after the current token in the input text).

In [None]:
# Reuse the query and key weight matrices of the
# SelfAttention object from the previous section for convenience
queries = sa.W_query(inputs)
keys = sa.W_key(inputs)
attn_scores = queries @ keys.T

print(attn_scores)

attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)

print(attn_weights)

### Question 2: We can mask the unnormalized attention scores above the diagonal with negative infinity before they enter the softmax function. Generate an upper triangular matrix of ones and use it to replace the attention scores with negative infinity. Compute the masked attention weights. Hence, you should compute two elements in the following cell: the mask and the masked attention weights.

In [None]:
context_length = attn_scores.shape[0]

# Complete the code

### 2.2 Masking additional attention weights with dropout

- In addition, we also apply dropout to reduce overfitting during training
- We will apply the dropout mask after computing the attention weights because it's more common
- In this example, we use a dropout rate of 50%, which means randomly masking out half of the attention weights (with the GPT model later, the dropout rate is 0.1 or 0.2).



Let us test the dropout

In [None]:
torch.manual_seed(673)
dropout = torch.nn.Dropout(0.5) # dropout rate of 50%
example = torch.ones(6, 6) # create a matrix of ones

print(dropout(example))

Important note:
- If we apply a dropout rate of 0.5 (50%), the non-dropped values are scaled accordingly by a factor of 1/0.5 = 2
- The scaling is calculated by the formula 1 / (1 - `dropout_rate`)
- Dropout is only applied during training, not during inference

### Question 3: Apply the dropout to the attention weights and print the result.

In [None]:
torch.manual_seed(673)
# Complete the code


## 3- Implementing a compact causal self-attention class

- Now, we are ready to implement a working implementation of self-attention, including the causal and dropout masks
- One more thing is to implement the code to handle batches consisting of more than one input so that our `CausalAttention` class supports the batch outputs produced by the data loader we implemented in chapter 2
- For simplicity, to simulate such batch input, we duplicate the input text example:

In [None]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch.shape) # 2 inputs with 6 tokens each, and each token has embedding dimension 3

### Question 4: in the following class, why "keys.T" is replaced with "keys.transpose(1, 2)"?

Answer: write down your answer here.

In [None]:
class CausalAttention(nn.Module):

    def __init__(self, d_in, d_out, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.d_out = d_out
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key   = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.dropout = nn.Dropout(dropout)
        # register_buffer => Tensor which is not a parameter, but should be part of the modules state.
        # Used for tensors that need to be on the same device as the module.
        self.register_buffer('mask', torch.triu(torch.ones(context_length, context_length), diagonal=1))

    def forward(self, x):
        b, num_tokens, d_in = x.shape # batch dimension b
        # For inputs where `num_tokens` exceeds `context_length`, this will result in errors
        # in the mask creation further below.
        # In practice, this is not a problem since the LLM ensures that inputs
        # do not exceed `context_length` before reaching this forward method.
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        attn_scores = queries @ keys.transpose(1, 2)
        attn_scores.masked_fill_(  # _ ops are in-place
            self.mask.bool()[:num_tokens, :num_tokens], -torch.inf)  # `:num_tokens` to account for cases where the number of tokens in the batch is smaller than the supported context_size
        attn_weights = torch.softmax(
            attn_scores / keys.shape[-1]**0.5, dim=-1
        )
        attn_weights = self.dropout(attn_weights)

        context_vec = attn_weights @ values
        return context_vec



Let us test the CausalAttention class.

In [None]:
torch.manual_seed(123)

context_length = batch.shape[1]
ca = CausalAttention(d_in, d_out, context_length, 0.0)

context_vecs = ca(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

## 4- Extending single-head attention to multi-head attention

### 4.1 Stacking multiple single-head attention layers

- The self-attention implemented previously is also called single-head attention.
- We simply stack multiple single-head attention modules to obtain a multi-head attention module
- The main idea behind multi-head attention is to run the attention mechanism multiple times (in parallel) with different, learned linear projections.
- This allows the model to jointly attend to information from different representation subspaces at different positions.

### Question 5: in the following class, the heads must be concatenated in the "forward" function. Code this concatenation.

In [None]:
class MultiHeadAttentionWrapper(nn.Module):

    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        self.heads = nn.ModuleList(
            [CausalAttention(d_in, d_out, context_length, dropout, qkv_bias)
             for _ in range(num_heads)]
        )

    def forward(self, x):
        # Write the return line with a concatenation
        pass

Let us test the MultiHeadAttentionWrapper class.

In [None]:
torch.manual_seed(123)

context_length = batch.shape[1] # This is the number of tokens
d_in, d_out = 3, 2
mha = MultiHeadAttentionWrapper(
    d_in, d_out, context_length, 0.0, num_heads=2
)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)

- In the implementation above, the embedding dimension is 4, because we `d_out=2` as the embedding dimension for the key, query, and value vectors as well as the context vector. And since we have 2 attention heads, we have the output embedding dimension 2*2=4

### 4.2 Implementing multi-head attention with weight splits

- While the above is an intuitive and fully functional implementation of multi-head attention (wrapping the single-head attention `CausalAttention` implementation from earlier), we can write a stand-alone class called `MultiHeadAttention` to achieve the same

- This is essentially a rewritten version of `MultiHeadAttentionWrapper` that is more efficient

### Question 6: why is this second implememtation (see the cell below) more efficient that `MultiHeadAttentionWrapper`?

Answer: write down your answer here.

In [None]:
class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (d_out % num_heads == 0), \
            "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        self.head_dim = d_out // num_heads # Reduce the projection dim to match desired output dim

        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.out_proj = nn.Linear(d_out, d_out)  # Linear layer to combine head outputs
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            "mask",
            torch.triu(torch.ones(context_length, context_length),
                       diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        # As in `CausalAttention`, for inputs where `num_tokens` exceeds `context_length`,
        # this will result in errors in the mask creation further below.
        # In practice, this is not a problem since the LLM (chapters 4-7) ensures that inputs
        # do not exceed `context_length` before reaching this forwar

        keys = self.W_key(x) # Shape: (b, num_tokens, d_out)
        queries = self.W_query(x)
        values = self.W_value(x)

        # We implicitly split the matrix by adding a `num_heads` dimension
        # Unroll last dim: (b, num_tokens, d_out) -> (b, num_tokens, num_heads, head_dim)
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)

        # Transpose: (b, num_tokens, num_heads, head_dim) -> (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # Compute scaled dot-product attention (aka self-attention) with a causal mask
        attn_scores = queries @ keys.transpose(2, 3)  # Dot product for each head

        # Original mask truncated to the number of tokens and converted to boolean
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]

        # Use the mask to fill attention scores
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores / keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)

        # Shape: (b, num_tokens, num_heads, head_dim)
        context_vec = (attn_weights @ values).transpose(1, 2)

        # Combine heads, where self.d_out = self.num_heads * self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec) # optional projection

        return context_vec



Let us test the MultiHeadAttention class.

In [None]:
torch.manual_seed(123)

batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in, d_out, context_length, 0.0, num_heads=2)

context_vecs = mha(batch)

print(context_vecs)
print("context_vecs.shape:", context_vecs.shape)


- Both are fully functional implementations that can be used in a GPT class
- Note that we added a linear projection layer (`self.out_proj `) to the `MultiHeadAttention` class above as used in the usual definition of the Multi-head attention mecanism.