In [None]:
# 梯度：grad(f)(x)
# Hessian-vector products with grad-of-grad：grad(lambda x: jnp.vdot(grad(f)(x), v))(x)
# full Jacobian matrices using the jacfwd and jacrev functions
# VJPs, JVPs

In [1]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

key = random.PRNGKey(0)

No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


In [2]:
grad_tanh = grad(jnp.tanh)
print(grad_tanh(2.0))
print(grad(grad(jnp.tanh))(2.0))
print(grad(grad(grad(jnp.tanh)))(2.0))

0.070650816
-0.13621868
0.25265405


In [3]:
def sigmoid(x):
    return 0.5 * (jnp.tanh(x / 2) + 1)


# Outputs probability of a label being true.
def predict(W, b, inputs):
    return sigmoid(jnp.dot(inputs, W) + b)


# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12, 0.77], [0.88, -1.08, 0.15],
                    [0.52, 0.06, -1.30], [0.74, -2.49, 1.39]])
targets = jnp.array([True, True, False, True])


# Training loss is the negative log-likelihood of the training examples.
def loss(W, b):
    preds = predict(W, b, inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))


# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3, ))
b = random.normal(b_key, ())

In [4]:
# Differentiate `loss` with respect to the first positional argument:
W_grad = grad(loss, argnums=0)(W, b)
print('W_grad', W_grad)

# Since argnums=0 is the default, this does the same thing:
W_grad = grad(loss)(W, b)
print('W_grad', W_grad)

# But we can choose different values too, and drop the keyword:
b_grad = grad(loss, 1)(W, b)
print('b_grad', b_grad)

# Including tuple values
W_grad, b_grad = grad(loss, (0, 1))(W, b)
print('W_grad', W_grad)
print('b_grad', b_grad)

W_grad [-0.16965583 -0.8774644  -1.4901346 ]
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245
W_grad [-0.16965583 -0.8774644  -1.4901346 ]
b_grad -0.29227245


In [5]:
def loss2(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    label_probs = preds * targets + (1 - preds) * (1 - targets)
    return -jnp.sum(jnp.log(label_probs))


print(grad(loss2)({'W': W, 'b': b}))

{'W': Array([-0.16965583, -0.8774644 , -1.4901346 ], dtype=float32), 'b': Array(-0.29227245, dtype=float32)}


In [6]:
from jax import value_and_grad

loss_value, Wb_grad = value_and_grad(loss, (0, 1))(W, b)
print('loss value', loss_value)
print('loss value', loss(W, b))

loss value 3.0519385
loss value 3.0519385


In [7]:
# Set a step size for finite differences calculations
eps = 1e-4

# Check b_grad with scalar finite differences
b_grad_numerical = (loss(W, b + eps / 2.) - loss(W, b - eps / 2.)) / eps
print('b_grad_numerical', b_grad_numerical)
print('b_grad_autodiff', grad(loss, 1)(W, b))

# Check W_grad with finite differences in a random direction
key, subkey = random.split(key)
vec = random.normal(subkey, W.shape)
unitvec = vec / jnp.sqrt(jnp.vdot(vec, vec))
W_grad_numerical = (loss(W + eps / 2. * unitvec, b) -
                    loss(W - eps / 2. * unitvec, b)) / eps
print('W_dirderiv_numerical', W_grad_numerical)
print('W_dirderiv_autodiff', jnp.vdot(grad(loss)(W, b), unitvec))

b_grad_numerical -0.29325485
b_grad_autodiff -0.29227245
W_dirderiv_numerical -0.2002716
W_dirderiv_autodiff -0.19909117


In [8]:
from jax.test_util import check_grads

check_grads(loss, (W, b), order=2)  # check up to 2nd order derivatives

In [9]:
def hvp(f, x, v):
    return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)

In [10]:
from jax import jacfwd, jacrev

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

J = jacfwd(f)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(f)(W)
print("jacrev result, with shape", J.shape)
print(J)

jacfwd result, with shape (4, 3)
[[ 0.05981758  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188288  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
jacrev result, with shape (4, 3)
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]


In [11]:
def predict_dict(params, inputs):
    return predict(params['W'], params['b'], inputs)


J_dict = jacrev(predict_dict)({'W': W, 'b': b}, inputs)
for k, v in J_dict.items():
    print("Jacobian from {} to logits is".format(k))
    print(v)

Jacobian from W to logits is
[[ 0.05981757  0.12883787  0.08857603]
 [ 0.04015916 -0.04928625  0.00684531]
 [ 0.12188289  0.01406341 -0.3047072 ]
 [ 0.00140431 -0.00472531  0.00263782]]
Jacobian from b to logits is
[0.11503381 0.04563541 0.23439017 0.00189771]


In [12]:
def hessian(f):
    return jacfwd(jacrev(f))


H = hessian(f)(W)
print("hessian, with shape", H.shape)
print(H)

hessian, with shape (4, 3, 3)
[[[ 0.02285465  0.04922541  0.03384247]
  [ 0.04922541  0.10602397  0.07289147]
  [ 0.03384247  0.07289147  0.05011288]]

 [[-0.03195215  0.03921401 -0.00544639]
  [ 0.03921401 -0.04812629  0.00668421]
  [-0.00544639  0.00668421 -0.00092836]]

 [[-0.01583708 -0.00182736  0.03959271]
  [-0.00182736 -0.00021085  0.00456839]
  [ 0.03959271  0.00456839 -0.09898177]]

 [[-0.00103524  0.00348343 -0.00194457]
  [ 0.00348343 -0.01172127  0.0065432 ]
  [-0.00194457  0.0065432  -0.00365263]]]


In [14]:
# Jacobian-Vector products (JVPs, aka forward-mode autodiff)
# JVPs（雅可比矩阵-向量积）是一种在不显式计算整个雅可比矩阵的情况下计算雅可比矩阵和向量的乘积的方法。这种方法在雅可比矩阵很大且计算代价昂贵时很有用。JVPs可以使用前向自动微分来计算。
# 给定一个函数f，一个输入点x和一个切向量v，jvp(f, (x,), (v,))返回一对值，分别为f在x处的值和f在x处的JVP应用于v的值。
# (x, v) \mapsto (f(x), \partial f(x) v)
from jax import jvp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

key, subkey = random.split(key)
v = random.normal(subkey, W.shape)

# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W, ), (v, ))

In [15]:
# Vector-Jacobian products (VJPs, aka reverse-mode autodiff)
from jax import vjp

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

y, vjp_fun = vjp(f, W)

key, subkey = random.split(key)
u = random.normal(subkey, y.shape)

# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)

In [16]:
from jax import vjp


def vgrad(f, x):
    y, vjp_fn = vjp(f, x)
    return vjp_fn(jnp.ones(y.shape))[0]


print(vgrad(lambda x: 3 * x**2, jnp.ones((2, 2))))


[[6. 6.]
 [6. 6.]]


In [19]:
from jax import jvp, grad

# forward-over-reverse
def hvp(f, primals, tangents):
    return jvp(grad(f), primals, tangents)[1]

def f(X):
    return jnp.sum(jnp.tanh(X)**2)


key, subkey1, subkey2 = random.split(key, 3)
X = random.normal(subkey1, (30, 40))
V = random.normal(subkey2, (30, 40))

ans1 = hvp(f, (X, ), (V, ))
ans2 = jnp.tensordot(hessian(f)(X), V, 2)

print(jnp.allclose(ans1, ans2, 1e-4, 1e-4))

True


In [21]:
# reverse-over-forward
def hvp_revfwd(f, primals, tangents):
    g = lambda primals: jvp(f, primals, tangents)[1]
    return grad(g)(primals)


In [22]:
# reverse-over-reverse, only works for single arguments
def hvp_revrev(f, primals, tangents):
  x, = primals
  v, = tangents
  return grad(lambda x: jnp.vdot(grad(f)(x), v))(x)


print("Forward over reverse")
%timeit -n10 -r3 hvp(f, (X,), (V,))
print("Reverse over forward")
%timeit -n10 -r3 hvp_revfwd(f, (X,), (V,))
print("Reverse over reverse")
%timeit -n10 -r3 hvp_revrev(f, (X,), (V,))

print("Naive full Hessian materialization")
%timeit -n10 -r3 jnp.tensordot(hessian(f)(X), V, 2)

Forward over reverse
4.43 ms ± 536 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over forward
8.43 ms ± 5.04 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Reverse over reverse
10.5 ms ± 5.38 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)
Naive full Hessian materialization
49.5 ms ± 3.47 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)


In [23]:
# Jacobian-Matrix and Matrix-Jacobian products

# Isolate the function from the weight matrix to the predictions
f = lambda W: predict(W, b, inputs)

# Pull back the covectors `m_i` along `f`, evaluated at `W`, for all `i`.
# First, use a list comprehension to loop over rows in the matrix M.
def loop_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    return jnp.vstack([vjp_fun(mi) for mi in M])

# Now, use vmap to build a computation that does a single fast matrix-matrix
# multiply, rather than an outer loop over vector-matrix multiplies.
def vmap_mjp(f, x, M):
    y, vjp_fun = vjp(f, x)
    outs, = vmap(vjp_fun)(M)
    return outs

key = random.PRNGKey(0)
num_covecs = 128
U = random.normal(key, (num_covecs,) + y.shape)

loop_vs = loop_mjp(f, W, M=U)
print('Non-vmapped Matrix-Jacobian product')
%timeit -n10 -r3 loop_mjp(f, W, M=U)

print('\nVmapped Matrix-Jacobian product')
vmap_vs = vmap_mjp(f, W, M=U)
%timeit -n10 -r3 vmap_mjp(f, W, M=U)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Matrix-Jacobian Products should be identical'

Non-vmapped Matrix-Jacobian product
142 ms ± 1.24 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Matrix-Jacobian product
4.98 ms ± 49.3 µs per loop (mean ± std. dev. of 3 runs, 10 loops each)


In [24]:
def loop_jmp(f, W, M):
    # jvp immediately returns the primal and tangent values as a tuple,
    # so we'll compute and select the tangents in a list comprehension
    return jnp.vstack([jvp(f, (W,), (mi,))[1] for mi in M])

def vmap_jmp(f, W, M):
    _jvp = lambda s: jvp(f, (W,), (s,))[1]
    return vmap(_jvp)(M)

num_vecs = 128
S = random.normal(key, (num_vecs,) + W.shape)

loop_vs = loop_jmp(f, W, M=S)
print('Non-vmapped Jacobian-Matrix product')
%timeit -n10 -r3 loop_jmp(f, W, M=S)
vmap_vs = vmap_jmp(f, W, M=S)
print('\nVmapped Jacobian-Matrix product')
%timeit -n10 -r3 vmap_jmp(f, W, M=S)

assert jnp.allclose(loop_vs, vmap_vs), 'Vmap and non-vmapped Jacobian-Matrix products should be identical'

Non-vmapped Jacobian-Matrix product
448 ms ± 37.3 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)

Vmapped Jacobian-Matrix product
7.79 ms ± 1.06 ms per loop (mean ± std. dev. of 3 runs, 10 loops each)


In [25]:
from jax import jacrev as builtin_jacrev

def our_jacrev(f):

    def jacfun(x):
        y, vjp_fun = vjp(f, x)
        # Use vmap to do a matrix-Jacobian product.
        # Here, the matrix is the Euclidean basis, so we get all
        # entries in the Jacobian at once.
        J, = vmap(vjp_fun, in_axes=0)(jnp.eye(len(y)))
        return J

    return jacfun


assert jnp.allclose(
    builtin_jacrev(f)(W),
    our_jacrev(f)(W)), 'Incorrect reverse-mode Jacobian results!'

In [26]:
from jax import jacfwd as builtin_jacfwd


def our_jacfwd(f):

    def jacfun(x):
        _jvp = lambda s: jvp(f, (x, ), (s, ))[1]
        Jt = vmap(_jvp, in_axes=1)(jnp.eye(len(x)))
        return jnp.transpose(Jt)

    return jacfun


assert jnp.allclose(
    builtin_jacfwd(f)(W),
    our_jacfwd(f)(W)), 'Incorrect forward-mode Jacobian results!'

In [27]:
def f(x):
    try:
        if x < 3:
            return 2 * x**3
        else:
            raise ValueError
    except ValueError:
        return jnp.pi * x


y, f_vjp = vjp(f, 4.)
print(jit(f_vjp)(1.))

(Array(3.1415927, dtype=float32, weak_type=True),)


In [28]:
# Complex numbers and differentiation
def f(z):
    x, y = jnp.real(z), jnp.imag(z)
    return u(x, y) + v(x, y) * 1j


def g(x, y):
    return (u(x, y), v(x, y))


In [30]:
def check(seed):
    key = random.PRNGKey(seed)

    # random coeffs for u and v
    key, subkey = random.split(key)
    a, b, c, d = random.uniform(subkey, (4, ))

    def fun(z):
        x, y = jnp.real(z), jnp.imag(z)
        return u(x, y) + v(x, y) * 1j

    def u(x, y):
        return a * x + b * y

    def v(x, y):
        return c * x + d * y

    # primal point
    key, subkey = random.split(key)
    x, y = random.uniform(subkey, (2, ))
    z = x + y * 1j

    # tangent vector
    key, subkey = random.split(key)
    c, d = random.uniform(subkey, (2, ))
    z_dot = c + d * 1j

    # check jvp
    _, ans = jvp(fun, (z, ), (z_dot, ))
    expected = (grad(u, 0)(x, y) * c + grad(u, 1)(x, y) * d +
                grad(v, 0)(x, y) * c * 1j + grad(v, 1)(x, y) * d * 1j)
    print(jnp.allclose(ans, expected))


check(0)
check(1)
check(2)

True
True
True


In [31]:
def check(seed):
    key = random.PRNGKey(seed)

    # random coeffs for u and v
    key, subkey = random.split(key)
    a, b, c, d = random.uniform(subkey, (4, ))

    def fun(z):
        x, y = jnp.real(z), jnp.imag(z)
        return u(x, y) + v(x, y) * 1j

    def u(x, y):
        return a * x + b * y

    def v(x, y):
        return c * x + d * y

    # primal point
    key, subkey = random.split(key)
    x, y = random.uniform(subkey, (2, ))
    z = x + y * 1j

    # cotangent vector
    key, subkey = random.split(key)
    c, d = random.uniform(subkey, (2, ))
    z_bar = jnp.array(c + d * 1j)  # for dtype control

    # check vjp
    _, fun_vjp = vjp(fun, z)
    ans, = fun_vjp(z_bar)
    expected = (grad(u, 0)(x, y) * c + grad(v, 0)(x, y) * (-d) +
                grad(u, 1)(x, y) * c * (-1j) + grad(v, 1)(x, y) * (-d) * (-1j))
    assert jnp.allclose(ans, expected, atol=1e-5, rtol=1e-5)


check(0)
check(1)
check(2)

In [32]:
def f(z):
    x, y = jnp.real(z), jnp.imag(z)
    return x**2 + y**2


z = 3. + 4j
grad(f)(z)

Array(6.-8.j, dtype=complex64)

In [33]:
def f(z):
    x, y = jnp.real(z), jnp.imag(z)
    return x**2 + y**2


z = 3. + 4j
grad(f)(z)

Array(6.-8.j, dtype=complex64)

In [34]:
def f(z):
    return jnp.sin(z)


z = 3. + 4j
grad(f, holomorphic=True)(z)

Array(-27.034946-3.8511534j, dtype=complex64, weak_type=True)

In [35]:
def f(z):
    return jnp.conjugate(z)


z = 3. + 4j
grad(f, holomorphic=True)(z)  # f is not actually holomorphic!

Array(1.-0.j, dtype=complex64, weak_type=True)

In [36]:
A = jnp.array([[5., 2. + 3j, 5j], [2. - 3j, 7., 1. + 7j], [-5j, 1. - 7j, 12.]])


def f(X):
    L = jnp.linalg.cholesky(X)
    return jnp.sum((L - jnp.sin(L))**2)


grad(f, holomorphic=True)(A)

Array([[-0.7534186  +0.j      , -3.0509028 -10.940545j,
         5.9896846  +3.542303j],
       [-3.0509028 +10.940545j, -8.904491   +0.j      ,
        -5.1351523  -6.559373j],
       [ 5.9896846  -3.542303j, -5.1351523  +6.559373j,
         0.01320427 +0.j      ]], dtype=complex64)