In [1]:
import numpy as np
import jax
import jax.numpy as jnp
from jax import grad, value_and_grad, jit, random, jacfwd, jacrev
import optax

In [2]:
key = random.PRNGKey(42)

Computing derivatives and gradients of both basic and user-defined functions is really easy!

In [3]:
grad_sin = grad(jnp.sin)

print(f"Value {jnp.sin(jnp.pi)}")
print(f"Derivative value {grad_sin(jnp.pi)}")

Value -8.742277657347586e-08
Derivative value -1.0


Let's try dffrentiating the loss function of a very simple binary classifier with respect to the parameters. 

In [4]:
# Outputs the logits of the labels
def predict(W, b, inputs):
    return jnp.dot(inputs, W) + b

# Build a toy dataset.
inputs = jnp.array([[0.52, 1.12,  0.77],
                   [0.88, -1.08, 0.15],
                   [0.52, 0.06, -1.30],
                   [0.74, -2.49, 1.39],
                   [0.76, -0.89, -1.01]])
targets = jnp.array([1, 1, 0, 1, 0])

# Cross entropy loss function
def loss(params_dict):
    preds = predict(params_dict['W'], params_dict['b'], inputs)
    return jnp.sum(optax.softmax_cross_entropy_with_integer_labels(preds, targets))

In [5]:
# Initialize random model coefficients
key, W_key, b_key = random.split(key, 3)
W = random.normal(W_key, (3,2))
b = random.normal(b_key, ())

In [6]:
# Let's see the value of the parameters
print(f"W: {W}")
print(f"b {b}")

W: [[-1.8231415  -0.472541  ]
 [ 0.41849834 -1.598711  ]
 [ 1.1073328   1.028033  ]]
b -1.2226542234420776


In [7]:
print(f"The logits {predict(W, b, inputs)}")

The logits [[-0.84932333 -2.4673467 ]
 [-3.1128972   0.24232256]
 [-3.5851107  -2.900741  ]
 [-2.0746472   3.837422  ]
 [-4.099111   -1.197246  ]]


In [8]:
# The value of the loss
print(loss({'W': W, 'b': b}))

5.8840284


In [9]:
# The gradient of the loss
print(grad(loss)({'W': W, 'b': b}))

{'W': Array([[-0.60045743,  0.60045743],
       [ 1.6953038 , -1.6953039 ],
       [ 2.472938  , -2.4729378 ]], dtype=float32), 'b': Array(0., dtype=float32)}


In [10]:
# Or we could have computed both the value and the gradient in one shot
loss_value, Wb_grad = value_and_grad(loss)({'W': W, 'b': b})
print('loss value', loss_value)
print('gradient value', Wb_grad)

loss value 5.8840284
gradient value {'W': Array([[-0.60045743,  0.60045743],
       [ 1.6953038 , -1.6953039 ],
       [ 2.472938  , -2.4729378 ]], dtype=float32), 'b': Array(0., dtype=float32)}


# Jacobians and Hessians

Remember that vectors are rows and the Jacobians are transposed compared to the math notation.

In [11]:
# Isolate the function from the weight matrix to the predictions
fun = lambda W: jax.nn.sigmoid(predict(W, b, inputs))

# So what is the value
print(f"Value {fun(W)}\n")

J = jacfwd(fun)(W)
print("jacfwd result, with shape", J.shape)
print(J)

J = jacrev(fun)(W)
print("jacrev result, with shape", J.shape)
print(J)

Value [[0.29957482 0.07817923]
 [0.04257838 0.560286  ]
 [0.0269852  0.05211694]
 [0.11158551 0.97890544]
 [0.01631676 0.23196553]]

