# Demo

In [1]:
import numpy as np
from mmdfuse import mmdfuse
from jax import random
import jax.numpy as jnp
X = jnp.array([1, 2])
X.device()

StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

In [2]:
# generate data
key = random.PRNGKey(0)
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(500, 10))
Y = random.uniform(subkeys[1], shape=(500, 10)) + 1

In [3]:
# compile function
key, subkey = random.split(key)
output, p_value = mmdfuse(X, Y, subkey, return_p_val=True)

In [4]:
# Now the function runs fast for any inputs X and Y of the compiled shaped (500, 10)
# If the shape is changed, the function will need to be compiled again
key = random.PRNGKey(1) # different initialisation
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(500, 10))
Y = random.uniform(subkeys[1], shape=(500, 10)) + 1
# see section below for detailed speed comparision between jax cpu and jax gpu 
%timeit output, p_value = mmdfuse(X, Y, subkey, return_p_val=True)

53.3 ms ± 45.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
output, p_value = mmdfuse(X, Y, subkey, return_p_val=True)

In [6]:
# output is a jax array consisting of either 0 or 1
output

Array(1, dtype=int32)

In [7]:
# to convert it to an int use: 
output.item()

1

In [8]:
# p-value is a jax array
p_value

Array(0.00049975, dtype=float32)

In [9]:
# to convert it to an int use: 
p_value.item()

0.0004997501382604241

# Speed comparison

## Imports

Run only one of three next cells depending on whether to use Jax CPU or Jax GPU.

In [1]:
# run for Jax CPU
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import numpy as np
from mmdfuse import mmdfuse
from jax import random
import jax.numpy as jnp
%load_ext autoreload
%autoreload 2
X = jnp.array([1, 2])
X.device()

2023-05-23 14:30:48.087927: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


CpuDevice(id=0)

In [1]:
# run for Jax GPU
import numpy as np
from mmdfuse import mmdfuse
from jax import random
import jax.numpy as jnp
%load_ext autoreload
%autoreload 2
X = jnp.array([1, 2])
X.device()

StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

## MMD-FUSE Runtimes

In [2]:
key = random.PRNGKey(0)
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(500, 10))
Y = random.uniform(subkeys[1], shape=(500, 10)) + 1

In [3]:
# run for Jax CPU and Jax GPU to compile the function
# Do not run for Numpy CPU
output, p_value = mmdfuse(X, Y, subkey, return_p_val=True)

In [4]:
# Jax CPU
%timeit mmdfuse(X, Y, subkey, return_p_val=True)

2.95 s ± 105 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
# Jax GPU
%timeit mmdfuse(X, Y, subkey, return_p_val=True)

54 ms ± 23.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
