In [1]:
import jax
jax.__version__

'0.8.0'

## DeviceArray properties

In [2]:
import numpy as np
import jax.numpy as jnp

In [3]:
np.array([1, 42, 31337])

array([    1,    42, 31337])

In [4]:
jnp.array([1, 42, 31337])

Array([    1,    42, 31337], dtype=int32)

In [5]:
np.sum([1, 42, 31337])

np.int64(31380)

In [6]:
try:
  jnp.sum([1, 42, 31337])
except TypeError as e:
  print(e)

sum requires ndarray or scalar arguments, got <class 'list'> at position 0.


In [7]:
jnp.sum(jnp.array([1, 42, 31337]))

Array(31380, dtype=int32)

In [8]:
arr = jnp.array([1, 42, 31337])

In [9]:
arr.ndim

1

In [10]:
arr.shape

(3,)

In [11]:
arr.dtype

dtype('int32')

In [12]:
arr.size

3

In [13]:
arr.nbytes

12

## Devices

In [14]:
import jax

In [15]:
jax.devices()

[CudaDevice(id=0)]

In [16]:
jax.devices('cpu')

[CpuDevice(id=0)]

In [17]:
jax.device_count('gpu')

1

In [18]:
jax.local_devices()

[CudaDevice(id=0)]

In [19]:
arr = jnp.array([1, 42, 31337])

In [20]:
arr.device

CudaDevice(id=0)

In [21]:
arr_cpu = jax.device_put(arr, jax.devices('cpu')[0])

In [22]:
arr_cpu.device

CpuDevice(id=0)

In [23]:
arr.device

CudaDevice(id=0)

In [24]:
arr_host = jax.device_get(arr)

In [25]:
type(arr_host)

numpy.ndarray

In [26]:
arr_host

array([    1,    42, 31337], dtype=int32)

In [27]:
arr + arr_cpu

Array([    2,    84, 62674], dtype=int32)

In [28]:
arr_gpu = jax.device_put(arr, jax.devices('gpu')[0])

In [29]:
arr_gpu.device

CudaDevice(id=0)

In [30]:
try:
  arr_gpu + arr_cpu
except ValueError as e:
  print(e)

Received incompatible devices for jitted computation. Got argument x of add with shape int32[3] and device ids [0] on platform GPU and argument y of add with shape int32[3] and device ids [0] on platform CPU


## Asyncronous dispatch

In [31]:
import jax

In [32]:
a = jnp.array(range(1000000)).reshape((1000,1000))

In [33]:
a.shape

(1000, 1000)

In [34]:
a.device

CudaDevice(id=0)

In [35]:
%time x = jnp.dot(a,a)

CPU times: user 70.9 ms, sys: 494 μs, total: 71.4 ms
Wall time: 82.4 ms


In [36]:
%time x = jnp.dot(a,a).block_until_ready()

CPU times: user 368 μs, sys: 223 μs, total: 591 μs
Wall time: 1.4 ms


In [37]:
%time x = np.asarray(jnp.dot(a,a))

CPU times: user 1.55 ms, sys: 559 μs, total: 2.11 ms
Wall time: 2.92 ms


In [38]:
a_cpu = jax.device_put(a, jax.devices('cpu')[0])

In [39]:
a_cpu.device

CpuDevice(id=0)

In [40]:
%time x = jnp.dot(a_cpu,a_cpu).block_until_ready()

CPU times: user 104 ms, sys: 18.6 ms, total: 122 ms
Wall time: 9.58 ms


## GPU diagnostics

In [41]:
!nvcc --version

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2025 NVIDIA Corporation
Built on Fri_Feb_21_20:23:50_PST_2025
Cuda compilation tools, release 12.8, V12.8.93
Build cuda_12.8.r12.8/compiler.35583870_0


  pid, fd = os.forkpty()


In [42]:
#!pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_releases.html

In [43]:
gpu_info = !nvidia-smi
gpu_info = '\n'.join(gpu_info)
if gpu_info.find('failed') >= 0:
  print('Select the Runtime > "Change runtime type" menu to enable a GPU accelerator, ')
  print('and then re-execute this cell.')
else:
  print(gpu_info)

Wed Dec 10 12:45:24 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.105.08             Driver Version: 580.105.08     CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA GeForce RTX 4080        On  |   00000000:01:00.0 Off |                  N/A |
|  0%   40C    P2             33W /  400W |   12231MiB /  16376MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

## Immutability

In [44]:
import numpy as np
import jax.numpy as jnp

In [45]:
a_jnp = jnp.array(range(10))
a_np  = np.array(range(10))

In [46]:
a_jnp

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [47]:
a_np[5], a_jnp[5]

