In [1]:
# J - JIT compilation - Just In Time
# A - Automatic differentiation
# X - XLA (Accelerated linear algebra)

# JAX as NumPy

import jax
import jax.numpy as jnp

a = jnp.array([1.0, 2.0, 3.0])
b = jnp.array([4.0, 5.0, 6.0])

print("Array a:", a)
print("Array b:", b)
print("Sum of a and b:", a + b)
print("Dot product of a and b:", jnp.dot(a, b))
print("Element-wise multiplication of a and b:", a * b)
print("Sine of a:", jnp.sin(a))
print("Exponential of b:", jnp.exp(b))
print("Mean of a:", jnp.mean(a))
print("Standard deviation of b:", jnp.std(b))
print("Reshaped a to (3,1):", a.reshape((3, 1)))
print("Transpose of b reshaped to (3,1):", b.reshape((3, 1)).T)
print("Stacked arrays a and b vertically:\n", jnp.vstack((a, b)))
print("Stacked arrays a and b horizontally:\n", jnp.hstack((a, b)))
print("Concatenated arrays a and b:", jnp.concatenate((a, b)))
print("Maximum value in a:", jnp.max(a))
print("Minimum value in b:", jnp.min(b))    
print("Sum of all elements in a:", jnp.sum(a))
print("Cumulative sum of a:", jnp.cumsum(a))
print("Unique elements in b:", jnp.unique(b))
print("Sorted a:", jnp.sort(a))
print("Where a > 2:", jnp.where(a > 2))
import numpy as np
print("Convert JAX array a to NumPy array:", np.array(a))   
print("Convert NumPy array back to JAX array:", jnp.array(np.array(a)))
# This script demonstrates basic usage of JAX as a NumPy replacement.
# It covers array creation, arithmetic operations, mathematical functions,
# statistical functions, reshaping, stacking, concatenation, and conversion
# between JAX arrays and NumPy arrays.
# JAX is designed for high-performance numerical computing and can
# leverage GPU/TPU acceleration, automatic differentiation, and JIT compilation.
# It provides a NumPy-like API for ease of use.
# JAX arrays are immutable, meaning that operations on them return new arrays
# rather than modifying the original arrays in place.
# Example: a[1] = 10.0  (This will raise an error)

Array a: [1. 2. 3.]
Array b: [4. 5. 6.]
Sum of a and b: [5. 7. 9.]
Dot product of a and b: 32.0
Element-wise multiplication of a and b: [ 4. 10. 18.]
Sine of a: [0.84147096 0.9092974  0.14112   ]
Exponential of b: [ 54.598152 148.41316  403.4288  ]
Mean of a: 2.0
Standard deviation of b: 0.8164966
Reshaped a to (3,1): [[1.]
 [2.]
 [3.]]
Transpose of b reshaped to (3,1): [[4. 5. 6.]]
Stacked arrays a and b vertically:
 [[1. 2. 3.]
 [4. 5. 6.]]
Stacked arrays a and b horizontally:
 [1. 2. 3. 4. 5. 6.]
Concatenated arrays a and b: [1. 2. 3. 4. 5. 6.]
Maximum value in a: 3.0
Minimum value in b: 4.0
Sum of all elements in a: 6.0
Cumulative sum of a: [1. 3. 6.]
Unique elements in b: [4. 5. 6.]
Sorted a: [1. 2. 3.]
Where a > 2: (Array([2], dtype=int32),)
Convert JAX array a to NumPy array: [1. 2. 3.]
Convert NumPy array back to JAX array: [1. 2. 3.]


# JIT Compilation in JAX - Deep Dive

## What is JIT Compilation?

**JIT (Just-In-Time) compilation** is JAX's way of optimizing your code for maximum performance. When you decorate a function with `@jax.jit`, JAX:
1. **Traces** your function with abstract values to understand its structure
2. **Compiles** it to highly optimized machine code using XLA (Accelerated Linear Algebra)
3. **Caches** the compiled version for reuse
4. **Executes** the optimized code on subsequent calls

