In [73]:
import functools as ft
import math
import warnings
from typing import Literal, Optional, Tuple, Union, overload, Callable

import equinox as eqx
import jax
import jax.lax as lax
import jax.numpy as jnp
import jax.random as jrandom
from equinox.nn import Dropout, Linear, State, StateIndex
from jaxtyping import Array, Bool, Float, PRNGKeyArray
import torch

In [104]:
def repeat_kv(x: Array, n_rep: int) -> Array:
    """torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
    slen, n_kv_heads, head_dim = x.shape
    if n_rep == 1:
        return x
    return jnp.repeat(x, n_rep, axis=1).reshape(slen, n_kv_heads * n_rep, head_dim)


In [105]:
def dot_product_attention_weights(
    query: Float[Array, "q_seq qk_size"],
    key: Float[Array, "kv_seq qk_size"],
    mask: Optional[Bool[Array, "q_seq kv_seq"]] = None,
) -> Float[Array, "q_seq kv_seq"]:
    query = query / math.sqrt(query.shape[-1])
    logits = jnp.einsum("sd,Sd->sS", query, key)
    if mask is not None:
        if mask.shape != logits.shape:
            raise ValueError(
                f"mask must have shape (query_seq_length, "
                f"kv_seq_length)=({query.shape[0]}, "
                f"{key.shape[0]}). Got {mask.shape}."
            )
        logits = jnp.where(mask, logits, jnp.finfo(logits.dtype).min)

    return jax.nn.softmax(logits, axis=-1)


def dot_product_attention(
    query: Float[Array, "q_seq qk_size"],
    key_: Float[Array, "kv_seq qk_size"],
    value: Float[Array, "kv_seq v_size"],
    mask: Optional[Bool[Array, "q_seq kv_seq"]] = None,
    dropout: Optional[Dropout] = None,
    *,
    key: Optional[PRNGKeyArray] = None,
    inference: Optional[bool] = None,
) -> Float[Array, "q_seq v_size"]:
    weights = dot_product_attention_weights(query, key_, mask)
    if dropout is not None:
        weights = dropout(weights, key=key, inference=inference)
    attn = jnp.einsum("sS,Sd->sd", weights, value)
    return attn


def vmapped_attention(
    query_heads: Float[Array, "q_seq qk_size"],
    key_heads: Float[Array, "kv_seq qk_size"],
    value_heads: Float[Array, "kv_seq v_size"],
    dropout: Optional[Dropout] = None,
    inference: Optional[bool] = None,
    mask: Optional[Float[Array, "q_seq kv_seq"]] = None,
    keys: Optional[PRNGKeyArray] = None,
):
    attn_fn = ft.partial(
        dot_product_attention, dropout=dropout, inference=inference, key=keys, mask=mask
    )
    # Batch `keys` down its first axis as it is passed as a keyword argument.
    dpa = jax.vmap(
        lambda q, k, v: attn_fn(q, k, v),
        in_axes=(1, None, None),
        out_axes=1,
    )(query_heads, key_heads, value_heads)
    return dpa


class MultiheadAttention(eqx.Module):

    query_proj: Linear
    key_proj: Linear
    value_proj: Linear
    output_proj: Linear
    dropout: Dropout
    autoregressive_index: Optional[StateIndex]

    num_heads: int = eqx.field(static=True)
    query_size: int = eqx.field(static=True)
    key_size: int = eqx.field(static=True)
    value_size: int = eqx.field(static=True)
    output_size: int = eqx.field(static=True)

    state_length: Optional[int] = eqx.field(static=True)
    qk_size: int = eqx.field(static=True)
    vo_size: int = eqx.field(static=True)
    use_query_bias: bool = eqx.field(static=True)
    use_key_bias: bool = eqx.field(static=True)
    use_value_bias: bool = eqx.field(static=True)
    use_output_bias: bool = eqx.field(static=True)

    kv_multihead_dim: int = eqx.field(static=True)

    def __init__(
        self,
        num_heads: int,
        query_size: int,
        *,
        key_size: Optional[int] = None,
        value_size: Optional[int] = None,
        output_size: Optional[int] = None,
        kv_multihead_dim: Optional[int] = None,
        state_length: Optional[int] = None,
        qk_size: Optional[int] = None,
        vo_size: Optional[int] = None,
        use_query_bias: bool = False,
        use_key_bias: bool = False,
        use_value_bias: bool = False,
        use_output_bias: bool = False,
        dropout_p: float = 0.0,
        inference: bool = False,
        key: PRNGKeyArray,
        **kwargs,
    ):
        super().__init__(**kwargs)
        qkey, kkey, vkey, okey = jrandom.split(key, 4)

        if key_size is None:
            key_size = query_size
        if value_size is None:
            value_size = query_size
        if qk_size is None:
            qk_size = query_size // num_heads
        if vo_size is None:
            vo_size = query_size // num_heads
        if output_size is None:
            output_size = query_size

        def _make_autoregressive_cache(**_):
            if state_length is None:
                raise ValueError(
                    "Cannot use autoregressive decoding without specifying "
                    "`MultiheadAttention(..., state_length=...)`."
                )
            if kv_multihead_dim:
                key_shape = state_length, num_heads, qk_size
                value_shape = state_length, num_heads, vo_size
            else:
                key_shape = state_length, qk_size
                value_shape = state_length, vo_size
            if jax.config.jax_enable_x64:  # pyright: ignore
                _int = jnp.int64
            else:
                _int = jnp.int32
            return jnp.empty(key_shape), jnp.empty(value_shape), jnp.zeros((), _int)

        query_proj_out_size = qk_size
        key_proj_out_size = qk_size
        value_proj_out_size = vo_size

        kv_multihead_dim = num_heads if kv_multihead_dim is None else kv_multihead_dim

        query_proj_out_size = query_proj_out_size * num_heads
        key_proj_out_size = key_proj_out_size * kv_multihead_dim
        value_proj_out_size = value_proj_out_size * kv_multihead_dim

        self.query_proj = Linear(
            query_size, query_proj_out_size, use_bias=use_query_bias, key=qkey
        )
        self.key_proj = Linear(
            key_size, key_proj_out_size, use_bias=use_key_bias, key=kkey
        )
        self.value_proj = Linear(
            value_size, value_proj_out_size, use_bias=use_value_bias, key=vkey
        )

        self.output_proj = Linear(
            vo_size * num_heads, output_size, use_bias=use_output_bias, key=okey
        )
        self.dropout = Dropout(dropout_p, inference=inference)
        self.autoregressive_index = StateIndex(_make_autoregressive_cache()) if state_length is not None else None

        self.num_heads = num_heads
        self.query_size = query_size

        self.key_size = key_size
        self.value_size = value_size
        self.output_size = output_size
        self.qk_size = qk_size
        self.vo_size = vo_size
        self.use_query_bias = use_query_bias
        self.use_key_bias = use_key_bias
        self.use_value_bias = use_value_bias
        self.use_output_bias = use_output_bias
        self.state_length = state_length
        self.kv_multihead_dim = kv_multihead_dim
        

    def __call__(
        self,
        query: Float[Array, "q_seq q_size"],
        key_: Float[Array, "kv_seq k_size"],
        value: Float[Array, "kv_seq v_size"],
        mask: Union[
            None,
            Bool[Array, "q_seq kv_seq"],
            Bool[Array, "num_heads q_seq kv_seq"],
            Literal["causal"],
        ] = None,
        state: Optional[State] = None,
        *,
        key: Optional[PRNGKeyArray] = None,
        inference: Optional[bool] = None,
        deterministic: Optional[bool] = None,
        process_heads: Optional[
            Callable[
                [
                    Float[Array, "query_size num_heads qk_size"],
                    Float[Array, "key_size num_heads qk_size"],
                    Float[Array, "value_size num_heads vo_size"],
                ],
                Tuple[
                    Float[Array, "query_size num_heads qk_size"],
                    Float[Array, "key_size num_heads qk_size"],
                    Float[Array, "value_size num_heads vo_size"],
                ],
            ]
        ] = None,
    ) -> Union[
        Float[Array, "q_seq o_size"], Tuple[Float[Array, "q_seq o_size"], State]
    ]:
        if deterministic is not None:
            inference = deterministic
            warnings.warn(
                "MultiheadAttention()(deterministic=...) is deprecated "
                "in favour of MultiheadAttention()(inference=...)"
            )

        query_seq_length, _ = query.shape
        kv_seq_length, _ = key_.shape
        kv_seq_length2, _ = value.shape
        if kv_seq_length != kv_seq_length2:
            # query length can be different
            raise ValueError("key and value must both be sequences of equal length.")
        del kv_seq_length2

        # query_heads = self._project(self.query_proj, self.query_multihead, query)
        # key_heads = self._project(self.key_proj, self.key_multihead, key_)
        # value_heads = self._project(self.value_proj, self.value_multihead, value)

        query_heads = self._new_project(
            self.query_proj, self.num_heads, query
        )
        key_heads = self._new_project(self.key_proj, self.kv_multihead_dim, key_)
        value_heads = self._new_project(self.value_proj, self.kv_multihead_dim, value)

        if process_heads is not None:
            q_shape, k_shape, v_shape = (
                query_heads.shape,
                key_heads.shape,
                value_heads.shape,
            )
            query_heads, key_heads, value_heads = process_heads(
                query_heads, key_heads, value_heads
            )

            if (
                query_heads.shape != q_shape
                or key_heads.shape != k_shape
                or value_heads.shape != v_shape
            ):
                raise ValueError(
                    "process_heads must not change the shape of the heads."
                )

        if state is None:
            causal_mask_offset = 0
        else:
            key_state, value_state, index = state.get(self.autoregressive_index)
            key_state = lax.dynamic_update_slice_in_dim(
                key_state, key_heads, index, axis=0
            )
            value_state = lax.dynamic_update_slice_in_dim(
                value_state, value_heads, index, axis=0
            )
            causal_mask_offset = index
            index = index + kv_seq_length
            state = state.set(
                self.autoregressive_index, (key_state, value_state, index)
            )
            key_heads = key_state
            value_heads = value_state
            kv_seq_length = self.state_length

        if mask == "causal":
            query_indices = jnp.arange(query_seq_length)[:, None]
            kv_indices = jnp.arange(kv_seq_length)[None, :]
            mask = kv_indices <= query_indices + causal_mask_offset
        if state is not None:
            # Also mask out the latter parts of the state we haven't written into yet.
            unwritten_mask = jnp.arange(self.state_length) < index  # pyright: ignore
            if mask is None:
                mask = jnp.broadcast_to(
                    unwritten_mask, (query_seq_length, self.state_length)
                )
            else:
                mask = mask & unwritten_mask

        keys = None if key is None else jax.random.split(key, self.num_heads)
        if self.kv_multihead_dim == self.num_heads:
            # Normal multi-head attention
            attn_fn = ft.partial(
                dot_product_attention, dropout=self.dropout, inference=inference
            )
            print(query_heads.shape, key_heads.shape, value_heads.shape)
            in_axes = (1, 1, 1, 0 if mask is not None and mask.ndim == 3 else None)
            attn = jax.vmap(
                attn_fn, in_axes=in_axes, out_axes=1, axis_size=self.num_heads
            )(query_heads, key_heads, value_heads, mask, key=keys)
        else:
            pt_vmapped_attention = ft.partial(
                vmapped_attention,
                dropout=self.dropout,
                inference=inference,
                mask=mask,
                keys=keys,
            )
            attn = jax.vmap(pt_vmapped_attention, in_axes=(None, 1, 1), out_axes=1)(
                query_heads, key_heads, value_heads
            )

            attn = jnp.sum(attn, axis=1)
            # Taking the mean over the d dimension
            attn = attn / self.kv_multihead_dim

        attn = attn.reshape(query_seq_length, self.num_heads * self.vo_size)
        out = jax.vmap(self.output_proj)(attn)

        if state is None:
            return out
        else:
            return out, state

    def _project(self, proj, multihead, x):
        seq_length, _ = x.shape
        projection = jax.vmap(proj)(x)
        if multihead:
            _, projection_size = projection.shape
            size_per_head = projection_size // self.num_heads
            projection = projection.reshape(seq_length, self.num_heads, size_per_head)
        return projection

    def _new_project(self, proj: eqx.Module, multihead: int | None, x: Array) -> Array:
        seq_length, _ = x.shape
        projection = jax.vmap(proj)(x)

        if multihead is not None:
            _, projection_size = projection.shape
            size_per_head = projection_size // multihead
            projection = projection.reshape(seq_length, multihead, size_per_head)

        return projection


def self_attention(
    num_heads: int,
    size: int,
    *,
    multiquery: bool = False,
    state_length: Optional[int] = None,
    key: PRNGKeyArray,
):
    return MultiheadAttention(
        num_heads=num_heads,
        query_size=size,
        state_length=state_length,
        key_multihead=not multiquery,
        value_multihead=not multiquery,
        key=key,
    )


In [106]:
num_heads = 6
query_size = 12
kv_multihead_dim = 6
key = jax.random.PRNGKey(33)

mha = MultiheadAttention(num_heads=num_heads, query_size=query_size, kv_multihead_dim=kv_multihead_dim, key=key)
seq = 8

In [107]:
query = jax.random.uniform(key=jax.random.PRNGKey(20), shape=(seq, query_size))
key_ = jax.random.uniform(key=jax.random.PRNGKey(21), shape=(seq, query_size))
value = jax.random.uniform(key=jax.random.PRNGKey(22), shape=(seq, query_size))

In [108]:
val = mha(query, key_, value)
val.shape

(8, 6, 2) (8, 6, 2) (8, 6, 2)


(8, 12)

(8, 12, 2)

In [121]:
num_heads = 6
query_size = 12
kv_multihead_dim = 4
key = jax.random.PRNGKey(33)

mha = MultiheadAttention(num_heads=num_heads, query_size=query_size, kv_multihead_dim=kv_multihead_dim, key=key)
seq = 8

In [122]:
query = jax.random.uniform(key=jax.random.PRNGKey(20), shape=(seq, query_size))
key_ = jax.random.uniform(key=jax.random.PRNGKey(21), shape=(seq, query_size))
value = jax.random.uniform(key=jax.random.PRNGKey(22), shape=(seq, query_size))

In [123]:
val = mha(query, key_, value)
val.shape

(8, 12)

In [124]:
q = jnp.ones(shape=(seq, num_heads, 2))
k = jnp.ones(shape=(seq, kv_multihead_dim, 2))
v = jnp.ones(shape=(seq, kv_multihead_dim, 2))

repeat_kv(k, 2).shape

(8, 8, 2)