# JAX Fundamentals

## What is JAX?

**JAX** is a library for high-performance numerical computing that brings together three key capabilities:

- **J** - **JIT compilation** (Just-In-Time): Compiles Python functions to optimized machine code
- **A** - **Automatic differentiation**: Computes gradients of any function automatically
- **X** - **XLA** (Accelerated Linear Algebra): Google's optimizing compiler that targets CPU/GPU/TPU

If you know NumPy, you already know most of JAX. The API is nearly identical, but JAX adds transformations that make it powerful for ML and scientific computing.

## Key Characteristics

1. **NumPy Compatible**: `jax.numpy` works like NumPy in most cases
2. **Immutable Arrays**: JAX arrays can't be modified in place (functional style)
3. **Hardware Accelerated**: Runs on GPU/TPU without code changes
4. **Function Transformations**: Compose `jit`, `grad`, `vmap` however you need
5. **Pure Functions**: Works best with functions that don't have side effects

---

## JAX as NumPy - Basic Operations

If you've used NumPy, this will look familiar. JAX arrays behave like NumPy arrays with one key difference: they're immutable. You can't modify them in place, but you get automatic GPU/TPU support in return.

In [11]:
# =============================================================================
# JAX AS NUMPY - COMPREHENSIVE DEMONSTRATION
# =============================================================================

import jax
import jax.numpy as jnp
import numpy as np

print("=" * 70)
print("ARRAY CREATION AND BASIC OPERATIONS")
print("=" * 70)

# -----------------------------------------------------------------------------
# Creating JAX Arrays
# -----------------------------------------------------------------------------
# JAX arrays are created just like NumPy arrays but are immutable
a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])

print("\nüì¶ Array Creation:")
print(f"Array a: {a}")
print(f"Array b: {b}")
print(f"Array a shape: {a.shape}, dtype: {a.dtype}")

# -----------------------------------------------------------------------------
# Arithmetic Operations
# -----------------------------------------------------------------------------
print("\n‚ûï Arithmetic Operations:")
print(f"Sum (a + b):                    {a + b}")
print(f"Element-wise multiplication:    {a * b}")
print(f"Dot product:                    {jnp.dot(a, b)}")
print(f"Matrix multiplication:          {jnp.matmul(a, b)}")  # Same as dot for 1D

# -----------------------------------------------------------------------------
# Mathematical Functions
# -----------------------------------------------------------------------------
print("\nüìê Mathematical Functions:")
print(f"Sine of a:                      {jnp.sin(a)}")
print(f"Exponential of b:               {jnp.exp(b)}")
print(f"Logarithm of a:                 {jnp.log(a)}")
print(f"Square root of b:               {jnp.sqrt(b)}")
print(f"Power (a^2):                    {jnp.power(a, 2)}")

# -----------------------------------------------------------------------------
# Statistical Functions
# -----------------------------------------------------------------------------
print("\nüìä Statistical Functions:")
print(f"Mean of a:                      {jnp.mean(a)}")
print(f"Standard deviation of b:        {jnp.std(b)}")
print(f"Variance of a:                  {jnp.var(a)}")
print(f"Median of b:                    {jnp.median(b)}")
print(f"Maximum value in a:             {jnp.max(a)}")
print(f"Minimum value in b:             {jnp.min(b)}")

# -----------------------------------------------------------------------------
# Aggregation Functions
# -----------------------------------------------------------------------------
print("\nüî¢ Aggregation Functions:")
print(f"Sum of all elements in a:       {jnp.sum(a)}")
print(f"Product of all elements in b:   {jnp.prod(b)}")
print(f"Cumulative sum of a:            {jnp.cumsum(a)}")
print(f"Cumulative product of a:        {jnp.cumprod(a)}")

# -----------------------------------------------------------------------------
# Array Manipulation - Reshaping
# -----------------------------------------------------------------------------
print("\nüîÑ Array Reshaping:")
a_reshaped = a.reshape((3, 1))
print(f"Reshaped a to (3,1):\n{a_reshaped}")
print(f"Transpose of reshaped b:\n{b.reshape((3, 1)).T}")

# -----------------------------------------------------------------------------
# Array Manipulation - Stacking and Concatenation
# -----------------------------------------------------------------------------
print("\nüìö Stacking and Concatenation:")
print(f"Vertical stack (vstack):\n{jnp.vstack((a, b))}")
print(f"Horizontal stack (hstack):\n{jnp.hstack((a, b))}")
print(f"Concatenate (same as hstack for 1D):\n{jnp.concatenate((a, b))}")

# -----------------------------------------------------------------------------
# Array Query Operations
# -----------------------------------------------------------------------------
print("\nüîç Array Query Operations:")
print(f"Unique elements in b:           {jnp.unique(b)}")
print(f"Sorted a (descending):          {jnp.sort(a)[::-1]}")
print(f"Indices where a > 2:            {jnp.where(a > 2)}")
print(f"Boolean mask (a > 2):           {a > 2}")
print(f"Elements of a where a > 2:      {a[a > 2]}")

# -----------------------------------------------------------------------------
# Conversion Between JAX and NumPy
# -----------------------------------------------------------------------------
print("\nüîÑ JAX ‚Üî NumPy Conversion:")
numpy_array = np.array(a)
print(f"JAX to NumPy:                   {numpy_array} (type: {type(numpy_array)})")
jax_array = jnp.array(numpy_array)
print(f"NumPy to JAX:                   {jax_array} (type: {type(jax_array)})")

# -----------------------------------------------------------------------------
# IMMUTABILITY - Key Difference from NumPy
# -----------------------------------------------------------------------------
print("\n" + "=" * 70)
print("‚ö†Ô∏è  JAX ARRAYS ARE IMMUTABLE")
print("=" * 70)
print("""
Unlike NumPy, JAX arrays cannot be modified in place.
Operations return NEW arrays rather than modifying existing ones.

‚ùå This will FAIL:
   a[1] = 10.0  # TypeError: JAX arrays are immutable

‚úÖ Use this instead:
   a = a.at[1].set(10.0)  # Returns a new array with index 1 set to 10.0
""")

# Demonstrate immutable update
a_updated = a.at[1].set(10.0)
print(f"Original a:  {a}")
print(f"Updated a:   {a_updated}")
print("Notice: Original 'a' is unchanged!\n")

# -----------------------------------------------------------------------------
# SUMMARY
# -----------------------------------------------------------------------------
print("=" * 70)
print("KEY POINTS")
print("=" * 70)
print("""
‚úÖ JAX provides a NumPy-compatible API (jax.numpy)
‚úÖ Most NumPy operations work identically in JAX
‚úÖ JAX arrays are immutable - use .at[].set() for updates
‚úÖ JAX automatically runs on GPU/TPU when available
‚úÖ Seamlessly convert between JAX and NumPy arrays
‚úÖ JAX is designed for high-performance numerical computing
‚úÖ Perfect for machine learning, scientific computing, and simulations
""")

ARRAY CREATION AND BASIC OPERATIONS

üì¶ Array Creation:
Array a: [1. 2. 3.]
Array b: [4. 5. 6.]
Array a shape: (3,), dtype: float32

‚ûï Arithmetic Operations:
Sum (a + b):                    [5. 7. 9.]
Element-wise multiplication:    [ 4. 10. 18.]
Dot product:                    32.0
Matrix multiplication:          32.0

üìê Mathematical Functions:
Sine of a:                      [0.84147096 0.9092974  0.14112   ]
Exponential of b:               [ 54.598152 148.41316  403.4288  ]
Logarithm of a:                 [0.        0.6931472 1.0986123]
Square root of b:               [2.        2.236068  2.4494898]
Power (a^2):                    [1. 4. 9.]

