In [1]:
import jax
import jax.numpy as np

from jax import random
from jax import grad, jit, vmap

from IPython import display
from matplotlib import pyplot as plt

In [3]:
key = random.PRNGKey(0)
key

DeviceArray([0, 0], dtype=uint32)

# 2.5.1. A Simple Example

In [13]:
x = np.arange(4.)
x

DeviceArray([0., 1., 2., 3.], dtype=float32)

In [21]:
def f(x): return 2 * np.dot(x, x)

In [23]:
y = f(x)
y

DeviceArray(28., dtype=float32)

In [24]:
grad(f)(x)

DeviceArray([ 0.,  4.,  8., 12.], dtype=float32)

In [25]:
grad(f)(x) == 4 * x

DeviceArray([ True,  True,  True,  True], dtype=bool)

In [26]:
def f2(x): return x.sum()

In [27]:
y2 = f2(x)
y2

DeviceArray(6., dtype=float32)

In [28]:
grad(f2)(x)

DeviceArray([1., 1., 1., 1.], dtype=float32)

# 2.5.2. Backward for Non-Scalar Variables

In [33]:
def f(x): return np.dot(x, x)
def f2(x): return np.dot(x, x).sum()

In [37]:
grad(f)(x) == grad(f2)(x)

DeviceArray([ True,  True,  True,  True], dtype=bool)

# 2.5.3. Detaching Computation

In [45]:
# TODO

# 2.5.4. Computing the Gradient of Python Control Flow

In [48]:
def f(a):
    b = a * 2
    while np.linalg.norm(b) < 1000:
        b = b * 2
    if b.sum() > 0:
        c = b
    else:
        c = 100 * b
    return c

In [51]:
a = random.normal(key)
a

DeviceArray(-0.20584235, dtype=float32)

In [53]:
d = f(a)
d

DeviceArray(-168626.05, dtype=float32)

In [52]:
grad(f)(a)

DeviceArray(819200., dtype=float32)

In [55]:
grad(f)(a) == d / a

DeviceArray(True, dtype=bool)