# DIFFERENTIABILITY

In this notebook we will present the differences between forward and reverse mode automatic differentiation (AD) as well as the jax routines to retrieve them.

## JAX imports

## Beginner
### Prerequisites
Numpy

### Imports

In [3]:
from jax import grad, jvp, vjp, jacfwd, jacrev
import jax.numpy as jnp

import numpy as np

### Example

We will consider a "canonical" JAX example function:

In [19]:
def my_fun(x, y):
    return x * jnp.sin(y) + jnp.cos(x)

### Theoretical perspective

The goal here is less to present a comprehensive mathematical study of AD, and more to give an intuition behind the concepts that dictated JAX implementation. I'll be following this [presentation](https://www.youtube.com/watch?v=zqaJeKZXS1U), so you can find references embedded within it.

#### Forward mode differentiation

Let's consider a function $f \colon \mathbb{R}^n \to  \mathbb{R}^m$, provided it is differentiable at a point $x$, its directional Taylor expansion is given by:
$$
    f(x + \epsilon y + o(\epsilon)) = f(x) + \epsilon J_x[f] y + o(\epsilon)
$$
Let's define the concatenation of $x$ and of said "dual" number $\vec{y}$: $x \triangleright y$, then the Taylor identity can be written as
$$
    \tilde{f}(x \triangleright y) = f(x)\triangleright J_x[f]y
$$
This directly expresses the forward chain rule as a composition of $\tilde{f}$:
$$
\begin{align}
    \tilde{f} \circ \tilde{g}(x \triangleright y) 
        &= \tilde{f}(g(x) \triangleright J_x[g]y)\\
        &= f \circ g(x) \triangleright J_{g(x)}J_x[g]y
\end{align}
$$

#### Reverse mode differentiation

On the other end, what can we say about $y^\top J_x[f]$? The chain rule for this is now written
$$
    y^\top J_x[f\circ g] = (y^\top J_{g(x)}[f]) J_{x}[g]
$$

So what can be seen now is that to compute the forward mode derivative, you need to "carry" your gradients as you compute the output, but for the backward mode you need to remember the intermediate values of the functions involved.

#### Duality

Provided the Jabobian don't depend on the dual number $y$, both operations are dual to each other, so that one can be retrieved from the other via a transposition (TODO I still need to understand the exact implementation).

This allowed JAX developers to implement a free conversion from JVP (easier to implement) to VJP rule.

#### Complexity

In practice the complexity of each methods will depend on the input-output sizes: if $f \colon \mathbb{R}^n \to  \mathbb{R}^m$, its Jacobian will be in $\mathbb{R}^{n \times n}$ so that the forward operation will have complexity $O(n)$ and the backward one $O(m)$ with the additional constraint that the intermediate inputs will need to be remembered.

#### Usage

The usage of both methods actually follows directly from the derivation: the forward mode computes the output gradient directly, while the backward only returns a function that is capable of evaluating the gradient:

In [29]:
print(jvp(my_fun, (0.5, 0.75), (1., 1.)))

(DeviceArray(1.2184019, dtype=float32), DeviceArray(0.56805766, dtype=float32))


In [31]:
val, my_fun_vjp = vjp(my_fun, 0.5, 0.75)
print(val)
print(my_fun_vjp)

1.2184019
Partial(functools.partial(<function _vjp_pullback_wrapper at 0x7f5743ae98c0>, [dtype('float32')], (*, PyTreeDef(tuple, [*,*]))), Partial(functools.partial(<function vjp.<locals>.unbound_vjp at 0x7f573b0f2710>, [(ShapedArray(float32[]), *)], { lambda  ; c d.
  let e = mul c 0.681638777256012
      f = mul d 0.731688916683197
      g = mul 0.5 f
      h = add_any e g
      i = mul c 0.4794255495071411
      j = neg i
      k = add_any h j
  in (k,) }), ()))


In [34]:
print(my_fun_vjp(1.))

(DeviceArray(0.20221323, dtype=float32), DeviceArray(0.36584446, dtype=float32))


### Jacobians

As a consequence it is also possible to compute Jacobians, using one or the other method:

In [36]:
jacfwd(my_fun, argnums=(0, 1))(0.5, 1.)

