# Profiling in Jax

## Lesson Goals:

By the end of this lesson, you will get a hands-on understanding of how to profile your jax code and use the `tensorboard` interface. You'll identify long compilation steps, and prove to yourself that `vmap` is comparable to vectorized operations.

**Note**: this notebook assumes that you've worked through [exe_04_vmap](./exe_04_vmap.ipynb), so please work through it (or at least view the solutions at)

## Core Concepts:

- `vmap` vs. Vectorized operations
- `fori_loop`, but generally [jax loop primitives](./exe_03_loop_primitives.ipynb)


In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
import jax
from jax import random
import numpy as np


## Vmap vs. Vectorized

![](../assets/gaussian_pdf.png)

In [None]:
NUM_SAMPLES = 20_000
NUM_DIMS = 300


means = np.random.rand(NUM_DIMS)
sigma = np.random.rand(NUM_DIMS, NUM_DIMS)
sigma = sigma @ sigma.T
X = np.random.multivariate_normal(mean=means, cov=sigma, size=NUM_SAMPLES)

k = means.shape[0]
t1 = (2 * jnp.pi) ** (-k / 2)
t2 = jnp.linalg.det(sigma) ** (-0.5)
inv = jnp.linalg.inv(sigma)

@jax.jit
def gaussian_pdf_v(x_vec, mu_vec):
    # TODO: Reimplement the Gaussian PDF below
    #   Note: you already did this in exe_04_vmap :) 
    diff = x_vec - mu_vec
    to_exp = -0.5 * diff.T @ inv @ diff
    return t1 * t2 * jnp.exp(to_exp)


vmapped_gaussian = vmap(gaussian_pdf_v, in_axes=(0, None))

@jax.jit
def gaussian_pdf(x_mat, mu_mat) -> np.array:
    diff = x_mat - mu_mat
    ###############################################################
    to_exp = -0.5 * jnp.sum(diff @ inv * diff, axis=1)
    ###############################################################
    return t1 * t2 * jnp.exp(to_exp)

In [None]:
with jax.profiler.trace("/tmp/tensorboard/gauss"):
    # We run this 10 times to get an accurate idea of the time the `vmapped-res` takes. This is because there is some start-up time where we 
    #   jit the program
    for i in range(10):
        with jax.profiler.TraceAnnotation("manually-optimized"):
            normal_res = gaussian_pdf(X, means)
        
        with jax.profiler.TraceAnnotation("vmapped-res"):
            vmap_gauss_res = vmapped_gaussian(X, means)

# Viewing the trace

Run `tensorboard` via `tensorboard --logdir=/tmp/tensorboard/gauss` and navigate to the URL. You can navigate to the trace page by

1) Specifying the run you'd like (red box, on the right)
2) Specifying the tool you'd like, the `trace_viewer` (orange box)

![](../assets/tb_profiler.png)

You can then search up `manually-optimized` and `vmapped-res` in the text-input box in the blue square! Congratulations!


## Loop-Unrolling

In [None]:
@jax.jit
def slow_block(init_values):
    container = init_values
    for i in range(100):
        container = container + container
        for j in range(100):
            container = container + 5 / container
    return container

In [None]:
init_values = jnp.asarray(np.random.rand(5))

@jax.jit
def body_fun(_, container):
    return container + 5 / container

@jax.jit
def slow_block_equiv(_, container):
    container = container + container
    return jax.lax.fori_loop(0, 100, body_fun, container)
    
with jax.profiler.trace("/tmp/tensorboard/unroll"):
    # As before, we run this multiple times to get an accurate idea of the time from the jit
    for i in range(10):
        with jax.profiler.TraceAnnotation("first_call"):
            result = slow_block(init_values).block_until_ready()

        with jax.profiler.TraceAnnotation("for-loop-primitive"):
            result2 = jax.lax.fori_loop(0, 100, slow_block_equiv, init_values)


In [None]:
print(jnp.allclose(result, result2))

# Diagnosis

As before, run `tensorboard` via `tensorboard --logdir=/tmp/tensorboard/unroll` and navigate to the URL