flax.linen.scan module can be used to reduce the whole repeated or iterative computation graph into a single interation, therefore reduces the graph memory and compilation time.

In [1]:
# Auto-reload modules so we can edit the code and have it automatically reload
%load_ext autoreload
%autoreload 2

import timeit

import jax
import jaxlib
import flax.linen as nn
import jax.numpy as jnp


BATCH_SIZE = 16
SEQ_LEN = 32
HIDDEN_DIM = 512
LAYERS = 42
NUM_HEADS = 4
DTYPE = jnp.float32
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
dummy_input = jax.random.normal(rng, (1, SEQ_LEN, HIDDEN_DIM))
input = jax.random.normal(rng, (BATCH_SIZE, SEQ_LEN, HIDDEN_DIM))




In [2]:

def measure_time(fn, params, input, num_runs=5):
  total_time = timeit.timeit(
    lambda: jax.tree.map(lambda x: x.block_until_ready(), fn(params, input)),
    number=num_runs
  )
  avg_time = total_time / num_runs
  if avg_time < 1e-3:
    return f"{avg_time * 1e6:.2f} μs"
  elif avg_time < 1:
    return f"{avg_time * 1e3:.2f} ms"
  else:
    return f"{avg_time:.2f} s"

def compare_time(fn, params, input, num_runs=5):
  assert not isinstance(fn, jaxlib._jax.PjitFunction), "Function appears to be jitted"
  jit_fn = jax.jit(fn)
  # print(f"Average time for the regular apply: {measure_time(fn, params, input, num_runs)}")
  # print(f"Average time for the first JIT call: {measure_time(jit_fn, params, input, num_runs)}")
  output = jit_fn(params, input)
  print(f"Average time for the subsequent JIT calls: {measure_time(jit_fn, params, input, num_runs)}")

def print_cost_analysis(cost, show_utilization=False):
    """
    Pretty-print the cost analysis dictionary returned by:
        jax.jit(fn).lower(...).compile().cost_analysis()
    
    Args:
        cost: Dictionary containing cost analysis metrics
        show_utilization: Whether to print per-layer utilization metrics
    """
    print("=== JAX Cost Analysis ===")

    # Total memory accessed
    total_mem = cost.get('bytes accessedout{}', 0)
    print(f"Total estimated memory accessed: {total_mem / (1024 ** 2):.2f} MB")

    # Find peak memory estimate across all ops
    peak_mem = max(
        val for key, val in cost.items() if key.startswith('bytes accessedout{')
    )
    print(f"Estimated peak memory usage: {peak_mem / (1024 ** 2):.2f} MB")

    # Total FLOPs
    flops = cost.get('flops', 0)
    print(f"Estimated FLOPs: {flops / 1e9:.3f} GFLOPs")

    # Transcendental operations
    trans = cost.get('transcendentals', 0)
    print(f"Transcendental ops (exp, sin, etc.): {int(trans)}")

    # Optimal execution time (TPU internal estimate)
    optimal_time = cost.get('optimal_seconds', None)
    if optimal_time is not None:
        print(f"Estimated optimal execution time: {optimal_time * 1e3:.3f} ms")

    if show_utilization:
        print("=== Per-Layer Utilization (if available) ===")
        utilizations = {
            k: v for k, v in cost.items() if k.startswith('utilization')
        }

        if utilizations:
            for key, value in sorted(utilizations.items()):
                print(f"{key}: {value}")
        else:
            print("No detailed utilization metrics found.")
    print(); print()

# Basic iterative module

In [3]:
class SimpleBlock(nn.Module):
    hidden_dim: int
    num_heads: int
    dtype: jnp.dtype
    
    @nn.compact
    def __call__(self, x, attention_mask=None):
        out = {}
        out["input_hidden_state"] = x
        x += nn.SelfAttention(
            num_heads=self.num_heads,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros,
            dtype=self.dtype,
            name="self_attention"
					)(x, attention_mask)
        out["+self_attention"] = x
        x += nn.Dense(features=self.hidden_dim)(x)
        out["+dense"] = x
        return x, out

class MyIterativeModule(nn.Module):
    hidden_dim: int
    num_layers: int
    num_heads: int
    dtype: jnp.dtype
    
    def setup(self):
        self.blocks = [
            SimpleBlock(
                name=f"layer_{i}",
                hidden_dim=self.hidden_dim,
                num_heads=self.num_heads,
                dtype=self.dtype
						) 
            for i in range(self.num_layers)  # Use num_layers instead of hardcoded 3
        ]

    def __call__(self, x):
        out = {}
        for i, block in enumerate(self.blocks):
          x, layer_out = block(x)
          out[f"layer_{i}"] = layer_out
        return x, out

model = MyIterativeModule(hidden_dim=HIDDEN_DIM, num_layers=LAYERS, num_heads=NUM_HEADS, dtype=DTYPE)
params = model.init(init_rng, dummy_input)

compare_time(model.apply, params, input)

jax.tree.map(jnp.shape, params)

Average time for the subsequent JIT calls: 641.35 ms