(DeviceArray(0.36204547, dtype=float32),
 DeviceArray(0.27015114, dtype=float32))

In [37]:
jacrev(my_fun, argnums=(0, 1))(0.5, 1.)

(DeviceArray(0.36204547, dtype=float32),
 DeviceArray(0.27015114, dtype=float32))

### Questions:

#### Q1: 
Compute the derivative of the following $\ln(1 + \exp(x))$ at $x = 100$, what do you think is happening?
```python
def log1p(x):
    return jnp.log(1 + jnp.exp(x))
```

#### Q2:
Compute the Hessian of the norm function: $(\sum_{i=1}^n x_i^2)^{\frac 1 2}$

## Intermediate
### Prerequisites
- Beginner autodiff
- Beginner loops

### Imports

In [79]:
from functools import partial

from jax import custom_jvp, custom_jvp, jvp, vjp, make_jaxpr
from jax.lax import scan, while_loop
import jax.numpy as jax_np

import numpy as np

In this section we will understand why sometimes we need to implement our AD routines manually and how to do so.

### Example

We will take two different examples, the first one is the `log1p` from the beginner section, the second is the iterative implementation of the square root function in the controlflow.loops notebook.

#### LOG1P

This example will help us highlight some problems that can happen when relying on AD:

In [69]:
def log1p(x):
    return 1. / jnp.log(1 + jnp.exp(x))

val, log1p_vjp = vjp(log1p, 100.)
print(val)
print(log1p_vjp(1.))

0.0
(DeviceArray(nan, dtype=float32),)


Clearly the derivative should be $0.$, so what's happening?

In [62]:
make_jaxpr(log1p_vjp)(1.)

{ lambda  ; a.
  let b = mul a 0.0
      c = mul b 1.0
      d = neg c
      e = div d inf
      f = mul e inf
  in (f,) }

Oh! It's an indefinite form $\frac{\inf}{\inf}$, we need to resolve it:

In [72]:
log1p = custom_jvp(log1p)
@log1p.defjvp
def _log1p_jvp(primals, tangents):
    x, = primals
    x_dot = tangents,
    
    ex = jnp.exp(x)
    primal_out = 1. / jnp.log(1 + ex)
    
    ex_inv = 1 / ex
    tangent_out = 1 / ((1 + ex_inv) * primal_out ** 2)
    return primal_out, tangent_out

# otherwise: 
# log1p.defjvps(lambda x_dot, primal_out, x: ...)

Let's see what this gives us:

In [73]:
val, log1p_vjp = vjp(log1p, 100.)
print(val)
print(log1p_vjp(1.))

0.0
(array(0., dtype=float32),)


#### SQRT

We will consider the while loop implementation of the square root:

In [74]:
def while_loop_sqrt(x, x0=1., n_iter=10):
    
    def cond(carry):
        i, _val = carry
        return i < n_iter
    
    def body(carry):
        i, val = carry
        return i+1, 0.5 * (val + x / val)
    
    _, res = while_loop(cond, body, (0,  x0))
    return res

In [77]:
val, while_loop_sqrt_vjp = vjp(while_loop_sqrt, 100.)
print(val)
print(while_loop_sqrt_vjp(1.))

10.0


ValueError: Reverse-mode differentiation does not work for lax.while_loop or lax.fori_loop. Try using lax.scan instead.

So we can't use the backward-mode for a while loop, that's a bummer!

As explained in the beginner section, reverse-mode needs to allocate memory, which can't be done dynamically in XLA, so that JAX doesn't allow for reverse mode differentiation through while_loops which could grow indefinitely. Instead some work has been planned in JAX to allow for bounded size while loops to be implemented using the scan syntax.

In [87]:
@partial(custom_jvp, nondiff_argnums=(1, 2))
def custom_jvp_while_loop_sqrt(x, x0=1., n_iter=10):
    sqrt_x = while_loop_sqrt(x, x0, n_iter)
    return sqrt_x

@custom_jvp_while_loop_sqrt.defjvp
def _custom_jvp_while_loop_sqrt(x0, n_iter, primals, tangents):
    x, = primals
    x_dot, = tangents
    sqrt_x = while_loop_sqrt(x, x0, n_iter)
    return sqrt_x, x_dot / (2 * sqrt_x)
    

