In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import numpy as np
from hsicfuse import hsicfuse
from sampler_gclusters import sampler_gclusters
from jax import random
import jax.numpy as jnp
X = jnp.array([1, 2])
X.device()

StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

In [3]:
# generate data
key = random.PRNGKey(0)
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(500, 10))
Y = random.uniform(subkeys[1], shape=(500, 10)) + 1

### hsicfuse <br> 
### sample from Uniform distributions

In [4]:
# compile function
key, subkey = random.split(key)
output, p_value = hsicfuse(X, Y, subkey, return_p_val=True)

In [5]:
# Now the function runs fast for any inputs X and Y of the compiled shaped (500, 10)
# If the shape is changed, the function will need to be compiled again
key = random.PRNGKey(1) # different initialisation
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(500, 10))
Y = random.uniform(subkeys[1], shape=(500, 10)) + X
# see section below for detailed speed comparision between jax cpu and jax gpu 
%timeit output, p_value = hsicfuse(X, Y, subkey, return_p_val=True)

2.18 s ± 6.13 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
output, p_value = hsicfuse(X, Y, subkey, return_p_val=True)

In [7]:
# output is a jax array consisting of either 0 or 1
output

Array(1, dtype=int32)

In [8]:
# to convert it to an int use: 
output.item()

1

In [9]:
# p-value is a jax array
p_value

Array(0.00049975, dtype=float32)

In [10]:
# to convert it to an int use: 
p_value.item()

0.0004997501382604241

### hsicfuse <br> 
### sample from three Gaussian clusters

In [16]:
# generate data
key = random.PRNGKey(0)
key, subkey = random.split(key)

N = 500
L = 10 
theta = jnp.pi / 12
# means
means = [
    jnp.array([0,0]),
    jnp.array([L * jnp.cos(theta), L * jnp.sin(theta)]),
    jnp.array([(L / 2) * jnp.cos(theta) - (L * jnp.sqrt(3) / 2) * jnp.sin(theta), (L / 2) * jnp.sin(theta) + (L * jnp.sqrt(3) / 2) * jnp.cos(theta)])  
]
# Covariance matrix
var = 1
cov = jnp.array([[var,   0], 
                 [  0, var]])

X, Y = sampler_gclusters(key, L, theta, means, cov, N, permute=False)

In [17]:
# compile function
key, subkey = random.split(key)
output, p_value = hsicfuse(X, Y, subkey, return_p_val=True)

In [24]:
# Now the function runs fast for any inputs X and Y of the compiled shaped (500, 10)
# If the shape is changed, the function will need to be compiled again
key = random.PRNGKey(1)
key, subkey = random.split(key)

N = 500
L = 10
theta = jnp.pi / 12
means = [
    jnp.array([0,0]),
    jnp.array([L * jnp.cos(theta), L * jnp.sin(theta)]),
    jnp.array([(L / 2) * jnp.cos(theta) - (L * jnp.sqrt(3) / 2) * jnp.sin(theta), (L / 2) * jnp.sin(theta) + (L * jnp.sqrt(3) / 2) * jnp.cos(theta)])  
]
# Covariance matrix
var = 1
cov = jnp.array([[var,   0], 
                 [  0, var]])

X, Y = sampler_gclusters(key, L, theta, means, cov, N, permute=True)

# see section below for detailed speed comparision between jax cpu and jax gpu 
%timeit output, p_value = hsicfuse(X, Y, subkey, return_p_val=True)

984 ms ± 18.9 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [25]:
output, p_value = hsicfuse(X, Y, subkey, return_p_val=True)

In [26]:
# output is a jax array consisting of either 0 or 1
output

Array(0, dtype=int32)

In [27]:
# to convert it to an int use: 
output.item()

0

In [28]:
# p-value is a jax array
p_value

Array(0.24787606, dtype=float32)

In [29]:
# to convert it to an int use: 
p_value.item()

0.24787606298923492