# DIFFERENTIABILITY

In this notebook we will present the differences between forward and reverse mode automatic differentiation (AD) as well as the jax routines to retrieve them.

## JAX imports

## Beginner
### Prerequisites
Numpy

### Imports

In [1]:
import jax
from jax import grad, jvp, vjp, jacfwd, jacrev, linearize
import jax.numpy as jnp

import numpy as np

### Example

We will consider a "canonical" JAX example function:

In [2]:
def my_fun(x, y):
    return x * jnp.sin(y) + jnp.cos(x)

### Theoretical perspective

The goal here is less to present a comprehensive mathematical study of AD, and more to give an intuition behind the concepts that dictated JAX implementation. I'll be following this [presentation](https://www.youtube.com/watch?v=zqaJeKZXS1U), so you can find references embedded within it.

#### Forward mode differentiation

Let's consider a function $f \colon \mathbb{R}^n \to  \mathbb{R}^m$, provided it is differentiable at a point $x$, its directional Taylor expansion is given by:
$$
    f(x + \epsilon y + o(\epsilon)) = f(x) + \epsilon J_x[f] y + o(\epsilon)
$$
Let's define the concatenation of $x$ and of said "dual" number $\vec{y}$: $x \triangleright y$, then the Taylor identity can be written as
$$
    \tilde{f}(x \triangleright y) = f(x)\triangleright J_x[f]y
$$
This directly expresses the forward chain rule as a composition of $\tilde{f}$:
$$
\begin{align}
    \tilde{f} \circ \tilde{g}(x \triangleright y) 
        &= \tilde{f}(g(x) \triangleright J_x[g]y)\\
        &= f \circ g(x) \triangleright J_{g(x)}J_x[g]y
\end{align}
$$

#### Reverse mode differentiation

On the other end, what can we say about $y^\top J_x[f]$? The chain rule for this is now written
$$
    y^\top J_x[f\circ g] = (y^\top J_{g(x)}[f]) J_{x}[g]
$$

So what can be seen now is that to compute the forward mode derivative, you need to "carry" your gradients as you compute the output, but for the backward mode you need to remember the intermediate values of the functions involved.

#### Duality

Provided the Jabobian don't depend on the dual number $y$, both operations are dual to each other, so that one can be retrieved from the other via a transposition (TODO I still need to understand the exact implementation).

This allowed JAX developers to implement a free conversion from JVP (easier to implement) to VJP rule.

#### Complexity

In practice the complexity of each methods will depend on the input-output sizes: if $f \colon \mathbb{R}^n \to  \mathbb{R}^m$, its Jacobian will be in $\mathbb{R}^{n \times n}$ so that the forward operation will have complexity $O(n)$ and the backward one $O(m)$ with the additional constraint that the intermediate inputs will need to be remembered.

#### Usage

The usage of both methods actually follows directly from the derivation: the forward mode computes the output gradient directly, while the backward only returns a function that is capable of evaluating the gradient:

In [3]:
print(jvp(my_fun, (0.5, 0.75), (1.0, 1.0)))

An NVIDIA GPU may be present on this machine, but a CUDA-enabled jaxlib is not installed. Falling back to cpu.


(Array(1.2184019, dtype=float32, weak_type=True), Array(0.56805766, dtype=float32, weak_type=True))


In [4]:
val, my_fun_vjp = vjp(my_fun, 0.5, 0.75)
print(val)
print(my_fun_vjp)

