# Speed comparison

## Imports

Run only one of three next cells depending on whether to use Numpy CPU, Jax CPU or Jax GPU.

The CPU used is a AMD Ryzen Threadripper 3960X 24 Cores 128Gb RAM at 3.8GHz.

The GPU used is an NVIDIA RTX A5000 24Gb.

In [1]:
# run for Numpy CPU
import numpy as np
from agginc.np import agginc
from jax import random
import jax.numpy as jnp
%load_ext autoreload
%autoreload 2

In [1]:
# run for Jax CPU
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import numpy as np
from agginc.jax import agginc
from jax import random
import jax.numpy as jnp
%load_ext autoreload
%autoreload 2
X = jnp.array([1, 2])
X.device()

2023-02-07 12:18:06.087432: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


CpuDevice(id=0)

In [1]:
# run for Jax GPU
import numpy as np
from agginc.jax import agginc
from jax import random
import jax.numpy as jnp
%load_ext autoreload
%autoreload 2
X = jnp.array([1, 2])
X.device()

StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

## MMDAggInc

In [2]:
key = random.PRNGKey(0)
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(5000, 100))
Y = random.uniform(subkeys[1], shape=(5000, 100)) + 2

In [None]:
# run for Jax CPU and Jax GPU to compile the function
# Do not run for Numpy CPU
output, dictionary = agginc("mmd", X, Y, return_dictionary=True)

In [4]:
# Numpy CPU
%timeit agginc("mmd", np.array(X), np.array(Y), return_dictionary=True) 

4.49 s ± 21.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
# Jax CPU
%timeit agginc("mmd", X, Y, return_dictionary=True) 

844 ms ± 6.78 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
# Jax GPU
%timeit agginc("mmd", X, Y, return_dictionary=True) 

23 ms ± 295 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


## HSICAggInc

In [2]:
key = random.PRNGKey(0)
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(5000, 100))
Y = random.uniform(subkeys[1], shape=(5000, 100)) + 2

In [None]:
# run for Jax CPU and Jax GPU to compile the function
# Do not run for Numpy CPU
output, dictionary = agginc("hsic", X, Y, return_dictionary=True)

In [4]:
# Numpy CPU
%timeit agginc("hsic", np.array(X), np.array(Y), return_dictionary=True)

2.82 s ± 3.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
# Jax CPU
%timeit agginc("hsic", X, Y, return_dictionary=True)

539 ms ± 6.58 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [11]:
# Jax GPU
%timeit agginc("hsic", X, Y, return_dictionary=True) 

18 ms ± 10.5 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## KSDAggInc

In [2]:
perturbation = 0.5
rs = np.random.RandomState(0)
X = rs.gamma(5 + perturbation, 5, (5000, 1))
score_gamma = lambda x, k, theta : (k - 1) / x - 1 / theta
score_X = score_gamma(X, 5, 5)
X = jnp.array(X)
score_X = jnp.array(score_X)

In [3]:
# run for Jax CPU and Jax GPU to compile the function
# Do not run for Numpy CPU
output, dictionary = agginc("ksd", X, score_X, return_dictionary=True)

In [6]:
# Numpy CPU
%timeit agginc("ksd", np.array(X), np.array(score_X), return_dictionary=True) 

3.77 s ± 18.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
# Jax CPU
%timeit agginc("ksd", X, score_X, return_dictionary=True) 

590 ms ± 10.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
# Jax GPU
%timeit agginc("ksd", X, score_X, return_dictionary=True) 

21.5 ms ± 33.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
