In [1]:
%load_ext autoreload
%autoreload 2

In [4]:
import numpy as np
from hsic import hsic, human_readable_dict
from sampler_gclusters import sampler_gclusters
from jax import random
import jax.numpy as jnp
X = jnp.array([1, 2])
X.device()

StreamExecutorGpuDevice(id=1, process_index=0, slice_index=0)

In [5]:
# generate data
key = random.PRNGKey(0)
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(500, 10))
Y = random.uniform(subkeys[1], shape=(500, 10)) + 1

### hsic <br>
### sample from Uniform distributions

In [6]:
# compile function
key, subkey = random.split(key)
output, dictionary = hsic(X, Y, subkey, return_dictionary=True)

In [7]:
# Now the function runs fast for any inputs X and Y of the compiled shaped (500, 10)
# If the shape is changed, the function will need to be compiled again
key = random.PRNGKey(1) # different initialisation
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(500, 10))
Y = random.uniform(subkeys[1], shape=(500, 10))
# see section below for detailed speed comparision between jax cpu and jax gpu 
%timeit output, dictionary = hsic(X, Y, subkey, return_dictionary=True)

20.7 ms ± 872 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [6]:
key, subkey = random.split(key)
output, dictionary = hsic(X, Y, subkey, return_dictionary=True)

In [7]:
# output is a jax array consisting of either 0 or 1
output

Array(1, dtype=int32)

In [8]:
# to convert it to an int use: 
output.item()

1

In [9]:
human_readable_dict(dictionary) # use to convert jax arrays to scalars
dictionary

{'Bandwidth': 1,
 'HSIC': 0.0560261607170105,
 'HSIC quantile': 0.05598384141921997,
 'HSIC test reject': True,
 'Kernel gaussian': True,
 'p-value': 0.03048475831747055,
 'p-value threshold': 0.05000000074505806}

### hsic <br>
### sample from three Gaussian clusters

In [10]:
# generate data
key = random.PRNGKey(0)
key, subkey = random.split(key)

N = 500
L = 10 
theta = jnp.pi / 12
# means
means = [
    jnp.array([0,0]),
    jnp.array([L * jnp.cos(theta), L * jnp.sin(theta)]),
    jnp.array([(L / 2) * jnp.cos(theta) - (L * jnp.sqrt(3) / 2) * jnp.sin(theta), (L / 2) * jnp.sin(theta) + (L * jnp.sqrt(3) / 2) * jnp.cos(theta)])  
]
# Covariance matrix
var = 1
cov = jnp.array([[var,   0], 
                 [  0, var]])

X, Y = sampler_gclusters(key, L, theta, means, cov, N, permute=False)

In [11]:
# compile function
key, subkey = random.split(key)
output, dictionary = hsic(X, Y, subkey, return_dictionary=True)

In [12]:
# Now the function runs fast for any inputs X and Y of the compiled shaped (500, 10)
# If the shape is changed, the function will need to be compiled again
key = random.PRNGKey(1)
key, subkey = random.split(key)

N = 500
L = 10
theta = jnp.pi / 12
means = [
    jnp.array([0,0]),
    jnp.array([L * jnp.cos(theta), L * jnp.sin(theta)]),
    jnp.array([(L / 2) * jnp.cos(theta) - (L * jnp.sqrt(3) / 2) * jnp.sin(theta), (L / 2) * jnp.sin(theta) + (L * jnp.sqrt(3) / 2) * jnp.cos(theta)])  
]
# Covariance matrix
var = 1
cov = jnp.array([[var,   0], 
                 [  0, var]])

X, Y = sampler_gclusters(key, L, theta, means, cov, N, permute=False)

# see section below for detailed speed comparision between jax cpu and jax gpu 
%timeit output, dictionary = hsic(X, Y, subkey, return_dictionary=True)

21.9 ms ± 797 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
key, subkey = random.split(key)
output, dictionary = hsic(X, Y, subkey, return_dictionary=True)

In [14]:
# output is a jax array consisting of either 0 or 1
output

Array(1, dtype=int32)

In [15]:
# to convert it to an int use: 
output.item()

1

In [16]:
human_readable_dict(dictionary) # use to convert jax arrays to scalars
dictionary

{'Bandwidth': 1,
 'HSIC': 0.014159763231873512,
 'HSIC quantile': 0.00620127422735095,
 'HSIC test reject': True,
 'Kernel gaussian': True,
 'p-value': 0.0004997501382604241,
 'p-value threshold': 0.05000000074505806}