In [17]:
from typing import Optional

import equinox as eqx
import jax
import jax.numpy as jnp
from icecream import ic
from jaxtyping import Array, Float, Bool
import math

query_input_dim = 16
query_embedding_dim = 32
key_input_dim = 16
key_embedding_dim = 32
value_input_dim = 16
value_embedding_dim = 32
num_heads = 4
max_seq_len = 10
batch_size = 2
output_dim = 32

key = jax.random.PRNGKey(42)

In [18]:
def dot_product_attention(query_projection: Array, key_projection: Array, value_projection: Array, mask: Optional[Array | None]) -> Array:
    attention_weights = jax.vmap(
        lambda q, k: q @ k.T,
        in_axes=(1, 1),
        out_axes=1
    )(query_projection, key_projection)
    attention_weights = attention_weights / jnp.sqrt(key_projection.shape[-1])
    ic(attention_weights.shape)
    attention_weights = jax.nn.softmax(attention_weights, axis=-1)

    return attention_weights

In [22]:
# Version 1
class MultiheadAttention(eqx.Module):
    query_projection: eqx.nn.Linear
    key_projection: eqx.nn.Linear
    value_projection: eqx.nn.Linear

    query_input_dim: int = eqx.field(static=True)
    query_embedding_dim: int = eqx.field(static=True)
    
    key_input_dim: int = eqx.field(static=True)
    key_embedding_dim: int = eqx.field(static=True)

    value_input_dim: int = eqx.field(static=True)
    value_embedding_dim: int = eqx.field(static=True)

    output: eqx.nn.Linear
    num_heads: int = eqx.field(static=True)
    output_dim: int = eqx.field(static=True)

    def __init__(self, query_embedding_dim, key_embedding_dim, value_embedding_dim, query_input_dim, key_input_dim, value_input_dim, num_heads, output_dim, key):
        qkey, kkey, vkey, okey = jax.random.split(key, 4)
        self.query_projection = eqx.nn.Linear(query_input_dim, num_heads * query_embedding_dim, key=qkey, use_bias=False)
        self.key_projection = eqx.nn.Linear(key_input_dim, num_heads * key_embedding_dim, key=kkey, use_bias=False)
        self.value_projection = eqx.nn.Linear(value_input_dim, num_heads * value_embedding_dim, key=vkey, use_bias=False)

        self.output = eqx.nn.Linear(num_heads * value_embedding_dim, output_dim, key=okey, use_bias=False)
        
        # parameters
        self.query_input_dim = query_input_dim
        self.query_embedding_dim = query_embedding_dim
        self.key_input_dim = key_input_dim
        self.key_embedding_dim = key_embedding_dim
        self.value_input_dim = value_input_dim
        self.value_embedding_dim = value_embedding_dim
        self.num_heads = num_heads
        self.output_dim = output_dim

    def __call__(self, x: Float[Array, "max_seq_len input_dim"]):
        seq_len, _ = x.shape
        query = jax.vmap(self.query_projection)(x).reshape(seq_len, self.num_heads, self.query_embedding_dim)
        key = jax.vmap(self.key_projection)(x).reshape(seq_len, self.num_heads, self.key_embedding_dim) 
        value = jax.vmap(self.value_projection)(x).reshape(seq_len, self.num_heads, self.value_embedding_dim)
        ic(query.shape)
        ic(key.shape)
        ic(value.shape)

        scaled_attention_weights = dot_product_attention(query, key, value, None)
        
        qkv_matmul = jax.vmap(
            lambda attention_weights, value: attention_weights @ value,
            in_axes=(1, 1),
            out_axes=1
        )(scaled_attention_weights, value)

        ic(qkv_matmul.shape)

        concatenation = qkv_matmul.reshape(seq_len, -1)

        ic(concatenation.shape)

        output = jax.vmap(self.output)(concatenation)
        ic(output.shape)
        return output

key, subkey = jax.random.split(key)
mha = MultiheadAttention(query_embedding_dim, key_embedding_dim, value_embedding_dim, query_input_dim, key_input_dim, value_input_dim, num_heads, output_dim, key)
x = jax.random.normal(subkey, (max_seq_len, query_input_dim))
output = mha(x)

ic| query.shape: (10, 4, 32)


ic| key.shape: (10, 4, 32)
ic| value.shape: (10, 4, 32)
ic| attention_weights.shape: (10, 4, 10)
ic| qkv_matmul.shape: (10, 4, 32)
ic| concatenation.shape: (10, 128)
ic| output.shape: (10, 32)
