flax.linen.scan module can be used to reduce the whole repeated or iterative computation graph into a single interation, therefore reduces the graph memory and compilation time.

In [1]:
# Auto-reload modules so we can edit the code and have it automatically reload
%load_ext autoreload
%autoreload 2

import numpy as np
import timeit

import jax
import jaxlib
import flax.linen as nn
import jax.numpy as jnp
from flax import core


BATCH_SIZE = 12
SEQ_LEN = 32
HIDDEN_DIM = 512
LAYERS = 42
NUM_HEADS = 4
DTYPE = jnp.float32
rng = jax.random.PRNGKey(0)
rng, init_rng = jax.random.split(rng)
dummy_input = jax.random.normal(rng, (1, SEQ_LEN, HIDDEN_DIM))
input = jax.random.normal(rng, (BATCH_SIZE, SEQ_LEN, HIDDEN_DIM))


In [2]:
def measure_time(fn, params, input, num_runs=5):
  total_time = timeit.timeit(
    lambda: jax.tree.map(lambda x: x.block_until_ready(), fn(params, input)),
    number=num_runs
  )
  avg_time = total_time / num_runs
  if avg_time < 1e-3:
    return f"{avg_time * 1e6:.2f} μs"
  elif avg_time < 1:
    return f"{avg_time * 1e3:.2f} ms"
  else:
    return f"{avg_time:.2f} s"

def compare_time(fn, params, input, num_runs=5):
  assert not isinstance(fn, jaxlib._jax.PjitFunction), "Function appears to be jitted"
  jit_fn = jax.jit(fn)
  # print(f"Average time for the regular apply: {measure_time(fn, params, input, num_runs)}")
  # print(f"Average time for the first JIT call: {measure_time(jit_fn, params, input, num_runs)}")
  output = jit_fn(params, input)
  print(f"Average time for the subsequent JIT calls: {measure_time(jit_fn, params, input, num_runs)}")

def print_cost_analysis(cost, show_utilization=False):
    """
    Pretty-print the cost analysis dictionary returned by:
        jax.jit(fn).lower(...).compile().cost_analysis()
    
    Args:
        cost: Dictionary containing cost analysis metrics
        show_utilization: Whether to print per-layer utilization metrics
    """
    print("=== JAX Cost Analysis ===")

    # Total memory accessed
    total_mem = cost.get('bytes accessedout{}', 0)
    print(f"Total estimated memory accessed: {total_mem / (1024 ** 2):.2f} MB")

    # Find peak memory estimate across all ops
    peak_mem = max(
        val for key, val in cost.items() if key.startswith('bytes accessedout{')
    )
    print(f"Estimated peak memory usage: {peak_mem / (1024 ** 2):.2f} MB")

    # Total FLOPs
    flops = cost.get('flops', 0)
    print(f"Estimated FLOPs: {flops / 1e9:.3f} GFLOPs")

    # Transcendental operations
    trans = cost.get('transcendentals', 0)
    print(f"Transcendental ops (exp, sin, etc.): {int(trans)}")

    # Optimal execution time (TPU internal estimate)
    optimal_time = cost.get('optimal_seconds', None)
    if optimal_time is not None:
        print(f"Estimated optimal execution time: {optimal_time * 1e3:.3f} ms")

    if show_utilization:
        print("=== Per-Layer Utilization (if available) ===")
        utilizations = {
            k: v for k, v in cost.items() if k.startswith('utilization')
        }

        if utilizations:
            for key, value in sorted(utilizations.items()):
                print(f"{key}: {value}")
        else:
            print("No detailed utilization metrics found.")
    print(); print()
            

def pyloop_to_scan(params_pyloop):
  """
  Converts a PyLoop-style LLM checkpoint into a scanned-loop style.
  
  Args:
    params_pyloop: Original parameter dict with 'layer_0', 'layer_1', etc.

  Returns:
    params_scan: Updated dict with all layers stacked under 'layers'.
  """
  t = core.unfreeze(core.FrozenDict(params_pyloop))

  # Collect all layer keys like 'layer_0', 'layer_1', etc.
  layer_keys = [k for k in t if k.startswith("layer_")]
  layer_keys.sort(key=lambda x: int(x.split("_")[1]))  # sort by layer index

  depth = len(layer_keys)
  assert depth > 0, "No layers found in params"

  def stack(*values):
    return jnp.stack(values)

  # Stack all layers under one key: 'layers'
  t["blocks"] = jax.tree.map(stack, *[t.pop(k) for k in layer_keys])

  return core.freeze(t)

