# Jax's vmap

## Lesson Goals:

By the end of this lesson, you will understand how and where to use `jax`'s `vmap` operation.

## Core Concepts:

- `vmap`
- softmax-regression
- gaussian PDF
- Neural Network Inference

## Concepts In action:

- Easy: [lotka-volterra](../case_studies/lotka-volterra/README.md)

- Intermediate: [leaky_integrate_and_fire](../case_studies/leaky_integrate_and_fire/README.md)
 
- Advanced: [gaussian_mixture_model](../case_studies/gaussian_mixture_model/README.md)


In [None]:
import jax.numpy as jnp
from jax.scipy import stats
import numpy as np
from jax import vmap
np.random.seed(42)

# Vmap

`vmap` is a magical little function. You can essentially think of it as applying a function over the first axis of an array i.e. it's a `for-loop` applied
to the array. Consider the following exercises where you are doing a simple element-wise addition by 1.


In [None]:
def custom_vmap(x, func):
    return np.asarray([func(_x) for _x in x])

def scalar_add(x):
    assert len(x.shape) == 0, "x should be a scalar"
    return x + 1

def simple_vmap_example():
    """
    Here, we only have one axis, so applying the function is a one-liner. 
    TODO: use the `custom_vmap` and `scalar_add` 
    Returns:

    """
    vec = np.asarray([1, 2, 3, 4])
    added_to = custom_vmap(vec, scalar_add) 
    assert np.all(added_to == vec + 1)
    print("First-vmap application exercise passed!")
   
def less_simple_vmap_example():
    """
    Here, we have two axes that we will map over: the first axis has 3 elements, each of which, is a vector of 5 scalars. You are to 
    implement the scalar addition once again. 
    
    Hint: think of this as a vmap-on-vmap situation
    """
       
    mat = np.random.random(size=(3, 5))
    
    delayed_add = lambda x: custom_vmap(x, scalar_add)
    added_to_mat = custom_vmap(mat, delayed_add) 
    assert np.all(added_to_mat== mat + 1)
    print("Second-vmap application exercise passed!")
 
    
    
simple_vmap_example()
less_simple_vmap_example()

Clearly, the `for-loop` worked, so what's the issue? In frameworks like `numpy`, you want to use vectorized operations i.e. you'd want to just do the `x + 1`, because the CPU can do it all in parallel via SIMD. Comparing vectorized operations and the `for-loop`, you'll see that the vectorized operation is much faster. 
From a speed perspective, it is clear that vectorized operations are the way to go! Unfortunately, with the vectorized operations, we end up with some unnatural-looking equations i.e. the equations we see in the math vs. our implementation in python will look very different.


# Jax's VMAP

The solution is Jax's `vmap`. The `vmap` merges the speed and interpretability! You might want to look at [Jax - automatic vectorization](https://jax.readthedocs.io/en/latest/automatic-vectorization.html) for more information, but Jax will essentially "add" the batch axis to the to-be-mapped function. The process is quite similar to what happens when we `jit` a function - in fact, the two are composable! Check out the [gaussian_mixture_model](../case_studies/gaussian_mixture_model/README.md) to see this in action (fair warning, there's quite a bit going on)

It's important to note that although

> Performance-wise, automatically vectorized code written with vmap often lowers to an identical or near-identical sequence of XLA operations.

as per [when to use vmap](https://github.com/jax-ml/jax/discussions/18873), there are occasions where using the `vmap` can be slower, but this is not the expected behavior. If you encounter these situations you should probably open an issue on github!


# Simple Introduction

Here we give a quick working introduction to the arguments and what's happening

In [None]:
def vmap_on_all():
    def my_func(_x, _y, _z):
        assert len(_x.shape) == 0
        assert len(_y.shape) == 0
        assert len(_z.shape) == 0
        return _x + _y + _z

    x = jnp.asarray([1, 2, 3])
    y = jnp.asarray([1, 2, 3])
    z = jnp.asarray([1, 2, 3])
   
    # For each argument, we specify the axis to "map over"
    #   In this example, we essentially do
    #   [(1 * 3 ), (2 * 3), (3 * 3)]
    vmapped_fn_v1 = vmap(my_func, in_axes=(0, 0, 0))
    res = vmapped_fn_v1(x, y, z)
    assert jnp.all(
        res == 
        x * 3
    )
    
    # Alternatively, we can choose to not specify the axis
    vmapped_fn_v2 = vmap(my_func)
    res = vmapped_fn_v2(x, y, z)
    assert jnp.all(
        res == 
        x * 3
    )
    
def vmap_broadcast():
    def my_func(_x, _y, _z):
        assert len(_x.shape) == 0
        assert len(_y.shape) == 1 and _y.shape == (3,)
        assert len(_z.shape) == 1 and _z.shape == (3,)
        return _x * (_y + _z)

    x = jnp.asarray([1, 2, 3])
    y = jnp.asarray([1, 2, 3])
    z = jnp.asarray([1, 2, 3])
    
    # Here, we specify "None" for y and z, which means that they are passed in "as is"
    vmapped_fn = vmap(my_func, in_axes=(0, None, None))
    res = vmapped_fn(x, y, z)
    
    # What's happening here is we pass in y and z as vectors, so we essentially "broadcast" and go up in a dimension
    #   [(1 * (1,2,3) * 2), (2 * (1, 2, 3) * 2), (3 * (1, 2, 3) * 2)]
    #   = ((2, 4, 6), (4, 8, 12), (6, 12, 18))
    assert jnp.all(
        res == 
        jnp.asarray([[2, 4, 6], [4, 8, 12], [6, 12, 18]])
    )
    assert res.shape == (3, 3)
    
    
    
vmap_on_all()
vmap_broadcast()

# Real World Example: Batched Neural Network Inference

And on to some real-world examples! Throughout the rest of the notebook, we'll see the "core" function implemented in two ways:

`x_vmap` and `x`, where `x` is the way the function would be implemented in vectorized form, and the `x_vmap` is the vmap-ed form. Hopefully this illustrates how using the `vmap` more closely aligns our code with the math.

Note: in the first example, we will not see any difference, which is OK! This is more of a warm-up

In [None]:
# Define the neural network parameters
W1 = jnp.array([[0.2, 0.4], [0.5, 0.3]])  # Shape (2, 2)
W2 = jnp.array([0.6, 0.7])                # Shape (2,)

# Example batched input data
X_batch = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])  # Shape (3, 2)

