**S01P05_sharp_bits_out_of_bounds_indexing.ipynb**

Arz

2024 APR 05 (FRI)

reference:
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

In [1]:
import numpy as np

In [2]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit
from jax import random

In [3]:
%xmode minimal

Exception reporting mode: Minimal


# out-of_bounds indexing

## NumPy: throws error

In [4]:
# np.arange(3)[3]  # forbidden: throws error

## JAX: do something that's non-error

raising an error from code running on an accelerator can be difficult or impossible.

### 1) array index retrieval operations:

**index is clamped to the bounds of the array**

- ex) NumPy indexing, gather-like primitives

In [5]:
jnp.arange(3)[3]

Array(2, dtype=int32)

the last element is returned.

In [6]:
jnp.arange(3).at[3].get()

Array(2, dtype=int32)

**set to nan if our-of-bounds**

In [10]:
jnp.arange(3.0).at[3].get(mode="fill", fill_value=jnp.nan)

Array(nan, dtype=float32)

⚠️❓reverse-mode automatic differentiation (which turns index updates into index retrievals and vice versa) will not preserve the semantics of out of bounds indexing. 

thus it may be a good idea to think of out-of-bounds indexing in JAX as a case of **undefined behavior**.