# Quick Benchmarks of jax.numpy 

This is not designed to be an exhaustive list, or all edge cases. I just wanted to highlight somne of the speedups possible using jax.numpy as a drop in replacement for numpy. Installation is [here](https://jax.readthedocs.io/en/latest/installation.html). I found it very easy to install on a M2 Mac. 

pip install --upgrade "jax[cpu]"

##### These tests were all done locally, on Apple M2 hardware (CPU only)

Feel free to reach out at tmartin at ucar dot edu if you have more questions

In [1]:
import numpy as np
import jax
import jax.numpy as jnp
import jax.scipy.linalg as jsl

Some global variables

In [2]:
num_randoms=1e7
matrix_size = (10000,10000)

num_loops = 10

# numpy

In [3]:
%%timeit

for i in range(num_loops):
    random_numbers = np.random.uniform(low=-10, high=10, size=int(num_randoms))
    
    # numpy functions
    a = np.sin(random_numbers)
    b = np.abs(random_numbers)
    c = np.arctan(random_numbers)
    d = np.add(np.add(a, b), c)

2.36 s ± 6.16 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [4]:
%%timeit

# some matrix math
random_matrix = np.random.uniform(low=-10, high=10, size=matrix_size)
inv_matrix = np.linalg.inv(random_matrix)
det_matrix = np.linalg.det(random_matrix)

  r = _umath_linalg.det(a, signature=signature)


11.9 s ± 542 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [5]:
%%timeit

ptp_matrix = np.random.uniform(low=-10, high=100, size=(1000000,500))
peaks = np.ptp(ptp_matrix, axis=1)

2.43 s ± 24.4 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# jax

In [6]:
%%timeit

for i in range(num_loops):
    random_numbers = jax.random.uniform(jax.random.PRNGKey(0), shape=(int(num_randoms),), minval=-10, maxval=10)
    
    # jax functions
    a = jnp.sin(random_numbers)
    b = jnp.abs(random_numbers)
    c = jnp.arctan(random_numbers)
    d = jnp.add(jnp.add(a, b), c)

623 ms ± 10.1 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [7]:
%%timeit

random_mat_jax = jax.random.uniform(jax.random.PRNGKey(0), shape=(matrix_size), minval=-10, maxval=10)
inv_matrix = jsl.inv(random_mat_jax)
det_matrix = jsl.det(random_mat_jax)

6.2 s ± 110 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [8]:
%%timeit

ptp_matrix_jax = jax.random.uniform(jax.random.PRNGKey(0), shape=(1000000,500), minval=-10, maxval=100)
peaks_jax = jnp.ptp(ptp_matrix_jax, axis=1)

1.11 s ± 12.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