(np.int64(5), Array(5, dtype=int32))

In [48]:
a_np[5] = 100

In [49]:
a_np[5]

np.int64(100)

In [50]:
try:
  a_jnp[5] = 100
except TypeError as e:
  print(e)

JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html


In [51]:
a_jnp = a_jnp.at[5].set(100)

In [52]:
a_jnp[5]

Array(100, dtype=int32)

In [53]:
a_jnp[42]

Array(9, dtype=int32)

In [54]:
a_jnp.at[42].get()

Array(9, dtype=int32)

In [55]:
a_jnp.at[42].get(mode='clip')

Array(9, dtype=int32)

In [56]:
a_jnp.at[42].get(mode='drop')

Array(-2147483648, dtype=int32)

In [57]:
a_jnp.at[42].get(mode='fill', fill_value=-1)

Array(-1, dtype=int32)

In [58]:
a_jnp = a_jnp.at[42].set(100)
a_jnp

Array([  0,   1,   2,   3,   4, 100,   6,   7,   8,   9], dtype=int32)

In [59]:
a_jnp = a_jnp.at[42].set(100, mode='clip')
a_jnp

Array([  0,   1,   2,   3,   4, 100,   6,   7,   8, 100], dtype=int32)

## Working with float64

In [60]:
# this only works on startup!
from jax import config
config.update("jax_enable_x64", True)

In [61]:
import jax
import jax.numpy as jnp

In [62]:
# this will not work on TPU backend. Try using CPU or GPU.
x = jnp.array(range(10), dtype=jnp.float64)
x.dtype

dtype('float64')

In [63]:
x.device

CudaDevice(id=0)

In [64]:
xc = jax.device_put(x, jax.devices('cpu')[0])

In [65]:
xc.device

CpuDevice(id=0)

In [66]:
xc.dtype

dtype('float64')

In [67]:
xb16 = jnp.array(range(10), dtype=jnp.bfloat16)
xb16.dtype

dtype(bfloat16)

In [68]:
xb16.nbytes

20

In [69]:
x16 = jnp.array(range(10), dtype=jnp.float16)
x16.dtype

dtype('float16')

In [70]:
x16.nbytes

20

In [71]:
xb16+x16

Array([ 0.,  2.,  4.,  6.,  8., 10., 12., 14., 16., 18.], dtype=float32)

In [72]:
xb16+xb16

Array([0, 2, 4, 6, 8, 10, 12, 14, 16, 18], dtype=bfloat16)

## jax.numpy & jax.lax

In [73]:
config.update("jax_enable_x64", False)

In [74]:
import jax.numpy as jnp
from jax import lax
from jax import random

In [75]:
jnp.add(42, 42.0)

Array(84., dtype=float32, weak_type=True)

In [76]:
jnp.add(42.0, 42.0)

Array(84., dtype=float32, weak_type=True)

In [77]:
try:
   lax.add(42, 42.0)
except TypeError as e:
  print(e)

lax.add requires arguments to have the same dtypes, got int32, float32. (Tip: jnp.add is a similar function that does automatic type promotion on inputs).


In [78]:
lax.add(jnp.float32(42), 42.0)

Array(84., dtype=float32)

In [79]:
def random_augmentation(image, augmentations, rng_key):
   '''A function that applies a random transformation to an image'''
   augmentation_index = random.randint(key=rng_key, minval=0, maxval=len(augmentations), shape=())
   augmented_image = lax.switch(augmentation_index, augmentations, image)
   return augmented_image

In [80]:
add_noise_func = lambda x: x+10
horizontal_flip_func = lambda x: x+1
rotate_func = lambda x: x+2
adjust_colors_func = lambda x: x+3

augmentations = [
   add_noise_func,
   horizontal_flip_func,
   rotate_func,
   adjust_colors_func
]


In [81]:
image = jnp.array(range(100))

In [82]:
random_augmentation(image, augmentations, random.PRNGKey(211))

Array([  1,   2,   3,   4,   5,   6,   7,   8,   9,  10,  11,  12,  13,
        14,  15,  16,  17,  18,  19,  20,  21,  22,  23,  24,  25,  26,
        27,  28,  29,  30,  31,  32,  33,  34,  35,  36,  37,  38,  39,
        40,  41,  42,  43,  44,  45,  46,  47,  48,  49,  50,  51,  52,
        53,  54,  55,  56,  57,  58,  59,  60,  61,  62,  63,  64,  65,
        66,  67,  68,  69,  70,  71,  72,  73,  74,  75,  76,  77,  78,
        79,  80,  81,  82,  83,  84,  85,  86,  87,  88,  89,  90,  91,
        92,  93,  94,  95,  96,  97,  98,  99, 100], dtype=int32)