**S01P04_sharp_bits_in_place_updates.ipynb**

Arz

2024 APR 05 (FRI)

reference:
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

In [1]:
import numpy as np

In [2]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit
from jax import random

In [3]:
%xmode minimal

Exception reporting mode: Minimal


# in-place updates

## NumPy: in-place

In [10]:
x = np.zeros((3, 3), dtype=np.float32)

print(x)

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


In [12]:
x[:, 2] = 1

print(x)

[[0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]]


## JAX: out-of-place

In [13]:
x = jnp.zeros((3, 3), dtype=jnp.float32)

print(x)

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


In [14]:
# x[:, 2] = 1  # in-place forbidden

# print(x)

### 1) .at[index].set(value)

In [15]:
x = jnp.zeros((3, 3), dtype=jnp.float32)

print(x)

[[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


In [16]:
x = x.at[:, 2].set(1)

print(x)

[[0. 0. 1.]
 [0. 0. 1.]
 [0. 0. 1.]]


⚠️❓however, inside jit-compiled code, if the input value x of x.at[index].set(value) is not reused, the compiler will optimize the array update to occur in-place.

### 2) other operations

In [22]:
x = jnp.zeros((3, 5), dtype=jnp.float32)

print(x)

[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]


In [23]:
x = x.at[:2, 1::3].add(7)

print(x)

[[0. 7. 0. 0. 7.]
 [0. 7. 0. 0. 7.]
 [0. 0. 0. 0. 0.]]
