# JAX Advanced Concepts: PyTree, LAX, XLA, and Scan

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Ziaeemehr/workshop_hpcpy/blob/main/notebooks/jax/advanced_concepts.ipynb)

This notebook covers advanced JAX concepts that are essential for high-performance computing:
- **PyTree**: Working with nested data structures
- **LAX**: Low-level operations for performance
- **XLA**: Understanding JAX's compilation backend
- **Scan**: Efficient loops and sequential operations

In [None]:
import jax
import jax.numpy as jnp
from jax import random, grad, jit, vmap
from jax import lax
from jax.tree_util import tree_map, tree_flatten, tree_unflatten, tree_structure
import numpy as np
import matplotlib.pyplot as plt
from functools import partial
import time

# Setup for Google Colab or local environment
import os
import sys

# Check if running on Google Colab
try:
    from google.colab import drive
    IN_COLAB = True
    print("Running on Google Colab")
except ImportError:
    IN_COLAB = False
    print("Running locally")

# Clone repository if on Colab and not already cloned
if IN_COLAB:
    if not os.path.exists('/content/workshop_hpcpy'):
        print("Cloning workshop_hpcpy repository...")
        os.system('git clone https://github.com/Ziaeemehr/workshop_hpcpy.git /content/workshop_hpcpy')
    
    # Change to notebook directory
    os.chdir('/content/workshop_hpcpy/notebooks/jax')
    print(f"Working directory: {os.getcwd()}")

# Part 1: PyTree - Working with Nested Structures

PyTrees are a core abstraction in JAX for handling nested containers of arrays. They enable JAX transformations to work seamlessly with complex data structures.

## 1.1 What is a PyTree?

A PyTree is any nested structure made of:
- Lists
- Tuples
- Dictionaries
- Named tuples
- Custom classes (registered as PyTree nodes)

Leaves are typically arrays or None.

In [None]:
# Examples of PyTrees
pytree1 = [1, 2, 3]  # List
pytree2 = (jnp.array([1.0, 2.0]), jnp.array([3.0, 4.0]))  # Tuple of arrays
pytree3 = {'a': jnp.array([1.0]), 'b': jnp.array([2.0])}  # Dictionary
pytree4 = {'weights': jnp.ones((3, 3)), 'bias': jnp.zeros(3)}  # Neural network parameters
pytree5 = [{'W': jnp.ones((2, 2)), 'b': jnp.zeros(2)}, 
           {'W': jnp.ones((2, 1)), 'b': jnp.zeros(1)}]  # Multiple layers

print("PyTree examples:")
print(f"List: {pytree1}")
print(f"Tuple: {pytree2}")
print(f"Dictionary: {pytree3}")
print(f"Nested: {pytree4}")

## 1.2 Tree Operations

JAX provides utilities to manipulate PyTrees.

In [None]:
# tree_flatten: Convert PyTree to flat list and structure
params = {'weights': jnp.array([[1.0, 2.0], [3.0, 4.0]]), 
          'bias': jnp.array([0.5, 1.5])}

leaves, treedef = tree_flatten(params)
print("Original PyTree:")
print(params)
print("\nFlattened leaves:")
for i, leaf in enumerate(leaves):
    print(f"Leaf {i}: {leaf}")
print(f"\nTree structure: {treedef}")

# tree_unflatten: Reconstruct PyTree from leaves and structure
reconstructed = tree_unflatten(treedef, leaves)
print("\nReconstructed PyTree:")
print(reconstructed)

In [None]:
# tree_map: Apply function to all leaves
params = {'weights': jnp.array([[1.0, 2.0], [3.0, 4.0]]), 
          'bias': jnp.array([0.5, 1.5])}

# Double all parameters
doubled = tree_map(lambda x: 2 * x, params)
print("Original:")
print(params)
print("\nDoubled:")
print(doubled)

# Add two PyTrees element-wise
params2 = {'weights': jnp.ones((2, 2)), 'bias': jnp.ones(2)}
sum_params = tree_map(lambda x, y: x + y, params, params2)
print("\nSum of two PyTrees:")
print(sum_params)

## 1.3 PyTrees with JAX Transformations

PyTrees work seamlessly with `grad`, `jit`, `vmap`, etc.

