# Demo

In [1]:
# run for Jax GPU
import numpy as np
from agginc.jax import human_readable_dict
from agginc.jax import agginc
# from agginc.np import agginc
from jax import random
import jax.numpy as jnp
X = jnp.array([1, 2])
X.device()

StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

## MMDAggInc

In [2]:
key = random.PRNGKey(0)
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(5000, 100))
Y = random.uniform(subkeys[1], shape=(5000, 100)) + 2

In [3]:
# compile function
output, dictionary = agginc("mmd", X, Y, return_dictionary=True)
# Numpy version (no compilation)
# output, dictionary = agginc("mmd", np.array(X), np.array(Y), return_dictionary=True)

In [4]:
# Now the function runs fast for any inputs X and Y of the compiled shaped (5000, 100)
# If the shape is changed, the function will need to be compiled again
key = random.PRNGKey(1) # different initialisation
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(5000, 100))
Y = random.uniform(subkeys[1], shape=(5000, 100)) + 2
# see speed.ipynb for detailed speed comparision between numpy, jax cpu and jax gpu 
%timeit output, dictionary = agginc("mmd", X, Y, return_dictionary=True)

22.9 ms ± 17.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [5]:
output, dictionary = agginc("mmd", X, Y, return_dictionary=True)

In [6]:
# output is a jax array consisting of either 0 or 1
output 

Array(1, dtype=int32)

In [7]:
# to convert it to an int use: 
output.item()

1

In [8]:
human_readable_dict(dictionary) # use to convert jax arrays to scalars
dictionary

{'MMDAggInc test reject': True,
 'Single test 1': {'Bandwidth': 9.370773315429688,
  'Kernel Gaussian': True,
  'MMD': 1.6377350091934204,
  'MMD quantile': 0.002928956877440214,
  'Reject': True,
  'p-value': 0.0019960079807788134,
  'p-value threshold': 0.037924136966466904},
 'Single test 10': {'Bandwidth': 43.93260955810547,
  'Kernel Gaussian': True,
  'MMD': 0.3712867200374603,
  'MMD quantile': 0.0006533270934596658,
  'Reject': True,
  'p-value': 0.0019960079807788134,
  'p-value threshold': 0.037924136966466904},
 'Single test 2': {'Bandwidth': 11.125825881958008,
  'Kernel Gaussian': True,
  'MMD': 1.6792677640914917,
  'MMD quantile': 0.0029786520171910524,
  'Reject': True,
  'p-value': 0.0019960079807788134,
  'p-value threshold': 0.037924136966466904},
 'Single test 3': {'Bandwidth': 13.209583282470703,
  'Kernel Gaussian': True,
  'MMD': 1.6340610980987549,
  'MMD quantile': 0.0028756719548255205,
  'Reject': True,
  'p-value': 0.0019960079807788134,
  'p-value threshold

## HSICAggInc

In [9]:
key = random.PRNGKey(0)
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(5000, 100))
Y = random.uniform(subkeys[1], shape=(5000, 100)) + 2

In [10]:
# compile function
output, dictionary = agginc("hsic", X, Y, return_dictionary=True)
# Numpy version (no compilation)
# output, dictionary = agginc("hsic", np.array(X), np.array(Y), return_dictionary=True)

In [11]:
# Now the function runs fast for any inputs X and Y of the compiled shaped (5000, 100)
# If the shape is changed, the function will need to be compiled again
key = random.PRNGKey(1) # different initialisation
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(5000, 100))
Y = random.uniform(subkeys[1], shape=(5000, 100)) + 2
# see speed.ipynb for detailed speed comparision between numpy, jax cpu and jax gpu 
%timeit output, dictionary = agginc("hsic", X, Y, return_dictionary=True)

18 ms ± 6.97 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
output, dictionary = agginc("hsic", X, Y, return_dictionary=True)

In [13]:
output

Array(0, dtype=int32)

In [14]:
output.item()

0

In [15]:
human_readable_dict(dictionary) # use to convert jax arrays to scalars
dictionary

