In [1]:
import numpy as np
from jax import jit
from jax import lax
from jax import random
import jax
import jax.numpy as jnp

In [2]:
def impure_print_side_effect(x):
  print("Executing function")  # This is a side-effect
  return x

# The side-effects appear during the first run
print ("First call: ", jit(impure_print_side_effect)(4.))

# Subsequent runs with parameters of same type and shape may not show the side-effect
# This is because JAX now invokes a cached compilation of the function
print ("Second call: ", jit(impure_print_side_effect)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))

Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]


In [3]:
g = 0.
def impure_uses_globals(x):
  return x + g

# JAX captures the value of the global during the first run
print ("First call: ", jit(impure_uses_globals)(4.))
g = 10.  # Update the global

# Subsequent runs may silently use the cached value of the globals
print ("Second call: ", jit(impure_uses_globals)(5.))

# JAX re-runs the Python function when the type or shape of the argument changes
# This will end up reading the latest value of the global
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))

First call:  4.0
Second call:  5.0
Third call, different type:  [14.]


In [4]:
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x

# JAX runs once the transformed function with special Traced values for arguments
print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)  # Saved global has an internal JAX value

First call:  4.0
Saved global:  Traced<~float32[]>with<DynamicJaxprTrace>


In [5]:
def pure_uses_internal_state(x):
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x
  return state['even'] + state['odd']

print(jit(pure_uses_internal_state)(5.))

50.0


In [6]:
import jax.numpy as jnp
from jax import make_jaxpr

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error

45
0


In [7]:
%xmode Minimal

Exception reporting mode: Minimal


In [8]:
jax_array = jnp.zeros((3,3), dtype=jnp.float32)

# In place update of JAX's array will yield an error!
jax_array[1, :] = 1.0

TypeError: JAX arrays are immutable and do not support in-place item assignment. Instead of x[idx] = y, use x = x.at[idx].set(y) or another .at[] method: https://docs.jax.dev/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [9]:
jax_array = jnp.array([10, 20])
jax_array_new = jax_array
jax_array_new += 10
print(jax_array_new)  # `jax_array_new` is rebound to a new value [20, 30], but...
print(jax_array)      # the original value is unodified as [10, 20] !

numpy_array = np.array([10, 20])
numpy_array_new = numpy_array
numpy_array_new += 10
print(numpy_array_new)  # `numpy_array_new is numpy_array`, and it was updated
print(numpy_array)      # in-place, so both are [20, 30] !

[20 30]
[10 20]
[20 30]
[20 30]


In [10]:
jax_array = jnp.zeros((3,3), dtype=jnp.float32)
updated_array = jax_array.at[1, :].set(1.0)
print("updated array:\n", updated_array)

updated array:
 [[0. 0. 0.]
 [1. 1. 1.]
 [0. 0. 0.]]


In [11]:
print("original array unchanged:\n", jax_array)

original array unchanged:
 [[0. 0. 0.]
 [0. 0. 0.]
 [0. 0. 0.]]


In [12]:
print("original array:")
jax_array = jnp.ones((5, 6))
print(jax_array)

new_jax_array = jax_array.at[::2, 3:].add(7.)
print("new array post-addition:")
print(new_jax_array)

original array:
[[1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 1. 1. 1.]]
new array post-addition:
[[1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]
 [1. 1. 1. 1. 1. 1.]
 [1. 1. 1. 8. 8. 8.]]


In [13]:
np.arange(10)[11]

IndexError: index 11 is out of bounds for axis 0 with size 10

In [14]:
jnp.arange(10)[11]

Array(9, dtype=int32)

In [15]:
jnp.arange(10.0).at[11].get()

Array(9., dtype=float32)

In [16]:
jnp.arange(10.0).at[11].get(mode='fill', fill_value=jnp.nan)

Array(nan, dtype=float32)

In [17]:
def permissive_sum(x):
  return jnp.sum(jnp.array(x))

x = list(range(10))
permissive_sum(x)

Array(45, dtype=int32)

In [18]:
make_jaxpr(permissive_sum)(x)

{ [34;1mlambda [39;22m; a[35m:i32[][39m b[35m:i32[][39m c[35m:i32[][39m d[35m:i32[][39m e[35m:i32[][39m f[35m:i32[][39m g[35m:i32[][39m h[35m:i32[][39m i[35m:i32[][39m
    j[35m:i32[][39m. [34;1mlet
    [39;22mk[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] a
    l[35m:i32[1][39m = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] k
    m[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] b
    n[35m:i32[1][39m = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] m
    o[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] c
    p[35m:i32[1][39m = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      sharding=None
    ] o
    q[35m:i32[][39m = convert_element_type[new_dtype=int32 weak_type=False] d
    r[35m:i32[1][39m = broadcast_in_dim[
      broadcast_dimensions=()
      shape=(1,)
      

In [19]:
jnp.sum(jnp.array(x))

Array(45, dtype=int32)

In [20]:
def nansum(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  x_without_nans = x[mask]
  return x_without_nans.sum()

In [21]:
x = jnp.array([1, 2, jnp.nan, 3, 4])
print(nansum(x))

10.0


In [22]:
jax.jit(nansum)(x)

NonConcreteBooleanIndexError: Array boolean indices must be concrete; got bool[5]

See https://docs.jax.dev/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError

In [24]:
@jax.jit
def nansum_2(x):
  mask = ~jnp.isnan(x)  # boolean mask selecting non-nan values
  return jnp.where(mask, x, 0).sum()

print(jax.jit(nansum_2)(x))

10.0
