# File: notebooks/01_fundamentals/04_control_flow_scan.ipynb

## JAX Fundamentals: Control Flow and Scan

Welcome to the fourth notebook in the JAX-NSL series! In this notebook, we'll explore JAX's approach to control flow operations and the powerful `scan` function. These are essential for implementing recurrent neural networks, iterative algorithms, and any computation that involves loops or conditional logic while maintaining compatibility with JAX's transformations like `jit`, `grad`, and `vmap`.

JAX requires special handling of control flow to enable automatic differentiation and compilation. We'll cover `lax.cond`, `lax.while_loop`, `lax.fori_loop`, and most importantly, `lax.scan` for efficient sequential computations.

## Setting Up the Environment

```python
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap, lax
from jax import random
import numpy as np
import matplotlib.pyplot as plt
from typing import Callable, Tuple, Any
import functools

# Enable 64-bit precision
jax.config.update("jax_enable_x64", True)

print(f"JAX version: {jax.__version__}")
```

## Basic Control Flow Operations

### Conditional Operations with lax.cond

```python
# Simple conditional logic
def simple_conditional(x, threshold=0.0):
    """Return x^2 if x > threshold, else -x^2"""
    return lax.cond(x > threshold,
                    lambda x: x**2,      # true branch
                    lambda x: -x**2,     # false branch  
                    x)                   # operand

# Test conditional
x_values = jnp.array([-2.0, -1.0, 0.0, 1.0, 2.0])

print("Simple conditional results:")
for x in x_values:
    result = simple_conditional(x)
    print(f"f({x:4.1f}) = {result:6.2f}")

# Gradient computation works through conditionals
grad_conditional = grad(simple_conditional)
print("\nGradients:")
for x in x_values:
    if x != 0.0:  # Avoid exactly zero where gradient is discontinuous
        grad_result = grad_conditional(x)
        print(f"f'({x:4.1f}) = {grad_result:6.2f}")
```

### Multiple Conditions with lax.switch

```python
# Multiple branch conditional
def piecewise_function(x):
    """Piecewise function with three branches based on value"""
    
    # Define branch functions
    def branch_0(x):  # x < -1
        return x**3
    
    def branch_1(x):  # -1 <= x <= 1  
        return x
    
    def branch_2(x):  # x > 1
        return x**2
    
    # Determine which branch based on x value
    index = lax.cond(x < -1.0,
                     lambda _: 0,
                     lambda _: lax.cond(x <= 1.0,
                                        lambda _: 1,
                                        lambda _: 2),
                     None)
    
    branches = [branch_0, branch_1, branch_2]
    return lax.switch(index, branches, x)

# Test piecewise function
x_test = jnp.linspace(-2, 2, 9)
print("Piecewise function results:")
for x in x_test:
    result = piecewise_function(x)
    print(f"f({x:4.1f}) = {result:6.2f}")
```

## Loop Operations

### Fixed Number of Iterations with lax.fori_loop

```python
# Simple iteration example: compute x^n using repeated multiplication
def power_via_loop(base, exponent):
    """Compute base^exponent using fori_loop"""
    
    def body_fun(i, val):
        return val * base
    
    # Start with 1, multiply by base 'exponent' times
    return lax.fori_loop(0, exponent, body_fun, 1.0)

# Test power computation
base, exp = 2.0, 5
result = power_via_loop(base, exp)
print(f"{base}^{exp} = {result} (expected: {base**exp})")

# More complex example: numerical integration using trapezoidal rule
def trapezoidal_integration(func, a, b, n_intervals):
    """Numerical integration using trapezoidal rule with fori_loop"""
    
    h = (b - a) / n_intervals
    
    def body_fun(i, sum_val):
        x = a + i * h
        # Add f(x) * h, but first and last points get weight 0.5
        weight = lax.cond((i == 0) | (i == n_intervals),
                         lambda _: 0.5,
                         lambda _: 1.0,
                         None)
        return sum_val + weight * func(x) * h
    
    return lax.fori_loop(0, n_intervals + 1, body_fun, 0.0)

# Test integration
def test_function(x):
    return x**2

integral_result = trapezoidal_integration(test_function, 0.0, 1.0, 1000)
analytical_result = 1.0/3.0  # Integral of x^2 from 0 to 1
print(f"Numerical integral: {integral_result:.6f}")
print(f"Analytical result: {analytical_result:.6f}")
print(f"Error: {abs(integral_result - analytical_result):.2e}")
```

### Dynamic Loops with lax.while_loop

