In [None]:
import jax.numpy as jnp
import numpy as np

from jax import grad, jit, vmap, pmap

from jax import lax  # JAX's low level API, just anagram for XLA

from jax import make_jaxpr
from jax import random
from jax import device_put
import matplotlib.pyplot as plt

In [None]:
x_np = np.linspace(0, 10, 1000)
y_np = 2 * np.sin(x_np) * np.cos(x_np)
plt.plot(x_np, y_np)

In [None]:
x_jnp = jnp.linspace(0, 10, 1000)
y_jnp = 2 * jnp.sin(x_jnp) * jnp.cos(x_jnp)
plt.plot(x_jnp, y_jnp)

In [None]:
# TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method:
size = 10
index = 0
value = 23

x = np.arange(size)
print(x)
x[index] = value
print(x)

In [None]:
size = 10
index = 0
value = 23

x = jnp.arange(size)  # immutable arrays
print(x)
x[index] = value
print(x)

In [None]:
y = x.at[index].set(value)
print(x)
print(y)

In [None]:
seed = 0
key = random.PRNGKey(seed) # Create a legacy PRNG key given an integer seed.

x = random.normal(key, (10,))
print(type(x), x)

In [None]:
seed = 0
key = random.key(seed) # It is recommended for use instead.

x = random.normal(key, (10,))
print(type(x), x)

In [None]:
size = 3000

# Data is automagically pushed to the AI accelerator! (DeviceArray structure)
# No more need for ".to(device)" (PyTorch syntax)
x_jnp = random.normal(key, (size, size), dtype=jnp.float32)
x_np = np.random.normal(size=(size, size)).astype(np.float32)

%timeit jnp.dot(x_jnp, x_jnp.T).block_until_ready() # GPU or TPU
%timeit np.dot(x_np, x_np.T) # CPU
%timeit jnp.dot(x_np, x_np.T).block_until_ready() # GPU or TPU with transfer overhead

x_np_device = device_put(x_np) # numpu to GPU
%timeit jnp.dot(x_np_device, x_np_device.T).block_until_ready() # GPU

# block_until_ready() -> asynchronous dispatch

# jit()

jit compiles your functions using XLA and caches them -> speeeeed 🚀

In [None]:
def visualize_fn(fn, l=-10, r=10, n=1000):
    x = np.linspace(l, r, num=n)
    y = fn(x)
    plt.plot(x, y); plt.show()

In [None]:
def selu(x, alpha=1.67, lmbda=1.05):
    return lmbda * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)


silu_jit = jit(selu)

visualize_fn(silu_jit)

data = random.normal(key, (1_000_000,))

print("non-jit version:")
%timeit selu(data).block_until_ready()

print("jit version:")
%timeit silu_jit(data).block_until_ready()

# grad()

Differentation can be:

- manual
- symbolic
- numeric
- automatic! ❤️

In [None]:
# automatic!

def sum_log(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))

x = jnp.arange(3.)
loss = sum_log

# By default grad calculates the derivative of a fn w.r.t. 1st parameter!
# Here we bundled inputs into a 1st param so it doesn't matter.
grad_loss = grad(loss)
print(x)
print(grad_loss(x))

In [None]:
# Numeric diff (to double check that autodiff works correctly)
# A finite difference is a mathematical expression of the form f(x + b) − f(x + a). Finite differences are often used as approximations of derivatives
def finite_differences(f, x):
    eps = 1e-3
    return jnp.array([(f(x + eps * v) - f(x - eps * v)) / (2 * eps) for v in jnp.eye(len(x))])

print(finite_differences(loss, x))

In [None]:
x = 1.

f = lambda x: x**2 + x + 4
visualize_fn(f, l=-1, r=2, n=100)

dfdx = grad(f) # 2*x + 1
d2fdx2 = grad(dfdx) # 2
d3fdx3 = grad(d2fdx2) # 0

print(f"f(x) = {f(x)} -> ", f"dfdx(x) = {dfdx(x)} -> ", f"d2fdx2(x) = {d2fdx2(x)} ->", f"d3fdx3(x) = {d3fdx3(x)}")

