# Jax's vmap

## Goals:

- Understanding the `vmap` operation, and how it can clean up your code and make it more interpretable

## Concepts:

- `vmap`

In [1]:
import jax.numpy as jnp
from jax.scipy import stats
import numpy as np
from jax import vmap
np.random.seed(42)


# Calculating the Gaussian PDF.... again

In [2]:
import numpy as np

def gaussian_pdf_v(x, mu, Sigma):
    # Dimensions of the data
    k = mu.shape[0]
    
    # Calculate determinant and inverse of the covariance matrix
    Sigma_det = jnp.linalg.det(Sigma)
    Sigma_inv = jnp.linalg.inv(Sigma)
    
    # Calculate the normalization factor
    normalization_factor = 1 / jnp.sqrt((2 * jnp.pi) ** k * Sigma_det)
    x_mu = x - mu

    ###############################################################
    # From earlier exponent2 = -0.5 * jnp.einsum('ij,jj,ij->i', x_mu, Sigma_inv, x_mu)
    to_exp = -0.5 * jnp.einsum("j, jj, j->", x_mu, Sigma_inv, x_mu)
    ###############################################################
    
    # Compute the Gaussian PDF
    return normalization_factor * jnp.exp(to_exp)

# Example usage
mu = np.array([0, 0])  # Mean vector
Sigma = np.array([[1, 0], [0, 1]])  # Covariance matrix
X = np.array([[1, 1], [2, 2], [3, 3]])  # Point to evaluate the PDF

vmapped_gaussian = vmap(gaussian_pdf_v, in_axes=(0, None, None))
vmap_gauss_res = vmapped_gaussian(X, mu, Sigma)

print("Multivariate Normal correct?", jnp.allclose(
    vmap_gauss_res, 
    stats.multivariate_normal.pdf(X, mu, Sigma)
))

2024-04-26 16:11:53.447035: W external/xla/xla/service/gpu/nvptx_compiler.cc:718] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


Multivariate Normal correct? True


# Applying learned kernels to a batch of input data

In [3]:

BATCH_SIZE = 128
NUM_CHANNELS = 3
HEIGHT = 32
WIDTH = 32

LEARNED_FILTERS = 5
images = jnp.asarray(np.random.rand(BATCH_SIZE, HEIGHT, WIDTH, NUM_CHANNELS))
learned_kernels = jnp.asarray(np.random.rand(LEARNED_FILTERS, HEIGHT, WIDTH, NUM_CHANNELS))

In [4]:
def predict_images(x, kernels):
    return jnp.einsum('bhwc,fhwc->bf', x, kernels)

pred_batch = predict_images(images, learned_kernels)

In [5]:
def predict_single(x, kernels):
    return jnp.einsum('hwc,fhwc->f', x, kernels)

vmapped_predict = vmap(predict_single, in_axes=(0, None))
pred_single_batch = vmapped_predict(images, learned_kernels)

In [6]:
print(jnp.allclose(
    pred_batch, 
    pred_single_batch
))

True


# 