In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# run for Jax GPU
import numpy as np
from agginc.jax import agginc, human_readable_dict
from sampler_gclusters import sampler_gclusters
from jax import random
import jax.numpy as jnp
X = jnp.array([1, 2])
X.device()

StreamExecutorGpuDevice(id=0, process_index=0, slice_index=0)

In [3]:
# generate data
key = random.PRNGKey(0)
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(5000, 100))
Y = 0.5 * X + random.uniform(subkeys[1], shape=(5000, 100))

### hsicagginc <br>
### sample from Uniform distributions

In [4]:
# compile function
output, dictionary = agginc("hsic", X, Y, R=200, return_dictionary=True)

# Numpy version (no compilation)
# output, dictionary = agginc("hsic", np.array(X), np.array(Y), return_dictionary=True)

In [None]:
# Now the function runs fast for any inputs X and Y of the compiled shaped (5000, 100)
# If the shape is changed, the function will need to be compiled again
key = random.PRNGKey(1) # different initialisation
key, subkey = random.split(key)
subkeys = random.split(subkey, num=2)
X = random.uniform(subkeys[0], shape=(5000, 100))
Y = random.uniform(subkeys[1], shape=(5000, 100)) + 0.5 * X 
# see speed.ipynb for detailed speed comparision between numpy, jax cpu and jax gpu 
%timeit output, dictionary = agginc("hsic", X, Y, R=200, return_dictionary=True)

In [6]:
output, dictionary = agginc("hsic", X, Y, R=200, return_dictionary=True)

In [7]:
# output is a jax array consisting of either 0 or 1
output

Array(1, dtype=int32)

In [8]:
# to convert it to an int use: 
output.item()

1

In [9]:
# human_readable_dict(dictionary) # use to convert jax arrays to scalars
# dictionary

### hsicagginc <br>
### sample from three Gaussian clusters

In [10]:
# generate data
key = random.PRNGKey(0)
key, subkey = random.split(key)

N = 500
L = 10 
theta = jnp.pi / 12
# means
means = [
    jnp.array([0,0]),
    jnp.array([L * jnp.cos(theta), L * jnp.sin(theta)]),
    jnp.array([(L / 2) * jnp.cos(theta) - (L * jnp.sqrt(3) / 2) * jnp.sin(theta), (L / 2) * jnp.sin(theta) + (L * jnp.sqrt(3) / 2) * jnp.cos(theta)])  
]
# Covariance matrix
var = 1
cov = jnp.array([[var,   0], 
                 [  0, var]])

X, Y = sampler_gclusters(key, L, theta, means, cov, N, permute=False)
print(X.shape)
print(Y.shape)

(500, 2)
(500, 2)


In [11]:
# compile function
output, dictionary = agginc("hsic", X, Y, R=200, return_dictionary=True)

# Numpy version (no compilation)
# output, dictionary = agginc("hsic", np.array(X), np.array(Y), return_dictionary=True)

In [18]:
# Now the function runs fast for any inputs X and Y of the compiled shaped (500, 10)
# If the shape is changed, the function will need to be compiled again
key = random.PRNGKey(1)
key, subkey = random.split(key)

N = 500
L = 10
theta = jnp.pi / 12
means = [
    jnp.array([0,0]),
    jnp.array([L * jnp.cos(theta), L * jnp.sin(theta)]),
    jnp.array([(L / 2) * jnp.cos(theta) - (L * jnp.sqrt(3) / 2) * jnp.sin(theta), (L / 2) * jnp.sin(theta) + (L * jnp.sqrt(3) / 2) * jnp.cos(theta)])  
]
# Covariance matrix
var = 1
cov = jnp.array([[var,   0], 
                 [  0, var]])

X, Y = sampler_gclusters(key, L, theta, means, cov, N, permute=True)

# see section below for detailed speed comparision between jax cpu and jax gpu 
%timeit output, dictionary = agginc("hsic", X, Y, R=200, return_dictionary=True)

5.86 ms ± 218 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [19]:
key, subkey = random.split(key)
output, dictionary = agginc("hsic", X, Y, R=200, return_dictionary=True)

In [20]:
# output is a jax array consisting of either 0 or 1
output

Array(0, dtype=int32)

In [21]:
# to convert it to an int use: 
output.item()

0

In [17]:
# human_readable_dict(dictionary) # use to convert jax arrays to scalars
# dictionary