In [None]:
# JAX autodiff engine is very powerful ("advanced" example)

from jax import jacfwd, jacrev

f = lambda x, y: x**2 + y**2

# df/dx = 2x
# df/dy = 2y
# J = [df/dx, df/dy]

# d2f/dx2 = 2
# d2f/dy2 = 2
# d2f/dxdy = 0
# d2f/dydx = 0
# H = [[d2f/dx, d2f/dxdy], [d2f/dydx, d2f/dy]]

def hessian(f):
    return jit(jacfwd(jacrev(f, argnums=(0, 1)), argnums=(0, 1)))

print(f"Jacobian = {jacrev(f, argnums=(0, 1))(1., 1.)}")
print(f"Full Hessian = {hessian(f)(1., 1.)}")

In [None]:
# Edge case |x|, how does JAX handle it?

f = lambda x: abs(x)
visualize_fn(f)

print(f"f(-1) = {f(-1)}, f(1) = {f(1)}")
dfdx = grad(f)
print(f"dfdx(0.) = {dfdx(0.)}")
print(f"dfdx(0.001) = {dfdx(0.001)}")

# vmap() 101

Write your functions as if you were dealing with a single datapoint!

In [None]:
W = random.normal(key, (150, 100))
batch_x = random.normal(key, (10, 100))

def apply_matrix(x):
    return jnp.dot(x, W.T) # (10, 100) @ (150, 100) -> (150, 10)

apply_matrix(batch_x)

In [None]:
def naively_batched_apply_matrix(batched_x):
    return jnp.stack([apply_matrix(x) for x in batched_x])

print("Naively batched")
%timeit naively_batched_apply_matrix(batch_x).block_until_ready()

In [None]:
@jit
def batched_apply_matrix(batched_x):
    return jnp.dot(x, W.T)

print("Manually batched")
%timeit batched_apply_matrix(batch_x).block_until_ready()

In [None]:
def apply_matrix(x):
    return jnp.dot(W, x)

@jit  # Note: we can arbitrarily compose JAX transforms! Here jit + vmap.
def vmap_batched_apply_matrix(batched_x):
    return vmap(apply_matrix)(batched_x)

print("Auto-vectorized")
%timeit vmap_batched_apply_matrix(batch_x).block_until_ready()

In [None]:
vmap_batched_apply_matrix(batch_x).block_until_ready().shape

In [None]:
def apply_matrix(x):
    return jnp.dot(W, x)

@jit  # Note: we can arbitrarily compose JAX transforms! Here jit + vmap.
def vmap_batched_apply_matrix(batched_x):
    return vmap(apply_matrix, in_axes=(0), out_axes=(0))(batched_x)

%timeit vmap_batched_apply_matrix(batch_x).block_until_ready()

In [None]:
vmap_batched_apply_matrix(batch_x).block_until_ready().shape

In [None]:
def apply_matrix(x):
    return jnp.dot(W, x)

@jit  # Note: we can arbitrarily compose JAX transforms! Here jit + vmap.
def vmap_batched_apply_matrix(batched_x):
    return vmap(apply_matrix)(batched_x)

vmap_batched_apply_matrix(batch_x).block_until_ready().shape

In [None]:
jnp.expand_dims(batch_x, 0).shape

In [None]:
def apply_matrix(x):
    return jnp.dot(W, x)

@jit  # Note: we can arbitrarily compose JAX transforms! Here jit + vmap.
def vmap_batched_apply_matrix(batched_x):
    return vmap(vmap(apply_matrix, in_axes=0), in_axes=0)(batched_x)

vmap_batched_apply_matrix(jnp.expand_dims(batch_x, 0)).block_until_ready().shape

In [None]:
# Example 1: lax is stricter

print(jnp.add(1, 1.0))  # jax.numpy API implicitly promotes mixed types
print(lax.add(1, 1.0))  # jax.lax API requires explicit type promotion

In [None]:
# Example 2: lax is more powerful (but as a tradeoff less user-friendly)

