flax.linen.scan module can be used to reduce the whole repeated or iterative computation graph into a single interation, therefore reduces the graph memory and compilation time.

In [1]:
# Auto-reload modules so we can edit the code and have it automatically reload
%load_ext autoreload
%autoreload 2

import timeit

import jax
import jaxlib
import flax.linen as nn
import jax.numpy as jnp

def measure_time(fn, params, input, num_runs=5):
  total_time = timeit.timeit(
    lambda: jax.tree.map(lambda x: x.block_until_ready(), fn(params, input)),
    number=num_runs
  )
  avg_time = total_time / num_runs
  if avg_time < 1e-3:
    return f"{avg_time * 1e6:.2f} Î¼s"
  elif avg_time < 1:
    return f"{avg_time * 1e3:.2f} ms"
  else:
    return f"{avg_time:.2f} s"

def compare_time(fn, params, input, num_runs=5):
  assert not isinstance(fn, jaxlib._jax.PjitFunction), "Function appears to be jitted"
  jit_fn = jax.jit(fn)
  print(f"Time for the regular apply: {measure_time(fn, params, input, num_runs)}")
  print(f"Time for the first JIT call: {measure_time(jit_fn, params, input, num_runs)}")
  print(f"Time for the subsequent JIT calls: {measure_time(jit_fn, params, input, num_runs)}")

BATCH_SIZE = 8
SEQ_LEN = 32
HIDDEN_DIM = 1024
LAYERS = 42
NUM_HEADS = 4
DTYPE = jnp.float32
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
dummy_input = jax.random.normal(rng, (1, SEQ_LEN, HIDDEN_DIM))
input = jax.random.normal(rng, (BATCH_SIZE, SEQ_LEN, HIDDEN_DIM))


# Basic iterative module

In [2]:
class SimpleBlock(nn.Module):
    hidden_dim: int
    num_heads: int
    dtype: jnp.dtype
    
    @nn.compact
    def __call__(self, x, attention_mask=None):
        out = {}
        out["input_hidden_state"] = x
        x += nn.SelfAttention(
            num_heads=self.num_heads,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros,
            dtype=self.dtype,
            name="self_attention"
					)(x, attention_mask)
        out["+self_attention"] = x
        x += nn.Dense(features=self.hidden_dim)(x)
        out["+dense"] = x
        return x, out

class MyIterativeModule(nn.Module):
    hidden_dim: int
    num_layers: int
    num_heads: int
    dtype: jnp.dtype
    
    def setup(self):
        self.blocks = [
            SimpleBlock(
                name=f"layer_{i}",
                hidden_dim=self.hidden_dim,
                num_heads=self.num_heads,
                dtype=self.dtype
						) 
            for i in range(self.num_layers)  # Use num_layers instead of hardcoded 3
        ]

    def __call__(self, x):
        out = {}
        for i, block in enumerate(self.blocks):
          x, layer_out = block(x)
          out[f"layer_{i}"] = layer_out
        return x, out

model = MyIterativeModule(hidden_dim=HIDDEN_DIM, num_layers=LAYERS, num_heads=NUM_HEADS, dtype=DTYPE)
params = model.init(init_rng, dummy_input)

compare_time(model.apply, params, input)

jax.tree.map(jnp.shape, params)

Time for the regular apply: 1.07 s
Time for the first JIT call: 1.23 s
Time for the subsequent JIT calls: 2.91 ms


{'params': {'layer_0': {'Dense_0': {'bias': (1024,), 'kernel': (1024, 1024)},
   'self_attention': {'key': {'bias': (4, 256), 'kernel': (1024, 4, 256)},
    'out': {'bias': (1024,), 'kernel': (4, 256, 1024)},
    'query': {'bias': (4, 256), 'kernel': (1024, 4, 256)},
    'value': {'bias': (4, 256), 'kernel': (1024, 4, 256)}}},
  'layer_1': {'Dense_0': {'bias': (1024,), 'kernel': (1024, 1024)},
   'self_attention': {'key': {'bias': (4, 256), 'kernel': (1024, 4, 256)},
    'out': {'bias': (1024,), 'kernel': (4, 256, 1024)},
    'query': {'bias': (4, 256), 'kernel': (1024, 4, 256)},
    'value': {'bias': (4, 256), 'kernel': (1024, 4, 256)}}},
  'layer_10': {'Dense_0': {'bias': (1024,), 'kernel': (1024, 1024)},
   'self_attention': {'key': {'bias': (4, 256), 'kernel': (1024, 4, 256)},
    'out': {'bias': (1024,), 'kernel': (4, 256, 1024)},
    'query': {'bias': (4, 256), 'kernel': (1024, 4, 256)},
    'value': {'bias': (4, 256), 'kernel': (1024, 4, 256)}}},
  'layer_11': {'Dense_0': {'bias

# Scan Examples

In [3]:
# List of different scan configurations for nn.scan
scan_args_list = [
    # Configuration 1: 
    # - No parameter axes (shared parameters across layers)
    # - Broadcast parameters to all layers
    # - Don't split RNGs for parameters
    # - Input/Output along axis 0 (batch dimension)
    {
        "variable_axes": {},  # No parameter axes (shared params)
        "variable_broadcast": "params",  # Broadcast params to all layers
        "split_rngs": {"params": False},  # Don't split RNGs for params
        "in_axes": 0,  # Input along axis 0 (batch dim)
        "out_axes": 0,  # Output along axis 0 (batch dim)
    },
    # Configuration 2:
    # - Parameters along axis 0 (separate params per layer)
    # - Don't broadcast parameters
    # - Split RNGs for parameters
    # - Broadcast input to all layers
    # - Output along axis 0 (batch dimension)
    {
        "variable_axes": {"params": 0},  # Parameters along axis 0 (per layer)
        "variable_broadcast": False,  # Don't broadcast params
        "split_rngs": {"params": True},  # Split RNGs for params
        "in_axes": nn.broadcast,  # Broadcast input to all layers
        "out_axes": 0,  # Output along axis 0 (batch dim)
    },
]

In [4]:
IDX = 1
class MyScanModule(nn.Module):
    hidden_dim: int
    num_layers: int
    num_heads: int
    dtype: jnp.dtype
    
    def setup(self):
      self.blocks = nn.scan(
         SimpleBlock,
         length=self.num_layers,
         **scan_args_list[IDX],
			)(
				hidden_dim=self.hidden_dim,
				num_heads=self.num_heads,
				dtype=self.dtype
			)

    def __call__(self, x):
      x, scan_out = self.blocks(x)
      return x, scan_out

scan_model = MyScanModule(hidden_dim=HIDDEN_DIM, num_layers=LAYERS, num_heads=NUM_HEADS, dtype=DTYPE)
scan_params = scan_model.init(init_rng, dummy_input)

compare_time(scan_model.apply, scan_params, input)

jax.tree.map(jnp.shape, scan_params)

Time for the regular apply: 532.56 ms
Time for the first JIT call: 95.10 ms
Time for the subsequent JIT calls: 3.49 ms


{'params': {'blocks': {'Dense_0': {'bias': (42, 1024),
    'kernel': (42, 1024, 1024)},
   'self_attention': {'key': {'bias': (42, 4, 256),
     'kernel': (42, 1024, 4, 256)},
    'out': {'bias': (42, 1024), 'kernel': (42, 4, 256, 1024)},
    'query': {'bias': (42, 4, 256), 'kernel': (42, 1024, 4, 256)},
    'value': {'bias': (42, 4, 256), 'kernel': (42, 1024, 4, 256)}}}}}

# Practical Example: Distributed LLM inference with KV cachingm