In [1]:
import jax
import numpy as np
import jax.numpy as jnp
from jax import jit, grad
import jaxtest
import sampling
import timeit

jax_array = jnp.zeros((10000, 10000))
np_array = np.zeros((10000, 10000))

def np_add(s, t, a):
    a[s, t] = 1
    a[t, s] = 1

def np_sub(s, t, a):
    a[s, t] = 0
    a[t, s] = 0
    
def jx_add(s, t, a):
    a.at[s, t].set(1)
    a.at[t, s].set(1)
    return a
def jx_sub(s, t, a):
    a.at[s, t].set(0)
    a.at[t, s].set(0)
    return a

#Test the inplace numpy updates
print("Numpy times")
%time np_add(0, 1, np_array)
%timeit np_add(0, 1, np_array)

#Test the time to add the jax_array to the device

print("Uncompiled JAX times")
%time jx_add(0, 1, jax_array)
%timeit jx_add(0, 1, jax_array)

jx_add = jit(jx_add)
jx_sub = jit(jx_sub)



Original Array
[[0 1 0 1]
 [1 0 1 0]
 [0 1 0 0]
 [1 0 0 0]]
Updated Array
[[0 1 0 1]
 [1 0 1 0]
 [0 1 0 1]
 [1 0 1 0]]
Updated after sub
[[0 0 0 1]
 [0 0 0 0]
 [0 0 0 0]
 [1 0 0 0]]
Numpy times
CPU times: user 8 µs, sys: 7 µs, total: 15 µs
Wall time: 19.1 µs
386 ns ± 16.8 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)
Uncompiled JAX times
CPU times: user 331 ms, sys: 389 ms, total: 720 ms
Wall time: 942 ms
486 ms ± 16.8 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [3]:
print("JAX compile time")
%time jx_add(0, 1, jax_array).block_until_ready()
print("JAX runtime")
%timeit jx_add(0, 1, jax_array).block_until_ready()

JAX compile time
CPU times: user 113 ms, sys: 99.7 ms, total: 213 ms
Wall time: 222 ms
JAX runtime
262 ms ± 46.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
sampling.A.sum(axis=1)

array([1, 1, 2, 1, 0, 1, 0, 0, 0, 1])

In [1]:
import sampling
import numpy as np
A=sampling.A
for i in range(len(A)):
    print(sampling.dgr(i, A))

#A = np.ones((10000, 10000), dtype=np.int32)

1
1
2
1
0
1
0
0
0
1


In [11]:
%timeit sampling.dgrs(A)

76.6 ms ± 1.29 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [14]:
j_dgrs = jit(sampling.dgrs)
j_dgr = jit(sampling.dgr)

In [15]:
j_dgrs(A)
j_dgr

<CompiledFunction at 0x7fc5c6177a90>

In [12]:
%timeit j_dgrs(A).block_until_ready()

36.9 ms ± 1e+03 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
%timeit sampling.dgr(0, A)

10.4 µs ± 246 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [16]:
%timeit j_dgr(0, A).block_until_ready()

16.2 µs ± 380 ns per loop (mean ± std. dev. of 7 runs, 100000 loops each)


In [4]:
np.tril(sampling.G, -1)

array([[0., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 0., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 0., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 0., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 0., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 0., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 0., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 0., 0.],
       [1., 1., 1., 1., 1., 1., 1., 1., 1., 0.]])