# Learning JAX: A Beginner's Guide

Welcome to JAX! JAX is a Python library for high-performance numerical computing, especially well-suited for machine learning research. It combines a familiar NumPy-like API with powerful transformations like Just-In-Time (JIT) compilation, automatic differentiation, and automatic vectorization.

## Installation

Before we start, you need to install JAX. The installation command depends on whether you want to use JAX with CPU, GPU, or TPU support.

**For CPU-only:**
```bash
pip install --upgrade "jax[cpu]"
```

**For NVIDIA GPU:**
You'll need to have CUDA and cuDNN installed. Then, find the appropriate JAX wheel for your CUDA version from the [official JAX installation guide](https://github.com/google/jax#installation).
An example command (replace with the correct one for your setup):
```bash
pip install --upgrade "jax[cuda11_pip]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html
```

**For Google Cloud TPU:**
JAX is often pre-installed or easily installable in TPU environments.
```bash
pip install --upgrade jax jaxlib -f https://storage.googleapis.com/jax-releases/jax_releases.html
```

Let's run the CPU installation command in this notebook (you can comment it out if you've already installed it or need a different version).

In [4]:
# !pip install --upgrade "jax[cpu]" # Run this if you haven't installed JAX yet

## Quickstart: Key Concepts

Let's dive into the core ideas that make JAX unique and powerful.

### 1. JAX NumPy (`jax.numpy`)

JAX provides a NumPy-compatible API through `jax.numpy`, which is conventionally imported as `jnp`.
If you know NumPy, you're already halfway there! Most NumPy functions have a `jnp` equivalent.

In [5]:
import jax
import jax.numpy as jnp
import numpy as np # We'll use standard NumPy for comparison sometimes

# Create a JAX array
x_jnp = jnp.array([1.0, 2.0, 3.0])
print(f"JAX array: {x_jnp}, type: {type(x_jnp)}")

# Create a NumPy array for comparison
x_np = np.array([1.0, 2.0, 3.0])
print(f"NumPy array: {x_np}, type: {type(x_np)}")

# Basic operations look the same
y_jnp = jnp.arange(1, 4, dtype=jnp.float32)
sum_jnp = x_jnp + y_jnp
print(f"Sum of JAX arrays: {sum_jnp}")

prod_jnp = jnp.dot(x_jnp, y_jnp)
print(f"Dot product of JAX arrays: {prod_jnp}")

# JAX arrays are typically on a device (CPU by default here, could be GPU/TPU)
print(f"Device of x_jnp: {x_jnp.device}")

JAX array: [1. 2. 3.], type: <class 'jaxlib._jax.ArrayImpl'>
NumPy array: [1. 2. 3.], type: <class 'numpy.ndarray'>
Sum of JAX arrays: [2. 4. 6.]
Dot product of JAX arrays: 14.0
Device of x_jnp: TFRT_CPU_0


#### Immutability
A key difference from NumPy is that JAX arrays are **immutable**. This means once a JAX array is created, its contents cannot be changed in-place. Operations that seem to modify an array actually return a *new* array.
This is crucial for JAX's functional programming paradigm and enables its powerful transformations.

In [6]:
x = jnp.array([1, 2, 3])
print(f"Original x: {x}")

# This would raise an error in JAX:
# x[0] = 10 

# To update an array, you use functional methods like .at[].set()
y = x.at[0].set(10)
print(f"Original x after trying to update (it's unchanged!): {x}")
print(f"New array y with the update: {y}")

# Other update operations:
z = x.at[1].add(5) # Adds 5 to the element at index 1
print(f"New array z (x[1]+5): {z}")