x = jnp.array([1, 2, 1])
y = jnp.ones(10)

result1 = jnp.convolve(x, y)

result2 = lax.conv_general_dilated(
    x.reshape(1, 1, 3).astype(float), # explicit float
    y.reshape(1, 1, 10),
    window_strides=(1,),
    padding=[(len(y) - 1, len(y) - 1)] # padding='full' numpy
)

print(result1)
print(result2)


## How does JIT actually work?

In [None]:
def norm(x):
    x -= x.mean(0)
    return x / x.std(0)

norm_compiled = jit(norm)

x = random.normal(key, (10_000, 100), dtype=jnp.float32)

%timeit norm(x).block_until_ready()
%timeit norm_compiled(x).block_until_ready()

In [None]:
def get_negative(x):
    return x[x < 0]

x = random.normal(key, (10,), dtype=jnp.float32)
print(get_negative(x))

In [None]:
print(jit(get_negative)(x))

This error occurs when a program attempts to use non-concrete boolean indices in a traced indexing operation. Under JIT compilation, JAX arrays must have static shapes (i.e. shapes that are known at compile-time) and so boolean masks must be used carefully. Some logic implemented via boolean masking is simply not possible in a jax.jit() function; in other cases, the logic can be re-expressed in a JIT-compatible way, often using the three-argument version of where().

In [None]:
def get_negative(x):
    return jnp.where(x > 0, x, 0)

x = random.normal(key, (10,), dtype=jnp.float32)
print(jit(get_negative)(x))

In [None]:
@jit
def f(x, y):
    print("Running f():")
    print(f"x = {x}")
    print(f"y = {y}")
    result = jnp.dot(x + 1, y + 1)
    print(f"result = {result}")
    return result

x = np.random.randn(3, 4)
y = np.random.randn(4)
print(f(x, y))

x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
print(f(x2, y2)) # Oops! Side effects (like print) are not compiled...

# Note: any time we get the same shapes and types we just call the compiled fn!

This is by design. JAX’s goal is to compile pure functions — functions without side effects (like printing, file I/O, modifying global variables).

If you put a print, logging, or any other side effect inside a @jit function, it will only run during the first trace — not on subsequent calls.

**Reusing the compiled graph makes subsequent calls extremely fast — often 10x–100x faster than raw NumPy.**

In [None]:
def f(x, y):
    return jnp.dot(x + 1, y + 1)

print(make_jaxpr(f)(x, y))

In [None]:
@jit
def f(x, neg):
    return -x if neg else x

f(1, True)

In [None]:
from functools import partial

@partial(jit, static_argnums=(1,))
def f(x, neg):
    print(x)
    return -x if neg else x

print(f(1, True))
print(f(2, True))
print(f(2, False))
print(f(23, False))

In [None]:
@jit
def f(x):
    print("expand dim")
    return x.reshape(jnp.array(x.shape).prod())


x = jnp.ones((2, 3))
f(x)

🚫 During tracing, JAX does NOT allow converting abstract shapes into concrete arrays via jnp.array.  
Why? Because jnp.array([2,3]) creates a concrete array, but during tracing, we’re still building a symbolic graph — we don’t have real values yet. JAX wants to keep everything symbolic until runtime.

In [None]:
# Workaround: using numpy instead of jax.numpy


@jit
def f(x):
    return x.reshape(np.array(x.shape).prod())


x = jnp.ones((2, 3))
f(x)

### Pure functions
JAX is designed to work only on pure functions!
Pure function? Informal definition:

1. All the input data is passed through the function parameters, all the results are output through the function results.
2. A pure function will always return the same result if invoked with the same inputs.

In [None]:
def impure_print_side_effect(x):
    print("Execution function")
    return x


print("First call: ", jit(impure_print_side_effect)(4.))


print("Second call: ", jit(impure_print_side_effect)(5.))


print("Third call: ", jit(impure_print_side_effect)(jnp.array([1.])))

In [None]:
g = 0.

def impure_use_global(x):
    print("Execution function")
    return x + g


print("First call: ", jit(impure_use_global)(4.))