1.2184019
Partial(_HashableCallableShim(functools.partial(<function _vjp_pullback_wrapper at 0x7e8d831e6340>, 'my_fun', [ShapedArray(float32[], weak_type=True)], (PyTreeDef(*), PyTreeDef((*, *))))), Partial(_HashableCallableShim(functools.partial(<function vjp.<locals>.unbound_vjp at 0x7e8d74b785e0>, [(ShapedArray(float32[], weak_type=True), None)], { lambda a:f32[] b:f32[] c:f32[] d:f32[]; e:f32[] f:f32[]. let
    g:f32[] = pjit[
      name=sin
      jaxpr={ lambda ; h:f32[] i:f32[]. let j:f32[] = mul h i in (j,) }
    ] f a
    k:f32[] = pjit[
      name=_multiply
      jaxpr={ lambda ; l:f32[] m:f32[] n:f32[] o:f32[]. let
          p:f32[] = mul l n
          q:f32[] = mul o m
          r:f32[] = add_any p q
        in (r,) }
    ] e g b c
    s:f32[] = pjit[
      name=cos
      jaxpr={ lambda ; t:f32[] u:f32[]. let
          v:f32[] = mul t u
          w:f32[] = neg v
        in (w,) }
    ] e d
    x:f32[] = pjit[
      name=_add
      jaxpr={ lambda ; y:f32[] z:f32[]. let ba:f32

In [5]:
print(my_fun_vjp(1.0))

(Array(0.20221323, dtype=float32, weak_type=True), Array(0.36584443, dtype=float32, weak_type=True))


### Jacobians

As a consequence it is also possible to compute Jacobians, using one or the other method:

In [6]:
jacfwd(my_fun, argnums=(0, 1))(0.5, 1.0)

(Array(0.3620454, dtype=float32, weak_type=True),
 Array(0.27015114, dtype=float32, weak_type=True))

In [7]:
jacrev(my_fun, argnums=(0, 1))(0.5, 1.0)

(Array(0.3620454, dtype=float32, weak_type=True),
 Array(0.27015114, dtype=float32, weak_type=True))

### Questions:

#### Q1: 
Compute the derivative of the following $\ln(1 + \exp(x))$ at $x = 100$, what do you think is happening?
```python
def log1p(x):
    return jnp.log(1 + jnp.exp(x))
```

#### Q2:
Compute the Hessian of the norm function: $(\sum_{i=1}^n x_i^2)^{\frac 1 2}$

## Intermediate
### Prerequisites
- Beginner autodiff
- Beginner loops

### Imports

In [8]:
from functools import partial

from jax import custom_jvp, custom_jvp, jvp, vjp, make_jaxpr
from jax.lax import scan, while_loop
import jax.numpy as jnp

import numpy as np

In this section we will understand why sometimes we need to implement our AD routines manually and how to do so.

### Example

We will take two different examples, the first one is the `log1p` from the beginner section, the second is the iterative implementation of the square root function in the controlflow.loops notebook.

#### LOG1P

This example will help us highlight some problems that can happen when relying on AD:

In [9]:
def log1p(x):
    return 1.0 / jnp.log(1 + jnp.exp(x))


val, log1p_vjp = vjp(log1p, 100.0)
print(val)
print(log1p_vjp(1.0))

0.0
(Array(nan, dtype=float32, weak_type=True),)


Clearly the derivative should be $0.$, so what's happening?

In [10]:
make_jaxpr(log1p_vjp)(1.0)

{ lambda a:f32[] b:f32[] c:f32[]; d:f32[]. let
    e:f32[] = mul d a
    f:f32[] = mul e 1.0
    g:f32[] = neg f
    h:f32[] = div g b
    i:f32[] = mul h c
  in (i,) }

Oh! It's an indefinite form $\frac{\inf}{\inf}$, we need to resolve it:

In [11]:
log1p = custom_jvp(log1p)


@log1p.defjvp
def _log1p_jvp(primals, tangents):
    (x,) = primals
    x_dot = (tangents,)

    ex = jnp.exp(x)
    primal_out = 1.0 / jnp.log(1 + ex)

    ex_inv = 1 / ex
    tangent_out = 1 / ((1 + ex_inv) * primal_out**2)
    return primal_out, tangent_out


# otherwise:
# log1p.defjvps(lambda x_dot, primal_out, x: ...)

Let's see what this gives us:

In [12]:
val, log1p_vjp = vjp(log1p, 100.0)
print(val)
print(log1p_vjp(1.0))

0.0
(Array(0., dtype=float32, weak_type=True),)


#### SQRT

We will consider the while loop implementation of the square root:

In [13]:
def while_loop_sqrt(x, x0=1.0, n_iter=10):

    def cond(carry):
        i, _val = carry
        return i < n_iter

    def body(carry):
        i, val = carry
        return i + 1, 0.5 * (val + x / val)

    _, res = while_loop(cond, body, (0, x0))
    return res

In [14]:
val, while_loop_sqrt_vjp = vjp(while_loop_sqrt, 100.0)
print(val)
print(while_loop_sqrt_vjp(1.0))

10.0


ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop with dynamic start/stop values. Try using lax.scan, or using fori_loop with static start/stop.

So we can't use the backward-mode for a while loop, that's a bummer!

As explained in the beginner section, reverse-mode needs to allocate memory, which can't be done dynamically in XLA, so that JAX doesn't allow for reverse mode differentiation through while_loops which could grow indefinitely. Instead some work has been planned in JAX to allow for bounded size while loops to be implemented using the scan syntax.

In [15]:
@partial(custom_jvp, nondiff_argnums=(1, 2))
def custom_jvp_while_loop_sqrt(x, x0=1.0, n_iter=10):
    sqrt_x = while_loop_sqrt(x, x0, n_iter)
    return sqrt_x


@custom_jvp_while_loop_sqrt.defjvp
def _custom_jvp_while_loop_sqrt(x0, n_iter, primals, tangents):
    (x,) = primals
    (x_dot,) = tangents
    sqrt_x = while_loop_sqrt(x, x0, n_iter)
    return sqrt_x, x_dot / (2 * sqrt_x)

In [16]:
custom_jvp_while_loop_sqrt(5.0)

Array(2.236068, dtype=float32, weak_type=True)

In [17]:
val, while_loop_sqrt_vjp = vjp(custom_jvp_while_loop_sqrt, 100.0)
print(val)
print(while_loop_sqrt_vjp(1.0))

10.0
(Array(0.05, dtype=float32, weak_type=True),)


### Questions:

#### Q1: 
Implement the VJP rule for `log1p`

#### Q2:
Can you think of other cases when it could be useful to implement custom gradient? Can the `sqrt` approach be generalised?

## Advanced
### Prerequisites
- Intermediate autodiff 
- Beginner randomness

Here we reproduce a subpart of the [FFJORD](https://arxiv.org/pdf/1810.01367.pdf) paper, which aimes at inferring an efficient way to estimate the trace of the Jacobian of a given bijective function.

### Imports

In [32]:
from jax import make_jaxpr, jvp, vjp, jit, linearize, vmap
from jax.lax import associative_scan, scan
from jax.random import normal, PRNGKey

import jax.numpy as jnp
import numpy as np

### Example

We will first present the naive (and exact) way of computing the trace of a Jacobian using JAX primitives, then discuss the complexity of the operation and introduce the [Hutchinson trick](https://www.tandfonline.com/doi/abs/10.1080/03610919008812866) as a way to reduce it.

In [33]:
D = 5
a = np.random.randn(D, D)
b = np.random.randn(D)


def affine_function(x):
    return a @ x + b


print(np.trace(a))

0.6597556874713932


How do we compute the trace of the Jacobian using JVP or VJP operations only?
Let $(e_i)_{i=1}^D$ be the canonical basis of $\mathbb{R}^d$, then for any $M \in \mathbb{R}^{d \times d}$,
$$
\mathrm{tr}(M) = \sum_{i=1}^d e_i^\top M e_i
$$
Or tautologically (in terms of identity, not computation)
$$
\mathrm{tr}(M) = \mathrm{tr}(M I_d)
$$

In [34]:
x = np.random.randn(D)

In [37]:
def compute_jac_trace(f, x):
    d = x.shape[0]
    eis = jnp.eye(d)
    val, lin_jvp = linearize(f, x)
    temp = vmap(lin_jvp)(eis)
    return jnp.trace(temp)

In [38]:
compute_jac_trace(affine_function, x)

Array(0.6597557, dtype=float32)

This has complexity $O(d^2)$ even though the JVP operation has complexity $O(d)$...

Instead one can use the Hutchinston trick:

$\mathrm{tr}(M) = \mathbb{E}\left[\epsilon^\top M \epsilon\right]$ where $\epsilon$ is for instance normally distributed: $\epsilon \sim \mathcal{N}(0, I_d)$

### Questions:

#### Q1:
Implement the Hutchinson trick to estimate the trace of the Jacobian of the affine function using Hessian vector products.

#### Q2:
Assume that $x(a)$ is the fixed point of a function $f(x, a)$, i.e., $x(a) = f(x(a), a)$. How would you compute the derivative of $x(a)$ with respect to $a$?
Take inspiration from the square root example above.