In [1]:
import jax
import jax.numpy as jnp
print("Using jax", jax.__version__)

Using jax 0.4.23


In [2]:
# create tensor that consists of zero 
a = jnp.zeros((1,3))

In [3]:
print(a, a.dtype)

# change the dtype
a = a.astype(jnp.int32)
print(a, a.dtype)

[[0. 0. 0.]] float32
[[0 0 0]] int32


In [4]:
# check the device
# the array a is already natively on a GPU although we did not specify this explicitly as you would do in PyTorch
print(a.devices())

{cuda(id=0)}


In [5]:
# jax.device_get: DeviceArray -> numpy array
# jax.device_put: numpy array -> DeviceArray
a_cpu = jax.device_get(a)
print(a_cpu.__class__)

a_gpu = jax.device_put(a_cpu)
print(f'Device put: {a_gpu.__class__} on {a_gpu.devices()}')

<class 'numpy.ndarray'>
Device put: <class 'jaxlib.xla_extension.ArrayImpl'> on {cuda(id=0)}


In [11]:
# numpy array + DeviceArray = DeviceArray
# We can operate between tensors that are on difference devices when using JAX
print((a_cpu + a_gpu).__class__, (a_cpu + a_gpu).devices())

<class 'jaxlib.xla_extension.ArrayImpl'> {cuda(id=0)}