In [None]:
# Example: Gradient of a function with PyTree parameters
def loss_fn(params, x, y):
    """Simple linear model loss."""
    pred = jnp.dot(x, params['weights']) + params['bias']
    return jnp.mean((pred - y) ** 2)

# Initialize parameters
params = {
    'weights': jnp.array([[0.1, 0.2], [0.3, 0.4]]),
    'bias': jnp.array([0.0, 0.0])
}

# Sample data
x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
y = jnp.array([[1.0, 2.0], [3.0, 4.0]])

# Compute gradients (returns PyTree with same structure)
grads = grad(loss_fn)(params, x, y)

print("Parameters:")
print(params)
print("\nGradients (same structure):")
print(grads)

In [None]:
# Example: Gradient descent with PyTree parameters
@jit
def update(params, x, y, learning_rate):
    """Perform one gradient descent step."""
    grads = grad(loss_fn)(params, x, y)
    # Use tree_map to update all parameters
    return tree_map(lambda p, g: p - learning_rate * g, params, grads)

# Training loop
params = {
    'weights': jnp.array([[0.1, 0.2], [0.3, 0.4]]),
    'bias': jnp.array([0.0, 0.0])
}

for i in range(5):
    loss = loss_fn(params, x, y)
    params = update(params, x, y, 0.1)
    print(f"Step {i}: Loss = {loss:.4f}")

print("\nFinal parameters:")
print(params)

## 1.4 Custom PyTree Nodes

You can register custom classes as PyTree nodes.

In [None]:
from typing import NamedTuple

class MLPParams(NamedTuple):
    """Parameters for a simple MLP."""
    w1: jnp.ndarray
    b1: jnp.ndarray
    w2: jnp.ndarray
    b2: jnp.ndarray

# NamedTuples are automatically PyTrees in JAX
params = MLPParams(
    w1=jnp.ones((10, 5)),
    b1=jnp.zeros(5),
    w2=jnp.ones((5, 2)),
    b2=jnp.zeros(2)
)

# Apply tree_map
scaled = tree_map(lambda x: 0.5 * x, params)
print("Original w1 sum:", params.w1.sum())
print("Scaled w1 sum:", scaled.w1.sum())

# Count parameters
def count_params(pytree):
    return sum(x.size for x in tree_flatten(pytree)[0])

print(f"\nTotal parameters: {count_params(params)}")

# Part 2: LAX - Low-Level Operations

The `jax.lax` module provides low-level operations that are more primitive than `jax.numpy`. These are closer to XLA operations and can be more efficient.

## 2.1 Control Flow with LAX

LAX provides functional control flow operations that work with JAX transformations.

In [None]:
# lax.cond: Functional if-else
def f_true(x):
    return x + 1

def f_false(x):
    return x - 1

x = 5.0
result_true = lax.cond(True, f_true, f_false, x)
result_false = lax.cond(False, f_true, f_false, x)

print(f"x = {x}")
print(f"cond(True, ...): {result_true}")
print(f"cond(False, ...): {result_false}")

# More practical example: ReLU with lax.cond
@jit
def relu_cond(x):
    return lax.cond(x > 0, lambda x: x, lambda x: 0.0, x)

# Note: For ReLU, jnp.maximum is more efficient
print(f"\nReLU(-2.0) = {relu_cond(-2.0)}")
print(f"ReLU(3.0) = {relu_cond(3.0)}")

In [None]:
# lax.switch: Multi-way branch (like switch/case)
def option_0(x):
    return x ** 2

def option_1(x):
    return x ** 3

def option_2(x):
    return jnp.sqrt(x)

branches = [option_0, option_1, option_2]
x = 4.0

for i in range(3):
    result = lax.switch(i, branches, x)
    print(f"Branch {i}(x={x}): {result}")

In [None]:
# lax.while_loop: Functional while loop
def cond_fun(val):
    i, total = val
    return i < 10

def body_fun(val):
    i, total = val
    return i + 1, total + i

init_val = (0, 0)
final_i, final_total = lax.while_loop(cond_fun, body_fun, init_val)

print(f"Sum of 0 to 9: {final_total}")
print(f"Expected: {sum(range(10))}")

## 2.2 LAX Operations

LAX provides efficient primitive operations.

