In [13]:
import numpy as np
import timeit
import jax
import flax
from flax import linen as nn
from jax import numpy as jnp
from flax import nnx

In [58]:
rngs = nnx.Rngs(0)

class MultiHeadAttention(nn.Module):  #@save
    E_q: int
    E_k: int
    E_v: int    
    E_total: int
    nheads: int
    
    def setup(self):
        self.query_proj = nnx.Linear(E_q, E_total, rngs=rngs)
        self.key_proj = nnx.Linear(E_k, E_total, rngs=rngs)
        self.value_proj = nnx.Linear(E_v, E_total, rngs=rngs)
        E_out = self.E_q
        self.out_proj = nnx.Linear(E_total, E_out, rngs=rngs)
        assert E_total % nheads == 0, "Embedding dim is not divisible by nheads"
        self.E_head = E_total // nheads
        self.E_q_last = self.E_q - 1
        self.E_k_last = self.E_k - 1
        self.E_v_last = self.E_v - 1        

    @nn.compact
    def __call__(self, queries, keys, values):
        query = self.query_proj(queries)
        key = self.key_proj(keys)
        value = self.value_proj(values)

        # Step 2. Split heads and prepare for SDPA
        # reshape query, key, value to separate by head
        # (N, L_t, E_total) -> (N, L_t, nheads, E_head) -> (N, nheads, L_t, E_head)
        query = jnp.reshape(query, (query.shape[0], query.shape[1], self.nheads, self.E_head)).transpose(0, 2, 1, 3)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        key = jnp.reshape(key, (key.shape[0], key.shape[1], self.nheads, self.E_head)).transpose(0, 2, 1, 3)
        # (N, L_s, E_total) -> (N, L_s, nheads, E_head) -> (N, nheads, L_s, E_head)
        value = jnp.reshape(value, (value.shape[0], value.shape[1], self.nheads, self.E_head)).transpose(0, 2, 1, 3)

        # Step 3. Run SDPA
        # (N, nheads, L_t, E_head)
        attn_output = jax.nn.dot_product_attention(query, key, value, is_causal=True)
        
        # (N, nheads, L_t, E_head) -> (N, L_t, nheads, E_head) -> (N, L_t, E_total)
        attn_output = jnp.transpose(attn_output, (0, 2, 1, 3))
        attn_output = attn_output.reshape(attn_output.shape[0], attn_output.shape[1], E_total)

        # Step 4. Apply output projection
        # (N, L_t, E_total) -> (N, L_t, E_out)
        attn_output = self.out_proj(attn_output)

        return attn_output

In [15]:
rkey = jax.random.key(1)
N = 512
E_q, E_k, E_v, E_total = 512, 512, 512, 512
E_out = E_q
nheads = 8

In [16]:
def zipf_sentence_lengths(alpha: float, batch_size: int):
    # generate fake corpus by unigram Zipf distribution
    # from wikitext-2 corpus, we get rank "." = 3, "!" = 386, "?" = 858
    sentence_lengths = np.empty(batch_size, dtype=int)
    for ibatch in range(batch_size):
        sentence_lengths[ibatch] = 1
        word = np.random.zipf(alpha)
        while word != 3 and word != 386 and word != 858:
            sentence_lengths[ibatch] += 1
            word = np.random.zipf(alpha)
    return jnp.asarray(sentence_lengths)

In [17]:
def gen_batch(N, E_q, E_k, E_v):
    # generate semi-realistic data using Zipf distribution for sentence lengths
    sentence_lengths = zipf_sentence_lengths(alpha=1.2, batch_size=N)

    # Note: the torch.jagged layout is a nested tensor layout that supports a single ragged
    # dimension and works with torch.compile. The batch items each have shape (B, S*, D)
    # where B = batch size, S* = ragged sequence length, and D = embedding dimension.
    max_l = max(sentence_lengths)
    query = jnp.stack([
        jnp.pad(jax.random.uniform(rkey, (l.item(), E_q)), ((0, max_l - l.item()), (0, 0)))
        for l in sentence_lengths
    ])
        # query = jnp.concatenate([
    #     jnp.pad(jax.random.uniform(rkey, (l.item(), E_q)), ((0, 0), (max_l - l, 0)))
    #     for l in sentence_lengths
    # ])

    key = jnp.stack([
        jnp.pad(jax.random.uniform(rkey, (l.item(), E_k)), ((0, max_l - l.item()), (0, 0)))
        for l in sentence_lengths
    ])

    value = jnp.stack([
        jnp.pad(jax.random.uniform(rkey, (l.item(), E_v)), ((0, max_l - l.item()), (0, 0)))
        for l in sentence_lengths
    ])

    # inds = [0]
    # for s in sentence_lengths:
    #     inds.append(inds[-1] + s.item())

    # slice_inds = jnp.stack([jnp.asarray(inds[:-1]), jnp.asarray(inds[1:])], 1)
    print(max_l)
    return query, key, value, sentence_lengths