{'HSICAggInc test reject': False,
 'Single test 1': {'Bandwidth X': 1.019727349281311,
  'Bandwidth Y': 1.016467571258545,
  'HSIC': -2.967349537743788e-14,
  'HSIC quantile': 5.4557981667580774e-14,
  'Kernel Gaussian': True,
  'Reject': False,
  'p-value': 0.9201596975326538,
  'p-value threshold': 0.011976033449172974},
 'Single test 10': {'Bandwidth X': 16.315637588500977,
  'Bandwidth Y': 2.03293514251709,
  'HSIC': 6.821696274528222e-08,
  'HSIC quantile': 6.499256528513797e-07,
  'Kernel Gaussian': True,
  'Reject': False,
  'p-value': 0.3772455155849457,
  'p-value threshold': 0.011976033449172974},
 'Single test 11': {'Bandwidth X': 1.019727349281311,
  'Bandwidth Y': 4.06587028503418,
  'HSIC': -2.3380247737847526e-10,
  'HSIC quantile': 9.816001345086534e-10,
  'Kernel Gaussian': True,
  'Reject': False,
  'p-value': 0.6946107745170593,
  'p-value threshold': 0.011976033449172974},
 'Single test 12': {'Bandwidth X': 2.039454698562622,
  'Bandwidth Y': 4.06587028503418,
  'HS

## KSDAggInc

In [16]:
perturbation = 0.5
rs = np.random.RandomState(0)
X = rs.gamma(5 + perturbation, 5, (5000, 1))
score_gamma = lambda x, k, theta : (k - 1) / x - 1 / theta
score_X = score_gamma(X, 5, 5)
X = jnp.array(X)
score_X = jnp.array(score_X)

In [17]:
# compile function
output, dictionary = agginc("ksd", X, score_X, return_dictionary=True)
# Numpy version (no compilation)
# output, dictionary = agginc("ksd", np.array(X), np.array(score_X), return_dictionary=True)

In [18]:
# Now the function runs fast for any inputs X and Y of the compiled shaped (5000, 100)
# If the shape is changed, the function will need to be compiled again
perturbation = 0.5
rs = np.random.RandomState(1) # different initialisation
X = rs.gamma(5 + perturbation, 5, (5000, 1))
score_gamma = lambda x, k, theta : (k - 1) / x - 1 / theta
score_X = score_gamma(X, 5, 5)
X = jnp.array(X)
score_X = jnp.array(score_X)
# see speed.ipynb for detailed speed comparision between numpy, jax cpu and jax gpu 
%timeit output, dictionary = agginc("ksd", X, score_X, return_dictionary=True)

21.5 ms ± 31.9 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [19]:
output, dictionary = agginc("ksd", X, score_X, return_dictionary=True)

In [20]:
output

Array(1, dtype=int32)

In [21]:
output.item()

1

In [22]:
human_readable_dict(dictionary) # use to convert jax arrays to scalars
dictionary

{'KSDAggInc test reject': True,
 'Single test 1': {'Bandwidth': 1.0,
  'KSD': 4.0504713979316875e-05,
  'KSD quantile': 0.00029450401780195534,
  'Kernel IMQ': True,
  'Reject': False,
  'p-value': 0.38522952795028687,
  'p-value threshold': 0.017964035272598267},
 'Single test 10': {'Bandwidth': 96.67843627929688,
  'KSD': 4.25097505285521e-06,
  'KSD quantile': 2.747115956935886e-07,
  'Kernel IMQ': True,
  'Reject': True,
  'p-value': 0.0019960079807788134,
  'p-value threshold': 0.017964035272598267},
 'Single test 2': {'Bandwidth': 1.661851406097412,
  'KSD': 4.388344314065762e-05,
  'KSD quantile': 8.466501458315179e-05,
  'Kernel IMQ': True,
  'Reject': False,
  'p-value': 0.113772451877594,
  'p-value threshold': 0.017964035272598267},
 'Single test 3': {'Bandwidth': 2.7617499828338623,
  'KSD': 4.6408065827563405e-05,
  'KSD quantile': 2.3729142412776127e-05,
  'Kernel IMQ': True,
  'Reject': True,
  'p-value': 0.0019960079807788134,
  'p-value threshold': 0.017964035272598267