In [3]:
import jax
import numpy as np
import jax.numpy as jnp
from jax import random, jit, lax, vmap, pmap, grad, value_and_grad

In [4]:
rng = np.random.default_rng()
key = jax.random.PRNGKey(0)

## Gradients

The `grad` transform returns a function that can compute the gradient of the passed in function at any point. The original function should be a scalar-valued function, i.e., it can take in multi-dimensional tensors as input, but should output a scalar.

By default, if $f$ is a function of multiple arguments, `grad` will calculate the gradient w.r.t the first arg -
$$
f(x, y) \\
grad(f) = \frac{\partial f}{\partial x}
$$

But I can change this and specify the index of the arugments I want to calculate the gradients w.r.t.

Applying `grad` multiple times, gives me the second- third- and so on order derivatives.

$$
grad(grad(f)) = \frac{\partial^2 f}{\partial x^2}
$$

Of course this will only work if the gradient itself is a scalar. E.g., consider this vector input scalar output function -

$$
f(\mathbf w, b, \mathbf x) = \mathbf w^T \mathbf x + b \\
\frac{\partial f}{\partial \mathbf w} = \begin{bmatrix}
\frac{\partial f}{\partial w_1} \\
\frac{\partial f}{\partial w_2}
\end{bmatrix}
$$

Applying `grad` a second time will not work in this case. To calculate the Hessians I'll have to use `jacfwd` or `jacrev` which computes the gradient of vector valued function. Mathematically both do the same thing, but the implementation is different based on autograd algorithms, which I have now forgotten. 

In [5]:
def perceptron(w, b, x):
    return w.T @ x + b

In [6]:
perceptron_ = grad(perceptron, argnums=(0, 1))

In [7]:
w = jnp.array([1., 2.])
b = 0.5
x = jnp.array([1.5, 2.5])

In [8]:
perceptron(w, b, x)

Array(7., dtype=float32)

In [9]:
dw, db = perceptron_(w, b, x)
print(dw)
print(db)

[1.5 2.5]
1.0


A useful function for debugging is the `value_and_grad` function that returns a function that returns not just the gradient, but also the value.

In [10]:
perceptron_and_perceptron_ = value_and_grad(perceptron, argnums=(0, 1))

In [11]:
perceptron_and_perceptron_(w, b, x)

(Array(7., dtype=float32),
 (Array([1.5, 2.5], dtype=float32), Array(1., dtype=float32, weak_type=True)))

In [12]:
try:
    grad(grad(perceptron))(w, b, x)
except Exception as err:
    print(f"{type(err)}\n{err}")

<class 'TypeError'>
Gradient only defined for scalar-output functions. Output had shape: (2,).


In typical DL scenarios, the loss function is usually differentiated. However, instead of just returning the loss value, it is convenient for the loss function to also return the result of the forward pass, e.g., the logits calculated. This will help in calculating training metrics. To enable this, use the `aux=True` flag for both `grad` and `value_and_grad` functions.

In the example below, the `grad(loss)()` function will return the gradient as its first return value, but the second return value (the auxilliary value) will be just a normal computation.

In [19]:
def f(x):
    val = x**2
    aux = [x**3, x**4]
    return val, aux

In [20]:
dx, aux = grad(f, has_aux=True)(2.)
print(dx, aux)

4.0 [Array(8., dtype=float32, weak_type=True), Array(16., dtype=float32, weak_type=True)]


In [25]:
(val, aux), dx = value_and_grad(f, has_aux=True)(3.)
print(val)
print(aux)
print(dx)

9.0
[Array(27., dtype=float32, weak_type=True), Array(81., dtype=float32, weak_type=True)]
6.0
