In [2]:
import numpy as np
import jax
import jax.numpy as jnp
from flax import nnx

import torch

import transformerlib
from transformerlib import multi_head_sdpa, make_causal_attn_mask

## Ensure SDPA was implemented properly by testing conformance with PyTorch implementation:


In [10]:
def _test_sdpa_function():
    # Batch=5, num_heads=4, seq_len=9, head_dim=4
    sample_q = np.random.randn(5, 9, 4, 4).astype(np.float32)
    sample_k = np.random.randn(5, 9, 4, 4).astype(np.float32)
    sample_v = np.random.randn(5, 9, 4, 4).astype(np.float32)
    sample_mask = make_causal_attn_mask(sample_q.shape[-2])
    ones_mask = np.ones((1, 1, sample_q.shape[-2], sample_k.shape[-2]), dtype=bool)

    pytorch_output = torch.nn.functional.scaled_dot_product_attention(
        query=torch.from_numpy(sample_q),
        key=torch.from_numpy(sample_k),
        value=torch.from_numpy(sample_v),
        attn_mask=torch.from_numpy(np.array(sample_mask)),
        # attn_mask=torch.from_numpy(sample_mask),
    )
    pytorch_output = pytorch_output.permute(0, 2, 1, 3).flatten(start_dim=-2)
    print("Pytorch output:")
    print(pytorch_output)
    print()

    jax_test = multi_head_sdpa(
        query=jnp.array(sample_q),
        key=jnp.array(sample_k),
        value=jnp.array(sample_v),
        mask=sample_mask,
        rngs=nnx.Rngs(dropout=jax.random.PRNGKey(0)),
        dropout_p=0.0,
        mask_value=float("-inf"),
    )
    print("Jax output:")
    print(jax_test[1])
    print()

    # See if they agree
    pt_output_np = pytorch_output.detach().numpy()
    jax_output_np = jax_test[1].block_until_ready()

    print("Outputs are consistent:")
    print(np.allclose(pt_output_np, jax_output_np, atol=1e-5))
    print("difference norm:")
    print(np.linalg.norm(pt_output_np - jax_output_np))


_test_sdpa_function()

Pytorch output:
tensor([[[-5.4639e-01, -2.4829e+00, -1.1164e-01,  1.1420e-01, -1.1732e+00,
           2.0667e+00,  3.1358e-01,  9.4439e-02,  1.6393e-01, -2.2072e+00,
           1.3147e+00,  9.1219e-01, -9.2440e-01, -1.0022e-01, -1.7523e+00,
          -1.6578e+00, -8.4128e-01, -1.1413e+00, -6.8456e-01, -5.8002e-01,
           1.0832e+00, -8.7145e-01,  1.1184e-01, -2.5475e-01, -5.4909e-01,
          -1.5671e+00, -1.8968e-01,  1.1013e+00, -2.7438e-02,  6.0801e-01,
           7.9748e-01, -1.0337e+00, -1.4015e+00,  1.2510e-01,  6.2605e-01,
           2.2658e+00],
         [-4.6226e-01, -2.0925e+00,  1.7030e-02,  1.5555e-01, -4.9066e-01,
           1.3839e+00,  5.1146e-01,  2.0277e-01,  9.7654e-01,  1.7557e-01,
           1.3824e+00,  1.5429e+00, -7.1960e-01,  3.3479e-02, -1.1538e+00,
          -9.0561e-01, -7.3632e-01, -4.1154e-01, -1.1444e+00,  1.6132e+00,
           9.9912e-01, -8.5900e-01,  8.5806e-03, -1.3314e-01, -6.5644e-01,
          -5.0175e-01,  1.8850e-01, -1.3602e+00, -9.7960e-01

## Test transformer implementation


In [6]:
rngs = nnx.Rngs(params=jax.random.key(0), dropout=jax.random.key(1))
transformer = transformerlib.Transformer(
    num_layers=10,
    d_model=512,
    num_heads=16,
    d_feedforward=1024,
    attn_dropout_p=0.0,
    rngs=rngs,
)


def _test_transformer_mask():
    # Ensures that no future tokens are attended to
    # Batch=1, seq_len=5, d_embedding=16
    sample_seq = np.random.randn(1, 5, 512).astype(np.float32)

    # Poison one element of the input sequence
    sample_seq[:, 3, :] = -1e10
    # If masking is done properly, elements 0, 1, and 2 should be fine

    sample_seq = jnp.array(sample_seq)
    transformer_output = transformer(sample_seq, use_causal_mask=True)

    print(transformer_output)

    # Try with pytorch for comparison
    sample_seq = torch.from_numpy(np.array(sample_seq))
    transformer_pytorch = torch.nn.TransformerEncoder(
        torch.nn.TransformerEncoderLayer(
            d_model=sample_seq.shape[-1],
            nhead=4,
            dim_feedforward=128,
            dropout=0.0,
            batch_first=True,
        ),
        num_layers=2,
    )
    mask = torch.nn.Transformer.generate_square_subsequent_mask(sample_seq.shape[1])
    mask = mask.to(sample_seq.device)
    transformer_pytorch_output = transformer_pytorch(sample_seq, mask=mask)
    print("Pytorch output:")
    print(transformer_pytorch_output)

In [7]:
_test_transformer_mask()

[[[-0.64218295 -1.5010592  -0.30352145 ...  0.57103145  0.9077463
    0.48531556]
  [-0.76753986 -1.3241218  -0.61255056 ...  0.7193657   1.0658777
    0.3473599 ]
  [-0.73136145 -1.2820809  -0.4903155  ...  0.6091205   1.2307248
    0.45945734]
  [-0.7353072  -1.3378686  -0.60222465 ...  0.525545    1.1315358
    0.46509817]
  [-0.7069002  -1.3818822  -0.6843058  ...  0.6023262   1.0768098
    0.4751909 ]]]
Pytorch output:
tensor([[[-0.1905,  0.3575,  1.1162,  ...,  0.6089,  0.6590,  0.2776],
         [-0.6990, -0.3251,  2.1994,  ...,  0.4689,  0.1280,  3.0869],
         [-0.5854,  0.3504, -0.2713,  ...,  0.3459, -1.7292,  1.2380],
         [ 0.8965, -0.1283,  1.5940,  ..., -0.1220, -0.0269, -0.3301],
         [-1.2585,  0.0236,  1.4659,  ..., -0.0803,  0.1859,  0.2346]]],
       grad_fn=<NativeLayerNormBackward0>)