Think of it like this: regular Python executes line-by-line (interpreted), while JIT-compiled code is translated into super-fast machine instructions that run directly on your hardware (CPU/GPU/TPU).

## Why Use JIT?

- **Speed**: 10x-100x faster for numerical computations
- **Hardware acceleration**: Automatically leverages GPUs/TPUs
- **Optimization**: XLA fuses operations and eliminates redundant computations
- **Parallelization**: Automatically parallelizes independent operations

## When JIT Compilation FAILS or Behaves Unexpectedly

### ‚ö†Ô∏è 1. DATA-DEPENDENT CONTROL FLOW (The Key Issue!)

**This is what your tutorial was referring to!** You CANNOT use regular Python `if/else` statements that depend on **array values**:

```python
# ‚ùå THIS WILL FAIL OR BEHAVE INCORRECTLY:
@jax.jit
def bad_function(x):
    if x > 0:  # ‚ùå Control flow depends on the VALUE of x
        return x * 2
    else:
        return x * 3
```

**Why?** During tracing, JAX doesn't know the actual value of `x` - it only knows its shape and type. So it can't decide which branch to take!

**Solution:** Use JAX's special control flow operations:
- `jnp.where(condition, true_val, false_val)` - for element-wise conditionals
- `jax.lax.cond(pred, true_fun, false_fun, operand)` - for scalar conditionals
- `jax.lax.switch()` - for multiple branches
- `jax.lax.select()` - for choosing between values

```python
# ‚úÖ THIS WORKS:
@jax.jit
def good_function(x):
    return jnp.where(x > 0, x * 2, x * 3)  # ‚úÖ JAX-compatible conditional
```

### ‚ö†Ô∏è 2. PYTHON SIDE EFFECTS

Side effects are operations that modify state outside the function or interact with the external world:

```python
# ‚ùå THESE DON'T WORK AS EXPECTED IN JIT:
@jax.jit
def has_side_effects(x):
    print(f"Value is {x}")  # ‚ùå Print only happens during tracing!
    global counter
    counter += 1  # ‚ùå Modifying global state
    my_list.append(x)  # ‚ùå Modifying external data structures
    return x * 2
```

**Why?** JIT traces the function ONCE, then caches the compiled version. Side effects only execute during tracing, not during every call!

**What happens:**
- `print()` statements execute only the first time (during tracing)
- Global variables are captured at trace time, not updated during execution
- File I/O, database calls, etc. won't work as expected

### ‚ö†Ô∏è 3. DATA-DEPENDENT LOOPS

```python
# ‚ùå THIS FAILS:
@jax.jit
def bad_loop(x):
    for i in range(int(x)):  # ‚ùå Loop count depends on x's value
        x = x + 1
    return x
```

**Solution:** Use `jax.lax.fori_loop()` or `jax.lax.while_loop()` for dynamic loops.

### ‚ö†Ô∏è 4. SHAPE-CHANGING OPERATIONS

```python
# ‚ùå THIS FAILS:
@jax.jit
def dynamic_shape(x):
    if x.sum() > 0:
        return x[:10]  # Shape changes based on condition
    return x
```

JIT requires shapes to be known at compile time. Dynamic shapes break this requirement.

### ‚ö†Ô∏è 5. IN-PLACE MUTATIONS

```python
# ‚ùå THIS FAILS:
@jax.jit
def mutate_array(x):
    x[0] = 10  # ‚ùå JAX arrays are immutable!
    return x
```

**Solution:** Use `.at[].set()` syntax:
```python
# ‚úÖ THIS WORKS:
@jax.jit
def update_array(x):
    return x.at[0].set(10)  # ‚úÖ Returns new array
```

## When to Use JIT

‚úÖ **USE JIT FOR:**
- Pure functions (no side effects)
- Functions operating on JAX arrays
- Numerical computations (matrix operations, neural networks, simulations)
- Functions called repeatedly with similar input shapes
- Performance-critical code

