**S02P04_tutorial_auto_differentiation_in_jax.ipynb**

Arz

2024 APR 11 (THU)

reference:
https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html

check:

- https://youtu.be/wG_nF1awSSY?si=xnaKSsOx8TtBQrNb
- https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

In [1]:
import numpy as np

In [2]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit, vmap
from jax import random

In [3]:
%xmode minimal

Exception reporting mode: Minimal


# higher-order derivatives

## ex) single variable

In [4]:
f = lambda x: x**3 + 2*x**2 - 3*x + 1

In [5]:
grad_f = grad(f)  # 3x^2 + 4x - 3

In [6]:
grad2_f = grad(grad_f)  # 6x + 4
grad3_f = grad(grad2_f)  # 6
grad4_f = grad(grad3_f)  # 0

**check case: x=1**

In [7]:
print(grad_f(1.))
print(grad2_f(1.))
print(grad3_f(1.))
print(grad4_f(1.))

4.0
10.0
6.0
0.0


## ex) multivariable: Hessian

In [8]:
def hessian(f):
    return jax.jacfwd(jax.grad(f))

In [9]:
def f(x):
    return jnp.dot(x, x)

- f = **x**^T **x**
- grad_f = 2**x**
- hess_f = 2**I**

In [10]:
hessian(f)(jnp.arange(3.))

Array([[2., 0., 0.],
       [0., 2., 0.],
       [0., 0., 2.]], dtype=float32)

# higher-order optimization

## ❓ ex) MAML: differentiating through gradient updates

In [11]:
def meta_loss_function(params, data):
    """computes the loss after one step of SGD."""
    grads = grad(loss_function)(params, data)
    return loss_function(params - learning_rate*grads, data)

# meta_grads = grad(meta_loss_function)(params, data)

# stopping gradients

- ex) using multiple loss functions for various portions of the network.

## ex) TD(0) RL: update

In [12]:
# value function
value_function = lambda theta, state: jnp.dot(theta, state)

In [13]:
theta = jnp.array([0.1, -0.1, 0.])

In [14]:
# ex) transition: (s_{k - 1} -> s_{k}) ~ r_{k}
s_prev = jnp.array([1., 2., -1.])
s_curr = jnp.array([2., 1., 0.])
r_curr = jnp.array(1.)

In [15]:
def td_loss(theta, s_prev, s_curr, r_curr):
    v_prev = value_function(theta, s_prev)
    target = value_function(theta, s_curr) + r_curr
    return -0.5*((target - v_prev)**2)

In [16]:
td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_prev, s_curr, r_curr)

In [17]:
delta_theta

Array([-1.2,  1.2, -1.2], dtype=float32)

❓ force JAX to ignore the dependency of the target on theta,

using **jax.lax.stop_gradient**

In [21]:
def td_loss(theta, s_prev, s_curr, r_curr):
    v_prev = value_function(theta, s_prev)
    target = value_function(theta, s_curr) + r_curr
    return -0.5*((lax.stop_gradient(target) - v_prev)**2)

In [22]:
td_update = grad(td_loss)
delta_theta = td_update(theta, s_prev, s_curr, r_curr)

In [23]:
delta_theta

Array([ 1.2,  2.4, -1.2], dtype=float32)

### validation

In [24]:
v_grad = grad(value_function)(theta, s_prev)
delta_theta = (r_curr + value_function(theta, s_curr) - value_function(theta, s_prev))*v_grad

In [25]:
delta_theta

Array([ 1.2,  2.4, -1.2], dtype=float32)

# straight-through estimator

- trick for defining a "gradient" of a function that is otherwise non-differentiable.

using **jax.lax.stop_gradient**

In [26]:
def f(x):
    return jnp.round(x)  # non-differentiable

In [27]:
def straight_through_f(x):
    # (?) create an exactly-zero expression with Sterbenz lemma
    # that has an exactly-one gradient.
    zero = x - lax.stop_gradient(x)
    return zero + lax.stop_gradient(f(x))

In [29]:
print("f(x):", f(3.2))
print("straight_through_f(x):", straight_through_f(3.2))

print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))

f(x): 3.0
straight_through_f(x): 3.0
grad(f)(x): 0.0
grad(straight_through_f)(x): 1.0


# per-example gradients

In [30]:
perex_grads = jit(vmap(grad(td_loss), in_axes=(None, 0, 0, 0)))

In [38]:
s_prevs = jnp.stack([s_prev, s_prev.at[1:].set(3), s_prev])
s_currs = jnp.stack([s_curr, s_curr, s_curr.at[2].set(7)])
r_currs = jnp.stack([r_curr, jnp.array(-1.), r_curr])

perex_grads(theta, s_prevs, s_currs, r_currs)

Array([[ 1.2      ,  2.4      , -1.2      ],
       [-0.6999999, -2.1      , -2.1      ],
       [ 1.2      ,  2.4      , -1.2      ]], dtype=float32)