# Demo

In [1]:
# run for Jax GPU
import numpy as np
from mmdagg.jax import human_readable_dict
from mmdagg.jax import mmdagg
# from mmdagg.np import mmdagg
from jax import random
import jax.numpy as jnp
X = jnp.array([1, 2])
X.device()

In [2]:
# generate data
key = random.PRNGKey(0)
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(5000, 100))
Y = random.uniform(subkeys[1], shape=(5000, 100)) + 1

In [3]:
# compile function
output, dictionary = mmdagg(X, Y, return_dictionary=True)
# Numpy version (no compilation)
# output, dictionary = agginc("mmd", np.array(X), np.array(Y), return_dictionary=True)

In [4]:
# Now the function runs fast for any inputs X and Y of the compiled shaped (5000, 100)
# If the shape is changed, the function will need to be compiled again
key = random.PRNGKey(1) # different initialisation
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(5000, 100))
Y = random.uniform(subkeys[1], shape=(5000, 100)) + 1
# see speed.ipynb for detailed speed comparision between numpy, jax cpu and jax gpu 
%timeit output, dictionary = mmdagg(X, Y, return_dictionary=True)

493 ms ± 3.54 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
output, dictionary = mmdagg(X, Y, return_dictionary=True)

In [5]:
# output is a jax array consisting of either 0 or 1
output

Array(1, dtype=int32)


In [6]:
# to convert it to an int use: 
output.item()

1


In [7]:
human_readable_dict(dictionary) # use to convert jax arrays to scalars
dictionary

{'MMDAgg test reject': True,
 'Single test 1.1': {'Bandwidth': 41.81019592285156,
  'Kernel laplace': True,
  'MMD': 0.7198492288589478,
  'MMD quantile': 0.00036044654552824795,
  'Reject': True,
  'p-value': 0.0004997501382604241,
  'p-value threshold': 0.061969008296728134},
 'Single test 1.10': {'Bandwidth': 232.77978515625,
  'Kernel laplace': True,
  'MMD': 0.43212345242500305,
  'MMD quantile': 0.00021708192070946097,
  'Reject': True,
  'p-value': 0.0004997501382604241,
  'p-value threshold': 0.061969008296728134},
 'Single test 1.2': {'Bandwidth': 50.5980339050293,
  'Kernel laplace': True,
  'MMD': 0.7591387033462524,
  'MMD quantile': 0.0003822108847089112,
  'Reject': True,
  'p-value': 0.0004997501382604241,
  'p-value threshold': 0.061969008296728134},
 'Single test 1.3': {'Bandwidth': 61.232940673828125,
  'Kernel laplace': True,
  'MMD': 0.7708603143692017,
  'MMD quantile': 0.00038869131822139025,
  'Reject': True,
  'p-value': 0.0004997501382604241,
  'p-value thresho

# Speed comparison

## Imports

Run only one of three next cells depending on whether to use Numpy CPU, Jax CPU or Jax GPU.

The CPU used is an AMD Ryzen Threadripper 3960X 24 Cores 128Gb RAM at 3.8GHz.

The GPU used is an NVIDIA RTX A5000 24Gb.

In [1]:
# run for Numpy CPU
import numpy as np
from mmdagg.np import mmdagg
from jax import random
import jax.numpy as jnp
%load_ext autoreload
%autoreload 2

In [1]:
# run for Jax CPU
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import numpy as np
from mmdagg.jax import mmdagg
from jax import random
import jax.numpy as jnp
%load_ext autoreload
%autoreload 2
X = jnp.array([1, 2])
X.device()

2023-02-07 12:18:06.087432: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


CpuDevice(id=0)

In [1]:
# run for Jax GPU
import numpy as np
from mmdagg.jax import mmdagg
from jax import random
import jax.numpy as jnp
%load_ext autoreload
%autoreload 2
X = jnp.array([1, 2])
X.device()

StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

## MMDAgg Runtimes

In [2]:
key = random.PRNGKey(0)
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(5000, 100))
Y = random.uniform(subkeys[1], shape=(5000, 100)) + 1

In [None]:
# run for Jax CPU and Jax GPU to compile the function
# Do not run for Numpy CPU
output, dictionary = mmdagg(X, Y, return_dictionary=True)

In [2]:
# Numpy CPU
%timeit mmdagg(np.array(X), np.array(Y), return_dictionary=True)

43.1 s ± 69.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [2]:
# Jax CPU
%timeit mmdagg(X, Y, return_dictionary=True)

14.9 s ± 51.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [2]:
# Jax GPU
%timeit mmdagg(X, Y, return_dictionary=True) 

495 ms ± 4.56 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