g = 10.

# Subsequent runs may silently use the cached value of the globals
print("Second call: ", jit(impure_use_global)(5.))

# This will end up reading the latest value of the global after recompile
print("Thrid call: ", jit(impure_use_global)(jnp.array([4.])))

In [None]:
def pure_use_internal_state(x):
    state = dict(even=0, odd=0)
    for i in range(10):
        state["even" if i % 2 == 0 else "odd"] += x
    return state["even"] + state["odd"]

print(jit(pure_use_internal_state)(5.))

In [None]:
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i, x: x + array[i], 0))  # expected result 45

iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i, x: x + next(iterator), 0))  # unexpected result 0

""" The semantics of fori_loop are given by this Python implementation:
def fori_loop(lower, upper, body_fun, init_val):
  val = init_val
  for i in range(lower, upper):
    val = body_fun(i, val)
  return val
"""

### In-Place Updates

In [None]:
jax_array = jnp.zeros((3, 3), dtype=jnp.float32)
updated_array = jax_array.at[1, :].set(1.0)

print(f"Original: {jax_array}")
print(f"Updated: {updated_array}")

In [None]:
print("Origibal array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

print("Updated array:")
new_jax_array = jax_array.at[::2, 3:].add(7.)
print(new_jax_array)

### Out-of-Bounds Indexing

In [None]:
try:
  np.arange(10)[11]
except Exception as e:
    print("Exception {}".format(e))

In [None]:
# JAX behavior
# 1) updates at out-of-bounds indices are skipped
# 2) retrievals result in index being clamped
# in general there are currently some bugs so just consider the behavior undefined!

print(jnp.arange(10).at[11].add(23))  # example of 1)
print(jnp.arange(10)[11])  # example of 2)

### Non-array inputs
This is added by design (performance reasons)

In [None]:
print(np.sum([1, 2, 3]))

In [None]:
try:
    jnp.sum([1, 2, 3])
except TypeError as e:
    print(f"TypeError: {e}")

In [None]:
try:
    print(jnp.sum(jnp.array([1, 2, 3])))
except TypeError as e:
    print(f"TypeError: {e}")

In [None]:
def permissive_sum(x):
    return jnp.sum(jnp.array(x))

x = list(np.arange(10))
print(make_jaxpr(permissive_sum)(x))

### random numbers

In [None]:
print(np.random.random())
print(np.random.random())

np.random.seed(seed)

rng_state = np.random.get_state()
print(rng_state[2:])

_ = np.random.uniform()
rng_state = np.random.get_state()
print(rng_state[2:])


_ = np.random.uniform()
rng_state = np.random.get_state()
print(rng_state[2:])
# Mersenne Twister PRNG is known to have a number of problems (NumPy's imp of PRNG)

In [None]:
key = random.PRNGKey(seed)
print(key)

print(random.normal(key, shape=(1,)))
print(key)

print(random.normal(key, shape=(1,))) # same result!!
print(key)


In [None]:
# Solution? -> Split every time you need a pseudorandom number.

print(f"Key = {key}")
key, subkey = random.split(key)
normal_pseudorandom = random.normal(subkey, shape=(1,))
print("    \---SPLIT --> new key   ", key)
print("             \--> new subkey", subkey, "--> normal", normal_pseudorandom)
# Note1: you can also split into more subkeys and not just 1
# Note2: key, subkey no difference it's only a convention

In [None]:
# Why this design?
# Well...think...with current design can the code be:
# 1) reproducible?
# 2) parallelizable?
# 3) vectorisable?

np.random.seed(seed)

def bar():
    return np.random.uniform()
def baz():
    return np.random.uniform()
def foo():
    return bar() + 2 * baz()

print(foo())
# What if we want to parallelize this code? NumPy assumes too much. 2) is violated.

In [None]:
print("Numpy:")
np.random.seed(seed)
print("individuality: ", np.stack([np.random.uniform() for _ in range(3)]))

np.random.seed(seed)
print("all at once: ", np.random.uniform(size=3))

# JAX
print("JAX:")
key = random.PRNGKey(seed)
subkeys = random.split(key, 3)
sequences = np.stack([random.normal(subkey) for subkey in subkeys])
print("individuality: ", sequences)


key = random.PRNGKey(seed)
print("all at once: ", random.uniform(key=key, shape=(3,)))

Excellent question! You've identified a **fundamental flaw** in NumPy's random number generation that JAX was specifically designed to fix. The issue with NumPy violates the principle of **parallelizability** and **functional purity**.

**Key differences**:
1. **`individual` and `batch` are DIFFERENT** (as they should be — different key usage patterns).
2. **But both are REPRODUCIBLE** — same `key` always gives same result.
3. **No hidden dependencies** — you can call `random.normal(other_key)` anywhere without affecting these results.

---

## 📋 Summary: Why JAX's Approach is Superior

| Aspect | NumPy | JAX |
|--------|-------|-----|
| **Global State** | ❌ Mutable global RNG | ✅ Explicit immutable keys |
| **Reproducibility** | ❌ Depends on call order | ✅ Same key = same result |
| **Parallelization** | ❌ Race conditions | ✅ Safe (no shared state) |
| **Functional Purity** | ❌ Side effects | ✅ Pure functions |
| **Debugging** | ❌ Hard to trace state | ✅ Explicit key flow |

---

JAX's design:

1. **Immutability**: Data structures shouldn't change
2. **Reproducibility**: Same inputs → same outputs  
3. **Parallelizability**: No shared mutable state

NumPy's global RNG state breaks **all three** of these principles, making it unsuitable for modern ML workflows that require:
- **Distributed training** (multiple GPUs/nodes)
- **Reproducible experiments**
- **Functional programming** patterns

JAX's explicit PRNG keys fix this fundamental design flaw, enabling **scalable, reproducible, and parallelizable** random number generation.

### Control Flow

In [None]:
# Python control flow + grad() -> everything is ok
def f(x):
    if x < 3:
        return 3. * x**2
    else:
        return -4. * x

x = np.linspace(-10, 10, 1000)
y = [f(n) for n in x]
plt.plot(x, y)
plt.show()

print(grad(f)(2.))
print(grad(f)(4.))

In [None]:
# Python contol flow + jit() -> issues

# "The tradeoff is that with higher levels of abstraction we gain a more general view
# of the Python code (and thus save on re-compilations),
# but we require more constraints on the Python code to complete the trace."

# Example 1: conditioning on value (same function as in the above cell)
# Solution (recall: we already have seen this)

f_jit = jit(f, static_argnums=(0,))
x = 2.0

print(make_jaxpr(f_jit, static_argnums=(0,))(x))
print(f_jit(x))

In [None]:
# Example 2: range depends on value again

def f(x, n):
    y = 0.
    for i in range(n):
        y += x[i]
    return y

f_jit = jit(f, static_argnums=(1,))
x = (jnp.array([2., 3., 4.]), 15)

print(make_jaxpr(f_jit, static_argnums=(1,))(*x))
print(f_jit(*x))

In [None]:
 # Even "better" (it's less readable) solution is to use low level API

def f_fori(x, n):
    body_func = lambda i, val: val + x[i]
    return lax.fori_loop(0, n, body_func, 0.)

f_fori_jit = jit(f_fori)

print(make_jaxpr(f_fori_jit)(*x))
print(f_fori_jit(*x))

In [None]:
# Example 3: this is not problematic (it'll only cache a single branch)

def log2_if_rank_2(x):
    if x.ndim == 2:
        ln_x = jnp.log(x)
        ln_2 = jnp.log(2.0)
        return ln_x / ln_2
    else:
        return x

print(make_jaxpr(log2_if_rank_2)(jnp.array([1, 2, 3])))
print(make_jaxpr(log2_if_rank_2)(jnp.array([[1, 2, 3]])))

### NaNs

In [None]:
jnp.divide(0., 0.)

from jax import config

config.update("jax_debug_nans", True)

In [None]:
x = random.uniform(key, (1000,), dtype=jnp.float32)
print(x.dtype)