üìä Statistical Functions:
Mean of a:                      2.0
Standard deviation of b:        0.8164966106414795
Variance of a:                  0.6666666865348816
Median of b:                    5.0
Maximum value in a:             3.0
Minimum value in b:             4.0

üî¢ Aggregation Functions:
Sum of all elements in a:       6.0


# JIT Compilation in JAX

## What is JIT Compilation?

**JIT (Just-In-Time) compilation** makes your code run faster - often 10-100x faster. When you add `@jax.jit` to a function, JAX:
1. **Traces** your function the first time it runs to understand what it does
2. **Compiles** it to optimized machine code using XLA
3. **Caches** the compiled version so future calls are fast
4. **Executes** the optimized code on subsequent calls

Regular Python is interpreted line-by-line. JIT-compiled code gets translated to machine instructions that run directly on your hardware. That's where the speed comes from.

## Why Use JIT?

- **Speed**: 10x-100x faster for numerical computations
- **Hardware acceleration**: Automatically leverages GPUs/TPUs
- **Optimization**: XLA fuses operations and eliminates redundant computations
- **Parallelization**: Automatically parallelizes independent operations

## When JIT Compilation FAILS or Behaves Unexpectedly

### ‚ö†Ô∏è 1. DATA-DEPENDENT CONTROL FLOW

This is probably the most common JIT gotcha. You can't use regular Python `if/else` statements when the condition depends on **array values**:

```python
# ‚ùå THIS WILL FAIL OR BEHAVE INCORRECTLY:
@jax.jit
def bad_function(x):
    if x > 0:  # ‚ùå Control flow depends on the VALUE of x
        return x * 2
    else:
        return x * 3
```

**Why?** During tracing, JAX doesn't know the actual value of `x` - it only knows its shape and type. So it can't decide which branch to take!

**Solution:** Use JAX's special control flow operations:
- `jnp.where(condition, true_val, false_val)` - for element-wise conditionals
- `jax.lax.cond(pred, true_fun, false_fun, operand)` - for scalar conditionals
- `jax.lax.switch()` - for multiple branches
- `jax.lax.select()` - for choosing between values

```python
# ‚úÖ THIS WORKS:
@jax.jit
def good_function(x):
    return jnp.where(x > 0, x * 2, x * 3)  # ‚úÖ JAX-compatible conditional
```

### ‚ö†Ô∏è 2. PYTHON SIDE EFFECTS

Side effects are operations that modify state outside the function or interact with the external world:

```python
# ‚ùå THESE DON'T WORK AS EXPECTED IN JIT:
@jax.jit
def has_side_effects(x):
    print(f"Value is {x}")  # ‚ùå Print only happens during tracing!
    global counter
    counter += 1  # ‚ùå Modifying global state
    my_list.append(x)  # ‚ùå Modifying external data structures
    return x * 2
```

**Why?** JIT traces the function ONCE, then caches the compiled version. Side effects only execute during tracing, not during every call!

**What happens:**
- `print()` statements execute only the first time (during tracing)
- Global variables are captured at trace time, not updated during execution
- File I/O, database calls, etc. won't work as expected

### ‚ö†Ô∏è 3. DATA-DEPENDENT LOOPS

```python
# ‚ùå THIS FAILS:
@jax.jit
def bad_loop(x):
    for i in range(int(x)):  # ‚ùå Loop count depends on x's value
        x = x + 1
    return x
```

**Solution:** Use `jax.lax.fori_loop()` or `jax.lax.while_loop()` for dynamic loops.

### ‚ö†Ô∏è 4. SHAPE-CHANGING OPERATIONS

```python
# ‚ùå THIS FAILS:
@jax.jit
def dynamic_shape(x):
    if x.sum() > 0:
        return x[:10]  # Shape changes based on condition
    return x
```

JIT requires shapes to be known at compile time. Dynamic shapes break this requirement.

### ‚ö†Ô∏è 5. IN-PLACE MUTATIONS

```python
# ‚ùå THIS FAILS:
@jax.jit
def mutate_array(x):
    x[0] = 10  # ‚ùå JAX arrays are immutable!
    return x
```

**Solution:** Use `.at[].set()` syntax:
```python
# ‚úÖ THIS WORKS:
@jax.jit
def update_array(x):
    return x.at[0].set(10)  # ‚úÖ Returns new array
```

## When to Use JIT

‚úÖ **USE JIT FOR:**
- Pure functions (no side effects)
- Functions operating on JAX arrays
- Numerical computations (matrix operations, neural networks, simulations)
- Functions called repeatedly with similar input shapes
- Performance-critical code

‚ùå **DON'T USE JIT FOR:**
- Functions with print/debug statements
- Code that modifies global state
- Functions with Python control flow depending on array values
- Small, one-off computations (compilation overhead > benefit)
- Code interacting with external systems (files, databases, APIs)

## Best Practices

1. **Keep functions pure**: Input ‚Üí Output, no side effects
2. **Use JAX control flow**: `jnp.where()`, `jax.lax.cond()`, etc.
3. **Warm up the JIT**: Run once before timing to avoid compilation overhead
4. **Use `.block_until_ready()`**: JAX executes asynchronously by default
5. **Inspect with `jax.make_jaxpr()`**: See the intermediate representation
6. **Static arguments**: Use `static_argnums` for non-array arguments

---

In [12]:
# =============================================================================
# JIT COMPILATION - PRACTICAL EXAMPLES
# =============================================================================

import time
import jax
import jax.numpy as jnp

# -----------------------------------------------------------------------------
# EXAMPLE 1: JIT vs Non-JIT Performance Comparison
# -----------------------------------------------------------------------------
# The Collatz conjecture: Take any positive integer n. If n is even, divide it
# by 2. If n is odd, multiply it by 3 and add 1. Repeat the process.

print("=" * 70)
print("PERFORMANCE COMPARISON: WITH JIT vs WITHOUT JIT")
print("=" * 70)

