# Setting the Stage

JAX is a numerical computing library that is designed to be composable, modular, and extensible. It is built on top of XLA, which is a compiler that can take a function and compile it into a highly optimized version that can run on a variety of hardware. JAX is designed to be a drop-in replacement for NumPy, but it also has a number of features that make it more powerful and flexible than NumPy.

JAX is built from three main components:
- `jax.vmap`: A function that takes a Python function and vectorizes it.
- `jax.jit`: A function that creates a compiled and optimized version of a function.
- `jax.grad`: A function that automatically differentiates another Python function.



## Installation

Installing JAX is typically as simple as running:

```bash
$ pip install jax
```

For a more detailed installation guide, see the [JAX documentation](https://jax.readthedocs.io/en/latest/installation.html).

## Using JAX as a Drop-In Replacement for NumPy

JAX includes most of the functionality you might be familiar with from NumPy, rewritten to utilize XLA instructions and to be compatible with JAX's other features. This is provided through the `jax.numpy` module, typically aliased as `jnp`. 

In [1]:
import jax.numpy as jnp

JAX can then be utilized in much the same way as NumPy. Taking the simple example of a matrix-vector product of

```{math}
:label: matvec
y = Ax
```

for instance,

$$
    \begin{bmatrix}
        0 & 1 \\
        2 & 3 \\
        4 & 5
    \end{bmatrix}
    \begin{bmatrix}
        1 \\ 1
    \end{bmatrix}
    =
    \begin{bmatrix}
        1 \\ 5 \\ 9
    \end{bmatrix}.
$$

In NumPy, this could be written as:

In [2]:
A = jnp.arange(6).reshape(3, 2)
x = jnp.ones(2)

print(A @ x)

[1. 5. 9.]


The beauty of JAX then comes from the concept of _composing_ functions, allowing complex operations to be natively written in a fashion closely resembeling their mathematical counterparts. Continuing with the matrix-vector product example, consider the case where instead of a single vector $x \in \mathbb{R}^d$, we have a set of $N$ vectors $X$. We are then interested in computing

$$
    y_i = Ax_{i} \quad \text{for } i  = 1, 2, \ldots, N,
$$

for a matrix $A \in \mathbb{R}^{m \times d}$. In NumPy, we would typically store these vectors in a matrix $X \in \mathbb{R}^{N \times d}$ and compute the matrix-vector product $AX^T$. We would then be left with a matrix $Y \in \mathbb{R}^{m \times N}$, where each column $Y_i$ corresponds to the result of the matrix-vector product $Ax_i$. Keeping track of these dimensionalities can quickly become cumbersome, especially when performed in sequence. A band-aid to this would be to actually compute $(AX^T)^T = XA^T$, however we are then further removed from the mathematical notation.

In [3]:
import numpy as np

rng = np.random.default_rng(2002)
A = np.arange(6).reshape(3, 2)
X = rng.integers(-3, 3, size=(4, 2))


def apply_matrix(A: np.ndarray, X: np.ndarray) -> np.ndarray:
    return X @ A.T


Y = apply_matrix(A, X)
print(Y)

[[  2  10  18]
 [  2   0  -2]
 [  0   4   8]
 [ -3 -13 -23]]


An alternative, seemingly little known, approach in NumPy would be to use `np.einsum`, which allows for Einstein summation notation. This is a powerful tool, but can be difficult to read and write, especially for more complex operations.

In [4]:
def apply_matrix_einsum(A: np.ndarray, X: np.ndarray) -> np.ndarray:
    return np.einsum("ij,Nj->Ni", A, X)


Y_einsum = apply_matrix_einsum(A, X)
print(np.allclose(Y, Y_einsum))

True


In JAX, we can instead use `jax.vmap` to vectorize the matrix-vector product, allowing us to compute the matrix-vector product for each vector in $X$ in a single function call. This is done by simply wrapping the matrix-vector product in a call to `jax.vmap`. The result is a function that computes the matrix-vector product for each vector in $X$ and returns the result as a single array. This is done in the following code snippet:

In [5]:
from jax import vmap, Array


def apply_matrix_jax(A: Array, x: Array) -> Array:
    return A @ x


# Convert the NumPy arrays to JAX arrays
A_jax = jnp.array(A)
X_jax = jnp.array(X)

vectorized_apply_matrix = vmap(apply_matrix_jax, in_axes=(None, 0))
Y_vmap = vectorized_apply_matrix(A_jax, X_jax)

print(np.allclose(Y, Y_vmap))

True


Here, we simply define our function as in the same way as Equation {eq}`matvec`, and then _vectorize_ it through a call to `jax.vmap`. The `in_axes` argument specifies which axes of the input should be mapped over. By specifying `in_axes=(None, 0)`, we are telling `jax.vmap` to map over the over the first axis of the second argument (the vectors in $X$), while leaving the first argument (the matrix $A$) unchanged. The result is a function that computes the matrix-vector product for each vector in $X$ and returns the result as a single array. 

Note also the simple interplay between JAX and NumPy, where we can use NumPy to generate the matrix $A$ and the vectors $X$, and then use JAX to compute the matrix-vector product. We can then go back again to NumPy in order to verify the result.