# JAX 101


Click the image below to read the post online.

<a target="_blank" href="https://www.machinelearningnuggets.com/what-is-jax
"><img src="https://digitalpress.fra1.cdn.digitaloceanspaces.com/mhujhsj/2022/07/logo.png" alt="Open in ML Nuggets"></a>

JAX is an open source Python package for numerical computation in with accelearators and XLA.

## Installing JAX



pip install jax

### Why use JAX?

- faster than NumPy
- consumes less memory and is convenient to use
- Other data science packages are built on top of it, for example Flax, Haiku etc

## Setting up TPUs on Google Colab

Ensure you change the run time to TPUs

In [1]:
import jax

# import jax.tools.colab_tpu
# jax.tools.colab_tpu.setup_tpu()
jax.devices()

[cuda(id=0), cuda(id=1)]

In [2]:
import jax.numpy as jnp
import numpy as np

# Data Types in JAX

In [3]:
x = jnp.float32(1.25844)

In [4]:
x

Array(1.25844, dtype=float32)

In [5]:
type(x)

jaxlib.xla_extension.ArrayImpl

In [6]:
x = jnp.int32(45.25844)

In [7]:
x

Array(45, dtype=int32)

## Ways to Create JAX NumPy Arrays

A NumPy array is a multidimensional array-like data structure

### np.arange()

In [8]:
jnp.arange(10)

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [9]:
jnp.arange(0, 10)

Array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=int32)

In [10]:
jnp.arange(0, 10, 2)

Array([0, 2, 4, 6, 8], dtype=int32)

### Covert Python List to a NumPy Array

In [11]:
scores = [50, 60, 70, 30, 25, 70]

In [12]:
scores_array = jnp.array(scores)

In [13]:
scores_array

Array([50, 60, 70, 30, 25, 70], dtype=int32)

In [14]:
type(scores_array)

jaxlib.xla_extension.ArrayImpl

In [15]:
scores_array.ndim  # the dimension of the array

1

In [16]:
scores_array.size  # the number of items in the array

6

In [17]:
scores_array.dtype  # the type of data in the array

dtype('int32')

In [18]:
jnp.unique(scores_array)  # print unique items from the array

Array([25, 30, 50, 60, 70], dtype=int32)

In [19]:
scores_array

Array([50, 60, 70, 30, 25, 70], dtype=int32)

In [20]:
scores_array.devices()

{cuda(id=0)}

In [21]:
jnp.flip(scores_array)  # reverse the array

Array([70, 25, 30, 70, 60, 50], dtype=int32)

In [22]:
jnp.sort(scores_array)

Array([25, 30, 50, 60, 70, 70], dtype=int32)

In [23]:
scores_array

Array([50, 60, 70, 30, 25, 70], dtype=int32)

In [24]:
jnp.clip(scores_array, 20, 59)

Array([50, 59, 59, 30, 25, 59], dtype=int32)

# Part Two

### Joining Two Arrays

In [25]:
array_two = jnp.array([90, 26, 37, 77, 65, 55])

In [26]:
scores_array

Array([50, 60, 70, 30, 25, 70], dtype=int32)

In [27]:
concatenated = jnp.concatenate((scores_array, array_two))

jnp.concatenate((concatenated, jnp.array([100, 200.0, 300])))

Array([ 50.,  60.,  70.,  30.,  25.,  70.,  90.,  26.,  37.,  77.,  65.,
        55., 100., 200., 300.], dtype=float32)

### jnp.zeros()

In [28]:
jnp.zeros(5)

Array([0., 0., 0., 0., 0.], dtype=float32)

### jnp.ones()

In [29]:
o1 = jnp.ones(5)
o1, o1.ndim

(Array([1., 1., 1., 1., 1.], dtype=float32), 1)

In [30]:
jnp.eye(5)  # Return a 2-D array with ones on the diagonal
# and zeros elsewhere.

Array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]], dtype=float32)

In [31]:
jnp.identity(5)

Array([[1., 0., 0., 0., 0.],
       [0., 1., 0., 0., 0.],
       [0., 0., 1., 0., 0.],
       [0., 0., 0., 1., 0.],
       [0., 0., 0., 0., 1.]], dtype=float32)

