# JAX Introduction: Arrays, Differentiation, and JIT

Welcome! This notebook is designed to help you understand what JAX is and why it is important for modern scientific computing and machine learning.

---

## 1. What is JAX?

JAX is a numerical computing library from Google that brings together:
- **NumPy-like** array programming (but faster and more flexible)
- **Automatic differentiation (autograd)**
- **Just-In-Time (JIT) compilation** for performance
- Built-in support for **parallelism** (across CPUs, GPUs, TPUs)

> JAX is especially popular in research for its clean, functional programming style and speed.

---

## 2. JAX vs NumPy: Arrays and API

JAX arrays (`jax.numpy` or `jnp`) behave almost exactly like NumPy arrays, but with some crucial differences:

- Immutable (cannot be changed in-place)
- Operate on CPU, GPU, or TPU automatically
- Compatible with JAX's transformations (`grad`, `jit`, etc.)

### Exercise 1: Basic Array Operations

In [1]:
import jax
import jax.numpy as jnp


In [2]:
# Create an array
x = jnp.array([1.0, 2.0, 3.0])
print("x:", x)

# Elementwise operations
print("x + 2:", x + 2)
print("sin(x):", jnp.sin(x))
print("Sum:", x.sum())

# JAX arrays are immutable!
try:
    x[0] = 100  # This will raise an error
except Exception as e:
    print("Error:", e)

x: [1. 2. 3.]
x + 2: [3. 4. 5.]
sin(x): [0.84147096 0.9092974  0.14112   ]
Sum: 6.0
Error: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html


### Mutating or Replacing Parts of an Array in JAX

JAX arrays are **immutable**—you cannot change their contents in-place like NumPy arrays.  
Instead, you use functions that **return a new array** with your changes.

#### Example: Using `jax.numpy`'s `.at[].set()` and `.at[].add()`

In [17]:
x = jnp.array([1, 2, 3, 4, 5])

# Set the value at index 2 to 99
y = x.at[2].set(99)
print("Original x:", x)
print("After set:", y)  # [ 1  2 99  4  5]

# Add 10 to indices 1 and 3 (correct way: pass a 1D array of indices)
ind = jnp.array([1, 3])
z = x.at[ind].add(10)
print("After add:", z)  # [ 1 12  3 14  5]

Original x: [1 2 3 4 5]
After set: [ 1  2 99  4  5]
After add: [ 1 12  3 14  5]



## 3. Automatic Differentiation (`jax.grad`)

JAX can compute gradients of Python functions automatically, even for complicated functions.

### Example: Gradient of a Scalar Function

In [3]:
def f(x):
    return x**2 + 3.0 * x

dfdx = jax.grad(f)

print("f(2.0) =", f(2.0))
print("df/dx at x=2.0 =", dfdx(2.0))

f(2.0) = 10.0
df/dx at x=2.0 = 7.0


### Exercise 2: Try It Yourself

- Define a function `g(x) = sin(x) + x**3`
- Use `jax.grad` to compute its derivative at `x=1.0`
- Print both the function value and its derivative

## 4. JIT Compilation for Speed (`jax.jit`)

JAX can "compile" your functions to run much faster using XLA, especially for repeated calls.

### Example: JIT-accelerated function


In [8]:
 
import numpy as np 
# Define a normalization function
def norm(X):
    X = X - X.mean(0)
    return X / X.std(0)

# JIT-compiled version
norm_compiled = jax.jit(norm)

# Prepare a large random array
np.random.seed(1701)
X = jnp.array(np.random.rand(1000, 10))

# Check correctness
print("Are results close?", np.allclose(norm(X), norm_compiled(X), atol=1E-6))

# Time the uncompiled (eager) version
print("Eager (not-JIT) execution:")
%timeit norm(X).block_until_ready()  # block_until_ready() ensures accurate timing

# Time the compiled (JIT) version (first call includes compilation)
print("JIT-compiled execution (after compilation):")
_ = norm_compiled(X).block_until_ready()  # warm-up to compile
%timeit norm_compiled(X).block_until_ready()

Are results close? True
Eager (not-JIT) execution:
271 μs ± 4.19 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
JIT-compiled execution (after compilation):
182 μs ± 19.4 μs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


