In [1]:
from jax.lib import xla_bridge
xla_bridge.get_backend().platform

'gpu'

In [2]:
import jax
print(jax.devices(),jax.devices("cpu"),jax.devices("gpu"))

[cuda(id=0)] [CpuDevice(id=0)] [cuda(id=0)]


In [3]:
jax.process_count(),jax.process_index()

(1, 0)

In [4]:
import numpy as np
import jax.numpy as jnp

In [6]:
def f(x):
    y1 = x + x * x + 3
    y2 = x*x + x*x.T
    return y1*y2


x = np.random.randn(3000,3000).astype("float32")
jax_x_cpu = jax.device_put(jnp.array(x),device=jax.devices("cpu")[0])
jax_x_gpu = jax.device_put(jnp.array(x),device=jax.devices("gpu")[0])

jax_f_cpu = jax.jit(f,backend="cpu")
jax_f_gpu = jax.jit(f,backend="gpu")

jax_f_cpu(jax_x_cpu)
jax_f_gpu(jax_x_gpu);

In [9]:
%timeit -n100 f(x)

30.4 ms ± 217 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
%timeit -n100 f(jax_x_cpu).block_until_ready()

30.8 ms ± 828 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
%timeit -n100 jax_f_cpu(jax_x_cpu).block_until_ready()

7.46 ms ± 167 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [15]:
%timeit -n100 f(jax_x_gpu).block_until_ready()

971 µs ± 552 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [14]:
%timeit -n100 jax_f_gpu(jax_x_gpu).block_until_ready()

71.6 µs ± 11.3 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
