<a href="https://colab.research.google.com/github/yblee110/jax-flax-book/blob/main/ch02_4_%EC%9E%90%EB%8F%99_%EB%AF%B8%EB%B6%84.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install jax==0.4.26
import jax
print(jax.__version__)

0.4.26


In [None]:
import jax

f = lambda x: x**3 - 2*x**2 + 3*x - 4

dfdx = jax.grad(f)
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

In [None]:
import jax
import jax.numpy as jnp


def hessian(f):
    return jax.jacfwd(jax.grad(f))


def f(x):
    return jnp.dot(x, x)


hessian(f)(jnp.array([1., 2., 3.]))


Array([[2., 0., 0.],
       [0., 2., 0.],
       [0., 0., 2.]], dtype=float32)

In [None]:
# meta loss만 가져왔기에 실제로 작동하지 않는 코드

def meta_loss_fn(params, data):
    """SGD 한 스텝 후의 손실 계산."""
    grads = jax.grad(loss_fn)(params, data)
    return loss_fn(params - lr * grads, data)


meta_grads = jax.grad(meta_loss_fn)(params, data)


NameError: name 'params' is not defined

In [None]:
# 가치 평가 함수 및 초기 가중치
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])

In [None]:
# 상태 전이와 보상의 예
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])

In [None]:
def td_loss(theta, s_tm1, r_t, s_t):
    v_tm1 = value_fn(theta, s_tm1)
    target = r_t + value_fn(theta, s_t)
    return (target - v_tm1) ** 2


td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)


delta_theta


Array([ 2.4, -2.4,  2.4], dtype=float32)

In [None]:
def td_loss_with_stop_gradient(theta, s_tm1, r_t, s_t):
    v_tm1 = value_fn(theta, s_tm1)
    target = r_t + value_fn(theta, s_t)
    return (jax.lax.stop_gradient(target) - v_tm1) ** 2


td_update = jax.grad(td_loss_with_stop_gradient)
delta_theta = td_update(theta, s_tm1, r_t, s_t)


delta_theta


Array([-2.4, -4.8,  2.4], dtype=float32)

In [None]:
def f(x):
    return jnp.round(x) # 미분 불가능


def straight_through_f(x):
# f(x)와 결과는 같지만 미분은 1이 나오도록 함
    zero = x - jax.lax.stop_gradient(x)
    return zero + jax.lax.stop_gradient(f(x))


print("f(x): ", f(3.2))
print("straight_through_f(x):", straight_through_f(3.2))


print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))


f(x):  3.0
straight_through_f(x): 3.0
grad(f)(x): 0.0
grad(straight_through_f)(x): 1.0


In [None]:
dtdloss_dtheta = jax.grad(td_loss_with_stop_gradient)


dtdloss_dtheta(theta, s_tm1, r_t, s_t)


Array([-2.4, -4.8,  2.4], dtype=float32)

In [None]:
almost_perex_grads = jax.vmap(dtdloss_dtheta)


batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])


batched_theta = jnp.stack([theta, theta])
almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)


Array([[-2.4, -4.8,  2.4],
       [-2.4, -4.8,  2.4]], dtype=float32)

In [None]:
inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))


inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)


Array([[-2.4, -4.8,  2.4],
       [-2.4, -4.8,  2.4]], dtype=float32)

In [None]:
perex_grads = jax.jit(inefficient_perex_grads)


perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)


Array([[-2.4, -4.8,  2.4],
       [-2.4, -4.8,  2.4]], dtype=float32)

In [None]:
%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()

7.62 ms ± 1.18 ms per loop (mean ± std. dev. of 7 runs, 100 loops each)
93.9 µs ± 2.6 µs per loop (mean ± std. dev. of 7 runs, 10000 loops each)
