**S01P03_sharp_bits_pure_functions.ipynb**

Arz

2024 APR 04 (THU)

reference:
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

In [1]:
import numpy as np

In [2]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit
from jax import random

In [33]:
%xmode minimal

Exception reporting mode: Minimal


# pure functions

the expected result of JAX only works when used for python pure functions.

- pure function: if input is the same, output is the same.

## ex) impure: print() 

In [3]:
def impure_print(x):
    print("  see if I get printed.")
    return x

In [4]:
impure_print_jit = jit(impure_print)

In [5]:
print("1st run:")
print(" ", impure_print_jit(1))

print("2nd run:")
print(" ", impure_print_jit(2))

print("but, ...")  # type change
print(" ", impure_print_jit(jnp.array([1])))

1st run:
  see if I get printed.
  1
2nd run:
  2
but, ...
  see if I get printed.
  [1]


## ex) impure: using global variables

In [6]:
g = 9.81

In [7]:
def impure_global_var_use(x):
    return x + g

In [8]:
impure_global_var_use_jit = jit(impure_global_var_use)

In [9]:
print("1st run:")
print(" ", impure_global_var_use_jit(1))

1st run:
  10.81


In [10]:
# update the global variable, g
g = 1.62

In [11]:
print("run after g update:")
print(" ", impure_global_var_use_jit(1))

run after g update:
  10.81


⚠️ result still uses g=9.81 because JIT uses the already compiled code.

In [12]:
print("run after type change:")
print(" ", impure_global_var_use_jit(jnp.array([1])))

run after type change:
  [2.62]


## ex) impure: JAX saves global variable as traced object

In [13]:
g = 9.81

In [14]:
def impure_jax_traces_global_var(x):
    global g
    g = x
    return x

In [15]:
print("1st run:")
print(" ", jit(impure_jax_traces_global_var)(1))
print("g:", g)

1st run:
  1
g: Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


## ex) pure: using internal state

a function can be functionally pure even if it actually uses stateful objects internally, as long as it does not read or write external state

In [17]:
def pure_internal_state_use(x):
    internal_state = dict(even=0, odd=0)
    for i in range(3):
        internal_state["even" if i%2 == 0 else "odd"] += x
    return (internal_state["even"], internal_state["odd"])

In [18]:
print("1st run:")
print(" ", jit(pure_internal_state_use)(3))

1st run:
  (Array(6, dtype=int32, weak_type=True), Array(3, dtype=int32, weak_type=True))


## ex) impure: using Python iterators

It is not recommended to use iterators in any JAX function you want to jit or in any control-flow primitive.

Python iterators are incompatiable with JAX functional programming, and leads to unexpected behavior or error.

In [19]:
from jax import make_jaxpr

### ex) lax.fori_loop

**case: ok**

In [21]:
numbers = jnp.arange(10)

In [22]:
y = lax.fori_loop(0, 10, lambda i, x: x + numbers[i], 0)

print(y)

45


as expected.

**case: not ok**: using iterators

In [25]:
numbers = iter(range(10))

In [26]:
y = lax.fori_loop(0, 10, lambda i, x: x + next(numbers), 0)

print(y)

0


wrong, unexpected answer.

### ex) lax.scan

In [27]:
def f(numbers, extra):
    ones = jnp.ones(numbers.shape)
    def body(carry, aes):
        ae1, ae2 = aes
        return (carry + ae1*ae2 + extra, carry)
    return lax.scan(body, 0, (numbers, ones))

**case: ok**

In [29]:
numbers = jnp.arange(10)

In [30]:
make_jaxpr(f)(numbers, 3)

{ lambda ; a:i32[10] b:i32[]. let
    c:f32[10] = broadcast_in_dim[broadcast_dimensions=() shape=(10,)] 1.0
    d:f32[] e:f32[10] = scan[
      jaxpr={ lambda ; f:i32[] g:f32[] h:i32[] i:f32[]. let
          j:f32[] = convert_element_type[new_dtype=float32 weak_type=False] h
          k:f32[] = mul j i
          l:f32[] = add g k
          m:f32[] = convert_element_type[new_dtype=float32 weak_type=False] f
          n:f32[] = add l m
        in (n, g) }
      length=10
      linear=(False, False, False, False)
      num_carry=1
      num_consts=1
      reverse=False
      unroll=1
    ] b 0.0 a c
  in (d, e) }

**case: not ok**: using iterators

In [31]:
numbers = iter(range(10))

In [35]:
# make_jaxpr(f)(numbers, 3)  # forbidden: throws error

### ex) lax.cond

**case: ok**

In [36]:
operands = jnp.array([0])

In [37]:
lax.cond(True, lambda x: x + 1, lambda x: x - 1, operands)

Array([1], dtype=int32)

**case: not ok**: using iterators

In [38]:
operands = iter(range(10))

In [40]:
# lax.cond(True, lambda x: x + 1, lambda x: x - 1, operands)  # forbidden: throws error