In [71]:
import jax
import jax.numpy as jnp

In [72]:
arr = jnp.array([1, 2])
arr

Array([1, 2], dtype=int32)

In [73]:
type(arr)

jaxlib.xla_extension.ArrayImpl

In [74]:
arr /= jnp.sum(arr)
type(arr)

jaxlib.xla_extension.ArrayImpl

In [75]:
arr.is_fully_addressable, arr.sharding.device_set

(True, {cuda(id=0)})

In [76]:
jax.local_devices()

[cuda(id=0)]

# Computation follows data

1. Computations are executed on the same device where the input data resides.

## Uncommitted data

data can moved to operation site

In [77]:
arr1 = jnp.array([1])
arr1.devices()

{cuda(id=0)}

In [78]:
with jax.default_device(jax.devices("cpu")[0]):
    arr2 = jnp.array([2])

arr2.devices()

{CpuDevice(id=0)}

In [79]:
uncommitted_res1 = arr1 + arr2
uncommitted_res1.devices()

{cuda(id=0)}

In [80]:
arr2.devices()

{CpuDevice(id=0)}

In [81]:
uncommitted_res2 = arr2 + arr1
uncommitted_res2.devices()

{cuda(id=0)}

## Committed data

data cant be moved

In [82]:
arr3 = jax.device_put(1, device=jax.devices("gpu")[0])
arr3.devices()

{cuda(id=0)}

In [83]:
arr4 = jax.device_put(2, device=jax.devices("cpu")[0])
arr4.devices()

{CpuDevice(id=0)}

In [84]:
try:
    committed_res = arr3 + arr4
except ValueError as e:
    print(e)

Received incompatible devices for jitted computation. Got argument x1 of jax.numpy.add with shape int32[] and device ids [0] on platform GPU and argument x2 of jax.numpy.add with shape int32[] and device ids [0] on platform CPU


In [85]:
arr5 = jax.device_put(1, device=jax.devices("gpu")[0])
arr5.devices()

{cuda(id=0)}

In [86]:
arr6 = jax.device_put(2, device=jax.devices("gpu")[0])
arr6.devices()

{cuda(id=0)}

In [87]:
arr5 + arr6

Array(3, dtype=int32, weak_type=True)

## device.put

jax.device_put() function creates a copy of your data on the specified device and returns it. The original data is unchanged.

In [88]:
with jax.default_device(jax.devices("cpu")[0]):
    arr7 = jnp.array([[1, 2], [3, 4]])
arr8 = jax.device_put(arr7, device=jax.devices("gpu")[0])
arr7.devices(), arr8.devices()

({CpuDevice(id=0)}, {cuda(id=0)})

In [89]:
arr7.at[1, 1]

_IndexUpdateRef(Array([[1, 2],
       [3, 4]], dtype=int32), (1, 1))

In [90]:
arr7 = arr7.at[1, 1].set(100)

In [91]:
arr7

Array([[  1,   2],
       [  3, 100]], dtype=int32)

In [92]:
arr8

Array([[1, 2],
       [3, 4]], dtype=int32)

In [93]:
arr8.devices()

{cuda(id=0)}

In [94]:
arr8.addressable_data(0)

Array([[1, 2],
       [3, 4]], dtype=int32)

In [95]:
arr8.addressable_data(0).devices()

{cuda(id=0)}

## Async

In [96]:
with jax.default_device(jax.devices("cpu")[0]):
    arr_cpu = jax.random.normal(jax.random.PRNGKey(0), (10000, 10000))

arr_cpu.devices()

{CpuDevice(id=0)}

In [97]:
%time x = jnp.dot(arr_cpu,arr_cpu)

CPU times: user 83.3 ms, sys: 160 ms, total: 243 ms
Wall time: 214 ms


In [98]:
%time x = jnp.dot(arr_cpu,arr_cpu).block_until_ready()

CPU times: user 69.6 ms, sys: 200 ms, total: 269 ms
Wall time: 262 ms