# VERSION 1: WITHOUT JIT
def collatz_no_jit(x):
    """
    Collatz step WITHOUT JIT compilation.
    Uses jnp.where() for vectorized conditional logic.
    """
    return jnp.where(x % 2 == 0, x // 2, 3 * x + 1)

# VERSION 2: WITH JIT
@jax.jit
def collatz_with_jit(x):
    """
    Collatz step WITH JIT compilation.
    Same logic, but decorated with @jax.jit for optimization.
    """
    return jnp.where(x % 2 == 0, x // 2, 3 * x + 1)

# Create test arrays of different sizes
small_arr = jnp.arange(1, 1001)        # 1K elements
medium_arr = jnp.arange(1, 100001)     # 100K elements
large_arr = jnp.arange(1, 1000001)     # 1M elements

print("\nüîß Warming up JIT compiler (first call triggers compilation)...")
_ = collatz_with_jit(large_arr).block_until_ready()
print("‚úÖ JIT compilation complete! Compiled code is now cached.\n")

# Test function for timing
def benchmark_comparison(arr, label, iterations=10):
    """Benchmark both versions and compare results."""
    print(f"\nüìä {label} - {len(arr):,} elements, {iterations} iterations:")
    print("-" * 70)
    
    # Benchmark WITHOUT JIT
    start = time.time()
    for _ in range(iterations):
        result_no_jit = collatz_no_jit(arr).block_until_ready()
    time_no_jit = time.time() - start
    
    # Benchmark WITH JIT
    start = time.time()
    for _ in range(iterations):
        result_with_jit = collatz_with_jit(arr).block_until_ready()
    time_with_jit = time.time() - start
    
    # Calculate speedup
    speedup = time_no_jit / time_with_jit if time_with_jit > 0 else float('inf')
    
    # Display results side-by-side
    print(f"{'WITHOUT JIT:':20} {time_no_jit:8.6f} seconds")
    print(f"{'WITH JIT:':20} {time_with_jit:8.6f} seconds")
    print(f"{'SPEEDUP:':20} {speedup:8.2f}x faster")
    
    # Verify results match
    if jnp.allclose(result_no_jit, result_with_jit):
        print(f"‚úÖ Results match! First 5 values: {result_with_jit[:5]}")
    else:
        print(f"‚ö†Ô∏è  Warning: Results differ!")
    
    return speedup

# Run benchmarks for different array sizes
speedup_small = benchmark_comparison(small_arr, "SMALL ARRAY", iterations=100)
speedup_medium = benchmark_comparison(medium_arr, "MEDIUM ARRAY", iterations=50)
speedup_large = benchmark_comparison(large_arr, "LARGE ARRAY", iterations=10)

print("\n" + "=" * 70)
print("üìà PERFORMANCE SUMMARY")
print("=" * 70)
print(f"Small (1K):     {speedup_small:6.2f}x speedup")
print(f"Medium (100K):  {speedup_medium:6.2f}x speedup")
print(f"Large (1M):     {speedup_large:6.2f}x speedup")
print("\n‚úÖ Key Insight: JIT speedup increases with array size!")

# -----------------------------------------------------------------------------
# UNDERSTANDING ASYNCHRONOUS EXECUTION
# -----------------------------------------------------------------------------
print("\n" + "=" * 70)
print("üîÑ UNDERSTANDING .block_until_ready()")
print("=" * 70)
print("""
JAX executes operations ASYNCHRONOUSLY by default for maximum performance.

What this means:
1. JAX queues operations and returns control to Python immediately
2. Actual computation happens in the background on the accelerator
3. Without .block_until_ready(), you'd measure queue time, not compute time!

Example:
    result = collatz(arr)           # Returns instantly, not done yet!
    print(result)                   # NOW it blocks to get the value

For accurate timing, ALWAYS use .block_until_ready() after the computation:
    result = collatz(arr).block_until_ready()  # ‚úÖ Waits for completion
""")

# -----------------------------------------------------------------------------
# UNDERSTANDING JAXPR - JAX's Intermediate Representation
# -----------------------------------------------------------------------------
print("=" * 70)
print("üîç JAXPR - JAX's Intermediate Representation")
print("=" * 70)
print("""
JAXPR is like assembly language for JAX. It shows the low-level operations
that XLA compiles into machine code. Think of it as a peek under the hood!
""")

sample_arr = jnp.arange(1, 11)
print(f"\nJAXPR for collatz_with_jit with input shape {sample_arr.shape}:")
print("-" * 70)
print(jax.make_jaxpr(collatz_with_jit)(sample_arr))
print()

print("What you see:")
print("  ‚Ä¢ Input parameters (a:i32[10]) - integer array with 10 elements")
print("  ‚Ä¢ Primitive operations: mod, eq, where, floordiv, mul, add")
print("  ‚Ä¢ Data flow through the computation")
print("  ‚Ä¢ This gets sent to XLA for optimization and compilation!\n")

# -----------------------------------------------------------------------------
# EXAMPLE 2: Why Python if/else Fails with JIT
# -----------------------------------------------------------------------------
print("=" * 70)
print("‚ö†Ô∏è  DEMONSTRATION: Why Python Control Flow Breaks JIT")
print("=" * 70)

# ‚ùå INCORRECT: Using Python if/else with array values
@jax.jit
def broken_conditional(x):
    """
    This will behave INCORRECTLY when JIT-compiled!
    During tracing, JAX doesn't know x's value, only its shape/type.
    It picks ONE branch and always uses that branch.
    """
    if x > 0:  # ‚ùå Compares abstract tracer, not actual value
        return x * 2
    else:
        return x * 3

print("\n‚ùå Testing broken_conditional with Python if/else:")
try:
    result_pos = broken_conditional(jnp.array(5.0))
    print(f"   broken_conditional(5.0)  = {result_pos}")
    
    result_neg = broken_conditional(jnp.array(-5.0))
    print(f"   broken_conditional(-5.0) = {result_neg}")
    
    print(f"\n‚ö†Ô∏è  PROBLEM: Both give same result! Only one branch was compiled.")
    print(f"   Expected: 10.0 and -15.0, but got {result_pos} and {result_neg}")
except Exception as e:
    print(f"   Error: {type(e).__name__}: {e}")
    print(f"   This happens because JAX can't determine the branch during tracing!")

# ‚úÖ CORRECT: Using JAX-compatible control flow
@jax.jit
def correct_conditional(x):
    """
    Correct version using jnp.where().
    Evaluates BOTH branches and selects based on condition.
    Works with JIT because no Python control flow is needed!
    """
    return jnp.where(x > 0, x * 2, x * 3)

print("\n‚úÖ Testing correct_conditional with jnp.where():")
result_pos = correct_conditional(jnp.array(5.0))
result_neg = correct_conditional(jnp.array(-5.0))
print(f"   correct_conditional(5.0)  = {result_pos}  ‚úì")
print(f"   correct_conditional(-5.0) = {result_neg} ‚úì")
print(f"   Both results are CORRECT!")

# Side-by-side comparison
print("\n" + "=" * 70)
print("COMPARISON: if/else vs jnp.where()")
print("=" * 70)
test_values = jnp.array([5.0, -3.0, 2.0, -8.0, 0.0])
correct_results = correct_conditional(test_values)
print(f"Input values:     {test_values}")
print(f"jnp.where() results: {correct_results}")
print(f"Expected (x>0 ? 2x : 3x): [10.0, -9.0, 4.0, -24.0, 0.0]")
print(f"Match: {jnp.allclose(correct_results, jnp.array([10.0, -9.0, 4.0, -24.0, 0.0]))}")

# -----------------------------------------------------------------------------
# EXAMPLE 3: Side Effects in JIT - Print Statements
# -----------------------------------------------------------------------------
print("\n" + "=" * 70)
print("üñ®Ô∏è  DEMONSTRATION: Side Effects Only Happen During Tracing")
print("=" * 70)

@jax.jit
def function_with_print(x):
    """
    Print statements only execute during FIRST call (tracing phase).
    Subsequent calls use cached compiled code without prints!
    """
    print(f"   üîç TRACING: Inside function with x = {x}")
    return x * 2

print("\n1Ô∏è‚É£  First call (triggers tracing and compilation):")
result1 = function_with_print(jnp.array(10.0))
print(f"   Result: {result1}\n")

print("2Ô∏è‚É£  Second call (uses cached compiled version):")
result2 = function_with_print(jnp.array(20.0))
print(f"   Result: {result2}")
print(f"   üëÜ Notice: The print INSIDE the function didn't execute!\n")

print("3Ô∏è‚É£  Third call (still using cached version):")
result3 = function_with_print(jnp.array(30.0))
print(f"   Result: {result3}")
print(f"   üëÜ Still no print - using cached compilation\n")

print("üí° Key Point: Side effects (print, global vars, I/O) only happen once!")

# -----------------------------------------------------------------------------
# EXAMPLE 4: When NOT to Use JIT - Compilation Overhead
# -----------------------------------------------------------------------------
print("\n" + "=" * 70)
print("‚öñÔ∏è  DEMONSTRATION: JIT Overhead vs Benefit")
print("=" * 70)

def simple_add_no_jit(x):
    """Simple addition without JIT."""
    return x + 1

@jax.jit
def simple_add_with_jit(x):
    """Simple addition with JIT."""
    return x + 1

# SMALL COMPUTATION - JIT overhead dominates
print("\nüìâ TINY ARRAYS (3 elements, 1000 iterations):")
small_arr = jnp.array([1.0, 2.0, 3.0])

start = time.time()
for _ in range(1000):
    _ = simple_add_no_jit(small_arr)
time_no_jit = time.time() - start

_ = simple_add_with_jit(small_arr).block_until_ready()  # Warm up

start = time.time()
for _ in range(1000):
    _ = simple_add_with_jit(small_arr).block_until_ready()
time_with_jit = time.time() - start

print(f"  WITHOUT JIT: {time_no_jit:.6f} seconds")
print(f"  WITH JIT:    {time_with_jit:.6f} seconds")
speedup = time_no_jit / time_with_jit
print(f"  Speedup:     {speedup:.2f}x")
if speedup < 1.5:
    print(f"  ‚ö†Ô∏è  JIT overhead isn't worth it for tiny computations!")

# LARGE COMPUTATION - JIT benefit is clear
print("\nüìà LARGE ARRAYS (1M elements, 100 iterations):")
large_arr = jnp.arange(1000000.0)

start = time.time()
for _ in range(100):
    _ = simple_add_no_jit(large_arr)
time_no_jit = time.time() - start

_ = simple_add_with_jit(large_arr).block_until_ready()  # Warm up

start = time.time()
for _ in range(100):
    _ = simple_add_with_jit(large_arr).block_until_ready()
time_with_jit = time.time() - start

print(f"  WITHOUT JIT: {time_no_jit:.6f} seconds")
print(f"  WITH JIT:    {time_with_jit:.6f} seconds")
print(f"  Speedup:     {time_no_jit/time_with_jit:.2f}x")
print(f"  ‚úÖ JIT provides significant benefit for large computations!")

# -----------------------------------------------------------------------------
# FINAL SUMMARY
# -----------------------------------------------------------------------------
print("\n" + "=" * 70)
print("üéØ KEY TAKEAWAYS - JIT COMPILATION")
print("=" * 70)
print("""
1. ‚úÖ JIT provides 2x-100x speedup for large numerical computations
2. ‚úÖ Use jnp.where() for conditionals, NOT Python if/else with array values
3. ‚úÖ Pure functions work best (no side effects like print, globals, I/O)
4. ‚úÖ Warm up JIT before timing (first call includes compilation overhead)
5. ‚úÖ Use .block_until_ready() for accurate timing (JAX is async by default)
6. ‚úÖ JIT benefit increases with array size and computation complexity
7. ‚úÖ Inspect with jax.make_jaxpr() to see the compiled representation
8. ‚ùå Python control flow (if/else) that depends on array VALUES breaks JIT
9. ‚ùå Side effects (print, globals, I/O) only happen during tracing
10. ‚ùå Don't JIT tiny functions - compilation overhead isn't worth it

üí° Best Use Cases: Matrix operations, neural networks, scientific simulations,
   repeated computations on large arrays
""")

PERFORMANCE COMPARISON: WITH JIT vs WITHOUT JIT

üîß Warming up JIT compiler (first call triggers compilation)...
‚úÖ JIT compilation complete! Compiled code is now cached.


üìä SMALL ARRAY - 1,000 elements, 100 iterations:
----------------------------------------------------------------------
WITHOUT JIT:         0.095370 seconds
WITH JIT:            0.020290 seconds
SPEEDUP:                 4.70x faster
‚úÖ Results match! First 5 values: [ 4  1 10  2 16]

üìä MEDIUM ARRAY - 100,000 elements, 50 iterations:
----------------------------------------------------------------------
WITHOUT JIT:         0.119402 seconds
WITH JIT:            0.023932 seconds
SPEEDUP:                 4.99x faster
‚úÖ Results match! First 5 values: [ 4  1 10  2 16]

üìä LARGE ARRAY - 1,000,000 elements, 10 iterations:
----------------------------------------------------------------------
WITHOUT JIT:         0.123240 seconds
WITH JIT:            0.008304 seconds
SPEEDUP:                14.84x faster
WITHO

## Quick Reference: JIT-Compatible Control Flow

When you need conditionals in JIT-compiled functions, use these JAX operations:

| Scenario | ‚ùå Don't Use | ‚úÖ Use Instead |
|----------|--------------|----------------|
| Element-wise conditional | `if x > 0: ...` | `jnp.where(x > 0, true_val, false_val)` |
| Scalar conditional | `if x > 0: ...` | `jax.lax.cond(x > 0, true_fn, false_fn, operand)` |
| Multiple branches | `if/elif/else` | `jax.lax.switch(index, branches, operand)` |
| Dynamic loops | `for i in range(int(x)): ...` | `jax.lax.fori_loop(start, end, body_fn, init)` |
| While loops | `while condition: ...` | `jax.lax.while_loop(cond_fn, body_fn, init)` |
| Array updates | `arr[i] = val` | `arr.at[i].set(val)` |

**Why?** During JIT tracing, JAX works with abstract values (shapes/types), not actual data. Python control flow needs concrete values, which aren't available during tracing. JAX's control flow ops are designed to work with abstract values!

# Automatic Differentiation in JAX

Automatic differentiation (autodiff) lets JAX compute derivatives of any function. You write the forward pass, JAX figures out the gradients. This is what makes training neural networks possible without deriving backprop by hand.

## Key Functions

- **`jax.grad()`** - Gradient of scalar output w.r.t. first argument
- **`jax.value_and_grad()`** - Returns both function value and gradient
- **`jax.jacobian()`** - Jacobian matrix for vector-valued functions
- **`jax.hessian()`** - Hessian matrix (second derivatives)

## How It Works

JAX uses **reverse-mode autodiff** (backpropagation) by default. It traces your function during execution and builds a computation graph, then walks backward through it to compute derivatives.

---

In [14]:
# =============================================================================
# AUTOMATIC DIFFERENTIATION - BASICS
# =============================================================================

import jax
import jax.numpy as jnp
import numpy as np

print("=" * 70)
print("COMPUTING GRADIENTS WITH JAX")
print("=" * 70)

# -----------------------------------------------------------------------------
# Example 1: Basic Gradient - Single Variable
# -----------------------------------------------------------------------------
print("\n1Ô∏è‚É£  BASIC GRADIENT")
print("-" * 70)

def f(x):
    """Simple quadratic function: f(x) = x^2"""
    return x ** 2

# Create gradient function
df_dx = jax.grad(f)

# Evaluate at different points
x = 3.0
print(f"Function: f(x) = x^2")
print(f"f({x}) = {f(x)}")
print(f"f'({x}) = {df_dx(x)}")
print(f"Expected: 2*{x} = {2*x} ‚úì")

# -----------------------------------------------------------------------------
# Example 2: Multi-variable Function
# -----------------------------------------------------------------------------
print("\n2Ô∏è‚É£  MULTI-VARIABLE GRADIENT")
print("-" * 70)

def g(x, y):
    """Function with two inputs: g(x,y) = x^2 + 3xy + y^2"""
    return x**2 + 3*x*y + y**2

# grad() computes gradient w.r.t. FIRST argument by default
dg_dx = jax.grad(g)

# To get gradient w.r.t. other arguments, use argnums
dg_dy = jax.grad(g, argnums=1)

x, y = 2.0, 1.0
print(f"Function: g(x,y) = x^2 + 3xy + y^2")
print(f"g({x}, {y}) = {g(x, y)}")
print(f"‚àÇg/‚àÇx at ({x},{y}) = {dg_dx(x, y)}")
print(f"‚àÇg/‚àÇy at ({x},{y}) = {dg_dy(x, y)}")
print(f"Expected: ‚àÇg/‚àÇx = 2x + 3y = {2*x + 3*y} ‚úì")
print(f"Expected: ‚àÇg/‚àÇy = 3x + 2y = {3*x + 2*y} ‚úì")

# -----------------------------------------------------------------------------
# Example 3: value_and_grad - Get Both Function Value and Gradient
# -----------------------------------------------------------------------------
print("\n3Ô∏è‚É£  VALUE AND GRADIENT TOGETHER")
print("-" * 70)

def loss(params):
    """Typical ML loss function"""
    return jnp.sum(params ** 2)

# This is super useful for optimization - get loss and gradient in one call
val_and_grad_fn = jax.value_and_grad(loss)

params = jnp.array([1.0, 2.0, 3.0])
value, gradient = val_and_grad_fn(params)

print(f"Parameters: {params}")
print(f"Loss value: {value}")
print(f"Gradient: {gradient}")
print(f"Expected gradient: 2*params = {2*params} ‚úì")

# -----------------------------------------------------------------------------
# Example 4: Gradients of Vector Functions - Jacobian
# -----------------------------------------------------------------------------
print("\n4Ô∏è‚É£  JACOBIAN (for vector-valued functions)")
print("-" * 70)

def vector_func(x):
    """Returns a vector: [x^2, x^3, sin(x)]"""
    return jnp.array([x**2, x**3, jnp.sin(x)])

# Jacobian computes all partial derivatives
jacobian_fn = jax.jacobian(vector_func)

x = 2.0
jac = jacobian_fn(x)
print(f"Function: f(x) = [x^2, x^3, sin(x)]")
print(f"Jacobian at x={x}:")
print(f"  df‚ÇÅ/dx = 2x = {jac[0]:.4f}")
print(f"  df‚ÇÇ/dx = 3x^2 = {jac[1]:.4f}")
print(f"  df‚ÇÉ/dx = cos(x) = {jac[2]:.4f}")

# -----------------------------------------------------------------------------
# Example 5: Second Derivatives - Hessian
# -----------------------------------------------------------------------------
print("\n5Ô∏è‚É£  HESSIAN (second derivatives)")
print("-" * 70)

def h(x):
    """Multi-variable function for Hessian demo"""
    return jnp.array([x[0]**2 + x[1]**2, x[0]*x[1]])

# Hessian: matrix of second partial derivatives
hessian_fn = jax.hessian(lambda x: x[0]**2 + x[1]**2)

x = jnp.array([1.0, 2.0])
hess = hessian_fn(x)
print(f"Function: f(x,y) = x^2 + y^2")
print(f"Hessian matrix at {x}:")
print(hess)
print("(Second derivatives: diagonal is 2, off-diagonal is 0)")

# -----------------------------------------------------------------------------
# Example 6: Higher-Order Derivatives
# -----------------------------------------------------------------------------
print("\n6Ô∏è‚É£  HIGHER-ORDER DERIVATIVES")
print("-" * 70)

def cubic(x):
    """f(x) = x^3"""
    return x ** 3

# First derivative
first_deriv = jax.grad(cubic)
# Second derivative (gradient of gradient)
second_deriv = jax.grad(first_deriv)
# Third derivative
third_deriv = jax.grad(second_deriv)

x = 2.0
print(f"Function: f(x) = x^3")
print(f"f({x}) = {cubic(x)}")
print(f"f'({x}) = {first_deriv(x)} (expected: 3x^2 = {3*x**2})")
print(f"f''({x}) = {second_deriv(x)} (expected: 6x = {6*x})")
print(f"f'''({x}) = {third_deriv(x)} (expected: 6)")

# -----------------------------------------------------------------------------
# Example 7: Gradient with Respect to Multiple Arguments
# -----------------------------------------------------------------------------
print("\n7Ô∏è‚É£  GRADIENTS W.R.T. MULTIPLE ARGUMENTS")
print("-" * 70)

def model(weights, bias, x):
    """Simple linear model: y = w*x + b"""
    return jnp.dot(weights, x) + bias

# Get gradients w.r.t. first two arguments (weights and bias)
grad_fn = jax.grad(model, argnums=(0, 1))

w = jnp.array([1.0, 2.0, 3.0])
b = 0.5
x = jnp.array([1.0, 1.0, 1.0])

grad_w, grad_b = grad_fn(w, b, x)
print(f"Model: y = w¬∑x + b")
print(f"Gradient w.r.t. weights: {grad_w}")
print(f"Gradient w.r.t. bias: {grad_b}")
print(f"(dL/dw = x, dL/db = 1)")

# -----------------------------------------------------------------------------
# Example 8: Practical Example - Gradient Descent
# -----------------------------------------------------------------------------
print("\n8Ô∏è‚É£  GRADIENT DESCENT OPTIMIZATION")
print("-" * 70)

def loss_fn(params, x, y):
    """MSE loss for linear regression"""
    prediction = params[0] * x + params[1]
    return jnp.mean((prediction - y) ** 2)

# Training data: y = 2x + 1 with some noise
np.random.seed(42)
x_data = jnp.array(np.linspace(0, 10, 20))
y_data = 2 * x_data + 1 + np.random.randn(20) * 0.5

# Initialize parameters
params = jnp.array([0.0, 0.0])  # [slope, intercept]

# Training loop
learning_rate = 0.01
grad_fn = jax.grad(loss_fn)

print("Training linear regression with gradient descent...")
print(f"True parameters: slope=2.0, intercept=1.0")
print(f"Initial params: {params}")

for step in range(100):
    grads = grad_fn(params, x_data, y_data)
    params = params - learning_rate * grads
    
    if step % 20 == 0:
        loss = loss_fn(params, x_data, y_data)
        print(f"Step {step:3d}: loss={loss:.4f}, params={params}")

final_loss = loss_fn(params, x_data, y_data)
print(f"\nFinal params: slope={params[0]:.3f}, intercept={params[1]:.3f}")
print(f"Final loss: {final_loss:.4f}")

print("\n" + "=" * 70)
print("KEY POINTS - AUTOMATIC DIFFERENTIATION")
print("=" * 70)
print("""
‚úÖ jax.grad() computes exact derivatives (not numerical approximations)
‚úÖ Works with any Python/JAX function - no manual chain rule needed
‚úÖ Efficient reverse-mode autodiff (backpropagation)
‚úÖ Can compute higher-order derivatives by composing grad()
‚úÖ value_and_grad() gives both function value and gradient
‚úÖ Jacobian and Hessian for vector functions and second derivatives
‚úÖ argnums parameter controls which arguments to differentiate
‚úÖ Combine with JIT for ultra-fast gradient computations
""")

COMPUTING GRADIENTS WITH JAX

1Ô∏è‚É£  BASIC GRADIENT
----------------------------------------------------------------------
Function: f(x) = x^2
f(3.0) = 9.0
f'(3.0) = 6.0
Expected: 2*3.0 = 6.0 ‚úì

2Ô∏è‚É£  MULTI-VARIABLE GRADIENT
----------------------------------------------------------------------
Function: g(x,y) = x^2 + 3xy + y^2
g(2.0, 1.0) = 11.0
‚àÇg/‚àÇx at (2.0,1.0) = 7.0
‚àÇg/‚àÇy at (2.0,1.0) = 8.0
Expected: ‚àÇg/‚àÇx = 2x + 3y = 7.0 ‚úì
Expected: ‚àÇg/‚àÇy = 3x + 2y = 8.0 ‚úì

3Ô∏è‚É£  VALUE AND GRADIENT TOGETHER
----------------------------------------------------------------------
Parameters: [1. 2. 3.]
Loss value: 14.0
Gradient: [2. 4. 6.]
Expected gradient: 2*params = [2. 4. 6.] ‚úì

4Ô∏è‚É£  JACOBIAN (for vector-valued functions)
----------------------------------------------------------------------
Function: f(x) = [x^2, x^3, sin(x)]
Jacobian at x=2.0:
  df‚ÇÅ/dx = 2x = 4.0000
  df‚ÇÇ/dx = 3x^2 = 12.0000
  df‚ÇÉ/dx = cos(x) = -0.4161

5Ô∏è‚É£  HESSIAN (second deriv

# Vectorization with vmap

`vmap` is automatic batching. You write code that works on a single example, then `vmap` transforms it to work on batches - no loops needed.

## Why vmap Matters

Normally you'd write batched operations manually (lots of index wrangling). With `vmap`, you write single-example code and JAX handles the batching. This is particularly useful with `grad` for computing per-sample gradients.

## How It Works

`jax.vmap(function, in_axes, out_axes)` transforms a function to automatically batch over specified axes. The key parameter is `in_axes` which tells JAX which dimensions to map over.

---

In [15]:
# =============================================================================
# VECTORIZATION WITH VMAP
# =============================================================================

import jax
import jax.numpy as jnp
import numpy as np
import time

print("=" * 70)
print("AUTOMATIC BATCHING WITH VMAP")
print("=" * 70)

# -----------------------------------------------------------------------------
# Example 1: Basic vmap - Single Input Batching
# -----------------------------------------------------------------------------
print("\n1Ô∏è‚É£  BASIC VMAP")
print("-" * 70)

def square(x):
    """Operates on a SINGLE number"""
    return x ** 2

# Create vectorized version
vectorized_square = jax.vmap(square)

# Apply to batch
batch = jnp.array([1.0, 2.0, 3.0, 4.0, 5.0])
result = vectorized_square(batch)

print("Function for single input: square(x) = x^2")
print(f"Input batch: {batch}")
print(f"Output: {result}")
print("vmap automatically applies the function to each element!")

# -----------------------------------------------------------------------------
# Example 2: Multiple Inputs with Different Batching
# -----------------------------------------------------------------------------
print("\n2Ô∏è‚É£  VMAP WITH MULTIPLE INPUTS")
print("-" * 70)

def weighted_sum(x, weight):
    """Compute weighted sum: x * weight"""
    return x * weight

# Batch over first argument, keep second fixed
# in_axes=(0, None) means: map over axis 0 of x, don't map over weight
batched_fn = jax.vmap(weighted_sum, in_axes=(0, None))

x_batch = jnp.array([1.0, 2.0, 3.0])
weight = 2.5

result = batched_fn(x_batch, weight)
print(f"Batch: {x_batch}")
print(f"Fixed weight: {weight}")
print(f"Result: {result}")
print("in_axes=(0, None) batches first arg, keeps second fixed")

# Batch over both arguments
batched_fn_both = jax.vmap(weighted_sum, in_axes=(0, 0))
weights_batch = jnp.array([1.0, 2.0, 3.0])
result_both = batched_fn_both(x_batch, weights_batch)
print(f"\nWith batched weights: {weights_batch}")
print(f"Result: {result_both}")

# -----------------------------------------------------------------------------
# Example 3: Matrix-Vector Product with vmap
# -----------------------------------------------------------------------------
print("\n3Ô∏è‚É£  MATRIX-VECTOR PRODUCTS")
print("-" * 70)

def matvec(matrix, vector):
    """Single matrix-vector product"""
    return jnp.dot(matrix, vector)

# Batch over vectors (multiple vectors, same matrix)
batch_matvec = jax.vmap(matvec, in_axes=(None, 0))

matrix = jnp.array([[1.0, 2.0], [3.0, 4.0]])
vectors = jnp.array([[1.0, 0.0], [0.0, 1.0], [1.0, 1.0]])

results = batch_matvec(matrix, vectors)
print(f"Matrix:\n{matrix}")
print(f"Vector batch (shape {vectors.shape}):\n{vectors}")
print(f"Results (shape {results.shape}):\n{results}")

# -----------------------------------------------------------------------------
# Example 4: Per-Sample Gradients (The Killer App!)
# -----------------------------------------------------------------------------
print("\n4Ô∏è‚É£  PER-SAMPLE GRADIENTS")
print("-" * 70)

def loss_single_sample(params, x, y):
    """Loss for ONE training example"""
    prediction = jnp.dot(params, x)
    return (prediction - y) ** 2

# Create function that computes gradient for one sample
grad_fn_single = jax.grad(loss_single_sample)

# Vectorize it to compute per-sample gradients for entire batch!
# in_axes=(None, 0, 0) means: same params, batch over x and y
per_sample_grads = jax.vmap(grad_fn_single, in_axes=(None, 0, 0))

# Test data
params = jnp.array([1.0, 2.0, 3.0])
x_batch = jnp.array([
    [1.0, 0.0, 0.0],
    [0.0, 1.0, 0.0],
    [0.0, 0.0, 1.0],
    [1.0, 1.0, 1.0]
])
y_batch = jnp.array([1.0, 2.0, 3.0, 6.0])

grads = per_sample_grads(params, x_batch, y_batch)
print(f"Parameters: {params}")
print(f"Batch size: {len(x_batch)}")
print(f"Per-sample gradients (shape {grads.shape}):")
print(grads)
print("\nEach row is the gradient for one training example!")
print("This is critical for differential privacy and some RL algorithms.")

# -----------------------------------------------------------------------------
# Example 5: Performance Comparison - Loop vs vmap
# -----------------------------------------------------------------------------
print("\n5Ô∏è‚É£  PERFORMANCE: LOOP vs VMAP")
print("-" * 70)

def compute_single(x):
    """Some computation on a single input"""
    return jnp.sum(jnp.sin(x) ** 2 + jnp.cos(x) ** 2)

# Manual loop version
def loop_version(batch):
    results = []
    for x in batch:
        results.append(compute_single(x))
    return jnp.array(results)

# vmap version
vmap_version = jax.vmap(compute_single)

# Generate test batch
batch_size = 1000
test_batch = jnp.ones((batch_size, 100))

# Warm up JIT
_ = vmap_version(test_batch)

# Benchmark loop
start = time.time()
for _ in range(100):
    _ = loop_version(test_batch)
time_loop = time.time() - start

# Benchmark vmap
start = time.time()
for _ in range(100):
    _ = vmap_version(test_batch)
time_vmap = time.time() - start

print(f"Batch size: {batch_size}")
print(f"Loop version:  {time_loop:.4f} seconds")
print(f"vmap version:  {time_vmap:.4f} seconds")
print(f"Speedup: {time_loop/time_vmap:.2f}x faster with vmap!")

# -----------------------------------------------------------------------------
# Example 6: Nested vmap - Batching Over Multiple Dimensions
# -----------------------------------------------------------------------------
print("\n6Ô∏è‚É£  NESTED VMAP")
print("-" * 70)

def pairwise_distance(x1, x2):
    """Distance between two vectors"""
    return jnp.sqrt(jnp.sum((x1 - x2) ** 2))

# Compute pairwise distances between all vectors in two sets
# First vmap: over first set
# Second vmap: over second set
pairwise_distances = jax.vmap(
    lambda x1: jax.vmap(lambda x2: pairwise_distance(x1, x2))(set2)
)

set1 = jnp.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]])
set2 = jnp.array([[1.0, 1.0], [2.0, 2.0]])

