# Demo

In [1]:
import numpy as np
from ksdagg import ksdagg, human_readable_dict # jax version
# from ksdagg.np import ksdagg
from jax import random
import jax.numpy as jnp
%load_ext autoreload
%autoreload 2
X = jnp.array([1, 2])
X.device()

StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

In [3]:
# generate data
perturbation = 0.5
rs = np.random.RandomState(0)
X = rs.gamma(5 + perturbation, 5, (5000, 1))
score_gamma = lambda x, k, theta : (k - 1) / x - 1 / theta
score_X = score_gamma(X, 5, 5)
X = jnp.array(X)
score_X = jnp.array(score_X)

In [4]:
# compile function
output, dictionary = ksdagg(X, score_X, return_dictionary=True)
# Numpy version (no compilation)
# output, dictionary = ksdagg(np.array(X), np.array(score_X), return_dictionary=True)

In [8]:
# Now the function runs fast for any inputs X and score_X of the compiled shaped (5000, 1)
# If the shape is changed, the function will need to be compiled again
perturbation = 0.5
rs = np.random.RandomState(1) # different initialisation
X = rs.gamma(5 + perturbation, 5, (5000, 1))
score_gamma = lambda x, k, theta : (k - 1) / x - 1 / theta
score_X = score_gamma(X, 5, 5)
X = jnp.array(X)
score_X = jnp.array(score_X)
# see Speed Comparison section .ipynb for detailed speed comparision between numpy, jax cpu and jax gpu 
%timeit output, dictionary = ksdagg(X, score_X, return_dictionary=True)

In [9]:
%timeit output, dictionary = ksdagg(X, score_X, return_dictionary=True)

65.4 ms ± 271 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [10]:
output, dictionary = ksdagg(X, score_X, return_dictionary=True)

In [11]:
output

Array(1, dtype=int32)

In [12]:
output.item()

1

In [14]:
human_readable_dict(dictionary) # use to convert jax arrays to scalars
dictionary

{'KSDAgg test reject': True,
 'Single test 1': {'Bandwidth': 1.0,
  'KSD': 0.00012235062604304403,
  'KSD quantile': 9.613344445824623e-05,
  'Kernel IMQ': True,
  'Reject': True,
  'p-value': 0.004997501149773598,
  'p-value threshold': 0.011993973515927792},
 'Single test 10': {'Bandwidth': 96.67843627929688,
  'KSD': 0.0004186165751889348,
  'KSD quantile': 9.040639270097017e-06,
  'Kernel IMQ': True,
  'Reject': True,
  'p-value': 0.0004997501382604241,
  'p-value threshold': 0.011993973515927792},
 'Single test 2': {'Bandwidth': 1.661851406097412,
  'KSD': 0.00012240404612384737,
  'KSD quantile': 4.81025199405849e-05,
  'Kernel IMQ': True,
  'Reject': True,
  'p-value': 0.0004997501382604241,
  'p-value threshold': 0.011993973515927792},
 'Single test 3': {'Bandwidth': 2.7617499828338623,
  'KSD': 0.00015008114860393107,
  'KSD quantile': 2.4249460693681613e-05,
  'Kernel IMQ': True,
  'Reject': True,
  'p-value': 0.0004997501382604241,
  'p-value threshold': 0.011993973515927792

# Speed comparison

## Imports

Run only one of three next cells depending on whether to use Numpy CPU, Jax CPU or Jax GPU.

The CPU used is a AMD Ryzen Threadripper 3960X 24 Cores 128Gb RAM at 3.8GHz.

The GPU used is an NVIDIA RTX A5000 24Gb.

In [1]:
# run for Numpy CPU
import numpy as np
from ksdagg.np import ksdagg
from jax import random
import jax.numpy as jnp
%load_ext autoreload
%autoreload 2

In [1]:
# run for Jax CPU
import os
os.environ['CUDA_VISIBLE_DEVICES'] = ''
import numpy as np
from ksdagg.jax import ksdagg
from jax import random
import jax.numpy as jnp
%load_ext autoreload
%autoreload 2
X = jnp.array([1, 2])
X.device()

2023-02-09 18:42:13.293411: E external/org_tensorflow/tensorflow/compiler/xla/stream_executor/cuda/cuda_driver.cc:266] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


CpuDevice(id=0)

In [1]:
# run for Jax GPU
import numpy as np
from ksdagg.jax import ksdagg
from jax import random
import jax.numpy as jnp
%load_ext autoreload
%autoreload 2
X = jnp.array([1, 2])
X.device()

StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

## KSDAgg runtimes

In [5]:
perturbation = 0.5
rs = np.random.RandomState(0)
X = rs.gamma(5 + perturbation, 5, (5000, 1))
score_gamma = lambda x, k, theta : (k - 1) / x - 1 / theta
score_X = score_gamma(X, 5, 5)
X = jnp.array(X)
score_X = jnp.array(score_X)

In [3]:
# run for Jax CPU and Jax GPU to compile the function
# Do not run for Numpy CPU
output, dictionary = ksdagg(X, score_X, return_dictionary=True)

In [7]:
# Numpy CPU
%timeit ksdagg(np.array(X), np.array(score_X), return_dictionary=True) 

12.5 s ± 40.5 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
# Jax CPU
%timeit ksdagg(X, score_X, return_dictionary=True)

1.47 s ± 11.7 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
# Jax GPU
%timeit ksdagg(X, score_X, return_dictionary=True) 

21.5 ms ± 33.7 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
