This notebook covers:
- The relationship between JAX and NumPy
- How JAX arrays behave compared to NumPy arrays
- The concept of immutability in JAX
- Basic JAX operations and differences from standard NumPy


#### Jax is looks like Numpy


In [None]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import random

# JAX arrays look like Numpy arrays
x_np = np.arange(10)
x_jax = jnp.arange(10)

print(f"Numpy: {type(x_np)}")
print(f"JAX:   {type(x_jax)}")

# They operate identically for math
print(np.sin(x_np))
print(jnp.sin(x_jax))

#### JAX is immutable

In [None]:
# --- STANDARD NUMPY (Mutable) ---
x_np[0] = 99
print(f"Numpy mutation works: {x_np}")

# --- JAX (Immutable) ---
try:
    x_jax[0] = 99 # This raises a TypeError
except TypeError as e:
    print(f"\nError caught: {e}")

# The JAX Way: Create a NEW array with the update
y_jax = x_jax.at[0].set(99)

print(f"Original JAX array (unchanged): {x_jax}")
print(f"New JAX array (updated):        {y_jax}")

#### JAX is immutable

In [None]:
print("Numpy is mutable:")
numpy_array = np.array([10, 20])
numpy_array_new = numpy_array
numpy_array_new += 10
print(numpy_array_new)  # `numpy_array_new is numpy_array`, and it was updated
print(numpy_array)      # in-place, so both are [20, 30] !


print("\nJAX is immutable:")
jax_array = jnp.array([10, 20])
jax_array_new = jax_array
jax_array_new += 10
print(jax_array_new)  # `jax_array_new` is rebound to a new value [20, 30], but...
print(jax_array)      # the original value is unmodified as [10, 20] !


#### Random numbers in JAX

In [None]:
seed = 0
key = jax.random.PRNGKey(seed)

print(f"Initial Key: {key}")

# Generating random numbers
# Notice: If we re-use 'key', we get the SAME numbers.
r1 = jax.random.normal(key, shape=(3,))
r2 = jax.random.normal(key, shape=(3,))

print(f"Random 1: {r1}")
print(f"Random 2: {r2}")
print("They are identical! This ensures absolute reproducibility.")

#### Split the key

In [None]:
# Split the key into:
# 1. A key to use NOW (sub_key)
# 2. A key to save for LATER (new_key)
key, sub_key = jax.random.split(key)

r3 = jax.random.normal(sub_key, shape=(3,))
print(f"New Random 3: {r3}")
print("This is finally different.")

#### Vectorized random numbers

In [None]:
key = random.key(42)
subkeys = random.split(key, 3)
sequence = np.stack([random.normal(subkey) for subkey in subkeys])
print("individually:", sequence)

key = random.key(42)
print("all at once: ", random.normal(key, shape=(3,)))

In [None]:
print("vectorized:", jax.vmap(random.normal)(subkeys))

#### JIT compilation

In [None]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)

x_large = jax.random.normal(key, (1_000_000,))

In [None]:
# 1. Python/JAX (Interpreted Mode)
# It runs operation-by-operation (slow dispatch).
interp_time = %timeit -o selu(x_large).block_until_ready()

# 2. JIT Compiled
# Compiles the whole function into a single XLA kernel.
selu_jit = jax.jit(selu)

# Warm-up call (compilation happens here)
selu_jit(x_large).block_until_ready()

# Benchmark the compiled version
jit_time = %timeit -o selu_jit(x_large).block_until_ready()

print(f"\nJIT is {interp_time.average / jit_time.average:.1f}x faster than interpreted mode.")

### ⚡ Crucial Concept: Asynchronous Dispatch

**Why did we use `.block_until_ready()`?**

If you are coming from NumPy, this behavior might be surprising. NumPy is **Synchronous**: when you calculate `np.dot(a, b)`, Python waits until the math is finished before moving to the next line.

JAX is **Asynchronous**.
1.  **The Host (Python):** Passes the instruction ("compute dot product") to the accelerator (GPU/TPU).
2.  **The Future:** JAX immediately returns a placeholder (a "Future") to Python, allowing your script to continue running *while* the GPU works in the background.
3.  **The Benefit:** This allows Python to queue up thousands of operations without waiting for the GPU to finish each one, maximizing efficiency.

#### ☕ The Coffee Shop Analogy
* **Synchronous (NumPy):** You order coffee. You stand at the register and wait for the barista to brew it. You don't move until you have the cup. (Slow flow).
* **Asynchronous (JAX):** You order coffee. The cashier hands you a receipt immediately. You go sit down and check your email while the barista brews. (Fast flow).
* **`block_until_ready()`:** This is like the Health Inspector standing there with a stopwatch. They force you to wait until the coffee is physically on the counter to verify exactly how long the brewing took.

**⚠️ Benchmarking Warning:**
If you run `%timeit jnp.dot(a, b)` *without* blocking, you are only measuring how fast Python can "place the order" (microseconds), not how long the math takes (milliseconds).