In [None]:
# lax.select: Vectorized conditional (like np.where)
x = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
result = lax.select(x > 3, x, -x)
print(f"x: {x}")
print(f"select(x > 3, x, -x): {result}")

# lax.clamp: Clamp values to range
x = jnp.array([-1.0, 0.5, 2.0, 5.0])
clamped = lax.clamp(0.0, x, 3.0)  # min, x, max
print(f"\nx: {x}")
print(f"clamp(0, x, 3): {clamped}")

In [None]:
# Performance comparison: lax.scan vs Python loop
def cumsum_loop(arr):
    """Cumulative sum using Python loop."""
    result = jnp.zeros_like(arr)
    total = 0
    for i in range(len(arr)):
        total = total + arr[i]
        result = result.at[i].set(total)
    return result

def cumsum_scan(arr):
    """Cumulative sum using lax.scan."""
    def body(carry, x):
        new_carry = carry + x
        return new_carry, new_carry
    
    _, result = lax.scan(body, 0.0, arr)
    return result

arr = jnp.arange(10.0)
print(f"Array: {arr}")
print(f"Cumsum (loop): {cumsum_loop(arr)}")
print(f"Cumsum (scan): {cumsum_scan(arr)}")
print(f"Cumsum (jnp): {jnp.cumsum(arr)}")

# Part 3: XLA - Accelerated Linear Algebra

XLA is JAX's compilation backend. Understanding XLA helps you write more efficient code.

## 3.1 Understanding JIT Compilation

When you use `@jit`, JAX traces your function and compiles it with XLA.

In [None]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

selu_jit = jit(selu)

x = jnp.arange(1000000.0)

# First call: compilation + execution
start = time.time()
result = selu_jit(x).block_until_ready()
compile_time = time.time() - start

# Second call: only execution (cached)
start = time.time()
result = selu_jit(x).block_until_ready()
cached_time = time.time() - start

# Non-JIT version
start = time.time()
result = selu(x).block_until_ready()
no_jit_time = time.time() - start

print(f"First JIT call (compile + execute): {compile_time*1000:.2f} ms")
print(f"Second JIT call (cached): {cached_time*1000:.2f} ms")
print(f"No JIT: {no_jit_time*1000:.2f} ms")
print(f"\nSpeedup (cached): {no_jit_time/cached_time:.1f}x")

## 3.2 Static vs Traced Arguments

Some arguments should be static (known at compile time) rather than traced.

In [None]:
# Problem: Shape-dependent code
@jit
def normalize_bad(x):
    if x.ndim == 1:
        return x / jnp.linalg.norm(x)
    else:
        return x / jnp.linalg.norm(x, axis=1, keepdims=True)

# This will cause recompilation for different shapes
x1 = jnp.array([1.0, 2.0, 3.0])
x2 = jnp.array([[1.0, 2.0], [3.0, 4.0]])

try:
    result1 = normalize_bad(x1)
    result2 = normalize_bad(x2)
    print("Both calls succeeded but caused recompilation")
except Exception as e:
    print(f"Error: {e}")

In [None]:
# Solution: Use static_argnums or separate functions
from functools import partial

@partial(jit, static_argnums=(1,))
def normalize_axis(x, axis):
    return x / jnp.linalg.norm(x, axis=axis, keepdims=True)

x = jnp.array([[1.0, 2.0], [3.0, 4.0]])
result_axis0 = normalize_axis(x, 0)
result_axis1 = normalize_axis(x, 1)

print("Original:")
print(x)
print("\nNormalized (axis=0):")
print(result_axis0)
print("\nNormalized (axis=1):")
print(result_axis1)

## 3.3 XLA Optimization Tips

Understanding what XLA does helps you write faster code.

In [None]:
# Fusion: XLA fuses operations
# Separate operations
def separate_ops(x):
    a = x + 1
    b = a * 2
    c = b - 3
    return c

# Fused (XLA will do this automatically)
@jit
def fused_ops(x):
    return (x + 1) * 2 - 3

x = jnp.arange(1000000.0)

# Time fused version
start = time.time()
result = fused_ops(x).block_until_ready()
fused_time = time.time() - start

print(f"Fused operations: {fused_time*1000:.2f} ms")
print("XLA automatically fuses element-wise operations into a single kernel")