# Activation function
def relu(x):
    return jnp.maximum(0, x)

# Single forward pass
def forward_pass_vmap(x, W1, W2):
    """
    TODO: Your code here
    1) Take the dot product between W1 and x
    2) apply the relu
    3) return the dot product between W2 and the relu result
    """
    assert x.shape == X_batch.shape[1:]
    return relu(x.dot(W1)).dot(W2)

def forward_pass(X, W1, W2):
    return relu(X.dot(W1)).dot(W2)
    

# Vectorized forward pass using vmap
batched_forward_pass = vmap(forward_pass_vmap, in_axes=(0, None, None))
vmap_batch_output = batched_forward_pass(X_batch, W1, W2)

vectorized_batch_output = forward_pass(X_batch, W1, W2)

assert jnp.all(
    vmap_batch_output == 
    vectorized_batch_output
)


# Calculating the Gaussian PDF

`gaussian_pdf_v` is the vmap-ed version of `gaussian_pdf`, which implements the function in vectorized form. You should study how the two are different
and how this difference emerges because of the way the data is passed in.

![](../assets/gaussian_pdf.png)

In [None]:
mu = np.array([0, 0])  # Mean vector
Sigma = np.array([[1, 0], [0, 1]])  # Covariance matrix
X = np.array([[1, 1], [2, 2], [3, 3]])  # Point to evaluate the PDF

# Arguments implicitly passed in. Done to keep the code cleaner for the example
k = mu.shape[0]
t1 = (2 * jnp.pi) ** (-k / 2)
t2 = jnp.linalg.det(Sigma) ** (-0.5)
inv = jnp.linalg.inv(Sigma)

def gaussian_pdf_v(x_vec, mu_vec):
    """
    # TODO: implement the single-sample equivalent of `to_exp` in the gaussian_pdf/
    #       the elements to be exponentiated in the image above
    """
    diff = x_vec - mu_vec
    to_exp = -0.5 * diff.T @ inv @ diff
    return t1 * t2 * jnp.exp(to_exp)

def gaussian_pdf(x_mat, mu_mat) -> np.array:
    diff = x_mat - mu_mat
    ###############################################################
    to_exp = -0.5 * jnp.sum(diff @ inv * diff, axis=1)
    ###############################################################
    return t1 * t2 * jnp.exp(to_exp)


vmapped_gaussian = vmap(gaussian_pdf_v, in_axes=(0, None))
vmap_gauss_res = vmapped_gaussian(X, mu)


print("VMapped-Gaussian PDF correct?", jnp.allclose(
    vmap_gauss_res, 
    stats.multivariate_normal.pdf(X, mu, Sigma)
))

print("Typical-Gaussian PDF correct?", jnp.allclose(
    gaussian_pdf(X, mu), 
    stats.multivariate_normal.pdf(X, mu, Sigma)
))

# Softmax Regression

![](../assets/softmax_regression.png)

In [None]:
# Example data
X = jnp.array([[1, 2], [2, 3], [3, 4]])  # Batch of inputs
W = jnp.array([[0.2, 0.8], [0.5, 0.1]])  # Weight matrix
b = jnp.array([0.1, -0.2])  # Bias vector

In [None]:
def softmax_regression(X, W, b):
    logits = jnp.dot(X, W) + b
    exp_logits = jnp.exp(logits)
    return exp_logits / jnp.sum(exp_logits, axis=-1, keepdims=True)

# Calculate softmax probabilities for the batch of inputs
probabilities = softmax_regression(X, W, b)

In [None]:
def softmax_regression_v(x, W, b):
    """
    TODO: Implement the equivalent of the softmax_regression above on a single row of x
    """
    logits = x.dot(W) + b
    exp_logits = jnp.exp(logits)
    return exp_logits / jnp.sum(exp_logits)

# Vectorize the single input calculation
vectorized_softmax_regression = vmap(softmax_regression_v, in_axes=(0, None, None))

# Calculate softmax probabilities using vmap
probabilities_vmap = vectorized_softmax_regression(X, W, b)
print(f"Vmapped Softmax Regression equal to vectorized?: {np.allclose(probabilities, probabilities_vmap)}")

# Further Exercises: 

## 1) Read up on [jax.pmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html)

`pmap` is a parallel map across devices and is useful for scaling across devices