‚ùå **DON'T USE JIT FOR:**
- Functions with print/debug statements
- Code that modifies global state
- Functions with Python control flow depending on array values
- Small, one-off computations (compilation overhead > benefit)
- Code interacting with external systems (files, databases, APIs)

## Best Practices

1. **Keep functions pure**: Input ‚Üí Output, no side effects
2. **Use JAX control flow**: `jnp.where()`, `jax.lax.cond()`, etc.
3. **Warm up the JIT**: Run once before timing to avoid compilation overhead
4. **Use `.block_until_ready()`**: JAX executes asynchronously by default
5. **Inspect with `jax.make_jaxpr()`**: See the intermediate representation
6. **Static arguments**: Use `static_argnums` for non-array arguments

---

In [10]:
# =============================================================================
# JIT COMPILATION - PRACTICAL EXAMPLES
# =============================================================================

import time
import jax
import jax.numpy as jnp

# -----------------------------------------------------------------------------
# EXAMPLE 1: Basic JIT Compilation with Collatz Conjecture
# -----------------------------------------------------------------------------
# The Collatz conjecture: Take any positive integer n. If n is even, divide it
# by 2. If n is odd, multiply it by 3 and add 1. Repeat the process.
# We use jnp.where() instead of if/else because we need JAX-compatible conditionals!