In [88]:
custom_jvp_while_loop_sqrt(5.)

DeviceArray(2.236068, dtype=float32)

In [89]:
val, while_loop_sqrt_vjp = vjp(custom_jvp_while_loop_sqrt, 100.)
print(val)
print(while_loop_sqrt_vjp(1.))

10.0
(DeviceArray(0.05, dtype=float32),)


### Questions:

#### Q1: 
Implement the vjp rule for `log1p`

#### Q2:
Can you think of other cases when it could be useful to implement custom gradient? Can the `sqrt` approach be generalised?

## Advanced
### Prerequisites
- Intermediate loops 

### Imports

In [57]:
from jax import make_jaxpr, jvp, vjp, jit
from jax.lax import associative_scan, scan

import jax.numpy as jax_np
import numpy as np

We now present an additional primitive which can be useful when dealing with associative binary operations (such as summation): `associative_scan`, also known as [prefix sum](https://en.wikipedia.org/wiki/Prefix_sum). It consists in applying recursive operations to subsets of the inputs (divide and conquer strategy) instead of applying it sequentially. This has the benefit of being easily parallelisable and is natively implemented in JAX. 

### Example

In practice it implements a parallelised version (see for example the [Wikipedia](https://en.wikipedia.org/wiki/Prefix_sum) article, or ask Fatemeh, she's an expert at it) of the following algorithm:
```python
def my_sequential_associative_scan(binary_op, xs):
    res = np.copy(xs)
    val = xs[0]
    for i, x in enumerate(xs[1:]):
        val = binary_op(val, x)
        res[i+1] = val
    return res
```

so that the cumulative sum would be for example written as 
```python
my_sequential_associative_scan(lambda x, y: x + y, np.arange(10))
```

Using JAX this would actually be written in the following wa:

In [52]:
def associative_cumulative_sum(xs):
    return associative_scan(lambda x, y: x + y, xs)

print(associative_cumulative_sum(np.arange(10.)))

[ 0.  1.  3.  6. 10. 15. 21. 28. 36. 45.]


Let's look at the code generated:

In [53]:
make_jaxpr(associative_cumulative_sum)(np.arange(10.))

{ lambda  ; a.
  let b = slice[ limit_indices=(9,)
                 start_indices=(0,)
                 strides=(2,) ] a
      c = slice[ limit_indices=(10,)
                 start_indices=(1,)
                 strides=(2,) ] a
      d = add b c
      e = slice[ limit_indices=(4,)
                 start_indices=(0,)
                 strides=(2,) ] d
      f = slice[ limit_indices=(5,)
                 start_indices=(1,)
                 strides=(2,) ] d
      g = add e f
      h = slice[ limit_indices=(1,)
                 start_indices=(0,)
                 strides=(2,) ] g
      i = slice[ limit_indices=(2,)
                 start_indices=(1,)
                 strides=(2,) ] g
      j = add h i
      k = slice[ limit_indices=(0,)
                 start_indices=(0,)
                 strides=(1,) ] j
      l = slice[ limit_indices=(2,)
                 start_indices=(2,)
                 strides=(2,) ] g
      m = add k l
      n = slice[ limit_indices=(1,)
                 start_indic

It looks just as when we were doing python loops! Is it bad? No, because the depth of the generated graph will only grow in $\log_2(n)$.

And as for `scan`, it is closed under differentiation:

In [62]:
print(jvp(associative_cumulative_sum, (np.arange(5.),), (np.ones(5),)))

(DeviceArray([ 0.,  1.,  3.,  6., 10.], dtype=float32), DeviceArray([1., 2., 3., 4., 5.], dtype=float32))


In [65]:
val, cumsum_bwd = vjp(associative_cumulative_sum, np.arange(5.))
print(val)
print(cumsum_bwd(np.ones(5)))

[ 0.  1.  3.  6. 10.]
(DeviceArray([5., 4., 3., 2., 1.], dtype=float32),)


### Questions:

#### Q1: 
Compare the speed of the associative_scan and scan implementation of cumulative sum on GPU and CPU (use the device flag in the jit function)

#### Q2:
Implements your own parallel version of `associative_scan` using jax primitives.