In [None]:
import jax
import jax.numpy as jnp

## Higher-oreder derivatives

In [None]:
f = lambda x: x**3 + 2*x**2 -3*x +1

In [None]:
dfdx = jax.grad(f)
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

In [None]:
print(dfdx(2.))
print(d2fdx(2.))
print(d3fdx(2.))
print(d4fdx(2.))

JAX provides two transformations for computing the Jacobian of a function, jax.jacfwd and jax.jacrev

In [None]:
def hessian(f):
  return jax.jacfwd(jax.grad(f))

In [None]:
def f(x):
  return jnp.dot(x, x)

In [None]:
hessian(f)(jnp.array([1., 2., 3.]))

## Stopping gradients

In [None]:
# Value function and initial parameters
value_fn = lambda theta, state: jnp.dot(theta, state)
theta = jnp.array([0.1, -0.1, 0.])

In [None]:
# An example transition.
s_tm1 = jnp.array([1., 2., -1.])
r_t = jnp.array(1.)
s_t = jnp.array([2., 1., 0.])

In [None]:
def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return (target - v_tm1) ** 2

In [None]:
td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)

In [None]:
print(delta_theta)

In [None]:
def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return (jax.lax.stop_gradient(target) - v_tm1) ** 2

td_update = jax.grad(td_loss)
delta_theta = td_update(theta, s_tm1, r_t, s_t)

delta_theta

## Straight-through estimator using stop_gradient

In [None]:
def f(x):
  return jnp.round(x)  # non-differentiable

In [None]:
def straight_through_f(x):
  # Create an exactly-zero expression with Sterbenz lemma that has
  # an exactly-one gradient.
  zero = x - jax.lax.stop_gradient(x)
  return zero + jax.lax.stop_gradient(f(x))

In [None]:
print("f(x): ", f(3.2))
print("straight_through_f(x):", straight_through_f(3.2))

In [None]:
print("grad(f)(x):", jax.grad(f)(3.2))
print("grad(straight_through_f)(x):", jax.grad(straight_through_f)(3.2))

## Per-example gradients

In [None]:
perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))

In [None]:
# Test it:
batched_s_tm1 = jnp.stack([s_tm1, s_tm1])
batched_r_t = jnp.stack([r_t, r_t])
batched_s_t = jnp.stack([s_t, s_t])

In [None]:
perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)

In [None]:
dtdloss_dtheta = jax.grad(td_loss)

dtdloss_dtheta(theta, s_tm1, r_t, s_t)

In [None]:
almost_perex_grads = jax.vmap(dtdloss_dtheta)

batched_theta = jnp.stack([theta, theta])
almost_perex_grads(batched_theta, batched_s_tm1, batched_r_t, batched_s_t)

In [None]:
inefficient_perex_grads = jax.vmap(dtdloss_dtheta, in_axes=(None, 0, 0, 0))

inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)

In [None]:
perex_grads = jax.jit(inefficient_perex_grads)

perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t)

In [None]:
%timeit inefficient_perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()
%timeit perex_grads(theta, batched_s_tm1, batched_r_t, batched_s_t).block_until_ready()