@jax.jit
def collatz(x):
    """
    Compute one step of the Collatz sequence.
    
    Args:
        x: JAX array of integers
    
    Returns:
        Next value(s) in the Collatz sequence
    
    Note: Uses jnp.where() for vectorized conditional logic instead of if/else
    This allows JIT compilation to work correctly!
    """
    return jnp.where(x % 2 == 0, x // 2, 3 * x + 1)

# Create a large array to see the performance benefit
arr = jnp.arange(1, 1000001)

# IMPORTANT: Warm up the JIT compiler
# First call triggers compilation (slow), subsequent calls use cached version (fast)
print("Warming up JIT compiler...")
_ = collatz(arr).block_until_ready()
print("JIT compilation complete! Compiled code is now cached.\n")

# Now measure the actual execution time (excluding compilation)
start = time.time()
result = collatz(arr).block_until_ready()
end = time.time()
print(f"Time taken for JIT-compiled Collatz computation: {end - start:.6f} seconds")
print(f"First 10 results: {result[:10]}")
print(f"Array size: {len(result):,} elements\n")

# -----------------------------------------------------------------------------
# UNDERSTANDING ASYNCHRONOUS EXECUTION
# -----------------------------------------------------------------------------
# JAX executes operations asynchronously by default to maximize performance.
# This means JAX queues operations and returns control to Python immediately,
# while the actual computation happens in the background on the accelerator.
#
# .block_until_ready() forces Python to wait until the computation completes.
# Without it, you'd measure queue time, not actual computation time!
#
# Example without block_until_ready():
# result = collatz(arr)  # Returns immediately, computation not done yet!
# print(result)  # NOW it blocks to print, but timing would be wrong

# -----------------------------------------------------------------------------
# UNDERSTANDING JAXPR - JAX's Intermediate Representation
# -----------------------------------------------------------------------------
# JAXPR is like assembly language for JAX - it shows the low-level operations
# that XLA will compile and optimize. It's useful for understanding what JIT
# actually does with your code.

print("=" * 70)
print("JAXPR (JAX's intermediate representation) for the Collatz function:")
print("=" * 70)
print(jax.make_jaxpr(collatz)(arr))
print()

# The JAXPR shows:
# - Input parameters and their shapes
# - Primitive operations (mod, eq, where, floordiv, mul, add)
# - How data flows through the computation
# This is what gets sent to XLA for compilation into machine code!

# -----------------------------------------------------------------------------
# EXAMPLE 2: Demonstrating Why Python if/else Fails
# -----------------------------------------------------------------------------
print("=" * 70)
print("DEMONSTRATION: Why Python control flow breaks JIT")
print("=" * 70)

# ‚ùå THIS WILL TRACE INCORRECTLY - only one branch gets compiled!
@jax.jit
def broken_conditional(x):
    """
    This function will behave incorrectly when JIT-compiled!
    During tracing, JAX doesn't know x's value, so it picks ONE branch
    (usually the first one it encounters) and always uses that.
    """
    if x > 0:  # This evaluates to an ARRAY comparison, not a simple True/False
        return x * 2
    else:
        return x * 3

# Test with different values - you'll see unexpected behavior!
try:
    # During tracing, JAX treats (x > 0) as a tracer, not a boolean
    # This may work but will give wrong results or raise a ConcretizationError
    test_val = jnp.array(5.0)
    result_pos = broken_conditional(test_val)
    print(f"broken_conditional(5.0) = {result_pos}")  # Might work once
    
    test_val = jnp.array(-5.0)
    result_neg = broken_conditional(test_val)
    print(f"broken_conditional(-5.0) = {result_neg}")  # Will give WRONG answer!
    print("‚ö†Ô∏è  Notice: Both might return the same result because only one branch was traced!\n")
except Exception as e:
    print(f"Error: {e}")
    print("This happens because JAX can't determine which branch to take during tracing!\n")

# ‚úÖ CORRECT VERSION using jnp.where()
@jax.jit
def correct_conditional(x):
    """
    Correct version using JAX-compatible control flow.
    jnp.where() evaluates BOTH branches and selects based on condition.
    This works with JIT because no Python control flow is needed!
    """
    return jnp.where(x > 0, x * 2, x * 3)

# Test the correct version
print("‚úÖ Correct version using jnp.where():")
result_pos = correct_conditional(jnp.array(5.0))
result_neg = correct_conditional(jnp.array(-5.0))
print(f"correct_conditional(5.0) = {result_pos}")
print(f"correct_conditional(-5.0) = {result_neg}")
print("Both results are correct!\n")

# -----------------------------------------------------------------------------
# EXAMPLE 3: Side Effects in JIT - Print Statements
# -----------------------------------------------------------------------------
print("=" * 70)
print("DEMONSTRATION: Side effects only happen during tracing")
print("=" * 70)

@jax.jit
def function_with_print(x):
    """
    Print statements only execute during the FIRST call (tracing phase).
    Subsequent calls use the cached compiled code, which doesn't include prints!
    """
    print(f"üîç TRACING: Inside function with x = {x}")
    return x * 2

print("First call (triggers tracing and compilation):")
result1 = function_with_print(jnp.array(10.0))
print(f"Result: {result1}\n")

print("Second call (uses cached compiled version):")
result2 = function_with_print(jnp.array(20.0))
print(f"Result: {result2}")
print("üëÜ Notice: The print inside the function didn't execute!\n")

print("Third call with same shape:")
result3 = function_with_print(jnp.array(30.0))
print(f"Result: {result3}")
print("üëÜ Still no print - using cached compilation\n")

# -----------------------------------------------------------------------------
# EXAMPLE 4: When NOT to use JIT (Compilation Overhead)
# -----------------------------------------------------------------------------
print("=" * 70)
print("DEMONSTRATION: JIT overhead vs benefit")
print("=" * 70)

def simple_add_no_jit(x):
    return x + 1

@jax.jit
def simple_add_with_jit(x):
    return x + 1

# Small array - JIT overhead dominates
small_arr = jnp.array([1.0, 2.0, 3.0])

start = time.time()
for _ in range(1000):
    _ = simple_add_no_jit(small_arr)
time_no_jit = time.time() - start

# Warm up JIT
_ = simple_add_with_jit(small_arr).block_until_ready()

start = time.time()
for _ in range(1000):
    _ = simple_add_with_jit(small_arr).block_until_ready()
time_with_jit = time.time() - start

print(f"Small array (3 elements), 1000 iterations:")
print(f"  Without JIT: {time_no_jit:.6f} seconds")
print(f"  With JIT:    {time_with_jit:.6f} seconds")
print(f"  Speedup:     {time_no_jit/time_with_jit:.2f}x")
print()

# Large array - JIT benefit is clear
large_arr = jnp.arange(1000000.0)

start = time.time()
for _ in range(100):
    _ = simple_add_no_jit(large_arr)
time_no_jit = time.time() - start

# Warm up JIT
_ = simple_add_with_jit(large_arr).block_until_ready()

start = time.time()
for _ in range(100):
    _ = simple_add_with_jit(large_arr).block_until_ready()
time_with_jit = time.time() - start

print(f"Large array (1M elements), 100 iterations:")
print(f"  Without JIT: {time_no_jit:.6f} seconds")
print(f"  With JIT:    {time_with_jit:.6f} seconds")
print(f"  Speedup:     {time_no_jit/time_with_jit:.2f}x")
print("\n‚úÖ Key Takeaway: JIT is beneficial for large computations, not tiny ones!")

# -----------------------------------------------------------------------------
# SUMMARY OF JIT COMPILATION
# -----------------------------------------------------------------------------
print("\n" + "=" * 70)
print("KEY TAKEAWAYS")
print("=" * 70)
print("""
1. ‚úÖ USE jnp.where() for conditionals, NOT Python if/else with array values
2. ‚úÖ Pure functions work best (no side effects like print, global variables)
3. ‚úÖ Warm up JIT before timing (first call includes compilation overhead)
4. ‚úÖ Use .block_until_ready() for accurate timing (JAX is async by default)
5. ‚úÖ JIT is best for large computations called repeatedly
6. ‚úÖ Inspect JAXPR with jax.make_jaxpr() to understand what gets compiled
7. ‚ùå Avoid Python control flow that depends on array VALUES
8. ‚ùå Side effects (print, globals, I/O) only happen during tracing
9. ‚ùå Don't JIT tiny functions - compilation overhead isn't worth it
""")

Warming up JIT compiler...
JIT compilation complete! Compiled code is now cached.

Time taken for JIT-compiled Collatz computation: 0.001199 seconds
First 10 results: [ 4  1 10  2 16  3 22  4 28  5]
Array size: 1,000,000 elements

JAXPR (JAX's intermediate representation) for the Collatz function:
let _where = { [34;1mlambda [39;22m; a[35m:bool[1000000][39m b[35m:i32[1000000][39m c[35m:i32[1000000][39m. [34;1mlet
    [39;22md[35m:i32[1000000][39m = select_n a c b
  [34;1min [39;22m(d,) } in
{ [34;1mlambda [39;22m; e[35m:i32[1000000][39m. [34;1mlet
    [39;22mf[35m:i32[1000000][39m = pjit[
      name=collatz
      jaxpr={ [34;1mlambda [39;22m; e[35m:i32[1000000][39m. [34;1mlet
          [39;22mg[35m:i32[1000000][39m = pjit[
            name=remainder
            jaxpr={ [34;1mlambda [39;22m; e[35m:i32[1000000][39m h[35m:i32[][39m. [34;1mlet
                [39;22mi[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] h
           

## Quick Reference: JIT-Compatible Control Flow

When you need conditionals in JIT-compiled functions, use these JAX operations:

| Scenario | ‚ùå Don't Use | ‚úÖ Use Instead |
|----------|--------------|----------------|
| Element-wise conditional | `if x > 0: ...` | `jnp.where(x > 0, true_val, false_val)` |
| Scalar conditional | `if x > 0: ...` | `jax.lax.cond(x > 0, true_fn, false_fn, operand)` |
| Multiple branches | `if/elif/else` | `jax.lax.switch(index, branches, operand)` |
| Dynamic loops | `for i in range(int(x)): ...` | `jax.lax.fori_loop(start, end, body_fn, init)` |
| While loops | `while condition: ...` | `jax.lax.while_loop(cond_fn, body_fn, init)` |
| Array updates | `arr[i] = val` | `arr.at[i].set(val)` |

**Why?** During JIT tracing, JAX works with abstract values (shapes/types), not actual data. Python control flow needs concrete values, which aren't available during tracing. JAX's control flow ops are designed to work with abstract values!