# Session 27 🐍

☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️☀️

***

# 214. JAX
`JAX` is a high-performance numerical computation library, primarily for scientific computing and machine learning research. It is developed by Google.

But that description sells it short. Think of it as NumPy on steroids, fused with a compiler, and designed for the modern hardware landscape (GPUs/TPUs).

Its superpower is composable function transformations, which allow you to write elegant, pure Python code that can be automatically differentiated, compiled, and executed in parallel across accelerators.

***

# 215. The Foundation: NumPy-like API
If you know NumPy, you already know 80% of JAX's basic API. JAX provides a nearly identical interface with `jnp` instead of `np`.

In [None]:
import jax.numpy as jnp
import numpy as np

# They look the same
x_np = np.array([1., 2., 3.])
x_jnp = jnp.array([1., 2., 3.])

# You can do similar operations
y_np = np.sin(x_np)
y_jnp = jnp.sin(x_jnp)

print("NumPy: ", y_np)
print("JAX:   ", y_jnp)

This familiarity is a huge advantage. You can often port NumPy code to JAX with just a search-and-replace of np to jnp.

***

# 216. Immutability
JAX arrays are immutable. You cannot do x_jnp[0] = 10. You must create a new array instead. This is fundamental to JAX's functional programming model.

In [None]:
# This will ERROR in JAX
# x_jnp[0] = 10

# Correct JAX way: create a new array
new_x_jnp = x_jnp.at[0].set(10)

***

# 217. PRNG for Randomness 
NumPy manages random state implicitly. JAX requires you to handle it explicitly using a PRNG (Pseudo-Random Number Generator) key for reproducibility and parallelism.

In [None]:
from jax import random

key = random.PRNGKey(42) # Seed of 42
key, subkey = random.split(key) # Splitting is common practice
random_numbers = random.normal(subkey, shape=(3,))

***

# 218. The Magic: Composable Transformations
This is the heart of JAX. It provides functions that take a Python function and return a new, transformed function.

***

## 218-1. Gradients: grad
The grad function performs automatic differentiation. You give it a function that takes arrays and returns a scalar, and it gives you back a new function that computes its gradient.

In [None]:
from jax import grad

def loss_fn(params, data):
    return jnp.sum(params ** 2) # A simple quadratic loss

# Create the gradient function
grad_fn = grad(loss_fn)

params = jnp.array([1.0, 2.0, 3.0])
data = jnp.ones(3) # Dummy data

gradients = grad_fn(params, data)
print(gradients) # [2. 4. 6.] (the gradient of sum(x_i^2) is 2*x_i)

***

## 218-2. Just-In-Time Compilation: jit
The jit function compiles your function using XLA (Accelerated Linear Algebra), the same compiler that powers TensorFlow and Google's TPUs.

Python is slow. XLA analyzes your function's operations, fuses them together, and generates highly optimized machine code for CPUs, GPUs, or TPUs. This leads to massive speedups, especially on accelerators.

In [None]:
from jax import jit

def slow_fn(x):
    return x * x + x * 2.0 # Some element-wise operation

# Compile it!
fast_fn = jit(slow_fn)

# First call has compilation overhead, subsequent calls are blazing fast
result = fast_fn(jnp.ones((1000, 1000)))

***

## 218-3. Vectorization / Automatic Parallelization: vmap
The vmap function automatically adds a batch dimension to your function. It "vectorizes" a function written for a single example so it can efficiently process a batch of examples.

It eliminates the need to write tedious batch dimensions and loops, reduces code complexity, and often allows jit to generate even more efficient code.

In [None]:
from jax import vmap

# A function that works on a single vector
def predict_single(params, input_vec):
    return jnp.dot(params, input_vec)

# Let's say we have a batch of 100 input vectors
batched_input = random.normal(key, (100, 5))
params = random.normal(key, (5,))

# Naive way: slow Python loop
predictions_loop = jnp.array([predict_single(params, x) for x in batched_input])

# JAX way: automatic vectorization
predict_batched = vmap(predict_single, in_axes=(None, 0)) # Don't batch params (None), batch the 0th axis of input
predictions_vmap = predict_batched(params, batched_input)

# predictions_loop and predictions_vmap will be identical, but vmap is much faster and jit-able!

***

## 218-4. Parallelization across devices: `pmap`
The `pmap` function is like `vmap`, but instead of vectorizing along a batch dimension, it parallelizes it across multiple cores/devices (e.g., multiple GPUs). It handles all the communication (e.g., gradient synchronization) required for data-parallel training.

***

# 219. The Power of Composition
The true genius of JAX is that these transformations are composable. You can `jit` a `grad` of a `vmap`'d function. This composability is what makes it so powerful and expressive for research.