distances = pairwise_distances(set1)
print(f"Set 1 (shape {set1.shape}):\n{set1}")
print(f"Set 2 (shape {set2.shape}):\n{set2}")
print(f"Pairwise distances (shape {distances.shape}):\n{distances}")
print("Each row shows distances from one point in set1 to all points in set2")

# -----------------------------------------------------------------------------
# Example 7: vmap with Different Output Axes
# -----------------------------------------------------------------------------
print("\n7Ô∏è‚É£  CONTROLLING OUTPUT AXES")
print("-" * 70)

def create_matrix(scale):
    """Create a 2x2 matrix based on scale"""
    return jnp.array([[scale, 0], [0, scale]])

# Default: out_axes=0 (stack along first dimension)
vmap_default = jax.vmap(create_matrix)
scales = jnp.array([1.0, 2.0, 3.0])
result_default = vmap_default(scales)

print(f"Scales: {scales}")
print(f"Output shape with out_axes=0 (default): {result_default.shape}")
print(f"Result:\n{result_default}")

# With out_axes=1: stack along second dimension
vmap_axis1 = jax.vmap(create_matrix, out_axes=1)
result_axis1 = vmap_axis1(scales)
print(f"\nOutput shape with out_axes=1: {result_axis1.shape}")

# -----------------------------------------------------------------------------
# Example 8: Combining JIT and vmap
# -----------------------------------------------------------------------------
print("\n8Ô∏è‚É£  JIT + VMAP COMBO")
print("-" * 70)

