In [3]:
import os
import sys

sys.path.append(os.path.abspath(os.path.join(os.getcwd(), os.pardir)))

import jax
import jax.numpy as jnp

from flax.linen import MultiHeadDotProductAttention as FlaxAttention
from tx.modules import Attention as TxAttention

In [4]:
from tx.models.gpt2 import PretrainedGPT2Model
from tx.network import GenerativeModel
from transformers import GPT2TokenizerFast


reference_gpt2 = GenerativeModel(
    config=PretrainedGPT2Model.tx_config,
    variables={"params": PretrainedGPT2Model.from_pretrained("gpt2").to_params()},
    tokenizer=GPT2TokenizerFast.from_pretrained("gpt2"),
)

## Compare Attention Implementations

In [None]:
NUM_HEADS = 4
HEAD_DIM = 10
MODEL_DIM = NUM_HEADS * HEAD_DIM

flax_attn = nn.MultiHeadDotProductAttention(
    num_heads=NUM_HEADS, qkv_features=MODEL_DIM, out_features=MODEL_DIM
)
tx_attn = TXAttention(num_heads=NUM_HEADS, head_dim=HEAD_DIM, model_dim=MODEL_DIM)


def convert_params(tx_params):
    c_attn, c_proj = tx_params["c_attn"], tx_params["c_proj"]

    qkv_kernel = jnp.split(c_attn["kernel"], 3, axis=-1)
    reshape_kernel = lambda a: jnp.reshape(a, (qkv_kernel[0].shape[0], NUM_HEADS, HEAD_DIM))
    q_kernel, k_kernel, v_kernel = tuple(map(reshape_kernel, qkv_kernel))
    o_kernel = jnp.reshape(c_proj["kernel"], (NUM_HEADS, HEAD_DIM, MODEL_DIM))

    qkv_bias = jnp.split(c_attn["bias"], 3, axis=-1)
    reshape_bias = lambda a: jnp.reshape(a, (NUM_HEADS, HEAD_DIM))
    q_bias, k_bias, v_bias = tuple(map(reshape_bias, qkv_bias))
    o_bias = c_proj["bias"]

    flax_params = {}
    flax_params["query"] = {"kernel": q_kernel, "bias": q_bias}
    flax_params["key"] = {"kernel": k_kernel, "bias": k_bias}
    flax_params["value"] = {"kernel": v_kernel, "bias": v_bias}
    flax_params["out"] = {"kernel": o_kernel, "bias": o_bias}

    return flax_params


ex_inputs = jnp.array([[1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1]])
tx_params = tx_attn.init(jr.PRNGKey(0), x=ex_inputs)["params"]
flax_params = convert_params(tx_params)

tx_out = tx_attn.apply({"params": tx_params}, ex_inputs)
flax_out = flax_attn.apply(
    {"params": flax_params},
    ex_inputs,
    ex_inputs,
    mask=nn.make_causal_mask(
        jnp.ones((ex_inputs.shape[0],), dtype="bool"),
        dtype="bool",
    )[:, :4, :4],
)

jnp.allclose(tx_out, flax_out, atol=1e-6)


In [None]:

jax.tree_util.tree_map(lambda a: a.shape, tx_as_flax)


In [None]:
tx_output = tx_attn.apply({"params": tx_params}, ex_inputs)
tx_output

In [None]:
flax_output = flax_attn.apply({"params": flax_params}, ex_inputs)
flax_output

In [None]:
tx_as_flax_output = flax_attn.apply({"params": tx_as_flax}, ex_inputs)
tx_as_flax_output

In [None]:
# test_params = reference_gpt2.variables["params"]["block_0"]["attn"]


# def transform_params(params):
#     attn_kernel, attn_bias = params["c_attn"]["kernel"], params["c_attn"]["bias"]
#     proj_kernel, proj_bias = params["c_proj"]["kernel"], params["c_proj"]["bias"]

#     def transform_kernel(a):
#         a = jnp.transpose(a)
#         a = jnp.reshape(a, (768, 12, 64))
#         return a

#     def transform_bias(b):
#         return jnp.reshape(b, (12, 64))

#     q_kernel, k_kernel, v_kernel = map(
#         transform_kernel, jnp.split(attn_kernel, 3, axis=-1)
#     )
#     q_bias, k_bias, v_bias = map(transform_bias, jnp.split(attn_bias, 3, axis=-1))

#     out_kernel = transform_kernel(proj_kernel)
#     out_kernel = jnp.transpose(out_kernel, (1, 2, 0))
#     # out_bias = jnp.reshape(proj_bias, (12, 64))

#     # return {
#     #     "query": {"kernel": q_kernel, "bias": q_bias},
#     #     "key": {"kernel": k_kernel, "bias": k_bias},
#     #     "value": {"kernel": v_kernel, "bias": v_bias},
#     #     "out": {"kernel": out_kernel, "bias": proj_bias},
#     # }
#     return {
#         "query": {"kernel": jnp.ones_like(q_kernel), "bias": jnp.zeros_like(q_bias)},
#         "key": {"kernel": jnp.ones_like(k_kernel), "bias": jnp.zeros_like(k_bias)},
#         "value": {"kernel": jnp.ones_like(v_kernel), "bias": jnp.zeros_like(v_bias)},
#         "out": {"kernel": jnp.ones_like(out_kernel), "bias": jnp.zeros_like(proj_bias)},
#     }


# def block_params(params):
#     # return {
#     #     "c_attn": {
#     #         "kernel": params["c_attn"]["kernel"],
#     #         "bias": params["c_attn"]["kernel"],
#     #     },
#     #     "c_proj": {
#     #         "kernel": params["c_proj"]["kernel"],
#     #         "bias": params["c_proj"]["bias"],
#     #     },
#     # }
#     return {
#         "query": {
#             "kernel": jnp.ones((768, 12, 64)),
#             "bias": jnp.zeros((12, 64)),
#         },
#         "key": {
#             "kernel": jnp.ones((768, 12, 64)),
#             "bias": jnp.zeros((12, 64)),
#         },
#         "value": {
#             "kernel": jnp.ones((768, 12, 64)),
#             "bias": jnp.zeros((12, 64)),
#         },
#         "proj": {
#             "kernel": jnp.ones_like(params["c_proj"]["kernel"]),
#             "bias": jnp.zeros_like(params["c_proj"]["bias"]),
#         },
#     }


In [None]:
# flax_variables = {"params": transform_params(test_params)}
# flax_attn.apply(flax_variables, ex_inputs)


In [None]:
# tx_attn.apply({"params": block_params(test_params)}, ex_inputs)