### jnp.linspace()

In [32]:
jnp.linspace(10, 50, 6)  # Return evenly spaced numbers over a
# specified interval.
# start,stop, num=5,

Array([10., 18., 26., 34., 42., 50.], dtype=float32)

In [33]:
jnp.linspace(10, 15, 5)

Array([10.  , 11.25, 12.5 , 13.75, 15.  ], dtype=float32)

## Generating random numbers with JAX

In [34]:
print(np.random.random())
print(np.random.random())
print(np.random.random())

0.978791172571429
0.9898967073552775
0.30208440392693026


In [35]:
seed = 98
key = jax.random.PRNGKey(seed)

In [36]:
key

Array([ 0, 98], dtype=uint32)

In [37]:
jax.random.uniform(key)

Array(0.3756802, dtype=float32)

In [38]:
jax.random.uniform(key)

Array(0.3756802, dtype=float32)

In [39]:
key, subkey = jax.random.split(key)

In [40]:
subkey

Array([3614062411, 3294896607], dtype=uint32)

In [41]:
jax.random.uniform(subkey)

Array(0.95996785, dtype=float32)

In [42]:
jax.random.uniform(subkey)

Array(0.95996785, dtype=float32)

# Part 3

# Checking Documentation

In [43]:
help(jnp.linspace)

Help on function linspace in module jax._src.numpy.lax_numpy:

linspace(start: 'ArrayLike', stop: 'ArrayLike', num: 'int' = 50, endpoint: 'bool' = True, retstep: 'bool' = False, dtype: 'DTypeLike | None' = None, axis: 'int' = 0) -> 'Array | tuple[Array, Array]'
    Return evenly spaced numbers over a specified interval.
    
    LAX-backend implementation of :func:`numpy.linspace`.
    
    *Original docstring below.*
    
    Returns `num` evenly spaced samples, calculated over the
    interval [`start`, `stop`].
    
    The endpoint of the interval can optionally be excluded.
    
    .. versionchanged:: 1.16.0
        Non-scalar `start` and `stop` are now supported.
    
    .. versionchanged:: 1.20.0
        Values are rounded towards ``-inf`` instead of ``0`` when an
        integer ``dtype`` is specified. The old behavior can
        still be obtained with ``np.linspace(start, stop, num).astype(int)``
    
    Parameters
    ----------
    start : array_like
        The starting

In [44]:
jnp.linspace?

