# Profiling in Jax

## Goals:

- Use the `Perfetto` profiler interface

## Concepts:

-

## Acknowledgments:

This example was based off of [Jax - Neural_Network_and_Data_Loading](https://jax.readthedocs.io/en/latest/notebooks/Neural_Network_and_Data_Loading.html)



In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
import jax
from jax import random
import numpy as np


In [2]:
NUM_SAMPLES = 20_000
NUM_DIMS = 300


means = np.random.rand(NUM_DIMS)
sigma = np.random.rand(NUM_DIMS, NUM_DIMS)
sigma = sigma @ sigma.T
X = np.random.multivariate_normal(mean=means, cov=sigma, size=NUM_SAMPLES)


# Arguments implicitly passed in. Done to keep the code cleaner for the example

@jax.jit
def gaussian_pdf_v(x_vec, mu_vec, Sigma):
    k = mu_vec.shape[0]
    t1 = (2 * jnp.pi) ** (-k / 2)
    t2 = jnp.linalg.det(Sigma) ** (-0.5)
    inv = jnp.linalg.inv(Sigma)
    diff = x_vec - mu_vec
    ###############################################################
    to_exp = -0.5 * diff @ inv @ diff
    ###############################################################
    return t1 * t2 * jnp.exp(to_exp)


vmapped_gaussian = vmap(gaussian_pdf_v, in_axes=(0, None, None))

@jax.jit
def gaussian_pdf(x_mat, mu_mat, Sigma) -> np.array:
    k = mu_mat.shape[0]
    t1 = (2 * jnp.pi) ** (-k / 2)
    t2 = jnp.linalg.det(Sigma) ** (-0.5)
    inv = jnp.linalg.inv(Sigma)
    diff = x_mat - mu_mat
    ###############################################################
    to_exp = -0.5 * jnp.sum(diff @ inv * diff, axis=1)
    ###############################################################
    return t1 * t2 * jnp.exp(to_exp)

In [3]:
print(means.shape, sigma.shape, X.shape)

(300,) (300, 300) (20000, 300)


In [4]:
with jax.profiler.trace("/tmp/tensorboard/gauss"):
    for i in range(10):
        with jax.profiler.TraceAnnotation("manually-optimized"):
            normal_res = gaussian_pdf(X, means, sigma)
        
        with jax.profiler.TraceAnnotation("vmapped_res"):
            vmap_gauss_res = vmapped_gaussian(X, means, sigma)
            


2024-05-24 15:57:27.342782: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [5]:
!tensorboard --logdir=/tmp/tensorboard/gauss

  pid, fd = os.forkpty()


2024-05-24 15:57:31.410821: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.16.2 at http://localhost:6006/ (Press CTRL+C to quit)
Illegal Content-Security-Policy for script-src: 'unsafe-inline'
Illegal Content-Security-Policy for connect-src: data:
Illegal Content-Security-Policy for connect-src: www.gstatic.com
Illegal Content-Security-Policy for script-src-elem: 'unsafe-inline'


In [2]:
with jax.profiler.trace("/tmp/tensorboard/unroll"):
    @jax.jit
    def slow_block():
        container = jnp.ones(10)
        for i in range(10):
            container = container + container
            for j in range(100):
                container = container + container ** 2
        return container
                
    with jax.profiler.TraceAnnotation("loop-unrolled"):
        result = slow_block().block_until_ready()

    with jax.profiler.TraceAnnotation("post-unroll"):
        result = slow_block().block_until_ready()

2024-05-24 16:12:09.934481: W external/xla/xla/service/gpu/nvptx_compiler.cc:760] The NVIDIA driver's CUDA version is 12.2 which is older than the ptxas CUDA version (12.4.131). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [None]:
!tensorboard --logdir=/tmp/tensorboard/unroll

  pid, fd = os.forkpty()


2024-05-24 16:12:13.465097: W tensorflow/core/common_runtime/gpu/gpu_device.cc:2251] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...

NOTE: Using experimental fast data loading logic. To disable, pass
    "--load_fast=false" and report issues on GitHub. More details:
    https://github.com/tensorflow/tensorboard/issues/4784

Serving TensorBoard on localhost; to expose to the network, use a proxy or pass --bind_all
TensorBoard 2.16.2 at http://localhost:6006/ (Press CTRL+C to quit)
Illegal Content-Security-Policy for script-src: 'unsafe-inline'
Illegal Content-Security-Policy for connect-src: data:
Illegal Content-Security-Policy for connect-src: www.gstatic.com
Illegal Content-Security-Policy for script-src-elem: 'unsafe-inline'
