In [2]:
import numpy as np
import timeit
import jax
import flax
from flax import linen as nn
from jax import numpy as jnp
from flax import nnx

In [23]:
rngs = nnx.Rngs(0)

class MultiHeadAttention(nn.Module):  #@save
    E_q: int
    E_k: int
    E_v: int    
    E_total: int
    nheads: int
    
    def setup(self):
        self.query_proj = nnx.Linear(E_q, E_total, rngs=rngs)
        self.key_proj = nnx.Linear(E_k, E_total, rngs=rngs)
        self.value_proj = nnx.Linear(E_v, E_total, rngs=rngs)
        E_out = self.E_q
        self.out_proj = nnx.Linear(E_total, E_out, rngs=rngs)
        assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
        self.E_head = E_total // nheads
        self.E_q_last = self.E_q - 1
        self.E_k_last = self.E_k - 1
        self.E_v_last = self.E_v - 1        

    @nn.compact
    def __call__(self, queries, keys, values):
        query = self.query_proj(queries)
        key = self.key_proj(keys)
        value = self.value_proj(values)

        # Step 2. Split heads and prepare for SDPA
        # reshape query, key, value to separate by head
        # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
        query = jnp.reshape(query, (query.shape[0], query.shape[1], self.nheads, self.E_head)).transpose(0, 2, 1, 3)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        key = jnp.reshape(key, (key.shape[0], key.shape[1], self.nheads, self.E_head)).transpose(0, 2, 1, 3)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        value = jnp.reshape(value, (value.shape[0], value.shape[1], self.nheads, self.E_head)).transpose(0, 2, 1, 3)

        # Step 3. Run SDPA
        # (N, nheads, L_t, E_head)
        attn_output = jax.nn.dot_product_attention(query, key, value, is_causal=True)
        
        # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
        attn_output = jnp.transpose(attn_output, (0, 2, 1, 3))
        attn_output = attn_output.reshape(attn_output.shape[0], attn_output.shape[1], E_total)

        # Step 4. Apply output projection
        # (N, L_t, E_total) -> (N, L_t, E_out)
        attn_output = self.out_proj(attn_output)

        return attn_output

In [24]:
rkey = jax.random.key(1)
N = 512
E_q, E_k, E_v, E_total = 512, 512, 512, 512
E_out = E_q
nheads = 8

In [25]:
def zipf_sentence_lengths(alpha: float, batch_size: int):
    # generate fake corpus by unigram Zipf distribution
    # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
    sentence_lengths = np.empty(batch_size, dtype=int)
    for ibatch in range(batch_size):
        sentence_lengths[ibatch] = 1
        word = np.random.zipf(alpha)
        while word != 3 and word != 386 and word != 858:
            sentence_lengths[ibatch] += 1
            word = np.random.zipf(alpha)
    return jnp.asarray(sentence_lengths)

In [26]:
def gen_batch(N, E_q, E_k, E_v):
    # generate semi-realistic data using Zipf distribution for sentence lengths
    sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)

    # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
    # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
    # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
    max_l = max(sentence_lengths)
    query = jnp.stack([
        jnp.pad(jax.random.uniform(rkey, (l.item(), E_q)), ((0, max_l - l.item()), (0, 0)))
        for l in sentence_lengths
    ])
        # query = jnp.concatenate([
    #     jnp.pad(jax.random.uniform(rkey, (l.item(), E_q)), ((0, 0), (max_l - l, 0)))
    #     for l in sentence_lengths
    # ])

    key = jnp.stack([
        jnp.pad(jax.random.uniform(rkey, (l.item(), E_k)), ((0, max_l - l.item()), (0, 0)))
        for l in sentence_lengths
    ])

    value = jnp.stack([
        jnp.pad(jax.random.uniform(rkey, (l.item(), E_v)), ((0, max_l - l.item()), (0, 0)))
        for l in sentence_lengths
    ])

    # inds = [0]
    # for s in sentence_lengths:
    #     inds.append(inds[-1] + s.item())

    # slice_inds = jnp.stack([jnp.asarray(inds[:-1]), jnp.asarray(inds[1:])], 1)
    print(max_l)
    return query, key, value, sentence_lengths

query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v)
print(query.shape)

160
(512, 160, 512)


In [27]:
mha = MultiHeadAttention(E_q, E_k, E_v, E_total, nheads)
params = mha.init(rkey, query, key, value)
mha.apply(params, query, key, value)

Array([[[ 0.12568846,  0.4754376 ,  0.97867215, ..., -0.07529405,
         -0.32287282,  0.3695725 ],
        [ 0.1775682 ,  0.43692213,  0.5709628 , ...,  0.10295839,
         -0.42616335,  0.6151868 ],
        [ 0.20226231,  0.46459487,  0.5802369 , ..., -0.10768469,
         -0.08904526,  0.6365539 ],
        ...,
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ]],

       [[ 0.2697184 ,  0.41408727,  0.67076516, ..., -0.01556746,
         -0.6889983 ,  0.5870678 ],
        [ 0.26947126,  0.3672987 ,  1.149145  , ..., -0.00326963,
         -0.41819423,  0.4805088 ],
        [ 0.48358738,  0.4095194 ,  0.94281244, ..., -0.49481627,
         -0.23489867,  0.5999225 ],
        ...,
        [ 0.        ,  0.        ,  0.        , ...,  

In [15]:
def benchmark(func, params, query, key, value):
    #torch.cuda.synchronize()
    begin = timeit.default_timer()
    output = mha.apply(params, query, key, value)
    #orch.cuda.synchronize()
    end = timeit.default_timer()
    return output, (end - begin)

jit_model = jax.jit(mha.apply, backend='gpu').lower(params, query, key, value)
#compiled_model = jit_model.compile()

output, time_padded = benchmark(mha, params, query, key, value)

avg = 0

for i in range(100):
    _, t = benchmark(mha, params, query, key, value)
    avg += t
print(t/100)

TraceContextError: Cannot call RngStream from a different trace level