# Profiling in Jax

## Lesson Goals:

By the end of this lesson, you will get a hands-on understanding of how to profile your jax code and use the `tensorboard` interface. You'll identify long compilation steps, and prove to yourself that `vmap` is comparable to vectorized operations.

## Core Concepts:

- `vmap` vs. Vectorized operations
- `fori_loop`, but generally [jax loop primitives](./exe_03_loop_primitives.ipynb)


In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
import jax
from jax import random
import numpy as np


## Vmap vs. Vectorized

![](../assets/gaussian_pdf.png)

In [None]:
NUM_SAMPLES = 20_000
NUM_DIMS = 300


means = np.random.rand(NUM_DIMS)
sigma = np.random.rand(NUM_DIMS, NUM_DIMS)
sigma = sigma @ sigma.T
X = np.random.multivariate_normal(mean=means, cov=sigma, size=NUM_SAMPLES)

@jax.jit
def gaussian_pdf_v(x_vec, mu_vec, Sigma):
    # TODO: Reimplement the Gaussian PDF below
    raise NotImplementedError


vmapped_gaussian = vmap(gaussian_pdf_v, in_axes=(0, None, None))

@jax.jit
def gaussian_pdf(x_mat, mu_mat, Sigma) -> np.array:
    k = mu_mat.shape[0]
    t1 = (2 * jnp.pi) ** (-k / 2)
    t2 = jnp.linalg.det(Sigma) ** (-0.5)
    inv = jnp.linalg.inv(Sigma)
    diff = x_mat - mu_mat
    ###############################################################
    to_exp = -0.5 * jnp.sum(diff @ inv * diff, axis=1)
    ###############################################################
    return t1 * t2 * jnp.exp(to_exp)

In [None]:
with jax.profiler.trace("/tmp/tensorboard/gauss"):
    for i in range(10):
        with jax.profiler.TraceAnnotation("manually-optimized"):
            normal_res = gaussian_pdf(X, means, sigma)
        
        with jax.profiler.TraceAnnotation("vmapped-res"):
            vmap_gauss_res = vmapped_gaussian(X, means, sigma)

In [None]:
!tensorboard --logdir=/tmp/tensorboard/gauss

## Loop-Unrolling

In [None]:
@jax.jit
def slow_block(init_values):
    container = init_values
    for i in range(10):
        container = container + container
        for j in range(100):
            container = container + 5 / container
    return container

In [None]:
init_values = jnp.asarray(np.random.rand(5))

@jax.jit
def slow_block_equiv():
    # TODO: Implement the equiuvalent of the `slow_block` function above.
    #     Note how the function is being called below
    raise NotImplementedError
    
with jax.profiler.trace("/tmp/tensorboard/unroll"):

                
    with jax.profiler.TraceAnnotation("first_call"):
        result = slow_block(init_values).block_until_ready()

    with jax.profiler.TraceAnnotation("second_call"):
        result2 = slow_block(init_values).block_until_ready()


    with jax.profiler.TraceAnnotation("for-loop-primitive"):
        result3 = jax.lax.fori_loop(0, 10, slow_block_equiv, init_values)


In [None]:
print(jnp.allclose(result, result2), jnp.allclose(result2, result3))

In [None]:
!tensorboard --logdir=/tmp/tensorboard/unroll