query, key, value, sentence_lengths = gen_batch(N, E_q, E_k, E_v)
print(query.shape)

175
(512, 175, 512)


In [59]:
mha = MultiHeadAttention(E_q, E_k, E_v, E_total, nheads)
params = mha.init(rkey, query, key, value)
mha.apply(params, query, key, value)

Array([[[ 0.49519572,  0.14322187,  0.8624785 , ..., -0.10200696,
         -0.34227815,  0.40329358],
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        ...,
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ],
        [ 0.        ,  0.        ,  0.        , ...,  0.        ,
          0.        ,  0.        ]],

       [[ 0.34477958,  0.41772458,  0.6371734 , ..., -0.15534085,
         -0.28626376,  0.7203759 ],
        [ 0.16461064,  0.3462381 ,  0.53304756, ...,  0.03806954,
         -0.32316613,  0.49088332],
        [ 0.4150186 ,  0.48036662,  0.6404403 , ...,  0.36826617,
         -0.48098183,  0.6967878 ],
        ...,
        [ 0.        ,  0.        ,  0.        , ...,  

In [None]:
import jax
import jax.numpy as jnp
from functools import partial

# Example data: List of arrays of different lengths
ragged_lists = [
    [1, 2, 3],
    [4, 5],
    [6, 7, 8, 9],
    [10]
]

# Convert to concatenated format
data = jnp.array([x for sublist in ragged_lists for x in sublist])
start_indices = jnp.array([0, 3, 5, 9, 10])  # Where each subarray begins
lengths = jnp.diff(start_indices)  # Length of each subarray

def process_slice(start_idx, length, data):
    """Process a single slice of the data."""
    # Get the slice for this segment
    slice_data = jax.lax.dynamic_slice(data, (start_idx,), (length,))
    
    # Example operation: compute mean and subtract it from each element
    slice_mean = jnp.mean(slice_data)
    return slice_data - slice_mean

# Create a version that handles variable lengths by padding
@partial(jax.vmap, in_axes=(0, 0, None))
def process_all_slices(starts, lens, full_data):
    return process_slice(starts, lens, full_data)

# Process all segments
result = process_all_slices(start_indices[:-1], lengths, data)

# Helper function to convert back to list format for visualization
def get_segments(data, start_indices):
    return [data[start_indices[i]:start_indices[i+1]] 
            for i in range(len(start_indices)-1)]

# Example of more complex processing: compute cumsum within each segment
def cumsum_slice(start_idx, length, data):
    """Compute cumulative sum within a slice."""
    slice_data = jax.lax.dynamic_slice(data, (start_idx,), (length,))
    return jnp.cumsum(slice_data)

@partial(jax.vmap, in_axes=(0, 0, None))
def cumsum_all_slices(starts, lens, full_data):
    return cumsum_slice(starts, lens, full_data)

# Demonstrate usage
if __name__ == "__main__":
    print("Original data:", data)
    print("Start indices:", start_indices)
    print("Lengths:", lengths)
    
    # Process and show results
    mean_centered = process_all_slices(start_indices[:-1], lengths, data)
    print("\nMean-centered segments:")
    print(get_segments(mean_centered, start_indices))
    
    # Compute cumulative sums
    cumsum_result = cumsum_all_slices(start_indices[:-1], lengths, data)
    print("\nCumulative sums within segments:")
    print(get_segments(cumsum_result, start_indices))

In [83]:


add = lambda x: x + 1
dyn = lambda s, e, arr: add(jax.lax.dynamic_slice_in_dim(arr, s, e, axis=0))

arr = jnp.ones((200, 512))
start_inds = jnp.asarray([0, 50, 100, 150, 175])
end_inds = jnp.asarray([50, 100, 150, 175, 200])
r = jax.vmap(dyn, in_axes=(0, 0, None))(start_inds, end_inds, arr)


TracerIntegerConversionError: The __index__() method was called on traced array with shape int32[]
This BatchTracer with object id 134404026083392 was created on line:
  /tmp/ipykernel_889141/1341147848.py:7:4 (<module>)
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerIntegerConversionError