# Taking gradients with jax.grad

In [1]:
import jax
import jax.numpy as jnp
from jax import grad

grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))

0.070650816


Since jax.grad() operates on functions, you can apply it to its own output to differentiate as many times as you like:

In [2]:
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))

-0.13621868
0.25265405


JAX’s autodiff makes it easy to compute higher-order derivatives, because the functions that compute derivatives are themselves differentiable. Thus, higher-order derivatives are as easy as stacking transformations. This can be illustrated in the single-variable case:

The derivative of f(x)=x^3+2x^2-3x+1
 can be computed as:

In [4]:
f = lambda x: x**3 + 2*x**2 - 3*x + 1
dfdx = jax.grad(f)

Computing any of these in JAX is as easy as chaining the jax.grad() function:

In [6]:
d2fdx = jax.grad(dfdx)
d3fdx = jax.grad(d2fdx)
d4fdx = jax.grad(d3fdx)

# Evaluation
print(dfdx(1.))
print(d2fdx(1.))
print(d3fdx(1.))
print(d4fdx(1.))

4.0
10.0
6.0
0.0


# Computing gradients in a linear logistic regression

The next example shows how to compute gradients with jax.grad() in a linear logistic regression model. First, the setup:

In [8]:
key = jax.random.key(0)

def sigmoid(x):
  return 0.5 * (jnp.tanh(x / 2) + 1)

# Outputs probability of a label being true.
def predict(W, b, inputs):
  return sigmoid(jnp.dot(inputs, W) + b)

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                    [0.88, -1.08, 0.15],
                    [0.52, 0.06, -1.30],
                    [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])

# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
  preds = predict(W, b, inputs)
  label_probs = preds * targets + (1 - preds) * (1 - targets)
  return -jnp.sum(jnp.log(label_probs))

# Initialize random model coefficients
key, W_key, b_key = jax.random.split(key, 3)
W = jax.random.normal(W_key, (3,))
b = jax.random.normal(b_key, ())

Use the jax.grad() function with its argnums argument to differentiate a function with respect to positional arguments.

In [9]:
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print(f'{W_grad=}')

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print(f'{W_grad=}')

# But you can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print(f'{b_grad=}')

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print(f'{W_grad=}')
print(f'{b_grad=}')

W_grad=Array([-0.43314594, -0.7354604 , -1.2598921 ], dtype=float32)
W_grad=Array([-0.43314594, -0.7354604 , -1.2598921 ], dtype=float32)
b_grad=Array(-0.69001764, dtype=float32)
W_grad=Array([-0.43314594, -0.7354604 , -1.2598921 ], dtype=float32)
b_grad=Array(-0.69001764, dtype=float32)


# Differentiating with respect to nested lists, tuples, and dicts

Due to JAX’s PyTree abstraction (see Working with pytrees), differentiating with respect to standard Python containers just works, so use tuples, lists, and dicts (and arbitrary nesting) however you like.

Continuing the previous example:

In [10]:
def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))

print(grad(loss2)({'W': W, 'b': b}))

{'W': Array([-0.43314594, -0.7354604 , -1.2598921 ], dtype=float32), 'b': Array(-0.69001764, dtype=float32)}


You can create Custom pytree nodes to work with not just jax.grad() but other JAX transformations (jax.jit(), jax.vmap(), and so on).

# Evaluating a function and its gradient using jax.value_and_grad

Another convenient function is jax.value_and_grad() for efficiently computing both a function’s value as well as its gradient’s value in one pass.

Continuing the previous examples:

In [11]:
loss_value, Wb_grad = jax.value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))

loss value 2.9729187
loss value 2.9729187


# Checking against numerical differences

A great thing about derivatives is that they’re straightforward to check with finite differences.

Continuing the previous examples:

In [12]:
# Set a step size for finite differences calculations
eps = 1e-4

# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))

# Check W_grad with finite differences in a random direction
key, subkey = jax.random.split(key)
vec = jax.random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) - loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))

b_grad_numerical -0.6890297
b_grad_autodiff -0.69001764
W_dirderiv_numerical 1.3041496
W_dirderiv_autodiff 1.3006743


JAX provides a simple convenience function that does essentially the same thing, but checks up to any order of differentiation that you like:

In [14]:
from jax.test_util import check_grads

check_grads(loss, (W, b), order=2)  # check up to 2nd order derivatives