def output_comparison(original_output, scan_output, threshold=1e-6, seed=42):
    """Compare outputs between original and scanned models with detailed formatting."""
    # Input validation
    assert len(original_output) == 2 and len(scan_output) == 2, "Output must be a tuple of two elements"
    assert original_output[0].shape == scan_output[0].shape, "Output shapes must match"
    
    # Final output comparison
    original_last = original_output[0]
    scan_last = scan_output[0]
    diff = jnp.abs(original_last - scan_last)
    
    # Summary statistics
    summary = {
        'Max diff': float(diff.max()),
        'Min diff': float(diff.min()),
        'Mean diff': float(diff.mean()),
        'Std diff': float(diff.std()),
        'Within threshold': float((diff < threshold).mean() * 100)
    }
    
    # Formatting constants
    SEP = "=" * 80
    ROW = "-" * 80
    COL1 = 20
    COL2 = 20
    COL3 = 20
    
    # Print final comparison header
    print(f"\n{SEP}")
    print(f"{' FINAL OUTPUT COMPARISON ':.^{len(SEP)}}")
    print(SEP)
    print(f"{'Metric':<{COL1}} | {'Value':>{COL2}} | {'Threshold':>{COL3}}")
    print(ROW)
    
    # Print final comparison rows
    for metric, value in summary.items():
        if metric == 'Within threshold':
            print(f"{metric:<{COL1}} | {f'{value:.2f}%':>{COL2}} | {'':>{COL3}}")
        else:
            status = "✅" if value < threshold else "❌"
            print(f"{metric:<{COL1}} | {value:>{COL2}.6f} | {status:>{COL3}}")
    print(f"{SEP}\n")

    # Layer-wise comparison
    for lyr in range(len(original_output[1])):
        lyr_key = f"layer_{lyr}"
        
        # Layer header
        print(f"\n{SEP}")
        print(f"{f' LAYER {lyr} COMPARISON ':.^{len(SEP)}}")
        print(SEP)
        
        for key in original_output[1][lyr_key].keys():
            orig_val = original_output[1][lyr_key][key]
            scan_val = scan_output[1][key][lyr]
            diff = jnp.abs(orig_val - scan_val)
            
            # Key header
            print(f"\n{key.upper()}:")
            print(ROW)
            print(f"{'Sampled idx':<5} | {'Original':>15}   | {'Scan':>15} | {'Diff':>15} | {'Status':>5}")
            print(ROW)
            
            # Random samples
            rng = np.random.default_rng(seed)
            batch_idx = rng.integers(0, orig_val.shape[0])
            seq_idx = rng.integers(0, orig_val.shape[1])
            dims = rng.choice(orig_val.shape[2], 3, False)
            
            for dim in dims:
                status = "✅" if diff[batch_idx, seq_idx, dim] < threshold else "❌"
                print(f"{dim:<5}       | {orig_val[batch_idx, seq_idx, dim]:>15.6f}   | "
                      f"{scan_val[batch_idx, seq_idx, dim]:>15.6f} | "
                      f"{diff[batch_idx, seq_idx, dim]:>15.6f} | {status:>5}")
            
            # Statistics footer
            print(ROW)
            print(f"{'Stats:':<5}      | {'Mean:':>15}   | {'':>15} | {float(diff.mean()):>15.6f} | {'':>5}")
            print(f"{'':<5}       | {'Std:':>15}   | {'':>15} | {float(diff.std()):>15.6f} | {'':>5}")
            print(f"{'':<5}       | {'Within threshold:':>15} | {'':>15} | {f'{(diff < threshold).mean() * 100:.2f}%':>15} | {'✅' if (diff < threshold).mean() > 0.99 else '❌':>5}")
            

# Basic iterative module

