<a href="https://colab.research.google.com/github/RCortez25/PhD/blob/main/LLM/4.%20Attention%20mechanism/3_Multi_head_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multihead attention wrapper

> In Multi head attention, one divides the attention mechanism into multiple heads, each operating independently.
>
> One stacks multiple single head attention layers. That is, one creates multiple instances of the self-attention mechanism, each with its own weights and then they are combined.
>
> Even though this is computationally expensive, it allows the LLM to capture complex patterns.

In [1]:
import torch
import torch.nn as nn

In [2]:
class MultiHeadAttentionWrapper(nn.Module):
    # Number of attention heads mut be given. These are the number of single
    # self-attention mechanisms
    def __init__(self, dimensions_in, dimensions_out, context_length,
                 dropout, num_heads, qkv_bias=False):
        super().__init__()
        # Create a list containing all single self-attention mechanisms. These
        # are created using the CausalAttention class
        self.heads = nn.ModuleList(
            [CausalAttention(dimensions_in, dimensions_out, context_length,
                             dropout, qkv_bias) for _ in range(num_heads)]
        )

    # Method for concatenating all single self-attention heads along the
    # columns
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

> Let's repeat the code we had before to test the class

In [3]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
     [0.55, 0.87, 0.66], # journey  (x^2)
     [0.57, 0.85, 0.64], # starts   (x^3)
     [0.22, 0.58, 0.33], # with     (x^4)
     [0.77, 0.25, 0.10], # one      (x^5)
     [0.05, 0.80, 0.55]] # step     (x^6)
)

batch = torch.stack([inputs, inputs], dim=0)

In [4]:
class CausalAttention(nn.Module):
    def __init__(self, dimension_inputs, dimension_outputs, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.dimension_outputs = dimension_outputs
        # Initialize the matrices using Linear layers
        self.W_q = nn.Linear(dimension_inputs, dimension_outputs, bias=qkv_bias)
        self.W_k = nn.Linear(dimension_inputs, dimension_outputs, bias=qkv_bias)
        self.W_v = nn.Linear(dimension_inputs, dimension_outputs, bias=qkv_bias)
        # Initialize the dropout layer
        self.dropout = nn.Dropout(dropout)
        # Initialize the buffer, for automatically moving the model to CPU or GPU
        self.register_buffer("mask",
                             torch.triu(torch.ones(context_length, context_length),
                                        diagonal=1))

    # Method to calculate the context vector
    def forward(self, input_vectors):
        # Obtain the relevant dimensions
        batch_size, number_of_tokens, dimension_inputs = input_vectors.shape
        queries = self.W_q(input_vectors)
        keys = self.W_k(input_vectors)
        values = self.W_v(input_vectors)

        attention_scores = queries @ keys.transpose(1, 2)
        attention_scores.masked_fill_(
            self.mask.bool()[:number_of_tokens, :number_of_tokens], -torch.inf)

        # Calculate attention weights
        dimension_keys = keys.shape[-1]
        attention_weights = torch.softmax(attention_scores / (dimension_keys ** 0.5), dim=-1)
        # Apply dropout
        attention_weights = self.dropout(attention_weights)

        # Calculate and return the context vectors
        context_vectors = attention_weights @ values

        return context_vectors

In [5]:
# Test the wrapper
torch.manual_seed(123)
context_length = batch.shape[1] # Number of tokens
dimension_inputs = 3
dimension_outputs = 2
oMultiHeadAttention = MultiHeadAttentionWrapper(dimension_inputs,
                                                dimension_outputs,
                                                context_length,
                                                dropout=0.0, num_heads=2)
context_vectors = oMultiHeadAttention(batch)
print(context_vectors)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)


In [6]:
context_vectors.shape

torch.Size([2, 6, 4])

> The dimensions here represent:
* 2: The number of batches
* 6: The number of tokens, or input vectors, in each batch
* 4: Ths is because we selected two heads `num_heads=2`. Therefore, since the output dimension of each head is 2, then 2 heads add 2 + 2 = 4.

# Multi-head attention with weight splits

