In [1]:
import jax
import numpy as np
import jax.numpy as jnp
from jax import jit, grad
import jaxtest
import sampling
import timeit

jax_array = jnp.zeros((10000, 10000))
np_array = np.zeros((10000, 10000))

def np_add(s, t, a):
    a[s, t] = 1
    a[t, s] = 1

def np_sub(s, t, a):
    a[s, t] = 0
    a[t, s] = 0
    
def jx_add(s, t, a):
    a.at[s, t].set(1)
    a.at[t, s].set(1)
    return a
def jx_sub(s, t, a):
    a.at[s, t].set(0)
    a.at[t, s].set(0)
    return a

#Test the inplace numpy updates
print("Numpy times")
%time np_add(0, 1, np_array)
%timeit np_add(0, 1, np_array)

#Test the time to add the jax_array to the device

print("Uncompiled JAX times")
%time jx_add(0, 1, jax_array)
%timeit jx_add(0, 1, jax_array)

jx_add = jit(jx_add)
jx_sub = jit(jx_sub)



Original Array
[[0 1 0 1]
 [1 0 1 0]
 [0 1 0 0]
 [1 0 0 0]]
Updated Array
[[0 1 0 1]
 [1 0 1 0]
 [0 1 0 1]
 [1 0 1 0]]
Updated after sub
[[0 0 0 1]
 [0 0 0 0]
 [0 0 0 0]
 [1 0 0 0]]
Numpy times
CPU times: user 13 µs, sys: 4 µs, total: 17 µs
Wall time: 26.7 µs
379 ns ± 0.287 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
Uncompiled JAX times
CPU times: user 228 ms, sys: 352 ms, total: 580 ms
Wall time: 569 ms
512 ms ± 9.02 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [2]:
print("JAX compile time")
%time jx_add(0, 1, jax_array).block_until_ready()
print("JAX runtime")
%timeit jx_add(0, 1, jax_array).block_until_ready()

JAX compile time
CPU times: user 114 ms, sys: 149 ms, total: 263 ms
Wall time: 258 ms
JAX runtime
251 ms ± 42.6 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [3]:
sampling.A.sum(axis=1)

array([1, 1, 2, 1, 0, 1, 0, 0, 0, 1])

In [4]:
import sampling
import numpy as np
A=sampling.A
for i in range(len(A)):
    print(sampling.dgr(i, A))

#A = np.ones((10000, 10000), dtype=np.int32)

1
1
2
1
0
1
0
0
0
1


In [5]:
%timeit sampling.dgrs(A)

2.31 µs ± 148 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [6]:
j_dgrs = jit(sampling.dgrs)
j_dgr = jit(sampling.dgr)

In [7]:
j_dgrs(A)
j_dgr

<CompiledFunction at 0x7fa2c00814a0>

In [8]:
%timeit j_dgrs(A).block_until_ready()

4.97 µs ± 18.2 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [9]:
%timeit sampling.dgr(0, A)

1.81 µs ± 14.3 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [10]:
%timeit j_dgr(0, A).block_until_ready()

5.79 µs ± 7.88 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [11]:
np.tril(sampling.G, -1)

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.]])

In [40]:
d = 5
e = int(d*(d-1)*0.5)
key = jax.random.PRNGKey(129837129873)
a = jax.random.uniform(key, shape=(d, d))
a = (a + a.T)/2

In [41]:
a

DeviceArray([[0.44665873, 0.58879507, 0.46223068, 0.5640855 , 0.75001204],
             [0.58879507, 0.3331939 , 0.32173544, 0.73361325, 0.7287323 ],
             [0.46223068, 0.32173544, 0.53922486, 0.5217623 , 0.42365187],
             [0.5640855 , 0.73361325, 0.5217623 , 0.22474968, 0.19031078],
             [0.75001204, 0.7287323 , 0.42365187, 0.19031078, 0.14194787]],            dtype=float32)

In [50]:
np.array([-3]) + jnp.sqrt(4)

DeviceArray([-1.], dtype=float32)