In [9]:
import jax

import numpy as np
from jax import numpy as jnp

In [2]:
np.sum([1, 42, 31337])

np.int64(31380)

In [3]:
try:
    jnp.sum([1, 42, 31337])
except TypeError as e:
    print(e)

sum requires ndarray or scalar arguments, got <class 'list'> at position 0.


In [4]:
jnp.sum(jnp.array([1, 42, 31337]))

2024-10-08 12:36:10.571917: W external/xla/xla/service/gpu/nvptx_compiler.cc:930] The NVIDIA driver's CUDA version is 12.3 which is older than the PTX compiler version 12.6.77. Because the driver is older than the PTX compiler version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


Array(31380, dtype=int32)

In [7]:
arr = jnp.array([1, 42, 31337])
arr.ndim, arr.shape, arr.dtype, arr.size, arr.nbytes

(1, (3,), dtype('int32'), 3, 12)

In [10]:
jax.local_devices()

[CudaDevice(id=0)]

In [11]:
jax.local_device_count()

1

In [12]:
jax.devices()

[CudaDevice(id=0)]

In [13]:
jax.devices('cpu')

[CpuDevice(id=0)]

In [16]:
jax.device_count('gpu')

1

In [20]:
arr = jnp.array([1, 42, 31337])

In [23]:
jax.device_get(arr)

array([    1,    42, 31337], dtype=int32)

In [26]:
arr_cpu = jax.device_put(arr, jax.devices('cpu')[0])
arr_cpu

Array([    1,    42, 31337], dtype=int32)

In [27]:
arr_host = jax.device_get(arr)
type(arr_host)

numpy.ndarray

In [29]:
arr_host

array([    1,    42, 31337], dtype=int32)

In [31]:
arr = jnp.array([1, 42, 31337])
arr.device

CudaDevice(id=0)

In [32]:
arr_cpu = jax.device_put(arr, jax.devices('cpu')[0])
arr_cpu.device

CpuDevice(id=0)

In [33]:
arr + arr_cpu

Array([    2,    84, 62674], dtype=int32)

In [34]:
arr_gpu = jax.device_put(arr, jax.devices('gpu')[0])

In [35]:
try:
    arr_gpu + arr_cpu
except ValueError as e:
    print(e)

Received incompatible devices for jitted computation. Got argument x of _add with shape int32[3] and device ids [0] on platform GPU and argument y of _add with shape int32[3] and device ids [0] on platform CPU


### Working with asynchronous dispatch

In [36]:
a = jnp.array(range(1_000_000)).reshape((1_000, 1_000))
a.device

CudaDevice(id=0)

In [61]:
%time x = jnp.dot(a, a)

CPU times: user 676 μs, sys: 125 μs, total: 801 μs
Wall time: 858 μs


In [62]:
%time x = jnp.dot(a, a).block_until_ready()

CPU times: user 1.53 ms, sys: 285 μs, total: 1.81 ms
Wall time: 3.67 ms


In [63]:
a_cpu = jax.device_put(a, jax.devices('cpu')[0])
a_cpu.device

CpuDevice(id=0)

In [65]:
%time x = jnp.dot(a_cpu, a_cpu).block_until_ready()

CPU times: user 306 ms, sys: 0 ns, total: 306 ms
Wall time: 11 ms


### Immutabiliity

In [66]:
a_jnp = jnp.array(range(10))
a_np = np.array(range(10))
a_np[5], a_jnp[5]

(np.int64(5), Array(5, dtype=int32))

In [67]:
a_np[5] = 100
a_np[5]

np.int64(100)

In [68]:
try:
    a_jnp[5] = 100
except TypeError as e:
    print(e)

'<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html


In [69]:
a_jnp = a_jnp.at[5].set(100)
a_jnp[5]

Array(100, dtype=int32)

### Out-of-bounds Indexing

In [70]:
a_jnp = jnp.array(range(10))
a_jnp

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [71]:
a_jnp[42]

Array(9, dtype=int32)

In [73]:
a_jnp.at[42].get(mode='drop')

Array(-2147483648, dtype=int32)

In [74]:
a_jnp.at[42].get(mode='fill', fill_value=-1)

Array(-1, dtype=int32)

In [75]:
a_jnp = a_jnp.at[42].set(100)
a_jnp

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [77]:
a_jnp = a_jnp.at[42].set(100, mode='clip')
a_jnp

Array([  0,   1,   2,   3,   4,   5,   6,   7,   8, 100], dtype=int32)

### Precision

In [78]:
xb16 = jnp.array(range(10), dtype=jnp.bfloat16)
xb16.dtype

dtype(bfloat16)

In [79]:
xb16.nbytes

20

In [80]:
x16 = jnp.array(range(10), dtype=jnp.float16)
x16.dtype

dtype('float16')

### Type promotion in `jax.lax`

In [81]:
jnp.add(42, 42.0)

Array(84., dtype=float32, weak_type=True)

In [82]:
from jax import lax

In [83]:
try:
    lax.add(42, 42.0)
except TypeError as e:
    print(e)

ValueError: Cannot lower jaxpr with verifier errors:
	op requires the same element type for all operands and results
		at loc("jit(add)/jit(main)/add"(callsite("<module>"("/tmp/ipykernel_58098/2249353911.py":2:4) at callsite("_run_code"("<frozen runpy>":88:4) at "_run_module_as_main"("<frozen runpy>":198:11)))))
Define JAX_DUMP_IR_TO to dump the module.

In [84]:
lax.add(jnp.float32(42), 42.0)

Array(84., dtype=float32)