## Multihead-Attention
Multihead-Attention (MHA) is the key part of the transformer architecture. It uses a mechanism called _self attention_, which has been very successful in NLP tasks. I already introduced the transformer in my other [blog post](/posts/learning-the-transformer). In this blog post, we will take a deepdive into the MHA and try to make it as general as we can. Let's start with an overview of the MHA block. In the following figure, you can see the structure of the MHA block. 

![MHA Overview](MHA.drawio.svg)

The input to the MHA block is duplicated into 3 vectors: the _query_, _key_ and _value_ vectors. Each of those is first passed through their respective linear layer. Each of those layers has a specific task:

- Query Layer: transforms the input to query vectors, i.e. _what you're interested in_
- Key Layer: transforms the input to a set of keys to match the query vectors against
- Value Layer: take the scaled combination of the query and key projects and compute the output of the MHA block

This is a good starting point, so let's write this down in code.

_(By the way, most of this was already covered in my [previous blog post](/posts/learning-the-transformer) and this implementation takes heavy inspiration from already existing implementations such as the [MHA block from Equinox](https://github.com/patrick-kidger/equinox/blob/main/equinox/nn/_attention.py))_

One thing to note is the dimensionalities of the vectors, so let's start by defining the dimensions first. Here's the notation for this blog post:

- $L$: maximum sequence length
- $h$: number of heads
- $\{q,k,v\}_{emdb}$: query, key or value embedding dimension

Furthermore, let's define the input to the MHA block.

- Query: $[L \times q_{in}]$, where $q_{in}$ is the query input dimension
- Key: $[L \times k_{in}]$, where $k_{in}$ is the key input dimension
- Value $[L \times v_{in}]$, where $v_{in}$ is the value input dimension

Usually, the query, key and value input dimensions are the same, but we want our implementation to be very general and make as few assumptions as possible about the current use case. Therefore, we will be more specific. The reason that normally they are the same is that, typically, they come out of the input embeddings (with the positional embeddings added on top) and the same embedding is used for all vectors, giving them all the same input dimension.

In [52]:
import jax
import jax.numpy as jnp
import equinox as eqx
from jaxtyping import Float, Array


query_input_dim = 16
query_embedding_dim = 32
key_input_dim = 16
key_embedding_dim = 32
value_input_dim = 16
value_embedding_dim = 32
num_heads = 4
max_seq_len = 10
batch_size = 2
key = jax.random.PRNGKey(42)

In [53]:
# Version 1
class MultiheadAttention(eqx.Module):
    query_projection: eqx.nn.Linear
    key_projection: eqx.nn.Linear
    value_projection: eqx.nn.Linear

    query_input_dim: int = eqx.field(static=True)
    query_embedding_dim: int = eqx.field(static=True)
    
    key_input_dim: int = eqx.field(static=True)
    key_embedding_dim: int = eqx.field(static=True)

    value_input_dim: int = eqx.field(static=True)
    value_embedding_dim: int = eqx.field(static=True)

    num_heads: int = eqx.field(static=True)

    def __init__(self, query_embedding_dim, key_embedding_dim, value_embedding_dim, query_input_dim, key_input_dim, value_input_dim, num_heads, key):
        qkey, kkey, vkey = jax.random.split(key, 3)
        self.query_projection = eqx.nn.Linear(query_input_dim, num_heads * query_embedding_dim, key=qkey, use_bias=False)
        self.key_projection = eqx.nn.Linear(key_input_dim, num_heads * key_embedding_dim, key=kkey, use_bias=False)
        self.value_projection = eqx.nn.Linear(value_input_dim, num_heads * value_embedding_dim, key=vkey, use_bias=False)
    
        # parameters
        self.query_input_dim = query_input_dim
        self.query_embedding_dim = query_embedding_dim
        self.key_input_dim = key_input_dim
        self.key_embedding_dim = key_embedding_dim
        self.value_input_dim = value_input_dim
        self.value_embedding_dim = value_embedding_dim
        self.num_heads = num_heads

    def __call__(self, x: Float[Array, "max_seq_len input_dim"]):
        seq_len, _ = x.shape
        query = jax.vmap(self.query_projection)(x).reshape(seq_len, self.num_heads, self.query_embedding_dim)
        key = jax.vmap(self.key_projection)(x).reshape(seq_len, self.num_heads, self.key_embedding_dim) 
        value = jax.vmap(self.value_projection)(x).reshape(seq_len, self.num_heads, self.value_embedding_dim)
        print(f"{query.shape=}")
        print(f"{key.shape=}")
        print(f"{value.shape=}")

key, subkey = jax.random.split(key)
mha = MultiheadAttention(query_embedding_dim, key_embedding_dim, value_embedding_dim, query_input_dim, key_input_dim, value_input_dim, num_heads, key)
x = jax.random.normal(subkey, (max_seq_len, query_input_dim))
mha(x)

query.shape=(10, 4, 32)
key.shape=(10, 4, 32)
value.shape=(10, 4, 32)


As mentioned in my previous blog post, a MHA block consists of multiple _heads_. But, instead of looping over each head, one at a time, we can instead simply enlarge the query, key and value layers to include all of the heads. Look at it this way: taking all of the heads into consideration, the **output** shape of, say, the query projection should be:

$$
    [L \times h \times q_{embd}]
$$

By making the query projection layer project from $q_{in}$ to $h * q_{emdb}$, we get initially a matrix with the same $[L \times h * q_{embd}]$. From there, we can simply reshape that matrix into our desired shape: $[L \times h \times q_{embd}]$.

This is just the first steps for the query and value projections, they have still quite the journey ahead. We still need to matrix multiply, scale (sometimes mask) and softmax them. Let's write a function that can do all of that in one go. 