In [4]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

import functools
from typing import Callable, Optional, Tuple, TypeVar
from jaxtyping import Float, Array, PRNGKeyArray, Bool

import jax
import jax.numpy as jnp
import jax.random as jr
import flax.linen as nn


In [6]:
Dtype = TypeVar("Dtype", bound=jnp.dtype)
Shape = Tuple[int, ...]


In [3]:
from tx.models.gpt2 import PretrainedGPT2Model
from tx.network import GenerativeModel
from transformers import GPT2TokenizerFast


reference_gpt2 = GenerativeModel(
    config=PretrainedGPT2Model.tx_config,
    variables={"params": PretrainedGPT2Model.from_pretrained("gpt2").to_params()},
    tokenizer=GPT2TokenizerFast.from_pretrained("gpt2"),
)

In [7]:
InitFn = Callable[[PRNGKeyArray, Shape, Dtype], Array]


class FlaxAttentionImplementation(nn.Module):
    num_heads: int
    qkv_features: Optional[int] = None
    out_features: Optional[int] = None
    kernel_init: InitFn = nn.initializers.xavier_uniform()
    bias_init: InitFn = nn.initializers.zeros_init()
    use_bias: bool = True

    @nn.compact
    def __call__(
        self,
        inputs_q: Float[Array, "... QL"],  # QL = query length
        mask: Optional[Array] = None,
    ):
        features = self.out_features or inputs_q.shape[-1]
        qkv_features = self.qkv_features or inputs_q.shape[-1]
        head_dim = qkv_features // self.num_heads
        # inputs_kv = inputs_kv if inputs_kv is not None else inputs_q

        dense = functools.partial(
            nn.linear.DenseGeneral,
            axis=-1,
            features=(self.num_heads, head_dim),
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            use_bias=self.use_bias,
        )

        # project inputs_q to multi-headed q/k/v
        # dimensions are then [batch..., length, n_heads, n_features_per_head]

        query, key, value = (
            dense(name="query")(inputs_q),
            dense(name="key")(inputs_q),
            dense(name="value")(inputs_q),
        )

        # calculate attention matrix
        depth = query.shape[-1]
        query = query / jnp.sqrt(depth)

        # attn weight shape is (batch..., num_heads, q_length, kv_length)
        attn_scores = jnp.einsum("...qhd,...khd->...hqk", query, key)

        # apply attention mask
        if mask is not None:
            big_neg = jnp.finfo(jnp.float32).min
            attn_scores = jnp.where(mask, attn_scores, big_neg)

        # normalize the attention weights
        attn_weights = jax.nn.softmax(attn_scores)

        # return weighted sum over values for each query position
        z = jnp.einsum("...hqk,...khd->...qhd", attn_weights, value)

        # back to the original inputs dimensions
        out = nn.linear.DenseGeneral(
            features=features,
            axis=(-2, -1),
            kernel_init=self.kernel_init,
            bias_init=self.bias_init,
            use_bias=self.use_bias,
            name="out",  # type: ignore[call-arg]
        )(z)

        return out


In [8]:
from typing import List, Tuple
from jaxtyping import Array, Float

from functools import partial

import jax
import jax.numpy as jnp
import flax.linen as nn
import flax.struct as struct


class TXAttention(nn.Module):
    num_heads: int
    head_dim: int
    model_dim: int
    init_range: float = 0.02
    use_bias: bool = True

    intermediates: List[str] = struct.field(default_factory=list)

    def intermediate(self, name: str, value: Array) -> bool:
        if name in self.intermediates:
            return self.sow("intermediates", name, value)
        return False

    @nn.compact
    def __call__(self, x: Float[Array, "seq embed"]) -> Float[Array, "seq embed"]:
        """
        References:
        - `flax.linen.attention`.
        """
        init_dense = partial(
            nn.DenseGeneral,
            kernel_init=jax.nn.initializers.normal(stddev=self.init_range),
            bias_init=jax.nn.initializers.zeros,
            use_bias=self.use_bias,
        )

        # Apply a linear transformation to the input tensor.
        hidden_states = init_dense(name="c_attn", features=3 * self.model_dim)(x)

        # Split the hidden states into query, key, and value.
        query, key, value = self._split_outputs(hidden_states)
        query_length, key_length = query.shape[-3], key.shape[-3]
        print(query.shape, key.shape, value.shape)
        self._qkv_intermediates((query, key, value))

        # Compute the attention weights.
        query = query / jnp.sqrt(query.shape[-1])
        scores = jnp.einsum("...qhd,...khd->...hqk", query, key)
        self.intermediate("scores", scores)

        # Apply the causal mask to the attention weights.
        mask = nn.make_causal_mask(jnp.ones((x.shape[0],), dtype="bool"))[
            :, :query_length, :key_length
        ]
        big_neg = jnp.finfo(jnp.float32).min
        scores = jnp.where(mask, scores, big_neg)

        # Normalize the attention weights
        weights = jax.nn.softmax(scores)
        self.intermediate("weights", weights)

        # Apply the attention pattern to the value tensor.
        z = jnp.einsum("...hqk,...khd->...qhd", weights, value)
        self.intermediate("z", z)

        # Apply a linear transformation to the attention output.
        merged_z = self._merge_heads(z)
        output = init_dense(name="c_proj", features=self.model_dim)(merged_z)
        return output

    def _qkv_intermediates(self, qkv: Tuple[Array, Array, Array]) -> bool:
        ret_vals = []
        for name, value in zip(("query", "key", "value"), qkv):
            ret_vals.append(self.intermediate(name, value))

        return all(ret_vals)

    def _split_outputs(self, states: Array):
        return map(self._split_heads, jnp.split(states, 3, axis=-1))

    def _split_heads(self, states: Array):
        return states.reshape((states.shape[0], self.num_heads, self.head_dim))

    def _merge_heads(self, states: Array):
        return states.reshape((states.shape[0], self.model_dim))