In [None]:
# A more complex example: Compute the per-example gradient for a whole batch, then compile it.
per_example_grad = jit(vmap(grad(loss_fn), in_axes=(None, 0)))

***

# 220. JAX's Ecosystem: Libraries Built on JAX
You rarely use raw JAX directly for complex models. Instead, you use powerful libraries that sit on top of it and handle neural network specifics:

`Flax:` A high-level, flexible neural network library developed by Google. It's the most popular choice for new projects, providing modules, optimizers, and serialization. (The "Keras/TF" for JAX).

`Haiku:` A neural network library from DeepMind. It's more object-oriented and closer to the Sonnet style (from TensorFlow). It emphasizes a clean separation between model definition and state.

`Optax:` A standard library for gradient processing and optimization (like Adam, SGD). It's used by both Flax and Haiku.

`RLax:` A library for reinforcement learning building blocks.

***

***

# Some Excercises

**1.** Create a 1D JAX array x with the values [5, 10, 15, 20].

Try to change the third element (value 15) to a 99 using standard Python assignment (x[2] = 99).

Observe and understand the error.

Now, do it correctly using the .at[].set() method and store the result in a new variable x_new.

Print both x and x_new to confirm the original array was unchanged.

___

**2.** Create a PRNG key with a seed of 0.

Use this key to generate a random array of shape (2, 3) from a normal distribution (mean=0, std=1). Store it as random_array_1.

Now, generate another random array of the same shape from the same distribution using the same original key. Store it as random_array_2.

Check if random_array_1 and random_array_2 are identical. They should be.

Now, split the original key to get a new subkey. Use this subkey to generate a new random array, random_array_3.

Verify that random_array_1 and random_array_3 are different.

---

**3.** Define a simple function f(x) that calculates the sum of the cubic function: $f(x) = \sum(x^3)$.

Define a point x = jnp.array([1.0, 2.0, 3.0, 4.0]).

Manually compute what the gradient of f(x) should be at point x (hint: $\frac{d}{dx}x^3 = 3x^2$).

Now, use jax.grad to create a gradient function df_dx = grad(f).

Call df_dx(x) and verify the result matches your manual calculation.

---

**4.** Create a large random vector v of size 10,000,000.

Write a function slow_function(x) that performs a non-trivial element-wise operation (e.g., x * x + jnp.sin(x) * jnp.cos(x) ** 2).

Time how long it takes to execute result = slow_function(v).

Now, create a JIT-compiled version: fast_function = jit(slow_function).

Time the first call to fast_function(v) (this includes compilation time).

Time a second call to fast_function(v). Observe the massive speedup on the compiled call.

***

**5.** Define a function predict that takes weights and a single input vector, and returns a single prediction (e.g., a dot product plus a bias: result = jnp.dot(weights, input) + bias).

Create a batch of 1000 input vectors, batch_inputs, of shape (1000, 5).

Create weights and bias.

Compute the predictions for the entire batch using a naive Python for loop. Time it.

Now, use vmap to create a batched version of predict. Hint: You'll need to specify which axes to map over (e.g., in_axes=(None, 0) for (weights, input)).

Compute the predictions on the entire batch using the vmap'd function. Time it and compare the speed and result to the loop method.

***

**6.** Define a function f(x) that takes a 3-element vector and returns a 2-element vector (e.g., f(x) = jnp.array([x[0] * x[1], x[1] ** 2 + x[2]])).

Define an input point x = jnp.array([1.0, 2.0, 3.0]).

The Jacobian is a matrix $J$ where $J_{ij} = \frac{\partial f_i}{\partial x_j}$.

Use a combination of vmap and grad to compute the Jacobian of f at point x in a single, elegant line of code without any loops.
Hint: vmap(grad(f)) gives you the gradient of the first output.

***

**7.** Define a simple mean-squared-error loss function mse(params, input, target) for a single data point. Assume params is a dictionary containing {'weights': w, 'bias': b}.

Your goal is to create a function that calculates the gradient of the loss with respect to the params for an entire batch of inputs and targets, and does it as fast as possible.

Construct this function by composing:
- a) grad to get the gradient for a single example.
- b) vmap to vectorize this gradient function over a batch.
- c) jit to compile the entire thing for maximum speed.

Apply this function to a dummy batch of data and print the resulting gradients.

***

**8.** Install Flax: pip install flax

Import flax.linen as nn.

Define a simple multi-layer perceptron (MLP) class MLP(nn.Module): with one hidden layer of 64 units and a output layer of 10 units (e.g., for classification). Use @nn.compact and the nn.Dense module.

Initialize a PRNG key.

Create a dummy input batch of shape (32, 784) (e.g., mimicking 32 MNIST images).

Use the init method to initialize the parameters of the model.

Use the apply method to perform a forward pass of the dummy batch through the network.

***

#                                                        🌞 https://github.com/AI-Planet 🌞