In [3]:
class SimpleBlock(nn.Module):
    hidden_dim: int
    num_heads: int
    dtype: jnp.dtype
    
    @nn.compact
    def __call__(self, x, attention_mask=None):
        out = {}
        out["input_hidden_state"] = x
        x = out["pre_attention_norm"] = nn.RMSNorm(dtype=self.dtype)(x)
        jax.debug.print("pre_attention_norm dtype={dtype}", dtype=x.dtype)
        x += nn.SelfAttention(
            num_heads=self.num_heads,
            kernel_init=nn.initializers.xavier_uniform(),
            bias_init=nn.initializers.zeros,
            dtype=self.dtype,
            name="self_attention"
					)(x, attention_mask)
        jax.debug.print("self_attention dtype={dtype}", dtype=x.dtype)
        out["+self_attention"] = x
        x = out["pre_ffw_norm"] = nn.RMSNorm(dtype=self.dtype)(x)
        jax.debug.print("pre_ffw_norm dtype={dtype}", dtype=x.dtype)
        x += nn.Dense(features=self.hidden_dim)(x)
        jax.debug.print("+dense dtype={dtype}", dtype=x.dtype)
        out["+dense"] = x
        return x, out

class MyIterativeModule(nn.Module):
    hidden_dim: int
    num_layers: int
    num_heads: int
    dtype: jnp.dtype
    
    def setup(self):
        self.blocks = [
            SimpleBlock(
                name=f"layer_{i}",
                hidden_dim=self.hidden_dim,
                num_heads=self.num_heads,
                dtype=self.dtype
						) 
            for i in range(self.num_layers)  # Use num_layers instead of hardcoded 3
        ]

    def __call__(self, x):
        out = {}
        for i, block in enumerate(self.blocks):
          x, layer_out = block(x)
          out[f"layer_{i}"] = layer_out
        return x, out

model = MyIterativeModule(hidden_dim=HIDDEN_DIM, num_layers=LAYERS, num_heads=NUM_HEADS, dtype=DTYPE)
params = model.init(init_rng, dummy_input)

jax.tree.map(jnp.shape, params)

pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attent