jacfwd result, with shape (5, 2, 3, 2)
[[[[ 0.10911146  0.        ]
   [ 0.2350093   0.        ]
   [ 0.1615689   0.        ]]

  [[ 0.          0.03747498]
   [ 0.          0.08071535]
   [ 0.          0.0554918 ]]]


 [[[ 0.03587361  0.        ]
   [-0.0440267   0.        ]
   [ 0.00611482  0.        ]]

  [[ 0.          0.21680173]
   [ 0.         -0.26607487]
   [ 0.          0.03695484]]]


 [[[ 0.01365364  0.        ]
   [ 0.00157542  0.        ]
   [-0.0341341   0.        ]]

  [[ 0.          0.02568839]
   [ 0.          0.00296405]
   [ 0.         -0.06422099]]]


 [[[ 0.0733593   0.        ]
   [-0.24684411  0.        ]
   [ 0.13779652  0.        ]]

  [[ 0.          0.01528069]
   [ 0.         -0.05141746]
   [ 0.          0.02870292]]]


 [[[ 0.01219839  0.        ]
   [-0.01428496  0.        ]
   [-0.01621102  0.        ]]

  [[ 0.          0.

In [12]:
# Timing the forward mode
%timeit -n10 jacfwd(fun)(W)

3.62 ms ± 216 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
# Timing the reverse mode
%timeit -n10 jacrev(fun)(W)

4.54 ms ± 252 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)


Composition of two of these functions gives us a way to compute the Hessian

In [14]:
def hessian(f):
    return jacfwd(jacrev(f))

H = hessian(fun)(W)
print("hessian, with shape", H.shape)
print(H)

hessian, with shape (5, 2, 3, 2, 3, 2)
[[[[[[ 2.27434319e-02  0.00000000e+00]
     [ 4.89858575e-02  0.00000000e+00]
     [ 3.36777791e-02  0.00000000e+00]]

    [[ 0.00000000e+00  0.00000000e+00]
     [ 0.00000000e+00  0.00000000e+00]
     [ 0.00000000e+00  0.00000000e+00]]]


   [[[ 4.89858538e-02  0.00000000e+00]
     [ 1.05508007e-01  0.00000000e+00]
     [ 7.25367591e-02  0.00000000e+00]]

    [[ 0.00000000e+00  0.00000000e+00]
     [ 0.00000000e+00  0.00000000e+00]
     [ 0.00000000e+00  0.00000000e+00]]]


   [[[ 3.36777754e-02  0.00000000e+00]
     [ 7.25367516e-02  0.00000000e+00]
     [ 4.98690195e-02  0.00000000e+00]]

    [[ 0.00000000e+00  0.00000000e+00]
     [ 0.00000000e+00  0.00000000e+00]
     [ 0.00000000e+00  0.00000000e+00]]]]



  [[[[ 0.00000000e+00  0.00000000e+00]
     [ 0.00000000e+00  0.00000000e+00]
     [ 0.00000000e+00  0.00000000e+00]]

    [[ 0.00000000e+00  1.64400339e-02]
     [ 0.00000000e+00  3.54093052e-02]
     [ 0.00000000e+00  2.43438967e-02]]]



This shape makes sense: if we start with a function $f : \mathbb{R}^n \to \mathbb{R}^m$, then at a point $x \in \mathbb{R}^n$ we expect to get the shapes

* $f(x) \in \mathbb{R}^m$, the value of $f$ at $x$,
* $\partial f(x) \in \mathbb{R}^{m \times n}$, the Jacobian matrix at $x$,
* $\partial^2 f(x) \in \mathbb{R}^{m \times n \times n}$, the Hessian at $x$.

# JVP and VJP

Mathematically, given a function $f : \mathbb{R}^n \to \mathbb{R}^m$, the Jacobian of $f$ evaluated at an input point $x \in \mathbb{R}^n$, denoted $\partial f(x)$, is often thought of as a matrix in $\mathbb{R}^m \times \mathbb{R}^n$:

$\qquad \partial f(x) \in \mathbb{R}^{m \times n}$.

But we can also think of $\partial f(x)$ as a linear map, which maps the tangent space of the domain of $f$ at the point $x$ (which is just another copy of $\mathbb{R}^n$) to the tangent space of the codomain of $f$ at the point $f(x)$ (a copy of $\mathbb{R}^m$):

$\qquad \partial f(x) : \mathbb{R}^n \to \mathbb{R}^m$.

This map is called the pushforward map.The Jacobian matrix is just the matrix for this linear map in a standard basis.

We call that mapping, from $(x, v)$ pairs to output tangent vectors, the *Jacobian-vector product*.

In [15]:
from jax import jvp

f = lambda W: predict(W, b, inputs)

key, subkey = random.split(key)
v = random.normal(subkey, W.shape)
print(f"Tangent vector of shape (3,2): {v}")

# Push forward the vector `v` along `f` evaluated at `W`
y, u = jvp(f, (W,), (v,))

Tangent vector of shape (3,2): [[-0.5727056   0.08604095]
 [-0.3807549  -0.76770204]
 [ 0.8285016   0.6423712 ]]


In [16]:
print(f"Point in the range, batch of five: {y}")
print(f"Pushforward, batch of five: {u}")

Point in the range, batch of five: [[-0.8493234  -2.4673462 ]
 [-3.1128972   0.24232256]
 [-3.5851107  -2.900741  ]
 [-2.0746472   3.8374214 ]
 [-4.0991116  -1.197246  ]]
Pushforward, batch of five: [[-0.08630621 -0.3204592 ]
 [ 0.03150962  1.00119   ]
 [-1.3977042  -0.8364033 ]
 [ 1.6758946   2.8681443 ]
 [-0.93317103  0.09985101]]


Where forward-mode gives us back a function for evaluating Jacobian-vector products, reverse-mode is a way to get back a function for evaluating vector-Jacobian products.

In [17]:
from jax import vjp

y, vjp_fun = vjp(f, W)
print(f"Output shape {y.shape}")

key, subkey = random.split(key)
u = random.normal(subkey, y.shape)
print(u)

Output shape (5, 2)
[[ 1.209467   -0.31576344]
 [ 1.5915456   1.2171068 ]
 [-0.06524966  0.62372786]
 [ 1.2134361  -0.27994543]
 [-0.18560699 -0.34305805]]


In [18]:
# Pull back the covector `u` along `f` evaluated at `W`
v = vjp_fun(u)
print(f"Pullback vector {v}")

Pullback vector (Array([[ 2.7524345 ,  0.7633117 ],
       [-3.224447  , -0.62832105],
       [ 3.1289852 , -0.9140535 ]], dtype=float32),)


1. Use %timeit to compare jacfwd and jacrev for a very wide function (n >> m).
2. Use %timeit to compare jacfwd and jacrev for a very tall function (n << m).
3. Write a routine to differentiate a holomorphic complex valued function of a complex variable.