# üìò Notebook 2: JIT Compilation - Making Your Code Blazing Fast

Welcome to the "speed" chapter of JAX! This notebook teaches you how to make your code **10-100x faster** using JIT compilation.

## üéØ What You'll Learn (30-40 minutes)

By the end of this notebook, you'll understand:
- ‚úÖ What JIT (Just-In-Time) compilation is and why it's powerful
- ‚úÖ How to use `@jax.jit` to accelerate code
- ‚úÖ **Critical pitfalls** that break JIT (and how to avoid them)
- ‚úÖ When JIT helps (and when it doesn't)
- ‚úÖ How to debug JIT issues

## ü§î What is JIT Compilation?

### The Problem: Python is Slow
Regular Python code is **interpreted** - executed line by line at runtime. This is flexible but slow for numerical computations.

### The Solution: JIT Compilation
**JIT (Just-In-Time) compilation** converts your Python function to **optimized machine code** that runs directly on your hardware.

### How JIT Works (4 Steps):
1. **Trace**: First time you call a JIT function, JAX traces it to understand what it does
2. **Compile**: JAX compiles it to optimized machine code using XLA
3. **Cache**: The compiled version is cached for reuse
4. **Execute**: Future calls use the fast compiled version

```python
@jax.jit  # This decorator enables JIT compilation
def fast_function(x):
    return x ** 2 + 2 * x + 1

# First call: Traces, compiles, executes (slow)
result = fast_function(5.0)

# Second call: Uses cached compiled version (FAST!)
result = fast_function(10.0)
```

### The Speedup
- **Small arrays**: 2-5x faster
- **Medium arrays**: 10-50x faster
- **Large arrays**: 50-100x faster
- **Complex operations**: Can be 1000x+ faster!

## ‚ö†Ô∏è JIT Pitfalls (Very Important!)

JIT is powerful but has **strict requirements**. The #1 source of JAX errors!

### Pitfall #1: Data-Dependent Control Flow ‚ùå
**This FAILS:**
```python
@jax.jit
def broken(x):
    if x > 0:  # ‚ùå Control flow depends on x's VALUE
        return x * 2
    else:
        return x * 3
```

**Why?** During tracing, JAX doesn't know x's value - only its shape and type!

**The Fix:**
```python
@jax.jit
def works(x):
    return jnp.where(x > 0, x * 2, x * 3)  # ‚úÖ JAX-compatible
```

### Pitfall #2: Side Effects Don't Work ‚ùå
**This FAILS (or behaves unexpectedly):**
```python
@jax.jit
def broken(x):
    print(f"x = {x}")  # ‚ùå Only prints ONCE (during tracing)
    global counter
    counter += 1  # ‚ùå Global state isn't updated
    return x * 2
```

**Why?** JIT traces once, then reuses the compiled code. Side effects only happen during tracing!

### Pitfall #3: Dynamic Shapes ‚ùå
JIT requires shapes to be known at compile time. Dynamic shapes break this.

### Pitfall #4: In-Place Mutations ‚ùå
Remember from Notebook 1? JAX arrays are immutable. Use `.at[].set()` instead.

## üéì What's in This Notebook?

1. **Performance comparisons** - JIT vs non-JIT (see the speedup!)
2. **Understanding async execution** - Why `.block_until_ready()` matters
3. **JAXPR inspection** - Peek under the hood
4. **Common pitfalls** - Data-dependent control flow, side effects
5. **Debugging JIT issues** - How to fix errors
6. **When NOT to use JIT** - Compilation overhead vs benefit
7. **Quick reference guide** - JAX-compatible control flow

## üöÄ Prerequisites

Before starting this notebook, you should:
- ‚úÖ Complete Notebook 1 (JAX Basics)
- ‚úÖ Understand what a function is
- ‚úÖ Know basic Python control flow (if/else, loops)

## üí° Key Takeaway

**JIT = Tracing ‚Üí Compiling ‚Üí Caching ‚Üí Fast Execution**

The first call is slow (compilation), but subsequent calls are **blazing fast**!

Let's see it in action! üî•

In [1]:
# =============================================================================
# JIT COMPILATION - PRACTICAL EXAMPLES
# =============================================================================

import time
import jax
import jax.numpy as jnp

# -----------------------------------------------------------------------------
# EXAMPLE 1: JIT vs Non-JIT Performance Comparison
# -----------------------------------------------------------------------------
# The Collatz conjecture: Take any positive integer n. If n is even, divide it
# by 2. If n is odd, multiply it by 3 and add 1. Repeat the process.

print("=" * 70)
print("PERFORMANCE COMPARISON: WITH JIT vs WITHOUT JIT")
print("=" * 70)

# VERSION 1: WITHOUT JIT
def collatz_no_jit(x):
    """
    Collatz step WITHOUT JIT compilation.
    Uses jnp.where() for vectorized conditional logic.
    """
    return jnp.where(x % 2 == 0, x // 2, 3 * x + 1)

# VERSION 2: WITH JIT
@jax.jit
def collatz_with_jit(x):
    """
    Collatz step WITH JIT compilation.
    Same logic, but decorated with @jax.jit for optimization.
    """
    return jnp.where(x % 2 == 0, x // 2, 3 * x + 1)

# Create test arrays of different sizes
small_arr = jnp.arange(1, 1001)        # 1K elements
medium_arr = jnp.arange(1, 100001)     # 100K elements
large_arr = jnp.arange(1, 1000001)     # 1M elements

print("\nüîß Warming up JIT compiler (first call triggers compilation)...")
_ = collatz_with_jit(large_arr).block_until_ready()
print("‚úÖ JIT compilation complete! Compiled code is now cached.\n")

# Test function for timing
def benchmark_comparison(arr, label, iterations=10):
    """Benchmark both versions and compare results."""
    print(f"\nüìä {label} - {len(arr):,} elements, {iterations} iterations:")
    print("-" * 70)
    
    # Benchmark WITHOUT JIT
    start = time.time()
    for _ in range(iterations):
        result_no_jit = collatz_no_jit(arr).block_until_ready()
    time_no_jit = time.time() - start
    
    # Benchmark WITH JIT
    start = time.time()
    for _ in range(iterations):
        result_with_jit = collatz_with_jit(arr).block_until_ready()
    time_with_jit = time.time() - start
    
    # Calculate speedup
    speedup = time_no_jit / time_with_jit if time_with_jit > 0 else float('inf')
    
    # Display results side-by-side
    print(f"{'WITHOUT JIT:':20} {time_no_jit:8.6f} seconds")
    print(f"{'WITH JIT:':20} {time_with_jit:8.6f} seconds")
    print(f"{'SPEEDUP:':20} {speedup:8.2f}x faster")
    
    # Verify results match
    if jnp.allclose(result_no_jit, result_with_jit):
        print(f"‚úÖ Results match! First 5 values: {result_with_jit[:5]}")
    else:
        print(f"‚ö†Ô∏è  Warning: Results differ!")
    
    return speedup

# Run benchmarks for different array sizes
speedup_small = benchmark_comparison(small_arr, "SMALL ARRAY", iterations=100)
speedup_medium = benchmark_comparison(medium_arr, "MEDIUM ARRAY", iterations=50)
speedup_large = benchmark_comparison(large_arr, "LARGE ARRAY", iterations=10)

print("\n" + "=" * 70)
print("üìà PERFORMANCE SUMMARY")
print("=" * 70)
print(f"Small (1K):     {speedup_small:6.2f}x speedup")
print(f"Medium (100K):  {speedup_medium:6.2f}x speedup")
print(f"Large (1M):     {speedup_large:6.2f}x speedup")
print("\n‚úÖ Key Insight: JIT speedup increases with array size!")

# -----------------------------------------------------------------------------
# UNDERSTANDING ASYNCHRONOUS EXECUTION
# -----------------------------------------------------------------------------
print("\n" + "=" * 70)
print("üîÑ UNDERSTANDING .block_until_ready()")
print("=" * 70)
print("""
JAX executes operations ASYNCHRONOUSLY by default for maximum performance.

What this means:
1. JAX queues operations and returns control to Python immediately
2. Actual computation happens in the background on the accelerator
3. Without .block_until_ready(), you'd measure queue time, not compute time!

Example:
    result = collatz(arr)           # Returns instantly, not done yet!
    print(result)                   # NOW it blocks to get the value

For accurate timing, ALWAYS use .block_until_ready() after the computation:
    result = collatz(arr).block_until_ready()  # ‚úÖ Waits for completion
""")

# -----------------------------------------------------------------------------
# UNDERSTANDING JAXPR - JAX's Intermediate Representation
# -----------------------------------------------------------------------------
print("=" * 70)
print("üîç JAXPR - JAX's Intermediate Representation")
print("=" * 70)
print("""
JAXPR is like assembly language for JAX. It shows the low-level operations
that XLA compiles into machine code. Think of it as a peek under the hood!
""")

sample_arr = jnp.arange(1, 11)
print(f"\nJAXPR for collatz_with_jit with input shape {sample_arr.shape}:")
print("-" * 70)
print(jax.make_jaxpr(collatz_with_jit)(sample_arr))
print()

print("What you see:")
print("  ‚Ä¢ Input parameters (a:i32[10]) - integer array with 10 elements")
print("  ‚Ä¢ Primitive operations: mod, eq, where, floordiv, mul, add")
print("  ‚Ä¢ Data flow through the computation")
print("  ‚Ä¢ This gets sent to XLA for optimization and compilation!\n")

# -----------------------------------------------------------------------------
# EXAMPLE 2: Why Python if/else Fails with JIT
# -----------------------------------------------------------------------------
print("=" * 70)
print("‚ö†Ô∏è  DEMONSTRATION: Why Python Control Flow Breaks JIT")
print("=" * 70)

# ‚ùå INCORRECT: Using Python if/else with array values
@jax.jit
def broken_conditional(x):
    """
    This will behave INCORRECTLY when JIT-compiled!
    During tracing, JAX doesn't know x's value, only its shape/type.
    It picks ONE branch and always uses that branch.
    """
    if x > 0:  # ‚ùå Compares abstract tracer, not actual value
        return x * 2
    else:
        return x * 3

print("\n‚ùå Testing broken_conditional with Python if/else:")
try:
    result_pos = broken_conditional(jnp.array(5.0))
    print(f"   broken_conditional(5.0)  = {result_pos}")
    
    result_neg = broken_conditional(jnp.array(-5.0))
    print(f"   broken_conditional(-5.0) = {result_neg}")
    
    print(f"\n‚ö†Ô∏è  PROBLEM: Both give same result! Only one branch was compiled.")
    print(f"   Expected: 10.0 and -15.0, but got {result_pos} and {result_neg}")
except Exception as e:
    print(f"   Error: {type(e).__name__}: {e}")
    print(f"   This happens because JAX can't determine the branch during tracing!")

# ‚úÖ CORRECT: Using JAX-compatible control flow
@jax.jit
def correct_conditional(x):
    """
    Correct version using jnp.where().
    Evaluates BOTH branches and selects based on condition.
    Works with JIT because no Python control flow is needed!
    """
    return jnp.where(x > 0, x * 2, x * 3)

print("\n‚úÖ Testing correct_conditional with jnp.where():")
result_pos = correct_conditional(jnp.array(5.0))
result_neg = correct_conditional(jnp.array(-5.0))
print(f"   correct_conditional(5.0)  = {result_pos}  ‚úì")
print(f"   correct_conditional(-5.0) = {result_neg} ‚úì")
print(f"   Both results are CORRECT!")

# Side-by-side comparison
print("\n" + "=" * 70)
print("COMPARISON: if/else vs jnp.where()")
print("=" * 70)
test_values = jnp.array([5.0, -3.0, 2.0, -8.0, 0.0])
correct_results = correct_conditional(test_values)
print(f"Input values:     {test_values}")
print(f"jnp.where() results: {correct_results}")
print(f"Expected (x>0 ? 2x : 3x): [10.0, -9.0, 4.0, -24.0, 0.0]")
print(f"Match: {jnp.allclose(correct_results, jnp.array([10.0, -9.0, 4.0, -24.0, 0.0]))}")

# -----------------------------------------------------------------------------
# EXAMPLE 3: Side Effects in JIT - Print Statements
# -----------------------------------------------------------------------------
print("\n" + "=" * 70)
print("üñ®Ô∏è  DEMONSTRATION: Side Effects Only Happen During Tracing")
print("=" * 70)

@jax.jit
def function_with_print(x):
    """
    Print statements only execute during FIRST call (tracing phase).
    Subsequent calls use cached compiled code without prints!
    """
    print(f"   üîç TRACING: Inside function with x = {x}")
    return x * 2

print("\n1Ô∏è‚É£  First call (triggers tracing and compilation):")
result1 = function_with_print(jnp.array(10.0))
print(f"   Result: {result1}\n")

print("2Ô∏è‚É£  Second call (uses cached compiled version):")
result2 = function_with_print(jnp.array(20.0))
print(f"   Result: {result2}")
print(f"   üëÜ Notice: The print INSIDE the function didn't execute!\n")

print("3Ô∏è‚É£  Third call (still using cached version):")
result3 = function_with_print(jnp.array(30.0))
print(f"   Result: {result3}")
print(f"   üëÜ Still no print - using cached compilation\n")

print("üí° Key Point: Side effects (print, global vars, I/O) only happen once!")

# -----------------------------------------------------------------------------
# EXAMPLE 4: When NOT to Use JIT - Compilation Overhead
# -----------------------------------------------------------------------------
print("\n" + "=" * 70)
print("‚öñÔ∏è  DEMONSTRATION: JIT Overhead vs Benefit")
print("=" * 70)

def simple_add_no_jit(x):
    """Simple addition without JIT."""
    return x + 1

@jax.jit
def simple_add_with_jit(x):
    """Simple addition with JIT."""
    return x + 1

# SMALL COMPUTATION - JIT overhead dominates
print("\nüìâ TINY ARRAYS (3 elements, 1000 iterations):")
small_arr = jnp.array([1.0, 2.0, 3.0])

start = time.time()
for _ in range(1000):
    _ = simple_add_no_jit(small_arr)
time_no_jit = time.time() - start

_ = simple_add_with_jit(small_arr).block_until_ready()  # Warm up

start = time.time()
for _ in range(1000):
    _ = simple_add_with_jit(small_arr).block_until_ready()
time_with_jit = time.time() - start

print(f"  WITHOUT JIT: {time_no_jit:.6f} seconds")
print(f"  WITH JIT:    {time_with_jit:.6f} seconds")
speedup = time_no_jit / time_with_jit
print(f"  Speedup:     {speedup:.2f}x")
if speedup < 1.5:
    print(f"  ‚ö†Ô∏è  JIT overhead isn't worth it for tiny computations!")

# LARGE COMPUTATION - JIT benefit is clear
print("\nüìà LARGE ARRAYS (1M elements, 100 iterations):")
large_arr = jnp.arange(1000000.0)

start = time.time()
for _ in range(100):
    _ = simple_add_no_jit(large_arr)
time_no_jit = time.time() - start

_ = simple_add_with_jit(large_arr).block_until_ready()  # Warm up

start = time.time()
for _ in range(100):
    _ = simple_add_with_jit(large_arr).block_until_ready()
time_with_jit = time.time() - start

print(f"  WITHOUT JIT: {time_no_jit:.6f} seconds")
print(f"  WITH JIT:    {time_with_jit:.6f} seconds")
print(f"  Speedup:     {time_no_jit/time_with_jit:.2f}x")
print(f"  ‚úÖ JIT provides significant benefit for large computations!")

# -----------------------------------------------------------------------------
# FINAL SUMMARY
# -----------------------------------------------------------------------------
print("\n" + "=" * 70)
print("üéØ KEY TAKEAWAYS - JIT COMPILATION")
print("=" * 70)
print("""
1. ‚úÖ JIT provides 2x-100x speedup for large numerical computations
2. ‚úÖ Use jnp.where() for conditionals, NOT Python if/else with array values
3. ‚úÖ Pure functions work best (no side effects like print, globals, I/O)
4. ‚úÖ Warm up JIT before timing (first call includes compilation overhead)
5. ‚úÖ Use .block_until_ready() for accurate timing (JAX is async by default)
6. ‚úÖ JIT benefit increases with array size and computation complexity
7. ‚úÖ Inspect with jax.make_jaxpr() to see the compiled representation
8. ‚ùå Python control flow (if/else) that depends on array VALUES breaks JIT
9. ‚ùå Side effects (print, globals, I/O) only happen during tracing
10. ‚ùå Don't JIT tiny functions - compilation overhead isn't worth it

üí° Best Use Cases: Matrix operations, neural networks, scientific simulations,
   repeated computations on large arrays
""")

PERFORMANCE COMPARISON: WITH JIT vs WITHOUT JIT

üîß Warming up JIT compiler (first call triggers compilation)...
‚úÖ JIT compilation complete! Compiled code is now cached.


üìä SMALL ARRAY - 1,000 elements, 100 iterations:
----------------------------------------------------------------------
WITHOUT JIT:         0.157027 seconds
WITH JIT:            0.057830 seconds
SPEEDUP:                 2.72x faster
‚úÖ Results match! First 5 values: [ 4  1 10  2 16]

üìä MEDIUM ARRAY - 100,000 elements, 50 iterations:
----------------------------------------------------------------------
WITHOUT JIT:         0.175127 seconds
WITH JIT:            0.037560 seconds
SPEEDUP:                 4.66x faster
‚úÖ Results match! First 5 values: [ 4  1 10  2 16]

üìä LARGE ARRAY - 1,000,000 elements, 10 iterations:
----------------------------------------------------------------------
WITHOUT JIT:         0.200553 seconds
WITH JIT:            0.013129 seconds
SPEEDUP:                15.28x faster
‚úÖ R

## Quick Reference: JIT-Compatible Control Flow

When you need conditionals in JIT-compiled functions, use these JAX operations:

| Scenario | ‚ùå Don't Use | ‚úÖ Use Instead |
|----------|--------------|----------------|
| Element-wise conditional | `if x > 0: ...` | `jnp.where(x > 0, true_val, false_val)` |
| Scalar conditional | `if x > 0: ...` | `jax.lax.cond(x > 0, true_fn, false_fn, operand)` |
| Multiple branches | `if/elif/else` | `jax.lax.switch(index, branches, operand)` |
| Dynamic loops | `for i in range(int(x)): ...` | `jax.lax.fori_loop(start, end, body_fn, init)` |
| While loops | `while condition: ...` | `jax.lax.while_loop(cond_fn, body_fn, init)` |
| Array updates | `arr[i] = val` | `arr.at[i].set(val)` |

**Why?** During JIT tracing, JAX works with abstract values (shapes/types), not actual data. Python control flow needs concrete values, which aren't available during tracing. JAX's control flow ops are designed to work with abstract values!