In [1]:
import inspect

In [2]:
import jax
import numpy as np
import jax.numpy as jnp
from jax import random, jit, lax, vmap, pmap, grad, value_and_grad

In [3]:
jax.__version__

'0.4.1'

In [4]:
rng = np.random.default_rng()
key = jax.random.PRNGKey(0)

In [5]:
ary = jnp.array([1, 2, 3, 4])
ary

Array([1, 2, 3, 4], dtype=int32)

In [6]:
inspect.getmro(type(ary))

(jaxlib.xla_extension.Array, object)

## Differences between Jax and Numpy

Numpy can sum over regular Python lists, Jax cannot.

In [7]:
np.sum([1, 2, 3, 4])

10

In [8]:
try:
    jnp.sum([1, 2, 3, 4])
except TypeError as err:
    print(err)

sum requires ndarray or scalar arguments, got <class 'list'> at position 0.


Numpy arrays are mutable, Jax are immutable. To change Jax arrays create a new one using the `at` and `set` methods as shown below.

In [9]:
nary = np.arange(5)
nary

array([0, 1, 2, 3, 4])

In [10]:
jary = jnp.arange(5)
jary

Array([0, 1, 2, 3, 4], dtype=int32)

In [11]:
nary[1] = 100
nary

array([  0, 100,   2,   3,   4])

In [12]:
try:
    jary[1] = 100
except TypeError as err:
    print(err)

'<class 'jaxlib.xla_extension.Array'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html


In [13]:
jary2 = jary.at[1].set(100)
jary2

Array([  0, 100,   2,   3,   4], dtype=int32)

In [14]:
jary

Array([0, 1, 2, 3, 4], dtype=int32)

Can also element ops at specific indices to get a new array.

In [15]:
jary3 = jary.at[3].add(10)
jary3

Array([ 0,  1,  2, 13,  4], dtype=int32)

## Op Execution
Numpy arrays work on a single thread (is this really true?). Jax ops will choose an accelerator if it is available and then run the op asynchronously. In order to time the op, we need to either convert the jnp array into a numpy array, which will force the op to run (like `list(it)` in `functools`), or use the `.block_until_ready()` API.

Without these, if we wait for the Jax array to asynchronously complete the op then return the results, the perf is comparable to numpy on both CPU-only and GPU machines.

Forcing the jax array to numpy speeds up the op a lot more, more so on GPU machines. However, there is still the overhead of moving the data from GPU to CPU and converting, so it is a bit slower than the next method.

With the `.block_until_ready()` API, there is no extra work to be done, I am guessing the tensor lives on the GPU. This is the fastest.

Of course all of these speedups are more stark on a GPU machine than on my MacBook.

#### Results
| Array Op | CPU-only Time | GPU Time |
|----------|---------------|----------|
| Numpy matmul | 5.66s | 18.5s |
| Async Jax matmul | 5.03s | 11.9s |
| Convert Jax matmul | 5.25s | 809ms |
| Blocking API Jax matmul | 4.87s | 569s |


In [16]:
nary = rng.normal(size=(10_000, 10_000)).astype(np.float32)
print(nary.shape, nary.dtype)

(10000, 10000) float32


In [17]:
jary = jax.random.normal(key, shape=(10_000, 10_000), dtype=jnp.float32)
print(jary.shape, jary.dtype)

(10000, 10000) float32


In [18]:
%time np.matmul(nary, nary)
print("Done.")

CPU times: user 37.9 s, sys: 812 ms, total: 38.7 s
Wall time: 6.22 s
Done.


In [19]:
%time jnp.matmul(jary, jary)

CPU times: user 33.3 s, sys: 712 ms, total: 34 s
Wall time: 5.55 s


Array([[ -90.649445 ,   13.972144 ,  -95.644165 , ...,   23.926046 ,
         133.53964  ,   53.143925 ],
       [  44.89999  ,  -33.37181  ,   94.96751  , ..., -100.38489  ,
         -56.939598 , -217.22066  ],
       [ -45.11518  , -185.57526  , -189.5532   , ..., -213.17264  ,
          11.18911  ,   18.810167 ],
       ...,
       [  57.29258  ,   89.31357  ,   96.97245  , ...,  -26.702862 ,
          32.321487 ,  159.42407  ],
       [  92.361145 ,   29.874641 ,  -63.73034  , ...,   41.16841  ,
         154.73874  ,   85.76169  ],
       [   7.0634165,   81.06201  ,  145.52756  , ...,   73.708145 ,
         -43.597656 ,  -10.067655 ]], dtype=float32)

In [20]:
%time np.asarray(jnp.matmul(jary, jary))

CPU times: user 34.1 s, sys: 810 ms, total: 34.9 s
Wall time: 6.1 s


array([[ -90.649445 ,   13.972144 ,  -95.644165 , ...,   23.926046 ,
         133.53964  ,   53.143925 ],
       [  44.89999  ,  -33.37181  ,   94.96751  , ..., -100.38489  ,
         -56.939598 , -217.22066  ],
       [ -45.11518  , -185.57526  , -189.5532   , ..., -213.17264  ,
          11.18911  ,   18.810167 ],
       ...,
       [  57.29258  ,   89.31357  ,   96.97245  , ...,  -26.702862 ,
          32.321487 ,  159.42407  ],
       [  92.361145 ,   29.874641 ,  -63.73034  , ...,   41.16841  ,
         154.73874  ,   85.76169  ],
       [   7.0634165,   81.06201  ,  145.52756  , ...,   73.708145 ,
         -43.597656 ,  -10.067655 ]], dtype=float32)

In [21]:
%time jnp.matmul(jary, jary).block_until_ready()

CPU times: user 35.4 s, sys: 723 ms, total: 36.1 s
Wall time: 5.72 s


Array([[ -90.649445 ,   13.972144 ,  -95.644165 , ...,   23.926046 ,
         133.53964  ,   53.143925 ],
       [  44.89999  ,  -33.37181  ,   94.96751  , ..., -100.38489  ,
         -56.939598 , -217.22066  ],
       [ -45.11518  , -185.57526  , -189.5532   , ..., -213.17264  ,
          11.18911  ,   18.810167 ],
       ...,
       [  57.29258  ,   89.31357  ,   96.97245  , ...,  -26.702862 ,
          32.321487 ,  159.42407  ],
       [  92.361145 ,   29.874641 ,  -63.73034  , ...,   41.16841  ,
         154.73874  ,   85.76169  ],
       [   7.0634165,   81.06201  ,  145.52756  , ...,   73.708145 ,
         -43.597656 ,  -10.067655 ]], dtype=float32)