# Part 4: Scan - Efficient Sequential Operations

`lax.scan` is crucial for efficient sequential computations in JAX. It's like a functional for-loop that can be JIT-compiled.

## 4.1 Basic Scan Usage

`scan(f, init, xs)` applies `f` sequentially to elements of `xs`, carrying state.

In [None]:
# Simple example: cumulative sum
def body_fn(carry, x):
    """Body function for scan.
    
    Args:
        carry: State being carried through
        x: Current element from input sequence
    
    Returns:
        new_carry: Updated state
        output: Value to collect
    """
    new_carry = carry + x
    output = new_carry
    return new_carry, output

# Compute cumulative sum
xs = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
init_carry = 0.0
final_carry, outputs = lax.scan(body_fn, init_carry, xs)

print(f"Input: {xs}")
print(f"Cumulative sum: {outputs}")
print(f"Final carry (total sum): {final_carry}")

## 4.2 Scan vs For Loop

Why use scan instead of a for loop?

In [None]:
# Fibonacci using for loop (slow)
def fib_loop(n):
    result = []
    a, b = 0, 1
    for _ in range(n):
        result.append(a)
        a, b = b, a + b
    return jnp.array(result)

# Fibonacci using scan (fast, can be JIT-compiled)
@jit
def fib_scan(n):
    def body(carry, _):
        a, b = carry
        return (b, a + b), a
    
    _, result = lax.scan(body, (0, 1), None, length=n)
    return result

n = 20
print(f"First {n} Fibonacci numbers:")
print(f"Loop: {fib_loop(n)}")
print(f"Scan: {fib_scan(n)}")

# Benchmark
n_large = 1000
start = time.time()
_ = fib_loop(n_large)
loop_time = time.time() - start

start = time.time()
_ = fib_scan(n_large).block_until_ready()
scan_time = time.time() - start

print(f"\nFor n={n_large}:")
print(f"Loop: {loop_time*1000:.2f} ms")
print(f"Scan: {scan_time*1000:.2f} ms")
print(f"Speedup: {loop_time/scan_time:.1f}x")

## 4.3 Practical Example: RNN

Scan is essential for implementing recurrent neural networks.

In [None]:
# Simple RNN cell
@jit
def rnn_cell(carry, x, W_hh, W_xh, b):
    """Single RNN step."""
    h = carry
    h_new = jnp.tanh(jnp.dot(W_hh, h) + jnp.dot(W_xh, x) + b)
    return h_new, h_new

# Process sequence with RNN
@jit
def rnn_forward(params, h0, xs):
    """Forward pass through RNN.
    
    Args:
        params: Dictionary with W_hh, W_xh, b
        h0: Initial hidden state
        xs: Input sequence (time_steps, input_dim)
    
    Returns:
        outputs: Hidden states at each time step
    """
    def body(carry, x):
        return rnn_cell(carry, x, params['W_hh'], params['W_xh'], params['b'])
    
    _, outputs = lax.scan(body, h0, xs)
    return outputs

# Initialize RNN parameters
key = random.PRNGKey(0)
hidden_dim = 4
input_dim = 3

key, *subkeys = random.split(key, 4)
params = {
    'W_hh': random.normal(subkeys[0], (hidden_dim, hidden_dim)) * 0.1,
    'W_xh': random.normal(subkeys[1], (hidden_dim, input_dim)) * 0.1,
    'b': jnp.zeros(hidden_dim)
}

# Generate input sequence
time_steps = 10
key, subkey = random.split(key)
xs = random.normal(subkey, (time_steps, input_dim))
h0 = jnp.zeros(hidden_dim)

# Forward pass
outputs = rnn_forward(params, h0, xs)

print(f"Input shape: {xs.shape}")
print(f"Output shape: {outputs.shape}")
print(f"\nFirst 3 hidden states:")
print(outputs[:3])

## 4.4 Bidirectional Scan

You can scan in reverse or both directions.

In [None]:
# Forward scan
def cumsum_forward(xs):
    def body(carry, x):
        new_carry = carry + x
        return new_carry, new_carry
    _, outputs = lax.scan(body, 0.0, xs)
    return outputs

# Reverse scan
def cumsum_reverse(xs):
    def body(carry, x):
        new_carry = carry + x
        return new_carry, new_carry
    _, outputs = lax.scan(body, 0.0, xs, reverse=True)
    return outputs

xs = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
print(f"Input: {xs}")
print(f"Forward cumsum: {cumsum_forward(xs)}")
print(f"Reverse cumsum: {cumsum_reverse(xs)}")

# Bidirectional: combine forward and reverse
def bidirectional_sum(xs):
    forward = cumsum_forward(xs)
    reverse = cumsum_reverse(xs)
    return forward + reverse

print(f"Bidirectional sum: {bidirectional_sum(xs)}")

## 4.5 Scan with Multiple Carries

You can carry multiple pieces of state.

In [None]:
# Example: Running mean and variance
def running_stats(xs):
    """Compute running mean and variance."""
    def body(carry, x):
        count, mean, M2 = carry
        count = count + 1
        delta = x - mean
        mean = mean + delta / count
        delta2 = x - mean
        M2 = M2 + delta * delta2
        variance = M2 / count
        return (count, mean, M2), (mean, variance)
    
    init_carry = (0, 0.0, 0.0)
    _, (means, variances) = lax.scan(body, init_carry, xs)
    return means, variances

# Test with random data
key = random.PRNGKey(42)
xs = random.normal(key, (100,)) * 2.0 + 5.0
means, variances = running_stats(xs)

# Plot results
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(means)
plt.axhline(xs.mean(), color='r', linestyle='--', label='True mean')
plt.title('Running Mean')
plt.xlabel('Sample')
plt.ylabel('Mean')
plt.legend()
plt.grid(True, alpha=0.3)

plt.subplot(1, 2, 2)
plt.plot(variances)
plt.axhline(xs.var(), color='r', linestyle='--', label='True variance')
plt.title('Running Variance')
plt.xlabel('Sample')
plt.ylabel('Variance')
plt.legend()
plt.grid(True, alpha=0.3)

plt.tight_layout()
plt.show()

print(f"Final mean: {means[-1]:.3f} (true: {xs.mean():.3f})")
print(f"Final variance: {variances[-1]:.3f} (true: {xs.var():.3f})")

## 4.6 Advanced: Associative Scan

For certain operations, JAX can parallelize scans using associativity.

In [None]:
# Example: Parallel prefix sum
from jax.lax import associative_scan

def add(a, b):
    return a + b

xs = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0])

# Regular scan
def regular_scan_sum(xs):
    def body(carry, x):
        return carry + x, carry + x
    _, result = lax.scan(body, 0.0, xs)
    return result

# Associative scan (can be parallelized)
def parallel_scan_sum(xs):
    return associative_scan(add, xs)

result_regular = regular_scan_sum(xs)
result_parallel = parallel_scan_sum(xs)

print(f"Input: {xs}")
print(f"Regular scan: {result_regular}")
print(f"Parallel scan: {result_parallel}")
print(f"\nResults match: {jnp.allclose(result_regular, result_parallel)}")

# Benchmark for large arrays
large_xs = jnp.arange(100000.0)

start = time.time()
_ = regular_scan_sum(large_xs).block_until_ready()
regular_time = time.time() - start

start = time.time()
_ = parallel_scan_sum(large_xs).block_until_ready()
parallel_time = time.time() - start

print(f"\nFor {len(large_xs)} elements:")
print(f"Regular scan: {regular_time*1000:.2f} ms")
print(f"Parallel scan: {parallel_time*1000:.2f} ms")
print(f"Speedup: {regular_time/parallel_time:.1f}x")

## Summary

### PyTree
- Abstraction for nested data structures
- Works with all JAX transformations
- Use `tree_map`, `tree_flatten`, `tree_unflatten`
- Essential for managing model parameters

### LAX
- Low-level operations closer to XLA
- Functional control flow: `cond`, `switch`, `while_loop`
- More efficient than Python control flow

### XLA
- JAX's compilation backend
- Fuses operations automatically
- Use `static_argnums` for shape-dependent code
- Understand compilation overhead

### Scan
- Efficient sequential operations
- Essential for RNNs and sequential models
- Can be JIT-compiled unlike Python loops
- Use `reverse=True` for backward passes
- `associative_scan` for parallelizable operations

These concepts are fundamental for high-performance computing with JAX!