flax.linen.scan module can be used to reduce the whole repeated or iterative computation graph into a single interation, therefore reduces the graph memory and compilation time.

In [1]:
# Auto-reload modules so we can edit the code and have it automatically reload
%load_ext autoreload
%autoreload 2

import timeit

import jax
import flax.linen as nn
import jax.numpy as jnp

# Basic iterative module

In [2]:

class SimpleBlock(nn.Module):
    hidden_dim: int
    num_heads: int
    dtype: jnp.dtype
    
    @nn.compact
    def __call__(self, x, attention_mask=None):
        out = {}
        out["input_hidden_state"] = x
        x += nn.SelfAttention(
            num_heads=self.num_heads,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros,
            dtype=self.dtype,
            name="self_attention"
					)(x, attention_mask)
        out["+self_attention"] = x
        x += nn.Dense(features=self.hidden_dim)(x)
        out["+dense"] = x
        return x, out

class MyIterativeModule(nn.Module):
    hidden_dim: int
    num_layers: int
    num_heads: int
    dtype: jnp.dtype
    
    def setup(self):
        self.blocks = [
            SimpleBlock(
                name=f"block_{i}",
                hidden_dim=self.hidden_dim,
                num_heads=self.num_heads,
                dtype=self.dtype
						) 
            for i in range(self.num_layers)  # Use num_layers instead of hardcoded 3
        ]

    def __call__(self, x):
        out = {}
        for i, block in enumerate(self.blocks):
          x, layer_out = block(x)
          out[f"block_{i}"] = layer_out
        return x, out

BATCH_SIZE = 8
SEQ_LEN = 32
HIDDEN_DIM = 1024
LAYERS = 4
NUM_HEADS = 4
DTYPE = jnp.float32

model = MyIterativeModule(hidden_dim=HIDDEN_DIM, num_layers=LAYERS, num_heads=NUM_HEADS, dtype=DTYPE)
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
dummy_input = jax.random.normal(rng, (1, SEQ_LEN, HIDDEN_DIM))
params = model.init(init_rng, dummy_input)

jax.tree.map(jnp.shape, params)

{'params': {'block_0': {'Dense_0': {'bias': (1024,), 'kernel': (1024, 1024)},
   'self_attention': {'key': {'bias': (4, 256), 'kernel': (1024, 4, 256)},
    'out': {'bias': (1024,), 'kernel': (4, 256, 1024)},
    'query': {'bias': (4, 256), 'kernel': (1024, 4, 256)},
    'value': {'bias': (4, 256), 'kernel': (1024, 4, 256)}}},
  'block_1': {'Dense_0': {'bias': (1024,), 'kernel': (1024, 1024)},
   'self_attention': {'key': {'bias': (4, 256), 'kernel': (1024, 4, 256)},
    'out': {'bias': (1024,), 'kernel': (4, 256, 1024)},
    'query': {'bias': (4, 256), 'kernel': (1024, 4, 256)},
    'value': {'bias': (4, 256), 'kernel': (1024, 4, 256)}}},
  'block_2': {'Dense_0': {'bias': (1024,), 'kernel': (1024, 1024)},
   'self_attention': {'key': {'bias': (4, 256), 'kernel': (1024, 4, 256)},
    'out': {'bias': (1024,), 'kernel': (4, 256, 1024)},
    'query': {'bias': (4, 256), 'kernel': (1024, 4, 256)},
    'value': {'bias': (4, 256), 'kernel': (1024, 4, 256)}}},
  'block_3': {'Dense_0': {'bias':

In [4]:
input = jax.random.normal(rng, (BATCH_SIZE, SEQ_LEN, HIDDEN_DIM))
# JIT compile the model's apply function for faster execution
jit_apply = jax.jit(model.apply)

def measure_time(fn, params, input, num_runs=10):
  total_time = timeit.timeit(
    lambda: jax.tree.map(lambda x: x.block_until_ready(), fn(params, input)),
    number=num_runs
  )
  avg_time = total_time / num_runs
  if avg_time < 1e-3:
    return f"{avg_time * 1e6:.2f} μs"
  elif avg_time < 1:
    return f"{avg_time * 1e3:.2f} ms"
  else:
    return f"{avg_time:.2f} s"

print(f"Average time for regular apply: {measure_time(model.apply, params, input, num_runs=10)}")
print(f"Time for first JIT call: {measure_time(jit_apply, params, input, num_runs=1)}")
print(f"Average time for subsequent JIT calls: {measure_time(jit_apply, params, input, num_runs=10)}")


Average time for regular apply: 70.92 ms
Time for first JIT call: 637.10 ms
Average time for subsequent JIT calls: 488.12 μs


# Scan Example

In [None]:
class MyScanModule(nn.Module):
    hidden_dim: int
    num_layers: int
    num_heads: int
    dtype: jnp.dtype
    
    def setup(self):
        self.blocks = nn.scan(
            SimpleBlock,
            variable_axes={},  # no scan over params
            variable_broadcast="params",  # share params
            split_rngs={"params": False},  # same rng for all
            in_axes=0,  # assume input is (layers, ...)
            out_axes=0,
            length=self.num_layers,
		)(
			hidden_dim=self.hidden_dim,
			num_heads=self.num_heads,
			dtype=self.dtype
		)
    def __call__(self, x):
           return self.blocks(x)

scan_model = MyScanModule(hidden_dim=HIDDEN_DIM, num_layers=LAYERS, num_heads=NUM_HEADS, dtype=DTYPE)
scan_params = scan_model.init(init_rng, dummy_input)

jax.tree.map(jnp.shape, scan_params)

ValueError: not enough values to unpack (expected 2, got 1)

# Practical Example: Distributed LLM inference with KV cachingm