```python
# Newton-Raphson method using while_loop
def newton_raphson(func, dfunc, x0, tolerance=1e-8, max_iters=50):
    """Solve f(x) = 0 using Newton-Raphson method"""
    
    def cond_fun(state):
        x, error, iteration = state
        return (error > tolerance) & (iteration < max_iters)
    
    def body_fun(state):
        x, _, iteration = state
        fx = func(x)
        dfx = dfunc(x)
        
        # Newton step
        x_new = x - fx / dfx
        error = jnp.abs(fx)
        
        return x_new, error, iteration + 1
    
    # Initial state: (x, error, iteration)
    initial_state = (x0, jnp.inf, 0)
    final_x, final_error, final_iter = lax.while_loop(cond_fun, body_fun, initial_state)
    
    return final_x, final_error, final_iter

# Example: Find square root of 2 by solving x^2 - 2 = 0
def f(x):
    return x**2 - 2.0

def df(x):
    return 2 * x

sqrt_2, error, iterations = newton_raphson(f, df, x0=1.0)
print(f"Square root of 2: {sqrt_2:.10f}")
print(f"True value: {jnp.sqrt(2.0):.10f}")
print(f"Error: {error:.2e}")
print(f"Iterations: {iterations}")
```

## The Scan Operation

### Basic Scan Usage

```python
# Simple cumulative sum using scan
def cumulative_sum_scan(arr):
    """Compute cumulative sum using lax.scan"""
    
    def scan_fun(carry, x):
        # carry: running sum
        # x: current element
        new_carry = carry + x
        output = new_carry  # output this intermediate result
        return new_carry, output
    
    # Initial carry value
    init_carry = 0.0
    final_carry, outputs = lax.scan(scan_fun, init_carry, arr)
    
    return final_carry, outputs

# Test cumulative sum
arr = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
final_sum, cumsum_result = cumulative_sum_scan(arr)

print(f"Array: {arr}")
print(f"Cumulative sum: {cumsum_result}")
print(f"Final sum: {final_sum}")
print(f"NumPy cumsum: {jnp.cumsum(arr)}")
```

### Scan with State: Running Statistics

```python
# Compute running mean and variance using scan
def running_statistics(data):
    """Compute running mean and variance using Welford's online algorithm"""
    
    def welford_update(state, x):
        count, mean, M2 = state
        
        count = count + 1
        delta = x - mean
        mean = mean + delta / count
        delta2 = x - mean
        M2 = M2 + delta * delta2
        
        # Variance calculation
        variance = lax.cond(count < 2,
                           lambda _: 0.0,
                           lambda _: M2 / (count - 1),
                           None)
        
        new_state = (count, mean, M2)
        output = (mean, variance)  # Output running mean and variance
        
        return new_state, output
    
    # Initial state: (count, mean, M2)
    init_state = (0.0, 0.0, 0.0)
    final_state, outputs = lax.scan(welford_update, init_state, data)
    
    return outputs

# Test with random data
key = random.PRNGKey(42)
data = random.normal(key, (100,))

running_stats = running_statistics(data)
running_means, running_vars = running_stats

print(f"Final running mean: {running_means[-1]:.4f}")
print(f"True mean: {jnp.mean(data):.4f}")
print(f"Final running variance: {running_vars[-1]:.4f}")
print(f"True variance: {jnp.var(data, ddof=1):.4f}")
```

### Scan for Recurrent Neural Networks

```python
# Simple RNN cell implementation using scan
def simple_rnn_cell(params, hidden_state, input_x):
    """Simple RNN cell: h_t = tanh(W_hh @ h_{t-1} + W_xh @ x_t + b)"""
    W_hh, W_xh, b = params
    
    new_hidden = jnp.tanh(W_hh @ hidden_state + W_xh @ input_x + b)
    return new_hidden, new_hidden  # return (new_carry, output)

def rnn_forward(params, inputs, initial_hidden):
    """Forward pass through RNN using scan"""
    
    def rnn_step(hidden_state, input_x):
        return simple_rnn_cell(params, hidden_state, input_x)
    
    final_hidden, all_hidden = lax.scan(rnn_step, initial_hidden, inputs)
    return all_hidden, final_hidden

# Initialize RNN parameters
key = random.PRNGKey(123)
hidden_size = 4
input_size = 3
seq_length = 10

# Random parameters
W_hh = random.normal(key, (hidden_size, hidden_size)) * 0.1
W_xh = random.normal(random.split(key)[1], (hidden_size, input_size)) * 0.1
b = jnp.zeros(hidden_size)
params = (W_hh, W_xh, b)

# Random input sequence
inputs = random.normal(random.split(key, 3)[2], (seq_length, input_size))
initial_hidden = jnp.zeros(hidden_size)

# Forward pass
hidden_states, final_hidden = rnn_forward(params, inputs, initial_hidden)

print(f"Input sequence shape: {inputs.shape}")
print(f"Hidden states shape: {hidden_states.shape}")
print(f"Final hidden state: {final_hidden}")

# Test gradient computation through RNN
def rnn_loss(params, inputs, initial_hidden, targets):
    hidden_states, _ = rnn_forward(params, inputs, initial_hidden)
    # Simple loss: sum of squared hidden states
    return jnp.sum(hidden_states**2)

# Compute gradients
grad_rnn = grad(rnn_loss)
targets = jnp.zeros_like(hidden_states)  # Dummy targets
grads = grad_rnn(params, inputs, initial_hidden, targets)

print(f"Gradient shapes: W_hh={grads[0].shape}, W_xh={grads[1].shape}, b={grads[2].shape}")
```

