Setting CPU or GPU as a device in JAX

https://jax.readthedocs.io/en/latest/faq.html#controlling-data-and-computation-placement-on-devices

https://github.com/google/jax/issues/2851

In [1]:
#import os

#os.environ["JAX_PLATFORM_NAME"] = "cpu"

In [2]:
import jax as J
import jax.numpy as jnp

In [3]:
cpus  = J.devices("cpu")
cpus

[<jaxlib.xla_extension.Device at 0x1b64d3d28b0>]

In [4]:
x_cpu = J.device_put(jnp.ones(shape=(1000, 5000)), cpus[0])

In [6]:
x_cpu.device_buffer.device()

<jaxlib.xla_extension.Device at 0x1b64d3d28b0>

In [7]:
gpus = J.devices("gpu")

x_gpu = J.device_put(jnp.ones(shape=(100, 50)), gpus[0])

x_gpu.device_buffer.device()

GpuDevice(id=0, process_index=0)

In [9]:
from jax import jit

def f(x):
    return jnp.sin(x) + jnp.cos(x)

In [10]:
f_cpu = J.jit(f, device=cpus[0])
f_gpu = J.jit(f, device=gpus[0])

In [11]:
%timeit f_cpu(x_cpu)

10.3 ms ± 133 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [12]:
%timeit f_cpu(x_cpu)

10.4 ms ± 22.9 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [13]:
%timeit f_gpu(x_gpu)

43.9 µs ± 567 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [14]:
%timeit f_gpu(x_gpu)

44.1 µs ± 1.01 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