[0;31mSignature:[0m
[0mjnp[0m[0;34m.[0m[0mlinspace[0m[0;34m([0m[0;34m[0m
[0;34m[0m    [0mstart[0m[0;34m:[0m [0;34m'ArrayLike'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mstop[0m[0;34m:[0m [0;34m'ArrayLike'[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mnum[0m[0;34m:[0m [0;34m'int'[0m [0;34m=[0m [0;36m50[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mendpoint[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mTrue[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mretstep[0m[0;34m:[0m [0;34m'bool'[0m [0;34m=[0m [0;32mFalse[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0mdtype[0m[0;34m:[0m [0;34m'DTypeLike | None'[0m [0;34m=[0m [0;32mNone[0m[0;34m,[0m[0;34m[0m
[0;34m[0m    [0maxis[0m[0;34m:[0m [0;34m'int'[0m [0;34m=[0m [0;36m0[0m[0;34m,[0m[0;34m[0m
[0;34m[0m[0;34m)[0m [0;34m->[0m [0;34m'Array | tuple[Array, Array]'[0m[0;34m[0m[0;34m[0m[0m
[0;31mDocstring:[0m
Return evenly spaced numbers over a specified interv

JAX NumPy Operations

In [45]:
matrix = jnp.arange(17, 33)

In [46]:
matrix = matrix.reshape(4, 4)

In [47]:
matrix

Array([[17, 18, 19, 20],
       [21, 22, 23, 24],
       [25, 26, 27, 28],
       [29, 30, 31, 32]], dtype=int32)

In [48]:
try:
    jnp.sum([1, 2, 3])
except TypeError as e:
    print(f"TypeError: {e}")

TypeError: sum requires ndarray or scalar arguments, got <class 'list'> at position 0.


In [49]:
matrix.shape

(4, 4)

In [50]:
matrix.ndim

2

In [51]:
jnp.max(matrix)

Array(32, dtype=int32)

In [52]:
matrix

Array([[17, 18, 19, 20],
       [21, 22, 23, 24],
       [25, 26, 27, 28],
       [29, 30, 31, 32]], dtype=int32)

In [53]:
jnp.argmax(matrix)

Array(15, dtype=int32)

In [54]:
jnp.min(matrix)

Array(17, dtype=int32)

In [55]:
jnp.argmin(matrix)

Array(0, dtype=int32)

In [56]:
jnp.sum(matrix)

Array(392, dtype=int32)

In [57]:
jnp.sqrt(matrix)

Array([[4.1231055, 4.2426405, 4.3588986, 4.4721355],
       [4.5825753, 4.6904154, 4.795831 , 4.898979 ],
       [5.       , 5.0990195, 5.196152 , 5.2915025],
       [5.3851647, 5.4772253, 5.5677643, 5.656854 ]], dtype=float32)

In [58]:
matrix.transpose()

Array([[17, 21, 25, 29],
       [18, 22, 26, 30],
       [19, 23, 27, 31],
       [20, 24, 28, 32]], dtype=int32)

In [59]:
matrix.flatten()

Array([17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32],      dtype=int32)

In [60]:
matrix.ravel()

Array([17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32],      dtype=int32)

In [61]:
matrix2 = jnp.arange(1, 17).reshape(4, 4)

In [62]:
matrix + matrix2

Array([[18, 20, 22, 24],
       [26, 28, 30, 32],
       [34, 36, 38, 40],
       [42, 44, 46, 48]], dtype=int32)

In [63]:
matrix = np.arange(17, 33)
matrix = matrix.reshape(4, 4)
matrix2 = np.arange(1, 17).reshape(4, 4)

In [64]:
%timeit matrix * matrix2

747 ns ± 2.87 ns per loop (mean ± std. dev. of 7 runs, 1,000,000 loops each)


In [65]:
%timeit jnp.dot(jnp.arange(17,33).reshape(4,4), jnp.arange(1,17).reshape(4,4)).block_until_ready()

720 µs ± 6.71 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [66]:
matrix / matrix2

array([[17.        ,  9.        ,  6.33333333,  5.        ],
       [ 4.2       ,  3.66666667,  3.28571429,  3.        ],
       [ 2.77777778,  2.6       ,  2.45454545,  2.33333333],
       [ 2.23076923,  2.14285714,  2.06666667,  2.        ]])

In [67]:
matrix % matrix2

array([[0, 0, 1, 0],
       [1, 4, 2, 0],
       [7, 6, 5, 4],
       [3, 2, 1, 0]])

## Device put

In [68]:
from jax import device_put
import numpy as np

size = 5000
x = np.random.normal(size=(size, size)).astype(np.float32)
x = device_put(x)

2024-06-03 00:17:42.420344: W external/tsl/tsl/framework/bfc_allocator.cc:482] Allocator (GPU_0_bfc) ran out of memory trying to allocate 95.37MiB (rounded to 100000000)requested by op 
2024-06-03 00:17:42.420436: W external/tsl/tsl/framework/bfc_allocator.cc:494] *___________________________________________________________________________________________________


XlaRuntimeError: RESOURCE_EXHAUSTED: Out of memory while trying to allocate 100000000 bytes.

In [None]:
x

Array([[-0.79447454,  1.2171433 ,  1.1587267 , ..., -1.1149006 ,
         0.5576706 ,  0.65717304],
       [ 0.06404939, -0.6359419 , -1.6473819 , ..., -0.79355246,
        -0.31298083, -1.437376  ],
       [-0.33799973,  0.19041944,  0.2353706 , ...,  1.4031185 ,
        -2.2811759 , -0.64421403],
       ...,
       [ 0.5736245 ,  0.15064004, -0.7230288 , ..., -0.62920284,
         0.18096147,  0.12478388],
       [ 0.5002141 , -0.16310322, -0.01618477, ..., -0.74241334,
         0.39063206, -0.38532573],
       [-0.46491575,  0.8566468 ,  0.9910643 , ..., -0.05014795,
        -1.3024176 , -0.2784468 ]], dtype=float32)

# Indexing & Broadcasting in Numpy

In [None]:
matrix = jnp.arange(1, 17)

In [None]:
matrix

Array([ 1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16],      dtype=int32)

In [None]:
matrix[0]

Array(1, dtype=int32)

## Out-of-Bounds Indexing

In [None]:
# Out-of-Bounds Indexing
print(jnp.shape(matrix))
matrix[20]

(16,)


Array(16, dtype=int32)

In [None]:
matrix[2]

Array(3, dtype=int32)

In [None]:
matrix[2:6]

Array([3, 4, 5, 6], dtype=int32)

In [None]:
matrix[12:]

Array([13, 14, 15, 16], dtype=int32)

### Indexing Two Dimensional Array

In [None]:
matrix = matrix.reshape(4, 4)
matrix

Array([[ 1,  2,  3,  4],
       [ 5,  6,  7,  8],
       [ 9, 10, 11, 12],
       [13, 14, 15, 16]], dtype=int32)

In [None]:
matrix[0]

Array([1, 2, 3, 4], dtype=int32)

In [None]:
matrix[0:2]

Array([[1, 2, 3, 4],
       [5, 6, 7, 8]], dtype=int32)

In [None]:
matrix[1:3, 1:3]  # [startrow:endrow, startcolumn:endcolumn]

Array([[ 6,  7],
       [10, 11]], dtype=int32)

In [None]:
matrix[2:4, 1:2]

Array([[10],
       [14]], dtype=int32)

In [None]:
matrix[2:4, 2:4]

Array([[11, 12],
       [15, 16]], dtype=int32)

In [None]:
matrix[2:, 2:]

Array([[11, 12],
       [15, 16]], dtype=int32)

### Broadcasting in NumPy

In [None]:
scores = [50, 60, 70, 30, 25]

In [None]:
scores_array = jnp.array(scores)

In [None]:
scores_array

Array([50, 60, 70, 30, 25], dtype=int32)

In [None]:
scores_array[0:3]

Array([50, 60, 70], dtype=int32)

### JAX arrays are immutable

In [None]:
scores_array[0:3] = [20, 40, 90]

TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [None]:
new_scores_array = scores_array.at[0:3].set([20, 40, 90])

In [None]:
new_scores_array

Array([20, 40, 90, 30, 25], dtype=int32)

In [None]:
scores_array

Array([50, 60, 70, 30, 25], dtype=int32)

## Using jit() to speed up functions

In [None]:
def test_fn(sample_rate=3000, frequency=3):
    x = jnp.arange(sample_rate)
    y = jnp.sin(2 * jnp.pi * frequency * (frequency / sample_rate))
    return jnp.dot(x, y)

In [None]:
%timeit test_fn()

249 µs ± 333 ns per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [None]:
test_fn_jit = jax.jit(test_fn)
%timeit test_fn_jit().block_until_ready()

69.3 µs ± 1.86 µs per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## How JIT works

By default JAX executes operations one at a time, in sequence.

Using a just-in-time (JIT) compilation decorator, sequences of operations can be optimized together and run at once.

Not all JAX code can be JIT compiled, as it requires array shapes to be static & known at compile time.

The fact that all JAX operations are expressed in terms of XLA allows JAX to use the XLA compiler to execute blocks of code very efficiently.

In [None]:
@jax.jit
def f(x, y):
    print("Running f():")
    print(f"  x = {x}")
    print(f"  y = {y}")
    result = jnp.dot(x + 1, y + 1)
    print(f"  result = {result}")
    return result


x = np.random.randn(3, 4)
y = np.random.randn(4)
f(x, y)

Running f():
  x = Traced<ShapedArray(float32[3,4])>with<DynamicJaxprTrace(level=1/0)>
  y = Traced<ShapedArray(float32[4])>with<DynamicJaxprTrace(level=1/0)>
  result = Traced<ShapedArray(float32[3])>with<DynamicJaxprTrace(level=1/0)>


Array([6.6637564, 5.8385086, 4.8083467], dtype=float32)

In [None]:
x2 = np.random.randn(3, 4)
y2 = np.random.randn(4)
f(x2, y2)

Array([1.9138228, 2.2620401, 5.1799603], dtype=float32)

In [None]:
from jax import make_jaxpr


def f(x, y):
    return jnp.dot(x + 1, y + 1)


make_jaxpr(f)(x, y)

{ [34m[22m[1mlambda [39m[22m[22m; a[35m:f32[3,4][39m b[35m:f32[4][39m. [34m[22m[1mlet
    [39m[22m[22mc[35m:f32[3,4][39m = add a 1.0
    d[35m:f32[4][39m = add b 1.0
    e[35m:f32[3][39m = dot_general[
      dimension_numbers=(([1], [0]), ([], []))
      preferred_element_type=float32
    ] c d
  [34m[22m[1min [39m[22m[22m(e,) }

In [None]:
@jax.jit
def f(boolean, x):
    return -x if boolean else x


f(True, 1)

TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]..
The error occurred while tracing the function f at /tmp/ipykernel_23106/3541786643.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument boolean.
See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.TracerBoolConversionError

In [None]:
from functools import partial


@partial(jax.jit, static_argnums=(0,))
def f(boolean, x):
    return -x if boolean else x


f(True, 1)

Array(-1, dtype=int32, weak_type=True)

In [None]:
f(False, 1)

Array(1, dtype=int32, weak_type=True)

## Taking derivatives with grad()

In [None]:
@jax.jit
def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x)))


x_small = jnp.arange(6.0)
derivative_fn = jax.grad(sum_logistic)
print(derivative_fn(x_small))

[0.25       0.19661194 0.10499357 0.04517666 0.01766271 0.00664806]


In [None]:
@jax.jit
def sum_logistic(x):
    return jnp.sum(1.0 / (1.0 + jnp.exp(-x))), (x + 1)


x_small = jnp.arange(6.0)
derivative_fn = jax.grad(sum_logistic, has_aux=True)
print(derivative_fn(x_small))

(Array([0.25      , 0.19661194, 0.10499357, 0.04517666, 0.01766271,
       0.00664806], dtype=float32), Array([1., 2., 3., 4., 5., 6.], dtype=float32))


In [None]:
arcsinh = jax.grad(jax.numpy.arcsinh)
print(arcsinh(0.9))

0.7432942


## Auto-vectorization with vmap()

In [None]:
mat = jax.random.normal(key, (150, 100))
batched_x = jax.random.normal(key, (10, 100))


def apply_matrix(v):
    return jnp.dot(mat, v)

In [None]:
def naively_batched_apply_matrix(v_batched):
  return jnp.stack([apply_matrix(v) for v in v_batched])

print('Naively batched')
%timeit naively_batched_apply_matrix(batched_x).block_until_ready()

Naively batched
3.82 ms ± 78.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
@jax.jit
def batched_apply_matrix(v_batched):
  return jnp.dot(v_batched, mat.T)

print('Manually batched')
%timeit batched_apply_matrix(batched_x).block_until_ready()

Manually batched
87.5 µs ± 520 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


In [None]:
@jax.jit
def vmap_batched_apply_matrix(v_batched):
  return jax.vmap(apply_matrix)(v_batched)

print('Auto-vectorized with vmap')
%timeit vmap_batched_apply_matrix(batched_x).block_until_ready()

Auto-vectorized with vmap
86 µs ± 410 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)


## Parallelization with pmap

In [None]:
x = np.arange(5)
w = np.array([2.0, 3.0, 4.0])


def convolve(x, w):
    output = []
    for i in range(1, len(x) - 1):
        output.append(jnp.dot(x[i - 1 : i + 2], w))
    return jnp.array(output)


convolve(x, w)

DeviceArray([11., 20., 29.], dtype=float32)

In [None]:
n_devices = jax.local_device_count()
xs = np.arange(5 * n_devices).reshape(-1, 5)
ws = np.stack([w] * n_devices)

xs

array([[ 0,  1,  2,  3,  4],
       [ 5,  6,  7,  8,  9],
       [10, 11, 12, 13, 14],
       [15, 16, 17, 18, 19],
       [20, 21, 22, 23, 24],
       [25, 26, 27, 28, 29],
       [30, 31, 32, 33, 34],
       [35, 36, 37, 38, 39]])

In [None]:
jax.vmap(convolve)(xs, ws)

DeviceArray([[ 11.,  20.,  29.],
             [ 56.,  65.,  74.],
             [101., 110., 119.],
             [146., 155., 164.],
             [191., 200., 209.],
             [236., 245., 254.],
             [281., 290., 299.],
             [326., 335., 344.]], dtype=float32)

In [None]:
jax.pmap(convolve)(xs, ws)

ShardedDeviceArray([[ 11.,  20.,  29.],
                    [ 56.,  65.,  74.],
                    [101., 110., 119.],
                    [146., 155., 164.],
                    [191., 200., 209.],
                    [236., 245., 254.],
                    [281., 290., 299.],
                    [326., 335., 344.]], dtype=float32)

## Debugging NaNs

In [None]:
jnp.divide(0.0, 0.0)

Array(nan, dtype=float32, weak_type=True)

In [None]:
from jax.config import config

config.update("jax_debug_nans", True)
jnp.divide(0.0, 0.0)

  from jax.config import config


ImportError: cannot import name 'config' from 'jax.config' (/home/j2y/envs/tiny/lib/python3.11/site-packages/jax/config.py)

## Double (64bit) precision

In [None]:
x = jnp.float64(1.25844)

  return asarray(x, dtype=self.dtype)


In [None]:
x

Array(1.25844, dtype=float32)

In [None]:
# set this config at the begining of the program
from jax.config import config

config.update("jax_enable_x64", True)
x = jnp.float64(1.25844)
x

ImportError: cannot import name 'config' from 'jax.config' (/home/j2y/envs/tiny/lib/python3.11/site-packages/jax/config.py)

## PyTrees

In [None]:
example_trees = [
    [1, "a", object()],
    (1, (2, 3), ()),
    [1, {"k1": 2, "k2": (3, 4)}, 5],
    {"a": 2, "b": (2, 3)},
    jnp.array([1, 2, 3]),
    None,
]

In [None]:
# Let's see how many leaves they have:
for pytree in example_trees:
    leaves = jax.tree.leaves(pytree)
    print(f"{repr(pytree):<45} has {len(leaves)} leaves: {leaves}")

[1, 'a', <object object at 0x7ca98c4de670>]   has 3 leaves: [1, 'a', <object object at 0x7ca98c4de670>]
(1, (2, 3), ())                               has 3 leaves: [1, 2, 3]
[1, {'k1': 2, 'k2': (3, 4)}, 5]               has 5 leaves: [1, 2, 3, 4, 5]
{'a': 2, 'b': (2, 3)}                         has 3 leaves: [2, 2, 3]
Array([1, 2, 3], dtype=int32)                 has 1 leaves: [Array([1, 2, 3], dtype=int32)]
None                                          has 0 leaves: []


## Where to go from here
Follow us on [LinkedIn](https://www.linkedin.com/company/mlnuggets), [Twitter](https://twitter.com/ml_nuggets), [GitHub](https://github.com/mlnuggets) and subscribe to our [blog](https://www.machinelearningnuggets.com/#/portal) so that you don't miss a new issue.

In [None]:
jnp.arange(10)[11]

Array(9, dtype=int32)

In [None]:
jnp.arange(10.0).at[11].get()

Array(9., dtype=float32)

In [None]:
jnp.arange(10).at[11].get(mode="fill", fill_value=jnp.nan)

  out = np.array(c).astype(eqn.params['new_dtype'])


Array(-2147483648, dtype=int32)

In [None]:
jnp.sum(jnp.array([1, 2, 3]))

Array(6, dtype=int32)