<a href="https://colab.research.google.com/github/RCortez25/PhD/blob/main/LLM/4.%20Attention%20mechanism/3_Multi_head_attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Multihead attention wrapper

> In Multi head attention, one divides the attention mechanism into multiple heads, each operating independently.
>
> One stacks multiple single head attention layers. That is, one creates multiple instances of the self-attention mechanism, each with its own weights and then they are combined.
>
> Even though this is computationally expensive, it allows the LLM to capture complex patterns.

In [1]:
import torch
import torch.nn as nn

In [2]:
class MultiHeadAttentionWrapper(nn.Module):
    # Number of attention heads mut be given. These are the number of single
    # self-attention mechanisms
    def __init__(self, dimensions_in, dimensions_out, context_length,
                 dropout, num_heads, qkv_bias=False):
        super().__init__()
        # Create a list containing all single self-attention mechanisms. These
        # are created using the CausalAttention class
        self.heads = nn.ModuleList(
            [CausalAttention(dimensions_in, dimensions_out, context_length,
                             dropout, qkv_bias) for _ in range(num_heads)]
        )

    # Method for concatenating all single self-attention heads along the
    # columns
    def forward(self, x):
        return torch.cat([head(x) for head in self.heads], dim=-1)

> Let's repeat the code we had before to test the class

In [3]:
inputs = torch.tensor(
    [[0.43, 0.15, 0.89], # Your     (x^1)
     [0.55, 0.87, 0.66], # journey  (x^2)
     [0.57, 0.85, 0.64], # starts   (x^3)
     [0.22, 0.58, 0.33], # with     (x^4)
     [0.77, 0.25, 0.10], # one      (x^5)
     [0.05, 0.80, 0.55]] # step     (x^6)
)

batch = torch.stack([inputs, inputs], dim=0)

In [4]:
class CausalAttention(nn.Module):
    def __init__(self, dimension_inputs, dimension_outputs, context_length,
                 dropout, qkv_bias=False):
        super().__init__()
        self.dimension_outputs = dimension_outputs
        # Initialize the matrices using Linear layers
        self.W_q = nn.Linear(dimension_inputs, dimension_outputs, bias=qkv_bias)
        self.W_k = nn.Linear(dimension_inputs, dimension_outputs, bias=qkv_bias)
        self.W_v = nn.Linear(dimension_inputs, dimension_outputs, bias=qkv_bias)
        # Initialize the dropout layer
        self.dropout = nn.Dropout(dropout)
        # Initialize the buffer, for automatically moving the model to CPU or GPU
        self.register_buffer("mask",
                             torch.triu(torch.ones(context_length, context_length),
                                        diagonal=1))

    # Method to calculate the context vector
    def forward(self, input_vectors):
        # Obtain the relevant dimensions
        batch_size, number_of_tokens, dimension_inputs = input_vectors.shape
        queries = self.W_q(input_vectors)
        keys = self.W_k(input_vectors)
        values = self.W_v(input_vectors)

        attention_scores = queries @ keys.transpose(1, 2)
        attention_scores.masked_fill_(
            self.mask.bool()[:number_of_tokens, :number_of_tokens], -torch.inf)

        # Calculate attention weights
        dimension_keys = keys.shape[-1]
        attention_weights = torch.softmax(attention_scores / (dimension_keys ** 0.5), dim=-1)
        # Apply dropout
        attention_weights = self.dropout(attention_weights)

        # Calculate and return the context vectors
        context_vectors = attention_weights @ values

        return context_vectors

In [7]:
# Test the wrapper
torch.manual_seed(123)
context_length = batch.shape[1] # Number of tokens
dimension_inputs = 3
dimension_outputs = 2
oMultiHeadAttention = MultiHeadAttentionWrapper(dimension_inputs,
                                                dimension_outputs,
                                                context_length,
                                                dropout=0.0, num_heads=2)
context_vectors = oMultiHeadAttention(batch)
print(context_vectors)

tensor([[[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]],

        [[-0.4519,  0.2216,  0.4772,  0.1063],
         [-0.5874,  0.0058,  0.5891,  0.3257],
         [-0.6300, -0.0632,  0.6202,  0.3860],
         [-0.5675, -0.0843,  0.5478,  0.3589],
         [-0.5526, -0.0981,  0.5321,  0.3428],
         [-0.5299, -0.1081,  0.5077,  0.3493]]], grad_fn=<CatBackward0>)


In [6]:
context_vectors.shape

torch.Size([2, 6, 4])

> The dimensions here represent:
* 2: The number of batches
* 6: The number of tokens, or input vectors, in each batch
* 4: Ths is because we selected two heads `num_heads=2`. Therefore, since the output dimension of each head is 2, then 2 heads add 2 + 2 = 4.