<a href="https://colab.research.google.com/github/yblee110/jax-flax-book/blob/main/ch02_2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import jax
import jax.numpy as jnp
from jax import make_jaxpr

def my_function(x, y):
  return x * y + 2

print(make_jaxpr(my_function)(3., 4.))

{ lambda ; a:f32[] b:f32[]. let c:f32[] = mul a b; d:f32[] = add c 2.0 in (d,) }


In [2]:
global_list = []

def my_function(x, y):
    global_list.append(x)
    global_list.append(y)
    return x * y + 2

print(make_jaxpr(my_function)(3., 4.))

{ lambda ; a:f32[] b:f32[]. let c:f32[] = mul a b; d:f32[] = add c 2.0 in (d,) }


In [3]:
def log2_with_print(x):
    print("printed x:", x)
    ln_x = jnp.log(x)
    ln_2 = jnp.log(2.0)
    return ln_x / ln_2


print(jax.make_jaxpr(log2_with_print)(3.))

printed x: Traced<ShapedArray(float32[], weak_type=True)>with<DynamicJaxprTrace(level=1/0)>
{ lambda ; a:f32[]. let
    b:f32[] = log a
    c:f32[] = log 2.0
    d:f32[] = div b c
  in (d,) }


In [4]:
def square_if_gt_2(x):
    if x.ndim > 2:
        return x**2
    else:
        return x


print(jax.make_jaxpr(square_if_gt_2)(jax.numpy.array([1, 2, 3])))

{ lambda ; a:i32[3]. let  in (a,) }


In [5]:
def selu(x, alpha=1.67, lambda_=1.05):
    return lambda_ * jnp.where(x > 0, x, alpha * jnp.exp(x) - alpha)


x = jnp.arange(1000000)
%timeit selu(x).block_until_ready()


848 µs ± 78.9 µs per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [6]:
selu_jit = jax.jit(selu)

# Warm up
selu_jit(x).block_until_ready()

%timeit selu_jit(x).block_until_ready()

208 µs ± 69.7 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)


In [7]:
def f(x):
    if x > 0:
        return x
    else:
        return 2 * x


f_jit = jax.jit(f)
f_jit(10) # 에러 발생.

TracerBoolConversionError: ignored

In [8]:
# 입력 n이 조건에 포함된 while 반복문.

def g(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i


g_jit = jax.jit(g)
g_jit(10, 20) # 에러 발생.

TracerBoolConversionError: ignored

In [9]:
@jax.jit
def loop_body(prev_i):
    return prev_i + 1


def g_inner_jitted(x, n):
    i = 0
    while i < n:
        i = loop_body(i)
    return x + i


g_inner_jitted(10, 20)


Array(30, dtype=int32, weak_type=True)

In [10]:
f_jit_correct = jax.jit(f, static_argnums=0)
print(f_jit_correct(10))


10


In [11]:
g_jit_correct = jax.jit(g, static_argnames=['n'])
print(g_jit_correct(10, 20))

30


In [12]:
from functools import partial

@partial(jax.jit, static_argnames=['n'])
def g_jit_decorated(x, n):
    i = 0
    while i < n:
        i += 1
    return x + i


print(g_jit_decorated(10, 20))


30


In [13]:
print("g jitted:")
%timeit g_jit_correct(10, 20).block_until_ready()


print("g:")
%timeit g(10, 20)

g jitted:
289 µs ± 101 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
g:
1.4 µs ± 519 ns per loop (mean ± std. dev. of 7 runs, 1000000 loops each)


In [14]:
def unjitted_loop_body(prev_i):
    return prev_i + 1


def g_inner_jitted_partial(x, n):
    i = 0
    while i < n:
    # 하지마세요!
    # 매번 partial이 다른 해쉬의 함수를 반환합니다.
        i = jax.jit(partial(unjitted_loop_body))(i)
    return x + i


def g_inner_jitted_lambda(x, n):
    i = 0
    while i < n:
    # 하지마세요!
    # lambda 또한 매번 다른 해쉬의 함수를 반환합니다.
        i = jax.jit(lambda x: unjitted_loop_body(x))(i)
    return x + i


def g_inner_jitted_normal(x, n):
    i = 0
    while i < n:
    # 이건 괜찮습니다!
    # JAX가 캐싱되고 컴파일된 함수를 다시 찾을 수 있습니다.
        i = jax.jit(unjitted_loop_body)(i)
    return x + i


print("jit called in a loop with partials:")
%timeit g_inner_jitted_partial(10, 20).block_until_ready()


print("jit called in a loop with lambdas:")
%timeit g_inner_jitted_lambda(10, 20).block_until_ready()


print("jit called in a loop with caching:")
%timeit g_inner_jitted_normal(10, 20).block_until_ready()


jit called in a loop with partials:
474 ms ± 8.38 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with lambdas:
632 ms ± 71.6 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
jit called in a loop with caching:
11.5 ms ± 252 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
