In [28]:
import numpy as np
from jax import numpy as jnp
from jax import jit

In [29]:
def impure_print_side_effect(x):
  print("Executing function")  # This is a side-effect
  return x

print ("First call: ", jit(impure_print_side_effect)(4.))
print ("Second call: ", jit(impure_print_side_effect)(5.))
print ("Third call, different type: ", jit(impure_print_side_effect)(jnp.array([5.])))

Executing function
First call:  4.0
Second call:  5.0
Executing function
Third call, different type:  [5.]


In [30]:
g = 0.
def impure_uses_globals(x):
  return x + g

print ("First call: ", jit(impure_uses_globals)(4.))
g = 10
print ("Second call: ", jit(impure_uses_globals)(5.))
print ("Third call, different type: ", jit(impure_uses_globals)(jnp.array([4.])))

First call:  4.0
Second call:  5.0
Third call, different type:  [14.]


In [31]:
g = 0.
def impure_saves_global(x):
  global g
  g = x
  return x

print ("First call: ", jit(impure_saves_global)(4.))
print ("Saved global: ", g)

First call:  4.0
Saved global:  Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>


In [32]:
def pure_uses_internal_state(x):
  state = dict(even=0, odd=0)
  for i in range(10):
    state['even' if i % 2 == 0 else 'odd'] += x
  return state['even'] + state['odd']

print(jit(pure_uses_internal_state)(5.))

50.0


In [41]:
from jax import xla_computation
comp = xla_computation(pure_uses_internal_state)
print(comp(1.0).as_hlo_text())

HloModule xla_computation_pure_uses_internal_state, entry_computation_layout={(f32[])->(f32[])}

ENTRY main.12 {
  Arg_0.1 = f32[] parameter(0)
  add.2 = f32[] add(Arg_0.1, Arg_0.1)
  add.4 = f32[] add(add.2, Arg_0.1)
  add.6 = f32[] add(add.4, Arg_0.1)
  add.8 = f32[] add(add.6, Arg_0.1)
  add.3 = f32[] add(Arg_0.1, Arg_0.1)
  add.5 = f32[] add(add.3, Arg_0.1)
  add.7 = f32[] add(add.5, Arg_0.1)
  add.9 = f32[] add(add.7, Arg_0.1)
  add.10 = f32[] add(add.8, add.9)
  ROOT tuple.11 = (f32[]) tuple(add.10)
}




In [44]:
from jax import make_jaxpr,lax

# lax.fori_loop
array = jnp.arange(10)
print(lax.fori_loop(0, 10, lambda i,x: x+array[i], 0)) # expected result 45
iterator = iter(range(10))
print(lax.fori_loop(0, 10, lambda i,x: x+next(iterator), 0)) # unexpected result 0

# lax.scan
def func11(arr, extra):
    ones = jnp.ones(arr.shape)
    def body(carry, aelems):
        ae1, ae2 = aelems
        return (carry + ae1 * ae2 + extra, carry)
    return lax.scan(body, 0., (arr, ones))
make_jaxpr(func11)(jnp.arange(16), 5.)
# make_jaxpr(func11)(iter(range(16)), 5.) # throws error

# lax.cond
array_operand = jnp.array([0.])
lax.cond(True, lambda x: x+1, lambda x: x-1, array_operand)
iter_operand = iter(range(10))
# lax.cond(True, lambda x: next(x)+1, lambda x: next(x)-1, iter_operand) # throws error

45
0