def expensive_computation(x):
    """Some complex computation"""
    result = x
    for _ in range(10):
        result = jnp.sin(result) + jnp.cos(result)
    return result

# Combine transformations: JIT the vmapped function
fast_batch_compute = jax.jit(jax.vmap(expensive_computation))

batch = jnp.ones(1000)

# Warm up
_ = fast_batch_compute(batch)

# Benchmark
start = time.time()
for _ in range(1000):
    _ = fast_batch_compute(batch).block_until_ready()
time_taken = time.time() - start

print("Combined jax.jit(jax.vmap(function)) for maximum performance")
print(f"Time for 1000 iterations on batch of 1000: {time_taken:.4f}s")
print("Both JIT compilation and automatic batching working together!")

print("\n" + "=" * 70)
print("KEY POINTS - VMAP")
print("=" * 70)
print("""
‚úÖ vmap automatically batches functions - no manual loop writing
‚úÖ Write single-example code, vmap handles the batching
‚úÖ in_axes controls which dimensions to map over (0, None, etc.)
‚úÖ Much faster than Python loops (compiled vectorization)
‚úÖ Perfect for per-sample gradients: vmap(grad(loss))
‚úÖ Can nest vmap for multi-dimensional batching
‚úÖ Combine with JIT for maximum performance
‚úÖ out_axes controls how outputs are stacked
‚úÖ Cleaner code: intent is clear, no index juggling
""")

