**S02P04_tutorial_auto_differentiation_in_jax.ipynb**

Arz

2024 APR 11 (THU)

reference:
https://jax.readthedocs.io/en/latest/jax-101/04-advanced-autodiff.html

check:

- https://youtu.be/wG_nF1awSSY?si=xnaKSsOx8TtBQrNb
- https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

In [1]:
import numpy as np

In [2]:
import jax
import jax.numpy as jnp
from jax import lax
from jax import grad, jit, vmap
from jax import random

In [3]:
%xmode minimal

Exception reporting mode: Minimal


# higher-order derivatives

## ex) single variable

In [4]:
f = lambda x: x**3 + 2*x**2 - 3*x + 1

In [8]:
grad_f = grad(f)  # 3x^2 + 4x - 3

In [9]:
grad2_f = grad(grad_f)  # 6x + 4
grad3_f = grad(grad2_f)  # 6
grad4_f = grad(grad3_f)  # 0

**check case: x=1**

In [10]:
print(grad_f(1.))
print(grad2_f(1.))
print(grad3_f(1.))
print(grad4_f(1.))

4.0
10.0
6.0
0.0


## ex) multivariable: Hessian

In [12]:
def hessian(f):
    return jax.jacfwd(jax.grad(f))

In [13]:
def f(x):
    return jnp.dot(x, x)

- f = **x**^T **x**
- grad_f = 2**x**
- hess_f = 2**I**

In [15]:
hessian(f)(jnp.arange(3.))

Array([[2., 0., 0.],
       [0., 2., 0.],
       [0., 0., 2.]], dtype=float32)

# higher-order optimization

❓ ex) MAML: differentiating through gradient updates

In [17]:
def meta_loss_function(params, data):
    """computes the loss after one step of SGD."""
    grads = grad(loss_function)(params, data)
    return loss_function(params - learning_rate*grads, data)

# meta_grads = grad(meta_loss_function)(params, data)