## Compare Attention Implementations

In [11]:
NUM_HEADS = 4
HEAD_DIM = 10
MODEL_DIM = NUM_HEADS * HEAD_DIM

flax_attn = nn.MultiHeadDotProductAttention(
    num_heads=NUM_HEADS, qkv_features=MODEL_DIM, out_features=MODEL_DIM
)
tx_attn = TXAttention(num_heads=NUM_HEADS, head_dim=HEAD_DIM, model_dim=MODEL_DIM)


def convert_params(tx_params):
    c_attn, c_proj = tx_params["c_attn"], tx_params["c_proj"]

    qkv_kernel = jnp.split(c_attn["kernel"], 3, axis=-1)
    reshape_kernel = lambda a: jnp.reshape(a, (qkv_kernel[0].shape[0], NUM_HEADS, HEAD_DIM))
    q_kernel, k_kernel, v_kernel = tuple(map(reshape_kernel, qkv_kernel))
    o_kernel = jnp.reshape(c_proj["kernel"], (NUM_HEADS, HEAD_DIM, MODEL_DIM))

    qkv_bias = jnp.split(c_attn["bias"], 3, axis=-1)
    reshape_bias = lambda a: jnp.reshape(a, (NUM_HEADS, HEAD_DIM))
    q_bias, k_bias, v_bias = tuple(map(reshape_bias, qkv_bias))
    o_bias = c_proj["bias"]

    flax_params = {}
    flax_params["query"] = {"kernel": q_kernel, "bias": q_bias}
    flax_params["key"] = {"kernel": k_kernel, "bias": k_bias}
    flax_params["value"] = {"kernel": v_kernel, "bias": v_bias}
    flax_params["out"] = {"kernel": o_kernel, "bias": o_bias}

    return flax_params


ex_inputs = jnp.array([[1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1]])
tx_params = tx_attn.init(jr.PRNGKey(0), x=ex_inputs)["params"]
flax_params = convert_params(tx_params)

tx_out = tx_attn.apply({"params": tx_params}, ex_inputs)
flax_out = flax_attn.apply(
    {"params": flax_params},
    ex_inputs,
    ex_inputs,
    mask=nn.make_causal_mask(
        jnp.ones((ex_inputs.shape[0],), dtype="bool"),
        dtype="bool",
    )[:, :4, :4],
)

jnp.allclose(tx_out, flax_out, atol=1e-6)


(4, 3, 12) (4, 3, 12) (4, 3, 12)
(4, 3, 12) (4, 3, 12) (4, 3, 12)


Array(True, dtype=bool)

In [7]:

jax.tree_util.tree_map(lambda a: a.shape, tx_as_flax)


{'key': {'kernel': (3, 2, 4)},
 'out': {'kernel': (2, 4, 8)},
 'query': {'kernel': (3, 2, 4)},
 'value': {'kernel': (3, 2, 4)}}

In [8]:
tx_output = tx_attn.apply({"params": tx_params}, ex_inputs)
tx_output

Array([[ 3.5190952e-04, -3.6897347e-04,  3.5196170e-04, -2.6260878e-04,
        -2.1131843e-04,  3.2719318e-04, -5.1394856e-04,  5.3117721e-04],
       [ 3.9654266e-04, -1.9495170e-04,  3.9853103e-04, -5.1042713e-05,
        -6.5240380e-04,  2.3290190e-04, -4.9521885e-04,  4.6091323e-04],
       [ 5.2868749e-04, -2.5989593e-04,  5.3130480e-04, -6.8043766e-05,
        -8.6979882e-04,  3.1042617e-04, -6.6020573e-04,  6.1448303e-04],
       [ 2.9878583e-04,  2.6193861e-04,  3.0228859e-04,  8.6668959e-05,
        -5.8025989e-04,  3.4064267e-04, -4.2032293e-04,  6.4078835e-04]],      dtype=float32)