## Advanced Scan Patterns

### Scan with Multiple Sequences

```python
# Process multiple input sequences simultaneously
def multi_sequence_scan(sequences):
    """Process multiple sequences with shared state"""
    
    def scan_fun(state, inputs):
        x1, x2, x3 = inputs  # Three input sequences
        
        # Update state based on all inputs
        new_state = state + x1 * 0.5 + x2 * 0.3 + x3 * 0.2
        
        # Compute outputs
        output1 = new_state * 2
        output2 = jnp.sin(new_state)
        
        return new_state, (output1, output2)
    
    seq1, seq2, seq3 = sequences
    inputs = (seq1, seq2, seq3)
    
    init_state = 0.0
    final_state, (outputs1, outputs2) = lax.scan(scan_fun, init_state, inputs)
    
    return final_state, outputs1, outputs2

# Test with multiple sequences
key = random.PRNGKey(0)
n_steps = 20

seq1 = random.normal(key, (n_steps,))
seq2 = random.normal(random.split(key)[1], (n_steps,))
seq3 = random.normal(random.split(key, 3)[2], (n_steps,))

final_state, out1, out2 = multi_sequence_scan((seq1, seq2, seq3))
print(f"Final state: {final_state:.4f}")
print(f"Output 1 shape: {out1.shape}")
print(f"Output 2 shape: {out2.shape}")
```

### Reverse Mode Scan

```python
# Scan in reverse order
def reverse_cumsum(arr):
    """Cumulative sum from right to left"""
    
    def scan_fun(carry, x):
        new_carry = carry + x
        return new_carry, new_carry
    
    # Reverse the array, scan, then reverse the results
    reversed_arr = jnp.flip(arr)
    _, reverse_outputs = lax.scan(scan_fun, 0.0, reversed_arr)
    
    return jnp.flip(reverse_outputs)

# Test reverse cumulative sum
arr = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
reverse_result = reverse_cumsum(arr)

print(f"Original array: {arr}")
print(f"Reverse cumsum: {reverse_result}")
print(f"Expected: {jnp.flip(jnp.cumsum(jnp.flip(arr)))}")
```

## Performance Optimizations

### Scan vs Python Loops

```python
import time

def python_loop_cumsum(arr):
    """Cumulative sum using Python loop (not JIT-able)"""
    result = []
    cumsum = 0.0
    for x in arr:
        cumsum += x
        result.append(cumsum)
    return jnp.array(result)

def scan_cumsum(arr):
    """Cumulative sum using scan (JIT-able)"""
    def scan_fun(carry, x):
        carry = carry + x
        return carry, carry
    
    _, outputs = lax.scan(scan_fun, 0.0, arr)
    return outputs

# JIT compile the scan version
scan_cumsum_jit = jit(scan_cumsum)

# Generate test data
key = random.PRNGKey(42)
large_array = random.normal(key, (10000,))

# Warmup JIT compilation
_ = scan_cumsum_jit(large_array[:10])

# Benchmark
n_trials = 100

# Python loop version (can't JIT this)
start = time.time()
for _ in range(n_trials):
    result_loop = python_loop_cumsum(large_array[:1000])  # Smaller for fairness
end = time.time()
time_loop = end - start

# Scan version with JIT
start = time.time()
for _ in range(n_trials):
    result_scan = scan_cumsum_jit(large_array)
end = time.time()
time_scan = end - start

print(f"Python loop time: {time_loop:.4f}s")
print(f"Scan + JIT time: {time_scan:.4f}s")
print(f"Speedup: {time_loop / time_scan:.1f}x")
```