> Now, the problem with the wrapper is that is not very efficient because of the matrix multiplications involved. We multiply the inputs by two matrices (`num_heads=2`). We can achieve the same result by performing only one matrix multiplication.Let's combine the two classes CausalAttention and MultiHeadAttentionWrapper together and add some other modifications.

In [23]:
class MultiHeadAttention(nn.Module):
    def __init__(self, dimension_inputs, dimension_outputs, context_length,
                 dropout, num_heads, qkv_bias=False):
        super().__init__()
        assert (dimension_outputs % num_heads == 0), "Dimension outputs must be \
                                                    divisible by number of heads" # 1
        self.dimension_outputs = dimension_outputs
        self.num_heads = num_heads
        self.head_dimension = dimension_outputs // num_heads # 2

        #Initialize weight matrices
        self.W_query = nn.Linear(dimension_inputs, dimension_outputs, bias=qkv_bias)
        self.W_key = nn.Linear(dimension_inputs, dimension_outputs, bias=qkv_bias)
        self.W_value = nn.Linear(dimension_inputs, dimension_outputs, bias=qkv_bias)
        # Initialize a linear layer to combine head outputs
        self.output_projection = nn.Linear(dimension_outputs, dimension_outputs) # 3
        # Initialize the dropout layer
        self.dropout = nn.Dropout(dropout)
        self.register_buffer("mask",
                             torch.triu(torch.ones(context_length, context_length),
                                        diagonal=1))

    # Method for calculating the context vectors
    def forward(self, input_vectors):
        batch_size, number_of_tokens, dimension_inputs = input_vectors.shape

        # Calculate keys, queries, and values
        queries = self.W_query(input_vectors)
        keys = self.W_key(input_vectors)
        values = self.W_value(input_vectors)

        # Split the obtained matrices by adding a "number of heads" dimension
        # that is, the output dimension will be splitted as
        # (batch_size, number_of_tokens, dimension_outputs) will be
        # (batch_size, number_of_tokens, num_heads, head_dimension)
        keys = keys.view(batch_size, number_of_tokens, self.num_heads, self.head_dimension)
        values = values.view(batch_size, number_of_tokens, self.num_heads, self.head_dimension)
        queries = queries.view(batch_size, number_of_tokens, self.num_heads, self.head_dimension)

        # Transpose for performing calculations
        keys = keys.transpose(1, 2)
        values = values.transpose(1, 2)
        queries = queries.transpose(1, 2)

        # Calculate attention scores
        attention_scores = queries @ keys.transpose(2, 3)

        # Mask future tokens
        attention_scores.masked_fill_(
            self.mask.bool()[:number_of_tokens, :number_of_tokens], -torch.inf)

        # Calculate attention weights
        attention_weights = torch.softmax(attention_scores / (keys.shape[-1] ** 0.5), dim=-1)

        # Dropout
        attention_weights = self.dropout(attention_weights)

        # Calculate context vectors
        context_vectors = attention_weights @ values

        # Swap dimensions
        context_vectors = context_vectors.transpose(1, 2)

        # Combine context vectors for each head into a single representation
        # Contiguous ensures matrices are in the same blocks of memory
        context_vectors = context_vectors.contiguous().view(batch_size, number_of_tokens, self.dimension_outputs)

        # Apply the output projection (optional)
        context_vectors = self.output_projection(context_vectors)

        return context_vectors

> Comments:
>
> * 1: Use the `assert` to ensure that dimension is correct
* 2: Calculate the dimension of each head
* 3:  

# Simple example of multihead attention

## Step 1: Inputs

Say we have 1 batch of 3 tokens, each in a 6-dimensional space. Say the input batch is "The cat sleeps". That is

`batch_size=1` \
`number_of_tokens=3` \
`dimension_inputs=6`

In [8]:
x = torch.tensor([[
    [1.0, 2.0, 3.0, 4.0, 5.0, 6.0], # The
    [6.0, 5.0, 4.0, 3.0, 2.0, 1.0], # cat
    [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]  # sleeps
]])