In [9]:
flax_output = flax_attn.apply({"params": flax_params}, ex_inputs)
flax_output

Array([[-0.18964374,  0.16998684,  0.00492778, -0.39047906,  0.12759988,
        -0.03431374,  0.2705934 ,  0.649871  ],
       [-0.15222657,  0.15960044,  0.0206019 , -0.3847821 ,  0.09544811,
        -0.04913256,  0.21884091,  0.64352965],
       [-0.17099951,  0.15805584,  0.01345147, -0.3938407 ,  0.10670467,
        -0.04940178,  0.24964438,  0.65235615],
       [-0.18512593,  0.17309962, -0.02557041, -0.4040893 ,  0.14441134,
        -0.01572113,  0.27858272,  0.6612222 ]], dtype=float32)

In [10]:
tx_as_flax_output = flax_attn.apply({"params": tx_as_flax}, ex_inputs)
tx_as_flax_output

Array([[ 2.9841467e-04,  2.6236032e-04,  3.0198149e-04,  8.6865257e-05,
        -5.7999179e-04,  3.4076214e-04, -4.2007797e-04,  6.4062094e-04],
       [ 2.9800236e-04,  2.6269123e-04,  3.0171062e-04,  8.7038155e-05,
        -5.7970901e-04,  3.4104238e-04, -4.1987709e-04,  6.4042286e-04],
       [ 2.9793353e-04,  2.6288204e-04,  3.0160198e-04,  8.7080378e-05,
        -5.7957176e-04,  3.4096732e-04, -4.1976344e-04,  6.4043072e-04],
       [ 2.9878583e-04,  2.6193861e-04,  3.0228859e-04,  8.6668959e-05,
        -5.8025989e-04,  3.4064267e-04, -4.2032293e-04,  6.4078835e-04]],      dtype=float32)

In [11]:
# test_params = reference_gpt2.variables["params"]["block_0"]["attn"]


# def transform_params(params):
#     attn_kernel, attn_bias = params["c_attn"]["kernel"], params["c_attn"]["bias"]
#     proj_kernel, proj_bias = params["c_proj"]["kernel"], params["c_proj"]["bias"]

#     def transform_kernel(a):
#         a = jnp.transpose(a)
#         a = jnp.reshape(a, (768, 12, 64))
#         return a

#     def transform_bias(b):
#         return jnp.reshape(b, (12, 64))

#     q_kernel, k_kernel, v_kernel = map(
#         transform_kernel, jnp.split(attn_kernel, 3, axis=-1)
#     )
#     q_bias, k_bias, v_bias = map(transform_bias, jnp.split(attn_bias, 3, axis=-1))

#     out_kernel = transform_kernel(proj_kernel)
#     out_kernel = jnp.transpose(out_kernel, (1, 2, 0))
#     # out_bias = jnp.reshape(proj_bias, (12, 64))

#     # return {
#     #     "query": {"kernel": q_kernel, "bias": q_bias},
#     #     "key": {"kernel": k_kernel, "bias": k_bias},
#     #     "value": {"kernel": v_kernel, "bias": v_bias},
#     #     "out": {"kernel": out_kernel, "bias": proj_bias},
#     # }
#     return {
#         "query": {"kernel": jnp.ones_like(q_kernel), "bias": jnp.zeros_like(q_bias)},
#         "key": {"kernel": jnp.ones_like(k_kernel), "bias": jnp.zeros_like(k_bias)},
#         "value": {"kernel": jnp.ones_like(v_kernel), "bias": jnp.zeros_like(v_bias)},
#         "out": {"kernel": jnp.ones_like(out_kernel), "bias": jnp.zeros_like(proj_bias)},
#     }


# def block_params(params):
#     # return {
#     #     "c_attn": {
#     #         "kernel": params["c_attn"]["kernel"],
#     #         "bias": params["c_attn"]["kernel"],
#     #     },
#     #     "c_proj": {
#     #         "kernel": params["c_proj"]["kernel"],
#     #         "bias": params["c_proj"]["bias"],
#     #     },
#     # }
#     return {
#         "query": {
#             "kernel": jnp.ones((768, 12, 64)),
#             "bias": jnp.zeros((12, 64)),
#         },
#         "key": {
#             "kernel": jnp.ones((768, 12, 64)),
#             "bias": jnp.zeros((12, 64)),
#         },
#         "value": {
#             "kernel": jnp.ones((768, 12, 64)),
#             "bias": jnp.zeros((12, 64)),
#         },
#         "proj": {
#             "kernel": jnp.ones_like(params["c_proj"]["kernel"]),
#             "bias": jnp.zeros_like(params["c_proj"]["bias"]),
#         },
#     }


In [12]:
# flax_variables = {"params": transform_params(test_params)}
# flax_attn.apply(flax_variables, ex_inputs)


In [13]:
# tx_attn.apply({"params": block_params(test_params)}, ex_inputs)