In [99]:
arr_gpu = jax.random.normal(jax.random.PRNGKey(0), (10000, 10000))
arr_gpu.devices()

{cuda(id=0)}

In [100]:
%time y = jnp.dot(arr_gpu,arr_gpu)

CPU times: user 0 ns, sys: 500 µs, total: 500 µs
Wall time: 301 µs


In [101]:
%time y = jnp.dot(arr_gpu,arr_gpu).block_until_ready()

CPU times: user 433 µs, sys: 280 µs, total: 713 µs
Wall time: 52.5 ms


# Immutable

In [102]:
arr = jnp.arange(5.0)
arr

Array([0., 1., 2., 3., 4.], dtype=float32)

In [103]:
try:
    arr[0] = 2
except TypeError as e:
    print(e)

'<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html


In [104]:
arr.at[0].set(100)
# ? Return new copy

Array([100.,   1.,   2.,   3.,   4.], dtype=float32)

In [105]:
arr

Array([0., 1., 2., 3., 4.], dtype=float32)

In [106]:
[attr for attr in dir(arr.at[0]) if not attr.startswith("__")]

['add',
 'apply',
 'array',
 'divide',
 'get',
 'index',
 'max',
 'min',
 'mul',
 'multiply',
 'power',
 'set']

In [107]:
arr.at[0].array

Array([0., 1., 2., 3., 4.], dtype=float32)

## Update in numpy and jax

Unlike NumPy in-place operations such as x[idx] += y, if multiple indices refer to the same location, all updates will be applied (NumPy would only apply the last update, rather than applying all updates.)

In [108]:
import numpy as np

In [109]:
np_arr = np.arange(10)
arr = jnp.arange(10)
np_arr, arr

(array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
 Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32))

In [110]:
np_arr[np.arange(4)] *= 10
arr = arr.at[jnp.arange(4)].multiply(10)
np_arr, arr

(array([ 0, 10, 20, 30,  4,  5,  6,  7,  8,  9]),
 Array([ 0, 10, 20, 30,  4,  5,  6,  7,  8,  9], dtype=int32))

In [111]:
np_arr[np.full(4, 2)] *= 10
arr = arr.at[jnp.full(4, 2)].multiply(10)
np_arr, arr

(array([  0,  10, 200,  30,   4,   5,   6,   7,   8,   9]),
 Array([     0,     10, 200000,     30,      4,      5,      6,      7,
             8,      9], dtype=int32))

## Mode in array indexing

out of bound don't throws error in jax

In [113]:
arr = jnp.arange(10)
arr

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [115]:
arr.at[20].get()  # ? get out bound indices is clipped
arr.at[20].set(30)  # ? set out bound indices is dropped

In [119]:
arr.at[20].get(mode="clip")

Array(9, dtype=int32)

In [120]:
arr.at[20].get(mode="drop")

Array(-2147483648, dtype=int32)

In [122]:
arr.at[20].get(mode="fill", fill_value=-1)

Array(-1, dtype=int32)

# Types

In [125]:
jnp.array(1.0, dtype=jnp.float64)
# ? jax make only float 32
# * for float 64 enable the x64 flag

  jnp.array(1.0,dtype=jnp.float64)


Array(1., dtype=float32)

## Weak types

In [138]:
arr = jnp.arange(3, dtype=jnp.int32)
np_arr = np.arange(3, dtype=np.int32)
np_arr.dtype, arr.dtype

(dtype('int32'), dtype('int32'))

## with python scalar

In [149]:
np_res = np.int16(1) + 1  #! promoted to int64
jnp_res = jnp.int16(1) + 1
np_res.dtype, jnp_res.dtype

(dtype('int64'), dtype('int16'))

In [150]:
jnp.int16(1) + np.array(1)

Array(2, dtype=int32)

In [152]:
np.int16(1) + 1 + jnp.int16(1)

Array(3, dtype=int32)

In [153]:
jnp.int16(1) + 1 + np.int16(1)

Array(3, dtype=int16)