Original x: [1 2 3]
Original x after trying to update (it's unchanged!): [1 2 3]
New array y with the update: [10  2  3]
New array z (x[1]+5): [1 7 3]


### 2. Just-In-Time (JIT) Compilation (`jax.jit`)

JAX can compile your Python functions into highly optimized machine code using XLA (Accelerated Linear Algebra compiler). This is done using the `jax.jit` transformation.

**Benefits:**
- **Speed:** Compiled code often runs much faster, especially for complex computations or code inside loops.
- **Kernel Fusion:** XLA can fuse multiple operations into a single kernel, reducing overhead.

In [7]:
import time

def slow_function(x):
  # A somewhat arbitrary computation
  return jnp.sum(jnp.sin(x) * jnp.cos(x) + jnp.tanh(x) / (jnp.exp(x) + 1e-6))

large_array = jnp.arange(1_000_000, dtype=jnp.float32)

# Time the uncompiled function
# JAX execution is asynchronous. block_until_ready() ensures computation finishes for timing.
start_time = time.time()
result_uncompiled = slow_function(large_array).block_until_ready()
time_uncompiled = time.time() - start_time
print(f"Uncompiled function time: {time_uncompiled:.6f} seconds")

# Compile the function with jax.jit
fast_function = jax.jit(slow_function)

# Time the compiled function (first run includes compilation time)
start_time = time.time()
result_compiled_first = fast_function(large_array).block_until_ready()
time_compiled_first = time.time() - start_time
print(f"Compiled function (first run): {time_compiled_first:.6f} seconds")

# Time the compiled function (subsequent runs are faster)
start_time = time.time()
result_compiled_second = fast_function(large_array).block_until_ready()
time_compiled_second = time.time() - start_time
print(f"Compiled function (second run): {time_compiled_second:.6f} seconds")

# Check results are the same
print(f"Results are close: {jnp.allclose(result_uncompiled, result_compiled_second)}")

# You can also use @jax.jit as a decorator
@jax.jit
def even_faster_function(x):
  return jnp.sum(jnp.sin(x) * jnp.cos(x) + jnp.tanh(x) / (jnp.exp(x) + 1e-6))

start_time = time.time()
result_decorated = even_faster_function(large_array).block_until_ready()
time_decorated = time.time() - start_time
print(f"Decorated JIT function (first run): {time_decorated:.6f} seconds")

Uncompiled function time: 0.366768 seconds
Compiled function (first run): 0.058543 seconds
Compiled function (second run): 0.009382 seconds
Results are close: True
Decorated JIT function (first run): 0.048020 seconds


**Note on JIT:**
- The first time a JIT-compiled function is called with specific input shapes and types, JAX traces and compiles it. This incurs some overhead.
- Subsequent calls with the *same* input shapes and types will use the cached, compiled version and be much faster.
- JIT works best with *pure functions*: functions whose output depends only on their inputs and have no side effects (like printing or modifying global variables within the traced part).

### 3. Automatic Differentiation (`jax.grad`)

JAX can automatically compute gradients (and higher-order derivatives) of your functions. This is the backbone of modern machine learning.

- `jax.grad(fun)`: Returns a new function that computes the gradient of `fun` with respect to its first argument.
- `argnums` parameter: To specify differentiation with respect to other arguments.
- `jax.value_and_grad(fun)`: Returns a function that computes both the value of `fun` and its gradient simultaneously, which can be more efficient.

In [8]:
# A simple scalar function: f(x) = x^3
def f(x):
  return x**3

# Get the gradient function (derivative)
grad_f = jax.grad(f)

x_val = 2.0
print(f"f({x_val}) = {f(x_val)}")
# Analytically, f'(x) = 3x^2, so f'(2.0) = 3 * (2.0)^2 = 12.0
print(f"f'({x_val}) = {grad_f(x_val)}")

# Function with multiple arguments: g(x, y) = x^2 * y
def g(x, y):
  return x**2 * y

# Gradient with respect to the first argument (x) by default
grad_g_wrt_x = jax.grad(g) # same as jax.grad(g, argnums=0)
# Analytically, dg/dx = 2xy
print(f"dg/dx(2.0, 3.0) = {grad_g_wrt_x(2.0, 3.0)}") # Expected: 2 * 2.0 * 3.0 = 12.0

# Gradient with respect to the second argument (y)
grad_g_wrt_y = jax.grad(g, argnums=1)
# Analytically, dg/dy = x^2
print(f"dg/dy(2.0, 3.0) = {grad_g_wrt_y(2.0, 3.0)}") # Expected: (2.0)^2 = 4.0

# Gradients with respect to multiple arguments (returns a tuple of gradients)
grad_g_wrt_xy = jax.grad(g, argnums=(0, 1))
gradients_xy = grad_g_wrt_xy(2.0, 3.0)
print(f"(dg/dx, dg/dy) at (2.0, 3.0) = {gradients_xy}")

# Compute value and gradient together
value_and_grad_f = jax.value_and_grad(f)
val, grad_val = value_and_grad_f(x_val)
print(f"Value of f({x_val}): {val}, Gradient of f({x_val}): {grad_val}")

# Gradients work with JIT too!
@jax.jit
def f_jit(x):
    return x**3

grad_f_jit = jax.grad(f_jit)
print(f"Gradient of JITted f'({x_val}) = {grad_f_jit(x_val)}")

f(2.0) = 8.0
f'(2.0) = 12.0
dg/dx(2.0, 3.0) = 12.0
dg/dy(2.0, 3.0) = 4.0
(dg/dx, dg/dy) at (2.0, 3.0) = (Array(12., dtype=float32, weak_type=True), Array(4., dtype=float32, weak_type=True))
Value of f(2.0): 8.0, Gradient of f(2.0): 12.0
Gradient of JITted f'(2.0) = 12.0


### 4. Automatic Vectorization (`jax.vmap`)

`jax.vmap` is a transformation for automatically vectorizing functions. If you have a function that operates on a single data point, `vmap` can transform it into a function that efficiently operates on a batch (or multiple axes) of data points, without you needing to write explicit loops.

**Benefits:**
- **Efficiency:** Pushes looping logic to XLA for parallel execution.
- **Simplicity:** Write code for a single instance, `vmap` handles batching.

In [9]:
# A function that operates on single vectors (e.g., a scaled dot product)
def scaled_dot_product(a, b, scale):
  return jnp.dot(a, b) * scale

vec1 = jnp.array([1., 2., 3.])
vec2 = jnp.array([0., 1., 0.])
s = 0.5

print(f"Single scaled_dot_product: {scaled_dot_product(vec1, vec2, s)}")

# Now, suppose we have batches of vectors 'a' and 'b', but a single 'scale'
batch_vec1 = jnp.array([[1., 2., 3.], [4., 5., 6.]]) # Shape (2, 3)
batch_vec2 = jnp.array([[0., 1., 0.], [1., 0., 1.]]) # Shape (2, 3)

# We want to apply scaled_dot_product element-wise to the batches.
# `in_axes=(0, 0, None)` means:
# - Map over the 0-th axis of the first argument (batch_vec1)
# - Map over the 0-th axis of the second argument (batch_vec2)
# - The third argument (scale) is broadcasted (not mapped over, treated as fixed)
batched_scaled_dot_product = jax.vmap(scaled_dot_product, in_axes=(0, 0, None))

result_vmap = batched_scaled_dot_product(batch_vec1, batch_vec2, s)
print(f"Batched result using vmap: {result_vmap}")

# Let's verify manually for the first element:
manual_first = scaled_dot_product(batch_vec1[0], batch_vec2[0], s)
print(f"Manual first element: {manual_first}")
assert jnp.allclose(result_vmap[0], manual_first)

# If all arguments should be batched along their first axis:
batch_scales = jnp.array([0.5, 2.0])
batched_all_args = jax.vmap(scaled_dot_product, in_axes=(0, 0, 0)) # or just in_axes=0
result_vmap_all = batched_all_args(batch_vec1, batch_vec2, batch_scales)
print(f"Batched result (all args batched): {result_vmap_all}")

Single scaled_dot_product: 1.0
Batched result using vmap: [1. 5.]
Manual first element: 1.0
Batched result (all args batched): [ 1. 20.]


### 5. Pseudorandom Numbers (`jax.random`)

JAX handles random numbers differently from NumPy to ensure reproducibility in its functional and parallel execution model. You must explicitly manage **PRNG keys**.

1.  Create a master key: `key = jax.random.PRNGKey(seed)`
2.  Split keys for use: `key, subkey = jax.random.split(key)`
    - Each time you need random numbers for an operation, you use a `subkey`.
    - The original `key` is updated (or a new one is returned) to be split further for future operations.
    - This ensures that sequences of random numbers are independent if generated from different subkeys.

In [10]:
key = jax.random.PRNGKey(0) # Create a master key with a seed
print(f"Initial key: {key}")

# Split the key to generate random numbers for one operation
key, subkey1 = jax.random.split(key)
random_matrix = jax.random.normal(subkey1, (2, 2))
print(f"\nSubkey1: {subkey1}")
print(f"Random matrix (using subkey1):\n{random_matrix}")
print(f"Key after first split: {key}")

# Split the key again for another independent random operation
key, subkey2 = jax.random.split(key) 
another_random_vector = jax.random.uniform(subkey2, (3,))
print(f"\nSubkey2: {subkey2}")
print(f"Another random vector (using subkey2): {another_random_vector}")
print(f"Key after second split: {key}")

# If you use the same subkey, you get the same random numbers
random_matrix_again = jax.random.normal(subkey1, (2, 2))
print(f"\nRandom matrix again (using same subkey1):\n{random_matrix_again}")
assert jnp.allclose(random_matrix, random_matrix_again)

# Common pattern in functions:
def generate_random_data(key, shape):
  key_data, key_noise = jax.random.split(key)
  clean_data = jax.random.uniform(key_data, shape)
  noise = jax.random.normal(key_noise, shape) * 0.1
  return clean_data + noise

key, subkey_for_func = jax.random.split(key)
my_data = generate_random_data(subkey_for_func, (2,3))
print(f"\nGenerated data from function:\n{my_data}")

Initial key: [0 0]

Subkey1: [ 928981903 3453687069]
Random matrix (using subkey1):
[[-2.4424558  -2.0356805 ]
 [ 0.20554423 -0.3535502 ]]
Key after first split: [1797259609 2579123966]

Subkey2: [1353695780 2116000888]
Another random vector (using subkey2): [0.10429037 0.34398758 0.13106728]
Key after second split: [4165894930  804218099]

Random matrix again (using same subkey1):
[[-2.4424558  -2.0356805 ]
 [ 0.20554423 -0.3535502 ]]

Generated data from function:
[[ 0.5087175  -0.14265323  0.48834768]
 [ 0.7908833   0.6513083   0.9378768 ]]


## Summary of Key Transformations

- **`jax.numpy` (`jnp`)**: Your familiar NumPy API, but for JAX arrays (which are immutable).
- **`jax.jit`**: Compiles your Python functions (using `jnp`) into fast XLA code.
- **`jax.grad`**: Computes gradients of your functions.
- **`jax.vmap`**: Automatically vectorizes your functions to handle batches of data.
- **`jax.random`**: Explicit PRNG key management for reproducible random numbers.

## Next Steps & Further Learning

This notebook covered the very basics to get you started. JAX has many more powerful features. Here are some topics you might want to explore next from the list you provided or the official JAX documentation:

- **Working with Pytrees:** JAX functions often operate on nested structures of arrays (like lists, tuples, dicts of arrays). Understanding pytrees is key for more complex models.
- **Control Flow (`jax.lax`):** For using `if/else` or loops within JIT-compiled functions (e.g., `jax.lax.cond`, `jax.lax.scan`).
- **Parallel Programming (`jax.pmap`):** For distributing computations across multiple devices (e.g., multiple GPUs or TPU cores).
- **Stateful Computations:** How to manage state in a functional programming paradigm (often involves passing state explicitly through functions).
- **Advanced Automatic Differentiation:** Hessians, Jacobians, custom VJPs/JVPs.
- **Debugging in JAX:** Techniques and tools for finding issues in JAX code.
- **JAX - The Sharp Bits 🔪:** A great resource in the JAX documentation that highlights common pitfalls and how to avoid them.

The official JAX GitHub repository and its documentation are excellent resources.

Happy JAXing!