{'params': {'layer_0': {'Dense_0': {'bias': (512,), 'kernel': (512, 512)},
   'RMSNorm_0': {'scale': (512,)},
   'RMSNorm_1': {'scale': (512,)},
   'self_attention': {'key': {'bias': (4, 128), 'kernel': (512, 4, 128)},
    'out': {'bias': (512,), 'kernel': (4, 128, 512)},
    'query': {'bias': (4, 128), 'kernel': (512, 4, 128)},
    'value': {'bias': (4, 128), 'kernel': (512, 4, 128)}}},
  'layer_1': {'Dense_0': {'bias': (512,), 'kernel': (512, 512)},
   'RMSNorm_0': {'scale': (512,)},
   'RMSNorm_1': {'scale': (512,)},
   'self_attention': {'key': {'bias': (4, 128), 'kernel': (512, 4, 128)},
    'out': {'bias': (512,), 'kernel': (4, 128, 512)},
    'query': {'bias': (4, 128), 'kernel': (512, 4, 128)},
    'value': {'bias': (4, 128), 'kernel': (512, 4, 128)}}},
  'layer_10': {'Dense_0': {'bias': (512,), 'kernel': (512, 512)},
   'RMSNorm_0': {'scale': (512,)},
   'RMSNorm_1': {'scale': (512,)},
   'self_attention': {'key': {'bias': (4, 128), 'kernel': (512, 4, 128)},
    'out': {'bias'

# Scan Examples

In [4]:
# List of different scan configurations for nn.scan
scan_args_list = [
    # Configuration 1: 
    # - No parameter axes (shared parameters across layers)
    # - Broadcast parameters to all layers
    # - Don't split RNGs for parameters
    # - Input/Output along axis 0 (batch dimension)
    {
        "variable_axes": {},  # No parameter axes (shared params)
        "variable_broadcast": "params",  # Broadcast params to all layers
        "split_rngs": {"params": False},  # Don't split RNGs for params
        "in_axes": 0,  # Input along axis 0 (batch dim)
        "out_axes": 0,  # Output along axis 0 (batch dim)
    },
    # Configuration 2:
    # - Parameters along axis 0 (separate params per layer)
    # - Don't broadcast parameters
    # - Split RNGs for parameters
    # - Broadcast input to all layers
    # - Output along axis 0 (batch dimension)
    {
        "variable_axes": {"params": 0},  # Parameters along axis 0 (per layer)
        "variable_broadcast": False,  # Don't broadcast params
        "split_rngs": {"params": True},  # Split RNGs for params
        "in_axes": nn.broadcast,  # Broadcast input to all layers
        "out_axes": 0,  # Output along axis 0 (batch dim)
    },
    {
        "variable_axes": {"params": 0},  
        "variable_broadcast": False,
        "split_rngs": {"params": False},
        "in_axes": nn.broadcast,
        "out_axes": 0,
		}
]

In [5]:
IDX = 1
class MyScanModule(nn.Module):
    hidden_dim: int
    num_layers: int
    num_heads: int
    dtype: jnp.dtype
    
    def setup(self):
      self.blocks = nn.scan(
         SimpleBlock,
         length=self.num_layers,
         **scan_args_list[IDX],
			)(
				hidden_dim=self.hidden_dim,
				num_heads=self.num_heads,
				dtype=self.dtype
			)

    def __call__(self, x):
      x, scan_out = self.blocks(x)
      return x, scan_out

scan_model = MyScanModule(hidden_dim=HIDDEN_DIM, num_layers=LAYERS, num_heads=NUM_HEADS, dtype=DTYPE)
scan_params = scan_model.init(init_rng, dummy_input)

jax.tree.map(jnp.shape, scan_params)

pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32


{'params': {'blocks': {'Dense_0': {'bias': (42, 512),
    'kernel': (42, 512, 512)},
   'RMSNorm_0': {'scale': (42, 512)},
   'RMSNorm_1': {'scale': (42, 512)},
   'self_attention': {'key': {'bias': (42, 4, 128),
     'kernel': (42, 512, 4, 128)},
    'out': {'bias': (42, 512), 'kernel': (42, 4, 128, 512)},
    'query': {'bias': (42, 4, 128), 'kernel': (42, 512, 4, 128)},
    'value': {'bias': (42, 4, 128), 'kernel': (42, 512, 4, 128)}}}}}

pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32


In [6]:
# pyloop_to_scan
new_params = {"params": pyloop_to_scan(params['params'])}

# compare new_params and scan_params
print(f"RMSNorm_0 scale: {abs(new_params['params']['blocks']['RMSNorm_0']['scale'] - scan_params['params']['blocks']['RMSNorm_0']['scale']).mean()}")
print(f"RMSNorm_1 scale: {abs(new_params['params']['blocks']['RMSNorm_1']['scale'] - scan_params['params']['blocks']['RMSNorm_1']['scale']).mean()}")
print(f"self_attention key kernel: {abs(new_params['params']['blocks']['self_attention']['key']['kernel'] - scan_params['params']['blocks']['self_attention']['key']['kernel']).mean()}")
print(f"self_attention key bias: {abs(new_params['params']['blocks']['self_attention']['key']['bias'] - scan_params['params']['blocks']['self_attention']['key']['bias']).mean()}")
print(f"self_attention query kernel: {abs(new_params['params']['blocks']['self_attention']['query']['kernel'] - scan_params['params']['blocks']['self_attention']['query']['kernel']).mean()}")
print(f"self_attention query bias: {abs(new_params['params']['blocks']['self_attention']['query']['bias'] - scan_params['params']['blocks']['self_attention']['query']['bias']).mean()}")
print(f"self_attention value kernel: {abs(new_params['params']['blocks']['self_attention']['value']['kernel'] - scan_params['params']['blocks']['self_attention']['value']['kernel']).mean()}")
print(f"self_attention value bias: {abs(new_params['params']['blocks']['self_attention']['value']['bias'] - scan_params['params']['blocks']['self_attention']['value']['bias']).mean()}")
print(f"self_attention out kernel: {abs(new_params['params']['blocks']['self_attention']['out']['kernel'] - scan_params['params']['blocks']['self_attention']['out']['kernel']).mean()}")
print(f"self_attention out bias: {abs(new_params['params']['blocks']['self_attention']['out']['bias'] - scan_params['params']['blocks']['self_attention']['out']['bias']).mean()}")
print(f"Dense_0 kernel: {abs(new_params['params']['blocks']['Dense_0']['kernel'] - scan_params['params']['blocks']['Dense_0']['kernel']).mean()}")
print(f"Dense_0 bias: {abs(new_params['params']['blocks']['Dense_0']['bias'] - scan_params['params']['blocks']['Dense_0']['bias']).mean()}")

# compare params and new_params
print(f"RMSNorm_0 scale layer 0: {abs(params['params']['layer_0']['RMSNorm_0']['scale'] - new_params['params']['blocks']['RMSNorm_0']['scale'][0]).mean()}")
print(f"RMSNorm_1 scale layer 0: {abs(params['params']['layer_0']['RMSNorm_1']['scale'] - new_params['params']['blocks']['RMSNorm_1']['scale'][0]).mean()}")
print(f"self_attention key kernel layer 0: {abs(params['params']['layer_0']['self_attention']['key']['kernel'] - new_params['params']['blocks']['self_attention']['key']['kernel'][0]).mean()}")
print(f"self_attention key bias layer 0: {abs(params['params']['layer_0']['self_attention']['key']['bias'] - new_params['params']['blocks']['self_attention']['key']['bias'][0]).mean()}")
print(f"self_attention query kernel layer 0: {abs(params['params']['layer_0']['self_attention']['query']['kernel'] - new_params['params']['blocks']['self_attention']['query']['kernel'][0]).mean()}")
print(f"self_attention query bias layer 0: {abs(params['params']['layer_0']['self_attention']['query']['bias'] - new_params['params']['blocks']['self_attention']['query']['bias'][0]).mean()}")
print(f"self_attention value kernel layer 0: {abs(params['params']['layer_0']['self_attention']['value']['kernel'] - new_params['params']['blocks']['self_attention']['value']['kernel'][0]).mean()}")
print(f"self_attention value bias layer 0: {abs(params['params']['layer_0']['self_attention']['value']['bias'] - new_params['params']['blocks']['self_attention']['value']['bias'][0]).mean()}")
print(f"self_attention out kernel layer 0: {abs(params['params']['layer_0']['self_attention']['out']['kernel'] - new_params['params']['blocks']['self_attention']['out']['kernel'][0]).mean()}")
print(f"self_attention out bias layer 0: {abs(params['params']['layer_0']['self_attention']['out']['bias'] - new_params['params']['blocks']['self_attention']['out']['bias'][0]).mean()}")
print(f"Dense_0 kernel layer 0: {abs(params['params']['layer_0']['Dense_0']['kernel'] - new_params['params']['blocks']['Dense_0']['kernel'][0]).mean()}")
print(f"Dense_0 bias layer 0: {abs(params['params']['layer_0']['Dense_0']['bias'] - new_params['params']['blocks']['Dense_0']['bias'][0]).mean()}")


self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32


pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attent

# Analysis

In [7]:
# For original model
compare_time(model.apply, params, input)
cost = jax.jit(model.apply).lower(params, input).compile().cost_analysis()
print_cost_analysis(cost, show_utilization=False)

# For scanned model
compare_time(scan_model.apply, scan_params, input)
scan_cost = jax.jit(scan_model.apply).lower(scan_params, input).compile().cost_analysis()
print_cost_analysis(scan_cost, show_utilization=False)

pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attent

In [8]:
# Output consistency check

original_output = model.apply(params, input)
scan_output = scan_model.apply(scan_params, input)
converted_scan_output = scan_model.apply(new_params, input)

pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attent

In [9]:
output_comparison(original_output, converted_scan_output, threshold=1e-10)

self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32


pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attention_norm dtype=float32
self_attention dtype=float32
pre_ffw_norm dtype=float32
+dense dtype=float32
pre_attent

In [10]:
# compare function
# output_comparison(original_output, scan_output)

# Practical Example: Distributed LLM inference with KV caching

In [11]:
class SimpleCausalAttention(nn.Module):
    num_heads: int
    
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x, cache=None, decode=False):
        B, T, D = x.shape
        H = self.num_heads
        DH = D // H

        # Project inputs
        q = nn.Dense(D, dtype=self.dtype)(x)
        k = nn.Dense(D, dtype=self.dtype)(x)
        v = nn.Dense(D, dtype=self.dtype)(x)

        # Reshape for multi-head attention
        q = q.reshape(B, T, H, DH).transpose(0, 2, 1, 3)  # (B, H, T, DH)
        k = k.reshape(B, T, H, DH).transpose(0, 2, 1, 3)
        v = v.reshape(B, T, H, DH).transpose(0, 2, 1, 3)

        if decode:
            # In generation mode: update cache
            assert cache is not None, "Cache must be provided during decoding"
            index = cache["index"]
            one_hot_indices = jax.nn.one_hot(index, T + 1, dtype=jnp.int32)
            k = cache["k"].at[:, :, index:index+T].set(k)
            v = cache["v"].at[:, :, index:index+T].set(v)
            index += T
            cache = {"k": k, "v": v, "index": index}
            k = k[:, :, :index]
            v = v[:, :, :index]

        else:
            # During prefilling, use full sequence
            pass

        # Scaled dot-product attention
        attn_weights = jnp.einsum("bhqd,bhkd->bhqk", q, k) / jnp.sqrt(DH)
        attn_weights = jnp.tril(attn_weights)
        attn_weights = jax.nn.softmax(attn_weights)
        attn_out = jnp.einsum("bhqk,bhvd->bhqd", attn_weights, v)

        # Reshape back
        out = attn_out.transpose(0, 2, 1, 3).reshape(B, T, D)
        out = nn.Dense(D, dtype=self.dtype)(out)

        return out, cache
    
class SimpleBlock(nn.Module):
    hidden_dim: int
    num_heads: int
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, x, cache=None, decode=False):
        # Self-Attention
        attn_out, cache = SimpleCausalAttention(
            num_heads=self.num_heads,
            dtype=self.dtype,
            name="self_attention"
        )(x, cache=cache, decode=decode)

        x = x + attn_out

        # Feed-forward
        x = x + nn.Sequential([
            nn.Dense(self.hidden_dim),
            nn.gelu,
            nn.Dense(self.hidden_dim)
        ])(x)

        return x, cache
    
ScannedBlock = nn.scan(
    SimpleBlock,
    variable_axes={"params": 0},
    variable_broadcast=False,
    split_rngs={"params": True},
    in_axes=(nn.broadcast, None),  # x broadcasted, cache not scanned
    out_axes=(0, None),  # x stacked, cache not stacked
    length=LAYERS
)

class SimpleLLM(nn.Module):
    vocab_size: int
    hidden_dim: int
    num_layers: int
    num_heads: int
    max_len: int
    dtype: jnp.dtype = jnp.float32

    @nn.compact
    def __call__(self, input_ids, cache=None, decode=False):
        embed = nn.Embed(num_embeddings=self.vocab_size, features=self.hidden_dim)
        x = embed(input_ids)

        # Stack blocks
        blocks = ScannedBlock(
            hidden_dim=self.hidden_dim,
            num_heads=self.num_heads,
            dtype=self.dtype
        )
        x, cache = blocks(x, cache=cache, decode=decode)

        # Final output logits
        logits = nn.Dense(self.vocab_size)(x)

        return logits, cache

def init_cache(model, params, batch_size, max_seq_len):
    @jax.jit
    def _init_cache():
        dummy_input = jnp.ones((batch_size, 1), dtype=jnp.int32)
        _, initial_cache = model.apply(params, dummy_input, decode=True, mutable=["cache"])
        return initial_cache
    return _init_cache()

def generate_tokens(model, params, tokenizer, prompt, max_new_tokens=30):
    tokenized = tokenizer(prompt, return_tensors="np")
    input_ids = tokenized["input_ids"]

    # JIT once
    @jax.jit
    def forward_step(input_ids, cache):
        logits, new_cache = model.apply(params, input_ids, cache=cache, decode=True)
        return logits, new_cache

    # Initialize cache
    cache = init_cache(model, params, input_ids.shape[0], max_new_tokens)
    print(jax.tree.map(jnp.shape, cache))

    # Prefill context
    logits, cache = forward_step(input_ids, cache)

    # Generate new tokens
    generated_ids = []
    current_id = input_ids[:, -1:]

    for _ in range(max_new_tokens):
        logits, cache = forward_step(current_id, cache)
        current_id = jnp.argmax(logits[:, -1:], axis=-1)
        generated_ids.append(current_id)

    return tokenizer.decode(jnp.concatenate(generated_ids, axis=-1)[0])

In [12]:
VOCAB_SIZE = 10000
HIDDEN_DIM = 512
NUM_LAYERS = 42
NUM_HEADS = 4
MAX_LEN = 1024
DTYPE = jnp.float32

model = SimpleLLM(vocab_size=VOCAB_SIZE, hidden_dim=HIDDEN_DIM, num_layers=NUM_LAYERS, num_heads=NUM_HEADS, max_len=MAX_LEN, dtype=DTYPE)
params = model.init(init_rng, jnp.ones((1, 32), dtype=jnp.int32))

jax.tree.map(jnp.shape, params)



ValueError: Tuple arity mismatch: 0 != 2; tuple: ().