**S01P08_sharp_bits_control_flow.ipynb**

Arz

2024 APR 06 (SAT)

reference:
https://jax.readthedocs.io/en/latest/notebooks/Common_Gotchas_in_JAX.html

In [1]:
import numpy as np

In [2]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit
from jax import random

In [3]:
%xmode minimal

Exception reporting mode: Minimal


# control flow

## ok! Python control flow & *grad*

applying *grad* to functions having regular control flow.

In [4]:
def f(x):
    if x < 3:
        return 3*x**2
    else:
        return -4*x

In [5]:
print(grad(f)(2.0))  # ok!
print(grad(f)(4.0))  # ok!

12.0
-4.0


## Python control flow & JIT

more complicated with more constraints.

### ex)

JIT traces {shape, type}, and not value. so value evaluation flow control fails.

In [6]:
@jit
def f(x):
    if x < 3:
        return 3*x**2
    else:
        return -4*x

In [7]:
# f(2.0)  # forbidden

however, one can control which variable | operation to trace.

In [8]:
def f(x):
    if x < 3:
        return 3*x**2
    else:
        return -4*x

In [9]:
f_jit = jit(f, static_argnums=(0,))

In [10]:
f_jit(2.0)

Array(12., dtype=float32, weak_type=True)

### ex) involving a loop

❓ in effect, the loop gets statically unrolled. 

In [11]:
def f(x, n):
    y = 0
    for i in range(n):
        y += x[i]
    return y

In [12]:
f_jit = jit(f)

In [13]:
f(jnp.array([1, 2, 3]), 2)

Array(3, dtype=int32)

In [14]:
f_jit = jit(f, static_argnums=(1,))

In [15]:
f(jnp.array([1, 2, 3]), 2)

Array(3, dtype=int32)

### ex) with argument-value dependent shapes

shape should be static.

In [16]:
def f(size, value):
    return jnp.ones((size,))*value

In [17]:
f(3, 7)

Array([7., 7., 7.], dtype=float32)

**not ok**

In [18]:
f_jit = jit(f)

# f_jit(3, 7)  # forbidden

**ok!**

In [19]:
f_jit = jit(f, static_argnums=(0,))

f_jit(3, 7)

Array([7., 7., 7.], dtype=float32)

whenever this static argument *size* changes, the code recompiles.

so, if *size* gets changed frequently, not very efficient.

## structured control flow primitives

more options for control flow in JAX.

use next if you want to avoid re-compilations, but still want it to be traceable and avoid un-rolling large loops.

- **lax.cond**: differentiable

- **lax.while_loop**: forward-mode-differentiable

- **lax.fori_loop**: forwardd-mode-differentiable in general; forward and reverse-mode differentiable if endpoints are static.

- **lax.scan**: differentiable

### lax.cond

**Python equivalent:**

In [20]:
def cond(pred, f_true, f_false, operand):
    if pred:
        return f_true(operand)
    else:
        return f_false(operand)

**demo**

In [21]:
operand = jnp.array([0.])

In [22]:
y = lax.cond(True, lambda x: x + 1, lambda x: x - 1, operand)
print(y)

y = lax.cond(False, lambda x: x + 1, lambda x: x - 1, operand)
print(y)

[1.]
[-1.]


**related jax.lax fuctions**

- **lax.select**
- **lax.switch**

**related jax.numpy wrappers**

- **jnp.where**
- **jnp.piecewise**
- **jnp.select**

### lax.while_loop

**Python equivalent:**

In [23]:
def while_loop(cond_f, body_f, init_value):
    value = init_value
    while cond_f(value):
        value = body_f(value)
    return value

**demo**

In [25]:
init_value = 0
cond_f = lambda x: x < 7
body_f = lambda x: x + 1

lax.while_loop(cond_f, body_f, init_value)

Array(7, dtype=int32, weak_type=True)

### lax.fori_loop

**Python equivalent:**

In [26]:
def fori_loop(start, stop, body_f, init_value):
    value = init_value
    for i in range(start, stop):
        value = body_f(i, value)
    return value

**demo**

In [29]:
start = 0
stop = 10
body_f = lambda i, x: x + i
init_value = 0

lax.fori_loop(start, stop, body_f, init_value)

Array(45, dtype=int32, weak_type=True)

### summary

|                 | JIT | grad    |
|-----------------|-----|---------|
| if              | ❌   | ✔       |
| for             | ✔*  | ✔       |
| while           | ✔*  | ✔       |
| lax.cond        | ✔   | ✔       |
| lax.while_loop  | ✔   | forward |
| lax.fori_loop   | ✔   | forward |
| lax.scan        | ✔   | ✔       |
 
*: argument-value-independent loop condition - unrolls the loop