# Jax's vmap

## Lesson Goals:

By the end of this lesson, you will understand how and where to use `jax`'s `vmap` operation.

## Core Concepts:

- `vmap`
- softmax-regression
- gaussian PDF
- Neural Network Inference

## Concepts In action:

- Easy: [lotka-volterra](../case_studies/lotka-volterra/README.md)

- Intermediate: [leaky_integrate_and_fire](../case_studies/leaky_integrate_and_fire/README.md)
 
- Advanced: [gaussian_mixture_model](../case_studies/gaussian_mixture_model/README.md)


In [None]:
import jax.numpy as jnp
from jax.scipy import stats
import numpy as np
from jax import vmap
np.random.seed(42)

# Example: Batched Neural Network Inference

In [None]:
# Define the neural network parameters
W1 = jnp.array([[0.2, 0.4], [0.5, 0.3]])  # Shape (2, 2)
W2 = jnp.array([0.6, 0.7])                # Shape (2,)

# Activation function
def relu(x):
    return jnp.maximum(0, x)

# Single forward pass
def forward_pass(x, W1, W2):
    """
    TODO: Your code here
    1) Take the dot product between W1 and x
    2) apply the relu
    3) return the dot product between W2 and the relu result
    """
    raise NotImplementedError

# Vectorized forward pass using vmap

batched_forward_pass = ... # TODO: Your code here! vmap(forward_pass, ...)

# Example batched input data
X_batch = jnp.array([[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]])  # Shape (3, 2)

# Compute the output for the batch
batch_output = batched_forward_pass(X_batch, W1, W2)
print(batch_output)


# Calculating the Gaussian PDF

![](../assets/gaussian_pdf.png)

In [None]:
mu = np.array([0, 0])  # Mean vector
Sigma = np.array([[1, 0], [0, 1]])  # Covariance matrix
X = np.array([[1, 1], [2, 2], [3, 3]])  # Point to evaluate the PDF

# Arguments implicitly passed in. Done to keep the code cleaner for the example
k = mu.shape[0]
t1 = (2 * jnp.pi) ** (-k / 2)
t2 = jnp.linalg.det(Sigma) ** (-0.5)
inv = jnp.linalg.inv(Sigma)

def gaussian_pdf_v(x_vec, mu_vec, Sigma):
    """
    # TODO: implement the single-sample equivalent of `to_exp` in the gaussian_pdf/
    #       the elements to be exponentiated in the image above
    """
    diff = x_vec - mu_vec
    to_exp = ... # TODO: Your code here! vmap(forward_pass, ...)
    
    return t1 * t2 * jnp.exp(to_exp)

def gaussian_pdf(x_mat, mu_mat) -> np.array:
    diff = x_mat - mu_mat
    ###############################################################
    to_exp = -0.5 * jnp.sum(diff @ inv * diff, axis=1)
    ###############################################################
    return t1 * t2 * jnp.exp(to_exp)


vmapped_gaussian = vmap(gaussian_pdf_v, in_axes=(0, None, None))
vmap_gauss_res = vmapped_gaussian(X, mu, Sigma)


print("VMapped-Gaussian PDF correct?", jnp.allclose(
    vmap_gauss_res, 
    stats.multivariate_normal.pdf(X, mu, Sigma)
))

print("Typical-Gaussian PDF correct?", jnp.allclose(
    gaussian_pdf(X, mu, Sigma), 
    stats.multivariate_normal.pdf(X, mu, Sigma)
))

# Softmax Regression

![](../assets/softmax_regression.png)

In [None]:
# Example data
X = jnp.array([[1, 2], [2, 3], [3, 4]])  # Batch of inputs
W = jnp.array([[0.2, 0.8], [0.5, 0.1]])  # Weight matrix
b = jnp.array([0.1, -0.2])  # Bias vector

In [None]:
def softmax_regression(X, W, b):
    logits = jnp.dot(X, W) + b
    exp_logits = jnp.exp(logits)
    probabilities = exp_logits / jnp.sum(exp_logits, axis=-1, keepdims=True)
    return probabilities

# Calculate softmax probabilities for the batch of inputs
probabilities = softmax_regression(X, W, b)

In [None]:
def single_softmax_regression(x, W, b):
    """
    TODO: Implement the equivalent of the softmax_regression above on a single row of x
    """
    raise NotImplementedError

# Vectorize the single input calculation
vectorized_softmax_regression = vmap(single_softmax_regression, in_axes=(0, None, None))

# Calculate softmax probabilities using vmap
probabilities_vmap = vectorized_softmax_regression(X, W, b)
print(f"Vmapped Softmax Regression equal to vectorized?: {np.allclose(probabilities, probabilities_vmap)}")

# Further Exercises: 

## 1) Read up on [jax.pmap](https://jax.readthedocs.io/en/latest/_autosummary/jax.pmap.html)

`pmap` is a parallel map across devices and is useful for scaling across devices