In [1]:
import os
os.environ['CUDA_VISIBLE_DEVICES'] = '3'

In [2]:
import jax
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [3]:
key = random.key(0)

In [4]:
# Derivatives of univariate functions is pretty simple. 
# Second order derivatives of multivariate functions are defined by the Hessian matrix.

In [5]:
# JAX provides two transformations, jax.jacfwd and jax.jacrev for forward mode and reverse mode autodiff.

In [6]:
def hessian(f):
    return jax.jacfwd(jax.grad(f))

In [7]:
def f(x):
  return jnp.dot(x, x)

hessian(f)(jnp.array([1., 2., 3.]))

Array([[2., 0., 0.],
       [0., 2., 0.],
       [0., 0., 2.]], dtype=float32)

In [8]:
# You can use jax.lax.stop_gradient() to prevent computation of gradients through part of the computation graph.

In [9]:
def td_loss(theta, s_tm1, r_t, s_t):
  v_tm1 = value_fn(theta, s_tm1)
  target = r_t + value_fn(theta, s_t)
  return -0.5 * ((jax.lax.stop_gradient(target) - v_tm1) ** 2)

In [10]:
# You can use stop_gradient to implement differentiation of functions that use non-differentiable inner functions.

In [12]:
# Per example gradients.
# Write a function to calculate loss per example.
# Transform it with grad
# Apply vmap to make it process batches efficiently
# Jit the entire thing and voila, you have fast per example gradient computation.

In [13]:
perex_grads = jax.jit(jax.vmap(jax.grad(td_loss), in_axes=(None, 0, 0, 0)))

## Hessian vector product

The trick is to not instantiate the full hessian metrix when calculating d2f/dx2 . v

jax.grad is efficient at differentiating scalar valued functions of vector-valued arguments.

In [14]:
from jax import jacfwd, jacrev

# Define a sigmoid function.
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39]])

# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,))
b = random.normal(b_key, ())

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)

jacfwd result, with shape (4, 3)
[[ 0.05981758  0.12883787  0.08857603]
 [ 0.04015914 -0.04928622  0.00684531]
 [ 0.12188288  0.01406341 -0.3047072 ]
 [ 0.00140433 -0.00472538  0.00263786]]
jacrev result, with shape (4, 3)
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015914 -0.04928622  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140433 -0.00472538  0.00263786]]


 jacfwd uses fwd-mode autodiff, which is more efficient for tall jacobians (more outputs than inputs)

 jacrev uses reverse mode autodiff, which is more efficient for wide jacobians (more inputs than outputs)

For matrices that are near-square, jax.jacfwd() probably has an edge over jax.jacrev().

In [15]:
def predict_dict(params, inputs):
    return predict(params['W'], params['b'], inputs)

J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
    print("Jacobian from {} to logits is".format(k))
    print(v)

Jacobian from W to logits is
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015914 -0.04928622  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140433 -0.00472538  0.00263786]]
Jacobian from b to logits is
[0.11503381 0.04563539 0.23439017 0.00189774]


In [16]:
J_dict

{'W': Array([[ 0.05981757,  0.12883787,  0.08857603],
        [ 0.04015914, -0.04928622,  0.00684531],
        [ 0.12188289,  0.01406341, -0.3047072 ],
        [ 0.00140433, -0.00472538,  0.00263786]], dtype=float32),
 'b': Array([0.11503381, 0.04563539, 0.23439017, 0.00189774], dtype=float32)}

In [17]:
# Using a composition of two of these functions gives us a way to compute dense Hessian matrices:

def hessian(f):
    return jacfwd(jacrev(f))

H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)

hessian, with shape (4, 3, 3)
[[[ 0.02285465  0.04922541  0.03384247]
  [ 0.04922541  0.10602397  0.07289147]
  [ 0.03384247  0.07289147  0.05011288]]

 [[-0.03195214  0.03921399 -0.00544639]
  [ 0.03921399 -0.04812626  0.0066842 ]
  [-0.00544639  0.0066842  -0.00092836]]

 [[-0.01583708 -0.00182736  0.03959271]
  [-0.00182736 -0.00021085  0.00456839]
  [ 0.03959271  0.00456839 -0.09898177]]

 [[-0.00103525  0.00348348 -0.0019446 ]
  [ 0.00348348 -0.01172145  0.0065433 ]
  [-0.0019446   0.0065433  -0.00365269]]]


In [18]:
# fwd over rev is most efficient

## Jacobian-vector products (fwd-mode autodiff)

In [20]:
# jax.grad is built on reverse mode autodiff

In [21]:
# JAX's jvp is a way to calculate df(x) * v given a function f, a point x, and a vector v.

In [22]:
# Memory cost is independent of the depth of the computation

In [23]:
# FLOP cost of JVP is about 3x the original function

In [24]:
# Builds jacobian matrices one column at a time

## Vector-jacobian products (rev-mode autodiff)

In [25]:
# Builds jacobian matrix one row at a time

In [26]:
# JAX's vjp is a way to calculate v * df(x)

In [27]:
# FLOP cost of VJP is about 3x the original function

In [28]:
# With this, we can get the gradient of a scalar output, vector argument function in just one call. This is how jax.grad is implmeneted

In [29]:
# There's a cost, though the flops are friendly, memory scales with the depth of the computation.

In [30]:
# Hessian vector products
# forward-over-reverse
def hvp(f, primals, tangents):
  return jvp(grad(f), primals, tangents)[1]


In [31]:
# You can define jax.custom_jvp and jax.custom_vjp to define custom differentiation rules 

In [32]:
# Gradient clipping can be implemented as a custom grad function.