# Profiling in Jax

## Goals:

- Use the `tensorboard` profiler interface

## Concepts:

- identifying slowdowns
- diagnosing long compilations

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
import jax
from jax import random
import numpy as np


## Vmap vs. Vectorized

In [2]:
NUM_SAMPLES = 20_000
NUM_DIMS = 300


means = np.random.rand(NUM_DIMS)
sigma = np.random.rand(NUM_DIMS, NUM_DIMS)
sigma = sigma @ sigma.T
X = np.random.multivariate_normal(mean=means, cov=sigma, size=NUM_SAMPLES)


# Arguments implicitly passed in. Done to keep the code cleaner for the example

@jax.jit
def gaussian_pdf_v(x_vec, mu_vec, Sigma):
    k = mu_vec.shape[0]
    t1 = (2 * jnp.pi) ** (-k / 2)
    t2 = jnp.linalg.det(Sigma) ** (-0.5)
    inv = jnp.linalg.inv(Sigma)
    diff = x_vec - mu_vec
    ###############################################################
    to_exp = -0.5 * diff @ inv @ diff
    ###############################################################
    return t1 * t2 * jnp.exp(to_exp)


vmapped_gaussian = vmap(gaussian_pdf_v, in_axes=(0, None, None))

@jax.jit
def gaussian_pdf(x_mat, mu_mat, Sigma) -> np.array:
    k = mu_mat.shape[0]
    t1 = (2 * jnp.pi) ** (-k / 2)
    t2 = jnp.linalg.det(Sigma) ** (-0.5)
    inv = jnp.linalg.inv(Sigma)
    diff = x_mat - mu_mat
    ###############################################################
    to_exp = -0.5 * jnp.sum(diff @ inv * diff, axis=1)
    ###############################################################
    return t1 * t2 * jnp.exp(to_exp)

In [3]:
print(means.shape, sigma.shape, X.shape)

(300,) (300, 300) (20000, 300)


In [4]:
with jax.profiler.trace("/tmp/tensorboard/gauss"):
    for i in range(10):
        with jax.profiler.TraceAnnotation("manually-optimized"):
            normal_res = gaussian_pdf(X, means, sigma)
        
        with jax.profiler.TraceAnnotation("vmapped_res"):
            vmap_gauss_res = vmapped_gaussian(X, means, sigma)
            


In [5]:
!tensorboard --logdir=/tmp/tensorboard/gauss

  pid, fd = os.forkpty()


Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.16.2 at http://localhost:6006/ (Press CTRL+C to quit)
E0525 08:47:40.217046 13807677440 _internal.py:97] 127.0.0.1 - - [25/May/2024 08:47:40] code 400, message Bad HTTP/0.9 request type ('\x16\x03\x01\x02\x8c\x01\x00\x02\x88\x03\x03´§Üp_l<.Ï¨h\x9aÇá\x9a\x99ðãd\x96u6\x14h2tï_\x06\x0f4}')
Illegal Content-Security-Policy for script-src: 'unsafe-inline'
Illegal Content-Security-Policy for connect-src: data:
Illegal Content-Security-Policy for connect-src: www.gstatic.com
Illegal Content-Security-Policy for script-src-elem: 'unsafe-inline'
Illegal Content-Security-Policy for script-src: 'unsafe-inline'
Illegal Content-Security-Policy for connect-src: data:
Illegal Content-Security-Policy for connect-src: www.gstatic.com
Illegal Content-Security-Policy for script-src-elem: 'unsafe-inline'
I0000 00:00:1716652070.343982 19499005 trace_events.cc:233] Storing 378250 as LevelDb table fast file

## Loop-Unrolling

In [3]:

@jax.jit
def body_fun(_, container):
    return container + 5 / container

@jax.jit
def inner(_, container):
    container = container + container
    return jax.lax.fori_loop(0, 100, body_fun, container)



In [4]:
init_values = jnp.asarray(np.random.rand(5))

@jax.jit
def slow_block():
    container = init_values
    for i in range(10):
        container = container + container
        for j in range(100):
            container = container + 5 / container
    return container

with jax.profiler.trace("/tmp/tensorboard/unroll"):

                
    with jax.profiler.TraceAnnotation("first_call"):
        result = slow_block().block_until_ready()

    with jax.profiler.TraceAnnotation("second_call"):
        result2 = slow_block().block_until_ready()


    with jax.profiler.TraceAnnotation("for-loop-primitive"):
        result3 = jax.lax.fori_loop(0, 10, inner, init_values)


In [None]:
print(jnp.allclose(result, result2), jnp.allclose(result2, result3))

In [None]:
!tensorboard --logdir=/tmp/tensorboard/unroll

  pid, fd = os.forkpty()


Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.16.2 at http://localhost:6006/ (Press CTRL+C to quit)
Illegal Content-Security-Policy for script-src: 'unsafe-inline'
Illegal Content-Security-Policy for connect-src: data:
Illegal Content-Security-Policy for connect-src: www.gstatic.com
Illegal Content-Security-Policy for script-src-elem: 'unsafe-inline'
Illegal Content-Security-Policy for script-src: 'unsafe-inline'
Illegal Content-Security-Policy for connect-src: data:
Illegal Content-Security-Policy for connect-src: www.gstatic.com
Illegal Content-Security-Policy for script-src-elem: 'unsafe-inline'
I0000 00:00:1716654466.316824 19566746 trace_events.cc:233] Storing 1448821 as LevelDb table fast file: /tmp/tensorboard/unroll/plugins/profile/2024_05_25_09_26_24/ians-MacBook-Pro-2.local.SSTABLE with 7877 events dropped.
I0000 00:00:1716654466.443116 19566746 trace_events.cc:345] Loaded 32411 events after filtering 0 events from Le