```

### Note:
- The **first call** to a JIT function might be slow (compilation), but all later calls are very fast!


# JAX: Understanding `vmap` and `pmap`

JAX provides **powerful tools for vectorization and parallelization**:  
- `vmap` for **automatic batching** on a single device (CPU, GPU, or TPU)
- `pmap` for **parallelization across multiple devices** (multiple GPUs or TPUs)

This notebook gives a detailed explanation and practical examples of both.

---

## 1. `vmap`: Automatic Vectorization (Batching)

### What does `vmap` do?

`vmap` takes a function written for **single examples** and automatically lifts it to operate on **batches of examples**—without writing explicit loops.

- **Why use it?**: Vectorized code is faster and easier to read than for-loops.
- **Analogy**: Like “numpy broadcasting on steroids,” but works for any function (even those with JAX transforms).

### Basic Example: Computing Gradients for a Batch

Suppose you want to compute the gradient of a function for many inputs:

In [9]:


def f(x):
    return x**2 + 3.0 * x

dfdx = jax.grad(f)
x_batch = jnp.linspace(-2, 2, 5)

# Without vmap: would require a Python for-loop
grads = jnp.array([dfdx(x) for x in x_batch])

# With vmap: no loop needed!
vmap_grads = jax.vmap(dfdx)(x_batch)

print("Manual loop:", grads)
print("vmap:", vmap_grads)

Manual loop: [-1.  1.  3.  5.  7.]
vmap: [-1.  1.  3.  5.  7.]


### How does batching work?

- By default, `vmap` applies the function along the **first axis** of array arguments ("leading batch dimension").
- You can control which axis is batched via the `in_axes` and `out_axes` arguments.

#### Example with multiple arguments:


In [10]:
def my_func(x, y):
    return x * y + 2

xs = jnp.arange(5)
ys = jnp.arange(5)

# vmap over both xs and ys simultaneously
result = jax.vmap(my_func)(xs, ys)
print(result)  # [2 3 6 11 18]

[ 2  3  6 11 18]


In [11]:
# vmap over only xs, keep y scalar
result = jax.vmap(my_func, in_axes=(0, None))(xs, 10)
print(result)  # [2 12 22 32 42]

[ 2 12 22 32 42]


### Advanced: Nested `vmap` for 2D batching

Sometimes you want to apply a function to **all pairs** from two arrays, for example:  
- Given vectors `A` (length M) and `B` (length N), compute `add(a, b)` for every combination of `a` in `A` and `b` in `B`, producing an M×N output.

You could do this with nested Python loops, but JAX's nested `vmap` lets you express this efficiently and concisely:


In [12]:
def add(x, y):
    return x + y

A = jnp.arange(3)    # shape (3,)
B = jnp.arange(4)    # shape (4,)

# First vmap: loops over elements of A, keeps B fixed each time
# Second vmap: loops over elements of B, keeps an element of A fixed
batched_add = jax.vmap(
    jax.vmap(add, in_axes=(None, 0)),  # vectorize over B (y): x fixed, y varies
    in_axes=(0, None)                  # vectorize over A (x): x varies, y fixed
)

result = batched_add(A, B)
print(result)

[[0 1 2 3]
 [1 2 3 4]
 [2 3 4 5]]


#### How does this work?

- The **inner vmap** (`in_axes=(None, 0)`) says: "For a fixed `x`, apply `add(x, y)` for all `y` in `B`" (vectorizing over `B`).
- The **outer vmap** (`in_axes=(0, None)`) says: "For all `x` in `A`, apply the inner vmap with that `x` and all of `B`".

This gives you a 2D array where `result[i, j] = A[i] + B[j]` for all `i, j`.

#### Generalization

You can nest `vmap` as many times as needed, enabling efficient, readable code for batched/broadcasted computations without writing explicit loops.

## 2. `pmap`: Parallelization Across Devices

### What does `pmap` do?

`pmap` transforms a function for **parallel execution across multiple devices** (e.g., multiple GPUs/TPUs in a single machine).

- **Use case**: Large-scale parallelism, model/data parallel training, hardware acceleration.
- **Not needed** for single-device code—use `vmap` instead!

### Basic Example: Parallel addition

In [13]:
def add_one(x):
    return x + 1

# Simulate inputs for 2 devices
x = jnp.arange(jax.device_count())
print("Device count:", jax.device_count())

# pmap distributes computation across devices
y = jax.pmap(add_one)(x)
print("Parallel result:", y)

Device count: 1
Parallel result: [1]


**On a machine with 4 GPUs, this will run on all 4 in parallel.**

### Key points about `pmap`:

- The **input array must be split** so that each device gets one chunk. Example: if you have 4 devices, input shape should have leading dimension 4.
- Output is a "sharded" DeviceArray: one chunk per device.
- JAX automatically handles communication between devices.

---


### Advanced: Collective operations with `pmap`

`pmap` supports device communication primitives (e.g., `jax.lax.pmean` for averaging):


In [14]:
import jax
import jax.numpy as jnp

def mean_across_devices(x):
    # Compute mean of x across devices
    return jax.lax.pmean(x, axis_name='i')

x = jnp.arange(jax.device_count()).astype(float)
out = jax.pmap(mean_across_devices, axis_name='i')(x)
print(out)

[0.]



## `vmap` vs. `pmap`: When to use which?

| Feature       | `vmap`                      | `pmap`                     |
|---------------|-----------------------------|----------------------------|
| Batching      | Yes (single device)         | Yes (across devices)       |
| Parallelism   | Implicit SIMD, single device| Explicit, multi-device     |
| Use case      | Data/model batching         | Distributed training/inference |
| Syntax        | Simple, works like a loop   | Needs device-aware shapes  |

---

## Assignment

1. **Write a function** that computes the norm of a vector. Use `vmap` to apply it to a batch of vectors.
2. **If you have multiple GPUs/TPUs**, write a function that multiplies by 2 and use `pmap` to apply over all devices.
3. **(Bonus)** Try using `vmap` and `grad` together to compute gradients for a batch of inputs.

---


# Pytrees and Pure Functional Programming in JAX: An Intuitive Introduction

JAX is designed for high-performance, reliable, and composable scientific computing. Two key ideas make this possible:

- **Pytrees**: Flexible, tree-like data structures.
- **Pure Functional Programming**: Functions with no side effects.

Let's understand these concepts in an intuitive, practical way.

---

## 1. What is a PyTree?

A **pytree** is any nested structure of lists, tuples, dicts, and other containers (including custom ones), where the "leaves" are regular JAX arrays or scalars.

**Think of it like a tree made of Python containers, with arrays at the ends.**

### Examples

In [18]:
# Simple pytree: a list of arrays
a = [jnp.ones(3), jnp.zeros(2)]

# Nested pytree: dicts, lists, tuples, arrays
b = {'params': (jnp.ones(2), [jnp.zeros(3), jnp.eye(2)]), 'lr': 0.01}

# Even more complex
c = (jnp.array(1.0), [{'a': jnp.array([1, 2])}, (jnp.zeros(1),)])

JAX can **recursively traverse** these objects, find all the arrays, and do things with them (mapping, flattening, etc).

### Why Pytrees?

- **Flexibility**: You can organize your model parameters, optimizer states, batches, etc. however you like.
- **Compatibility**: JAX's core APIs work with any pytree structure, not just flat arrays.
- **Ease of Use**: You can use familiar Python containers, no need for new data types.

---

## 2. Pure Functional Programming

JAX encourages **pure functions**: given the same inputs, always return the same outputs, with **no side effects** (no modifying globals, printing, or changing inputs in-place).

### Why does JAX care about purity?

- **Transformations**: JAX can only do things like `jit`, `grad`, `vmap`, etc. safely if functions have no side effects.
- **Reproducibility**: Pure functions are easier to debug and test.
- **Parallelism**: Pure functions can be run in parallel safely.

### Example: Pure vs. Impure

## 2.1. Examples: Pure vs. Impure Functions in Python (and JAX)

A **pure function** always returns the same output for the same input, and does not cause (or depend upon) any side effects.  
An **impure function** might change things outside itself, or depend on external/global state.



In [22]:

### Example 1: Mathematical function (Pure vs. Impure)


# Pure: always gives same output for same inputs, no side effects
def square(x):
    return x * x

# Impure: modifies external (global) variable
result = 0
def impure_square(x):
    global result
    result = x * x  # changes external state!
    return result


### Example 2: Modifying input arguments


# Pure: creates a new list, does not modify input
def append_pure(xs, x):
    return xs + [x]

# Impure: modifies the input list in-place
def append_impure(xs, x):
    xs.append(x)
    return xs


### Example 3: Printing, randomness, I/O


# Pure: no printing or randomness
def add(a, b):
    return a + b

# Impure: prints a message (side effect)
def add_print(a, b):
    print(f"Adding {a} and {b}")
    return a + b

# Impure: uses randomness (output can change each call)
import random
def add_random(a):
    return a + random.randint(0, 10)

### Example 4: JAX and side effects

 

# Pure: always same result, no print, no mutation
def jax_pure_fn(x):
    return jnp.sin(x) * 2

# Impure: not allowed in JAX transforms (prints inside function)
def jax_impure_fn(x):
    print("Calculating!")
    return jnp.sin(x) * 2

 




> **Key Point:**  
> JAX requires your functions to be pure for transformations like `jit`, `grad`, `vmap` to work correctly.  
> If you use print, mutate inputs, or access global state, your code may break or give wrong results!

---