## Step 2: Decide the output dimension and the number of heads

`dimension_outputs=6` \
`num_heads=2`

Recall the the output dimension is the dimension of the context vector for each input token. Typically, in GPT models `dimension_input=dimension_outputs`.

For GPT models, the number of heads is 96, but in this example we'll keep it simple.

This decision gives

`head_dimension = dimension_outputs/num_heads = 6/2 = 3`

## Spet 3: Initialize trainable weight matrices

Initialize `W_query, W_key_, W_value`. Their dimension must be

$$dimension\_inputs\times dimension\_outputs$$

so, in this case one has

$$6\times6$$

this is done so in order to make matrix multiplication plausible.

In [9]:
W_query = nn.Linear(6, 6)
W_key = nn.Linear(6, 6)
W_value = nn.Linear(6, 6)

print(W_query)
print(W_key)
print(W_value)

Linear(in_features=6, out_features=6, bias=True)
Linear(in_features=6, out_features=6, bias=True)
Linear(in_features=6, out_features=6, bias=True)


Print the actual values of the matrices. Note that they contain bias which in this case is not important.

In [10]:
print("W_query weights:\n", W_query.weight)
print("W_query bias:\n", W_query.bias)

print("\nW_key weights:\n", W_key.weight)
print("W_key bias:\n", W_key.bias)

print("\nW_value weights:\n", W_value.weight)
print("W_value bias:\n", W_value.bias)

W_query weights:
 Parameter containing:
tensor([[-0.3453, -0.1171, -0.2875,  0.0270, -0.0762, -0.2190],
        [-0.0371,  0.3868, -0.0322,  0.0130, -0.0637,  0.0642],
        [ 0.3638,  0.2496,  0.1449,  0.0887,  0.0963,  0.1577],
        [-0.0528, -0.3794, -0.2525,  0.3485,  0.0244, -0.3307],
        [ 0.0644,  0.3373, -0.3858, -0.2748, -0.1626,  0.0164],
        [-0.0952, -0.0449, -0.3980,  0.1912,  0.3583,  0.2495]],
       requires_grad=True)
W_query bias:
 Parameter containing:
tensor([-0.2891, -0.3291,  0.1695,  0.0092,  0.1674, -0.3989],
       requires_grad=True)

W_key weights:
 Parameter containing:
tensor([[-0.0243,  0.2879,  0.1894,  0.0149,  0.0803, -0.0387],
        [-0.2245, -0.1543, -0.2486,  0.3391,  0.2247,  0.1428],
        [-0.3131,  0.3150,  0.1280,  0.2824, -0.1606,  0.0865],
        [ 0.3986,  0.2746,  0.3274, -0.0857,  0.3110, -0.3198],
        [ 0.0353, -0.2299, -0.0952, -0.1046,  0.0305,  0.3716],
        [ 0.2021, -0.0017,  0.2898, -0.2092,  0.2104, -0.0379]

## Step 4: Calculate Keys, Queries, and Values matrices

$$inputs * W_{keys}$$
$$inputs * W_{queries}$$
$$inputs * W_{values}$$

The resulting dimensions will be

$$1\times3\times6$$

In [11]:
Keys = W_key(x)
print("Keys:\n", Keys)

Queries = W_query(x)
print("\nQueries:\n", Queries)

Values = W_value(x)
print("\nValues:\n", Values)

Keys:
 tensor([[[ 1.2777,  2.1052,  1.2342,  1.2710,  1.3911,  1.4051],
         [ 2.1467, -1.4555,  0.5085,  5.1667, -1.0620,  2.4676],
         [ 0.4384,  0.1270,  0.0256,  0.9534,  0.1451,  0.8025]]],
       grad_fn=<ViewBackward0>)

Queries:
 tensor([[[-3.3182e+00,  4.2931e-01,  3.2498e+00, -2.0282e+00, -2.0650e+00,
           2.2758e+00],
         [-4.3869e+00,  1.2290e+00,  4.7963e+00, -2.4509e+00, -4.3620e-01,
          -1.2468e+00],
         [-1.3072e+00,  1.8372e-03,  1.2705e+00, -6.3332e-01, -2.3778e-01,
          -1.3795e-01]]], grad_fn=<ViewBackward0>)

Values:
 tensor([[[ 1.1584,  1.9865, -1.2399,  2.4898, -4.1935,  3.7342],
         [ 1.3962,  3.1158, -2.7011,  0.1129, -2.2644, -0.2995],
         [ 0.1818,  0.7535, -0.8222,  0.5391, -0.8618,  0.7727]]],
       grad_fn=<ViewBackward0>)


## Step 5: Unroll the last dimension of the Keys, Queries, Values matrices to include the number of heads and the dimension of heads.


(batch_size, number_of_tokens, dimension_outputs) ->
(batch_size, number_of_tokens, num_heads, head_dimension)

that is

(1, 3, 6) ->
(1, 3, 2, 3)


In [12]:
reshaped_queries = Queries.view(1, 3, 2, 3)
reshaped_keys = Keys.view(1, 3, 2, 3)
reshaped_values = Values.view(1, 3, 2, 3)

In [13]:
print("Reshaped queries:\n", reshaped_queries)
print("\nReshaped keys:\n", reshaped_keys)
print("\nReshaped values:\n", reshaped_values)

Reshaped queries:
 tensor([[[[-3.3182e+00,  4.2931e-01,  3.2498e+00],
          [-2.0282e+00, -2.0650e+00,  2.2758e+00]],

         [[-4.3869e+00,  1.2290e+00,  4.7963e+00],
          [-2.4509e+00, -4.3620e-01, -1.2468e+00]],

         [[-1.3072e+00,  1.8372e-03,  1.2705e+00],
          [-6.3332e-01, -2.3778e-01, -1.3795e-01]]]], grad_fn=<ViewBackward0>)

Reshaped keys:
 tensor([[[[ 1.2777,  2.1052,  1.2342],
          [ 1.2710,  1.3911,  1.4051]],

         [[ 2.1467, -1.4555,  0.5085],
          [ 5.1667, -1.0620,  2.4676]],

         [[ 0.4384,  0.1270,  0.0256],
          [ 0.9534,  0.1451,  0.8025]]]], grad_fn=<ViewBackward0>)

Reshaped values:
 tensor([[[[ 1.1584,  1.9865, -1.2399],
          [ 2.4898, -4.1935,  3.7342]],

         [[ 1.3962,  3.1158, -2.7011],
          [ 0.1129, -2.2644, -0.2995]],

         [[ 0.1818,  0.7535, -0.8222],
          [ 0.5391, -0.8618,  0.7727]]]], grad_fn=<ViewBackward0>)


The first groud of two elements of each matrix corresponds to the first token, and each element corresponds to one head. For instance, for the reshaped queries matrix:

[-1.4601,  1.6082, -0.4082], \
[ 0.1709,  2.5136, -0.9453]

corresponds to the first token "The", and each row corresponds to one head each. Then

[-0.1715, -0.0334,  0.6383], \
[-0.4673, -0.3364, -4.3318]

corresponds to the second token "cat", and each row corresponds to one head each. And so on.

Lastly, each token is now in 3 dimensions, hence the 3 elements of each row, because the head dimension equals 3.

Recall then that each head pays attention to each token separatedly.

## Step 6: Group by heads

For performing calculations we need to reshape the matrices to be of the following format

(batch_size, number_of_tokens, num_heads, head_dimension) ->
(batch_size, num_heads, number_of_tokens, head_dimension)

that is

(1, 3, 2, 3) ->
(1, 2, 3, 3)

Now, given that we have indexes 0,1,2,3 and we want to interchange indexes 1 and 2, we transpose the matrix using these indexes.

In [14]:
transposed_queries = reshaped_queries.transpose(1, 2) # This means interchanging indexes 1 and 2
transposed_keys = reshaped_keys.transpose(1, 2)
transposed_values = reshaped_values.transpose(1, 2)

In [15]:
print("Transposed queries:\n", transposed_queries)
print("\nTransposed keys:\n", transposed_keys)
print("\nTransposed values:\n", transposed_values)
#

Transposed queries:
 tensor([[[[-3.3182e+00,  4.2931e-01,  3.2498e+00],
          [-4.3869e+00,  1.2290e+00,  4.7963e+00],
          [-1.3072e+00,  1.8372e-03,  1.2705e+00]],

         [[-2.0282e+00, -2.0650e+00,  2.2758e+00],
          [-2.4509e+00, -4.3620e-01, -1.2468e+00],
          [-6.3332e-01, -2.3778e-01, -1.3795e-01]]]],
       grad_fn=<TransposeBackward0>)

Transposed keys:
 tensor([[[[ 1.2777,  2.1052,  1.2342],
          [ 2.1467, -1.4555,  0.5085],
          [ 0.4384,  0.1270,  0.0256]],

         [[ 1.2710,  1.3911,  1.4051],
          [ 5.1667, -1.0620,  2.4676],
          [ 0.9534,  0.1451,  0.8025]]]], grad_fn=<TransposeBackward0>)

Transposed values:
 tensor([[[[ 1.1584,  1.9865, -1.2399],
          [ 1.3962,  3.1158, -2.7011],
          [ 0.1818,  0.7535, -0.8222]],

         [[ 2.4898, -4.1935,  3.7342],
          [ 0.1129, -2.2644, -0.2995],
          [ 0.5391, -0.8618,  0.7727]]]], grad_fn=<TransposeBackward0>)


Now, in the result, the first element of each matrix corresponds to the 3 tokens for the first head, and the second element corresponds to the 3 tokens for the second head.

That is, for the queries matrix

[[-1.4601,  1.6082, -0.4082], \
[-0.1715, -0.0334,  0.6383], \
[ 0.0160,  0.3210,  0.2834]]

corresponds to the 3 tokens for the first head, and

[[ 0.1709,  2.5136, -0.9453], \
[-0.4673, -0.3364, -4.3318], \
[-0.1095,  0.5651, -0.8929]]

corresponds to the 3 tokens for the second head.

## Step 7: Find attention scores (self-attention)

$$queries * keys.transpose(2,3)$$

We need to transpose

In [16]:
attention_scores = transposed_queries @ transposed_keys.transpose(2, 3)
print(attention_scores)

tensor([[[[  0.6750,  -6.0953,  -1.3173],
          [  2.9018,  -8.7669,  -1.6448],
          [ -0.0983,  -2.1627,  -0.5404]],

         [[ -2.2529,  -2.6703,  -0.4068],
          [ -5.4737, -15.2760,  -3.4004],
          [ -1.3295,  -3.3600,  -0.7490]]]], grad_fn=<UnsafeViewBackward0>)


Once again, each big block corresponds to each head. These block corresponds to attention scores. Recall that each row is a word and each column is also a word.

So the first element 1.6042 corresponds to the attention score between "The" and "The". The element 1.9762 is the attention score between "The" and "cat". The element 0.1200 corresponds to the attention score beteween "cat" and "sleeps". All these correspond to the first head and the same applies for the second big block which corresponds to the second head.

## Step 8: Calculate attention weights

Recall that in causal attention one masks future tokens. Firts, we mask future tokens, then we divide by the "scaled dot-product self-attention", then we apply softmax.

In [17]:
mask = torch.triu(torch.ones(attention_scores.size(-2), attention_scores.size(-1)), diagonal=1).bool()
attention_scores_masked = attention_scores.masked_fill(mask, -torch.inf)
print("Masked attention scores:\n", attention_scores_masked)

Masked attention scores:
 tensor([[[[  0.6750,     -inf,     -inf],
          [  2.9018,  -8.7669,     -inf],
          [ -0.0983,  -2.1627,  -0.5404]],

         [[ -2.2529,     -inf,     -inf],
          [ -5.4737, -15.2760,     -inf],
          [ -1.3295,  -3.3600,  -0.7490]]]], grad_fn=<MaskedFillBackward0>)


In [18]:
head_dimension = transposed_keys.size(-1) # Get the head dimension
attention_scores_scaled = attention_scores_masked / (head_dimension ** 0.5)
print("Scaled attention scores:\n", attention_scores_scaled)

Scaled attention scores:
 tensor([[[[ 0.3897,    -inf,    -inf],
          [ 1.6754, -5.0615,    -inf],
          [-0.0567, -1.2487, -0.3120]],

         [[-1.3007,    -inf,    -inf],
          [-3.1602, -8.8196,    -inf],
          [-0.7676, -1.9399, -0.4324]]]], grad_fn=<DivBackward0>)


In [19]:
attention_weights = torch.softmax(attention_scores_scaled, dim=-1)
print("Attention weights:\n", attention_weights)

Attention weights:
 tensor([[[[1.0000, 0.0000, 0.0000],
          [0.9988, 0.0012, 0.0000],
          [0.4812, 0.1461, 0.3727]],

         [[1.0000, 0.0000, 0.0000],
          [0.9965, 0.0035, 0.0000],
          [0.3693, 0.1144, 0.5163]]]], grad_fn=<SoftmaxBackward0>)


After this, one could implement dropout, but for the sake of simplicity we will leave it as is.

Now we have attention weights for each head.

## Step 9: Calculate context vectors

$$attention\_weights * values$$

Recall that the goal of the attention mechanism is to calculate the context matrix.

In [20]:
context_vectors = attention_weights @ transposed_values
print("Context vectors:\n", context_vectors)

Context vectors:
 tensor([[[[ 1.1584,  1.9865, -1.2399],
          [ 1.1587,  1.9879, -1.2416],
          [ 0.8291,  1.6919, -1.2977]],

         [[ 2.4898, -4.1935,  3.7342],
          [ 2.4816, -4.1868,  3.7202],
          [ 1.2108, -2.2525,  1.7438]]]], grad_fn=<UnsafeViewBackward0>)


Once again, there are 2 heads (big blocks), where each row represents the context vector for each input token, and these context vectors live in a 3D (columns) space.

## Step 10: Reformat context vectors

Now, we have to swap some dimensions in order to make them match the output dimension required. Before we performed the following

(batch_size, number_of_tokens, num_heads, head_dimension) ->
(batch_size, num_heads, number_of_tokens, head_dimension)

Now, it's basically performing the opposite operation

(batch_size, num_heads, number_of_tokens, head_dimension) ->
(batch_size, number_of_tokens, num_heads, head_dimension)

In [21]:
final_context_vectors = context_vectors.transpose(1, 2)
print("Final context vectors:\n", final_context_vectors)

Final context vectors:
 tensor([[[[ 1.1584,  1.9865, -1.2399],
          [ 2.4898, -4.1935,  3.7342]],

         [[ 1.1587,  1.9879, -1.2416],
          [ 2.4816, -4.1868,  3.7202]],

         [[ 0.8291,  1.6919, -1.2977],
          [ 1.2108, -2.2525,  1.7438]]]], grad_fn=<TransposeBackward0>)


Now, the 3 big groups represent each token, and each row in each group corresponds to the representation of each head in a 3D space.

## Step 11: Combine heads

Finally, we combine the representation of each head into a single representation.

In [22]:
combined_context_vectors = final_context_vectors.reshape(1, 3, 6)
print("Combined context vectors:\n", combined_context_vectors)

Combined context vectors:
 tensor([[[ 1.1584,  1.9865, -1.2399,  2.4898, -4.1935,  3.7342],
         [ 1.1587,  1.9879, -1.2416,  2.4816, -4.1868,  3.7202],
         [ 0.8291,  1.6919, -1.2977,  1.2108, -2.2525,  1.7438]]],
       grad_fn=<UnsafeViewBackward0>)


So we have 1 batch of 3 rows, one for each token, and they live in a `dimension_outputs=6` space as stated in the beginning.

With this, we cover the Multihead attention which is at the heart of the LLM. Now we're ready to start building the LLM block by block.