AUTOMATIC BATCHING WITH VMAP

1Ô∏è‚É£  BASIC VMAP
----------------------------------------------------------------------
Function for single input: square(x) = x^2
Input batch: [1. 2. 3. 4. 5.]
Output: [ 1.  4.  9. 16. 25.]
vmap automatically applies the function to each element!

2Ô∏è‚É£  VMAP WITH MULTIPLE INPUTS
----------------------------------------------------------------------
Batch: [1. 2. 3.]
Fixed weight: 2.5
Result: [2.5 5.  7.5]
in_axes=(0, None) batches first arg, keeps second fixed

With batched weights: [1. 2. 3.]
Result: [1. 4. 9.]

3Ô∏è‚É£  MATRIX-VECTOR PRODUCTS
----------------------------------------------------------------------
Matrix:
[[1. 2.]
 [3. 4.]]
Vector batch (shape (3, 2)):
[[1. 0.]
 [0. 1.]
 [1. 1.]]
Results (shape (3, 2)):
[[1. 3.]
 [2. 4.]
 [3. 7.]]

4Ô∏è‚É£  PER-SAMPLE GRADIENTS
----------------------------------------------------------------------
Parameters: [1. 2. 3.]
Batch size: 4
Per-sample gradients (shape (4, 3)):
[[0. 0. 0.]
 [0. 0. 0.]
 [0. 

# Putting It All Together

Let's see how JIT, autodiff, and vmap work together. We'll build and train a neural network from scratch using just JAX - no high-level frameworks, just the core transformations.

---

In [16]:
# =============================================================================
# COMPLETE EXAMPLE - TRAINING A NEURAL NETWORK
# =============================================================================

import jax
import jax.numpy as jnp
import numpy as np

print("=" * 70)
print("NEURAL NETWORK TRAINING - JAX PRIMITIVES ONLY")
print("=" * 70)

# -----------------------------------------------------------------------------
# Network Architecture and Initialization
# -----------------------------------------------------------------------------

def init_network_params(layer_sizes, key):
    """Initialize neural network parameters with random values."""
    params = []
    keys = jax.random.split(key, len(layer_sizes) - 1)
    
    for i, (key, n_in, n_out) in enumerate(zip(keys, layer_sizes[:-1], layer_sizes[1:])):
        # Xavier initialization
        weight_key, bias_key = jax.random.split(key)
        scale = jnp.sqrt(2.0 / n_in)
        W = scale * jax.random.normal(weight_key, (n_in, n_out))
        b = jnp.zeros(n_out)
        params.append({'W': W, 'b': b})
    
    return params

def relu(x):
    """ReLU activation function."""
    return jnp.maximum(0, x)

def forward_pass(params, x):
    """Forward pass through the network (for a SINGLE example)."""
    activation = x
    
    # Hidden layers with ReLU
    for layer in params[:-1]:
        activation = jnp.dot(activation, layer['W']) + layer['b']
        activation = relu(activation)
    
    # Output layer (no activation)
    final_layer = params[-1]
    output = jnp.dot(activation, final_layer['W']) + final_layer['b']
    
    return output

# Batch the forward pass using vmap
batched_forward = jax.vmap(forward_pass, in_axes=(None, 0))

# -----------------------------------------------------------------------------
# Loss Function and Gradient Computation
# -----------------------------------------------------------------------------

def mse_loss(params, x_batch, y_batch):
    """Mean squared error loss for the entire batch."""
    predictions = batched_forward(params, x_batch)
    return jnp.mean((predictions - y_batch) ** 2)

# Create loss and gradient function (JIT compiled for speed)
loss_and_grad = jax.jit(jax.value_and_grad(mse_loss))

# -----------------------------------------------------------------------------
# Training Loop
# -----------------------------------------------------------------------------

def train_network(params, x_train, y_train, num_epochs, learning_rate):
    """Train the network using gradient descent."""
    for epoch in range(num_epochs):
        loss, grads = loss_and_grad(params, x_train, y_train)
        
        # Update parameters
        params = [
            {
                'W': layer['W'] - learning_rate * grad_layer['W'],
                'b': layer['b'] - learning_rate * grad_layer['b']
            }
            for layer, grad_layer in zip(params, grads)
        ]
        
        if epoch % 100 == 0:
            print(f"Epoch {epoch:4d}: Loss = {loss:.6f}")
    
    return params

# -----------------------------------------------------------------------------
# Generate Training Data
# -----------------------------------------------------------------------------
print("\nüìä Generating synthetic training data...")

# Create a simple non-linear function: y = sin(x1) + cos(x2)
np.random.seed(42)
n_samples = 200
x_train = np.random.randn(n_samples, 2).astype(np.float32)
y_train = (np.sin(x_train[:, 0]) + np.cos(x_train[:, 1])).reshape(-1, 1).astype(np.float32)

x_train = jnp.array(x_train)
y_train = jnp.array(y_train)

print(f"Training data: {x_train.shape[0]} samples")
print(f"Input features: {x_train.shape[1]}")
print(f"Output dimension: {y_train.shape[1]}")

# -----------------------------------------------------------------------------
# Initialize and Train Network
# -----------------------------------------------------------------------------
print("\nüß† Initializing network: 2 ‚Üí 16 ‚Üí 16 ‚Üí 1")

layer_sizes = [2, 16, 16, 1]  # Input, hidden1, hidden2, output
key = jax.random.PRNGKey(0)
params = init_network_params(layer_sizes, key)

print(f"Total parameters: {sum(p['W'].size + p['b'].size for p in params)}")

print("\nüèãÔ∏è  Training network...")
params = train_network(params, x_train, y_train, num_epochs=500, learning_rate=0.01)

# -----------------------------------------------------------------------------
# Evaluate Trained Network
# -----------------------------------------------------------------------------
print("\nüìà Evaluating trained network...")

final_loss, _ = loss_and_grad(params, x_train, y_train)
print(f"Final training loss: {final_loss:.6f}")

# Test on a few examples
test_inputs = jnp.array([[0.0, 0.0], [1.0, 0.0], [0.0, 1.0]])
predictions = batched_forward(params, test_inputs)
expected = jnp.array([[jnp.sin(x[0]) + jnp.cos(x[1])] for x in test_inputs])

print("\nTest predictions vs expected:")
for i, (inp, pred, exp) in enumerate(zip(test_inputs, predictions, expected)):
    print(f"  Input: {inp}, Predicted: {pred[0]:.4f}, Expected: {exp[0]:.4f}")

# -----------------------------------------------------------------------------
# Demonstrate Per-Sample Gradients
# -----------------------------------------------------------------------------
print("\n" + "=" * 70)
print("COMPUTING PER-SAMPLE GRADIENTS")
print("=" * 70)

def loss_single(params, x, y):
    """Loss for a single example."""
    pred = forward_pass(params, x)
    return jnp.sum((pred - y) ** 2)

# Create per-sample gradient function
per_sample_grad_fn = jax.vmap(jax.grad(loss_single), in_axes=(None, 0, 0))

# Compute per-sample gradients for a small batch
small_batch_x = x_train[:5]
small_batch_y = y_train[:5]

per_sample_grads = per_sample_grad_fn(params, small_batch_x, small_batch_y)

print(f"Computed per-sample gradients for {len(small_batch_x)} examples")
print(f"First layer weight gradients shape: {per_sample_grads[0]['W'].shape}")
print("(5 samples, each with its own gradient)")

# Show that these differ from the averaged gradient
avg_grad = jax.grad(mse_loss)(params, small_batch_x, small_batch_y)
print(f"\nAverage gradient (first weight): {avg_grad[0]['W'][0, 0]:.6f}")
print(f"Per-sample gradients (first weight):")
for i in range(5):
    print(f"  Sample {i}: {per_sample_grads[0]['W'][i, 0, 0]:.6f}")

print("\n" + "=" * 70)
print("WHAT WE JUST DID")
print("=" * 70)
print("""
1. Built a neural network from scratch using JAX primitives
2. Used vmap to batch the forward pass (no manual loop)
3. Used jax.grad to compute gradients automatically
4. Used jax.jit to compile the loss+gradient computation
5. Trained with vanilla gradient descent (no frameworks!)
6. Computed per-sample gradients with vmap(grad(...))

Key JAX features in action:
‚úÖ JIT compilation - fast training loop
‚úÖ Automatic differentiation - no manual backprop
‚úÖ vmap - automatic batching and per-sample gradients
‚úÖ Composability - combine transformations seamlessly

This is the foundation of how JAX-based ML libraries (Flax, Haiku, Optax) work!
""")

NEURAL NETWORK TRAINING - JAX PRIMITIVES ONLY

üìä Generating synthetic training data...
Training data: 200 samples
Input features: 2
Output dimension: 1

üß† Initializing network: 2 ‚Üí 16 ‚Üí 16 ‚Üí 1
Total parameters: 337

üèãÔ∏è  Training network...
Total parameters: 337

üèãÔ∏è  Training network...
Epoch    0: Loss = 4.492578
Epoch  100: Loss = 0.138198
Epoch  200: Loss = 0.075353
Epoch  300: Loss = 0.054473
Epoch  400: Loss = 0.045023

üìà Evaluating trained network...
Final training loss: 0.039874
Epoch    0: Loss = 4.492578
Epoch  100: Loss = 0.138198
Epoch  200: Loss = 0.075353
Epoch  300: Loss = 0.054473
Epoch  400: Loss = 0.045023

üìà Evaluating trained network...
Final training loss: 0.039874

Test predictions vs expected:
  Input: [0. 0.], Predicted: 1.0070, Expected: 1.0000
  Input: [1. 0.], Predicted: 1.7170, Expected: 1.8415
  Input: [0. 1.], Predicted: 0.4573, Expected: 0.5403

COMPUTING PER-SAMPLE GRADIENTS

Test predictions vs expected:
  Input: [0. 0.], Predi

# Summary and Next Steps

## What You've Learned

This notebook covered the three core JAX transformations:

1. **JIT Compilation** (`jax.jit`) - Makes code run 10-100x faster by compiling to optimized machine code
2. **Automatic Differentiation** (`jax.grad`, `jax.jacobian`, etc.) - Computes exact derivatives without manual calculus
3. **Automatic Batching** (`jax.vmap`) - Transforms single-example code to work on batches efficiently

These three transformations compose freely. You can JIT a vmapped gradient function, or vmap a JIT-compiled function - they just work together.

## Key Takeaways

- JAX arrays are **immutable** - use `.at[].set()` instead of assignment
- Use **JAX control flow** (`jnp.where`, `jax.lax.cond`) inside JIT functions, not Python `if/else`
- **Pure functions** work best - avoid side effects like print statements
- `vmap` writes much cleaner code than manual batching
- Combine transformations for maximum power: `jax.jit(jax.vmap(jax.grad(...)))`

## What's Next

- **Random Numbers**: JAX uses explicit random keys (no global state)
- **pytrees**: JAX's way of handling nested structures (dicts, lists of arrays)
- **pmap**: Parallel map for multi-GPU/TPU computation
- **High-level libraries**: Flax (neural networks), Optax (optimizers), Haiku
- **Scan and while loops**: `jax.lax.scan` for efficient sequential operations

## Where JAX Shines

- Training neural networks (especially with custom architectures)
- Scientific computing and numerical optimization
- Differential equations and physics simulations
- Anywhere you need fast numerical code with gradients
- Research where you want control over the training loop

JAX gives you NumPy's simplicity with the performance and capabilities of a modern ML framework. That's the sweet spot.

---