### Memory Efficient Scan

```python
# Scan with minimal memory usage
def memory_efficient_scan(init_state, inputs, chunk_size=1000):
    """Process large sequences in chunks to save memory"""
    
    def process_chunk(state, chunk):
        def scan_fun(carry, x):
            # Simple processing function
            carry = carry * 0.99 + x * 0.01  # Exponential moving average
            return carry, carry
        
        final_carry, outputs = lax.scan(scan_fun, state, chunk)
        return final_carry, outputs
    
    # Split inputs into chunks
    n_total = len(inputs)
    n_chunks = (n_total + chunk_size - 1) // chunk_size
    
    state = init_state
    all_outputs = []
    
    for i in range(n_chunks):
        start_idx = i * chunk_size
        end_idx = min((i + 1) * chunk_size, n_total)
        chunk = inputs[start_idx:end_idx]
        
        state, chunk_outputs = process_chunk(state, chunk)
        all_outputs.append(chunk_outputs)
    
    return state, jnp.concatenate(all_outputs)

# Test with large sequence
key = random.PRNGKey(0)
large_sequence = random.normal(key, (50000,))

final_state, outputs = memory_efficient_scan(0.0, large_sequence, chunk_size=5000)
print(f"Processed sequence length: {len(large_sequence)}")
print(f"Final state: {final_state:.6f}")
print(f"Output shape: {outputs.shape}")
```

## Differentiating Through Scan

### Gradient Flow Through Time

```python
# Example: Gradient flow through recurrent computation
def recurrent_computation(params, sequence):
    """Simple recurrent computation with parameters"""
    
    def step_fun(state, input_val):
        # state evolves based on parameters and input
        new_state = params[0] * state + params[1] * input_val + params[2]
        output = new_state**2  # Some nonlinear output
        return new_state, output
    
    init_state = 0.0
    final_state, outputs = lax.scan(step_fun, init_state, sequence)
    
    # Loss is sum of outputs
    return jnp.sum(outputs)

# Test gradient computation
params = jnp.array([0.9, 0.1, 0.05])  # [state_coeff, input_coeff, bias]
sequence = jnp.array([1.0, 0.5, -0.5, 1.5, -1.0])

# Compute loss and gradients
loss = recurrent_computation(params, sequence)
grads = grad(recurrent_computation)(params, sequence)

print(f"Loss: {loss:.4f}")
print(f"Gradients: {grads}")

# Test stability of gradients
def gradient_norm_vs_sequence_length():
    """Check how gradient norms scale with sequence length"""
    
    key = random.PRNGKey(42)
    lengths = [10, 50, 100, 500, 1000]
    
    for length in lengths:
        seq = random.normal(key, (length,))
        grads = grad(recurrent_computation)(params, seq)
        grad_norm = jnp.linalg.norm(grads)
        print(f"Length {length:4d}: Gradient norm = {grad_norm:.4f}")

gradient_norm_vs_sequence_length()
```

## Summary

In this notebook, we've explored JAX's control flow and scan operations:

**Key Concepts:**

1. **Conditional Logic**: Using `lax.cond` and `lax.switch` for differentiable conditionals
2. **Loops**: `lax.fori_loop` for fixed iterations, `lax.while_loop` for dynamic loops
3. **Scan Operation**: `lax.scan` for efficient sequential processing with state
4. **RNN Implementation**: Using scan for recurrent neural networks
5. **Performance**: JIT compilation and memory efficiency considerations

**Practical Applications:**
- Recurrent neural networks and sequence models  
- Iterative algorithms (Newton-Raphson, gradient descent)
- Online statistics computation
- Numerical integration and differential equations
- Time series processing

**Best Practices:**
- Use scan instead of Python loops for JIT compilation
- Minimize state size in scan operations
- Consider chunking for very long sequences
- Test gradient flow through recurrent computations

**Performance Benefits:**
- JIT compilation of control flow operations
- Automatic differentiation through loops and conditionals
- Vectorization and parallelization support
- Memory-efficient processing of long sequences

**Next Steps:**
- The next notebook will cover linear algebra operations
- We'll explore matrix operations, decompositions, and solvers
- Understanding scan enables efficient implementation of iterative linear algebra algorithms

Control flow and scan are fundamental for implementing sophisticated algorithms that maintain JAX's functional programming paradigm while enabling automatic differentiation and compilation.