{'params': {'layer_0': {'Dense_0': {'bias': (512,), 'kernel': (512, 512)},
   'self_attention': {'key': {'bias': (4, 128), 'kernel': (512, 4, 128)},
    'out': {'bias': (512,), 'kernel': (4, 128, 512)},
    'query': {'bias': (4, 128), 'kernel': (512, 4, 128)},
    'value': {'bias': (4, 128), 'kernel': (512, 4, 128)}}},
  'layer_1': {'Dense_0': {'bias': (512,), 'kernel': (512, 512)},
   'self_attention': {'key': {'bias': (4, 128), 'kernel': (512, 4, 128)},
    'out': {'bias': (512,), 'kernel': (4, 128, 512)},
    'query': {'bias': (4, 128), 'kernel': (512, 4, 128)},
    'value': {'bias': (4, 128), 'kernel': (512, 4, 128)}}},
  'layer_10': {'Dense_0': {'bias': (512,), 'kernel': (512, 512)},
   'self_attention': {'key': {'bias': (4, 128), 'kernel': (512, 4, 128)},
    'out': {'bias': (512,), 'kernel': (4, 128, 512)},
    'query': {'bias': (4, 128), 'kernel': (512, 4, 128)},
    'value': {'bias': (4, 128), 'kernel': (512, 4, 128)}}},
  'layer_11': {'Dense_0': {'bias': (512,), 'kernel': (51

# Scan Examples

In [4]:
# List of different scan configurations for nn.scan
scan_args_list = [
    # Configuration 1: 
    # - No parameter axes (shared parameters across layers)
    # - Broadcast parameters to all layers
    # - Don't split RNGs for parameters
    # - Input/Output along axis 0 (batch dimension)
    {
        "variable_axes": {},  # No parameter axes (shared params)
        "variable_broadcast": "params",  # Broadcast params to all layers
        "split_rngs": {"params": False},  # Don't split RNGs for params
        "in_axes": 0,  # Input along axis 0 (batch dim)
        "out_axes": 0,  # Output along axis 0 (batch dim)
    },
    # Configuration 2:
    # - Parameters along axis 0 (separate params per layer)
    # - Don't broadcast parameters
    # - Split RNGs for parameters
    # - Broadcast input to all layers
    # - Output along axis 0 (batch dimension)
    {
        "variable_axes": {"params": 0},  # Parameters along axis 0 (per layer)
        "variable_broadcast": False,  # Don't broadcast params
        "split_rngs": {"params": True},  # Split RNGs for params
        "in_axes": nn.broadcast,  # Broadcast input to all layers
        "out_axes": 0,  # Output along axis 0 (batch dim)
    },
]

In [5]:
IDX = 1
class MyScanModule(nn.Module):
    hidden_dim: int
    num_layers: int
    num_heads: int
    dtype: jnp.dtype
    
    def setup(self):
      self.blocks = nn.scan(
         SimpleBlock,
         length=self.num_layers,
         **scan_args_list[IDX],
			)(
				hidden_dim=self.hidden_dim,
				num_heads=self.num_heads,
				dtype=self.dtype
			)

    def __call__(self, x):
      x, scan_out = self.blocks(x)
      return x, scan_out

scan_model = MyScanModule(hidden_dim=HIDDEN_DIM, num_layers=LAYERS, num_heads=NUM_HEADS, dtype=DTYPE)
scan_params = scan_model.init(init_rng, dummy_input)

compare_time(scan_model.apply, scan_params, input)

jax.tree.map(jnp.shape, scan_params)

Average time for the subsequent JIT calls: 603.53 ms


{'params': {'blocks': {'Dense_0': {'bias': (42, 512),
    'kernel': (42, 512, 512)},
   'self_attention': {'key': {'bias': (42, 4, 128),
     'kernel': (42, 512, 4, 128)},
    'out': {'bias': (42, 512), 'kernel': (42, 4, 128, 512)},
    'query': {'bias': (42, 4, 128), 'kernel': (42, 512, 4, 128)},
    'value': {'bias': (42, 4, 128), 'kernel': (42, 512, 4, 128)}}}}}

# Cost Analysis

In [7]:
# For original model
cost = jax.jit(model.apply).lower(params, input).compile().cost_analysis()
print_cost_analysis(cost, show_utilization=False)

# For scanned model
scan_cost = jax.jit(scan_model.apply).lower(scan_params, input).compile().cost_analysis()
print_cost_analysis(scan_cost, show_utilization=False)

=== JAX Cost Analysis ===
Total estimated memory accessed: 579.16 MB
Estimated peak memory usage: 579.16 MB
Estimated FLOPs: 57.880 GFLOPs
Transcendental ops (exp, sin, etc.): 2752512


=== JAX Cost Analysis ===
Total estimated memory accessed: 148.77 MB
Estimated peak memory usage: 210.00 MB
Estimated FLOPs: 1.378 GFLOPs
Transcendental ops (exp, sin, etc.): 65536




# Practical Example: Distributed LLM inference with KV cachingm