# Auto-diff with JAX

https://github.com/google/jax

JAX is a Google research project, developed by the former developers of [Autograd](https://github.com/hips/autograd), bringing together the potentialities of Autograd and the linear algebra accelerator [XLA](https://www.tensorflow.org/xla). It is based on three pillars:
- `grad`: Automatic Differentiation
- `jit`: Just-in-time compilation
- `vmap`: Automatic vectorization.

## Automatic differentiation in JAX

JAX augments numpy and Python code with function transformations which make it trivial to perform operations common in machine learning programs. JAX's augmented numpy lives at `jax.numpy`. With a few exceptions, you can think of `jax.numpy` as directly interchangeable with `numpy`. As a general rule, you should use `jax.numpy` whenever you plan to use any of JAX's transformations.

The function `df = jax.grad(f, argnums = 0)` takes the callable object `f` and returns another callable object, `df`, evaluating the gradient of `f` w.r.t. the argument(s) of index(es) `argnums`. For more information, check out the [documentation](https://jax.readthedocs.io/en/latest/jax.html?highlight=grad#jax.grad).

**Example**

We consider the function:
$$
f(x) = x \sin(x^2)
$$

and we compute $f'(x_0)$ for $x_0 = 0.13$

In [4]:
import numpy as np
import jax.numpy as jnp
import jax

func = lambda x : x * jnp.sin(x ** 2)
x0 = 0.13
dfunc_AD = jax.grad(func, argnums=0)
df_AD = dfunc_AD(x0)

# analytical derivative
dfunc = lambda x : np.sin(x**2)+2 * x**2 * np.cos(x**2)
df_ex = dfunc(x0)

print('df (ex): %f' % df_ex)
print('df (AD): %f' % df_AD)

print('err (AD): %e' % (abs(df_AD - df_ex)/abs(df_ex))) # errore percentuale minuscolo

df (ex): 0.050694
df (AD): 0.050694
err (AD): 7.348529e-08


Evaluate the execution times of the functions `func` and `dfunc_AD`.

In [32]:
%timeit dfunc_AD(x0)

3.27 ms ± 296 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


In [33]:
%timeit func(x0)

15.4 µs ± 19 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


### Speed it up with JIT

Compile the functions `func` and `dfunc_AD` using the [just-in-time compilation](https://en.wikipedia.org/wiki/Just-in-time_compilation) utility `jax.jit`. 

With `f_jit = jax.jit(f)` a callable `f` is compiled into `f_jit`.

Then, check that the compiled functions return the same results as the original ones. Finally, evaluate the execution times and compare it with the previous results.

In [34]:
dfunc_AD_jit = jax.jit(dfunc_AD)

In [35]:
%timeit dfunc_AD_jit(x0) # ora è veloce come una normale function evaluation --> ecco lo standard migliore che abbiamo sul mercato

6.42 µs ± 19.8 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [36]:
# per le reti neurali userai solo backwards mode --> molti input e pochi output --> questo è l'unico modo per scalare a reti neurali profonde (backpropagation)

In [37]:
def func(x, y):
    return x * jnp.sin(x ** 2) + x * y

x0 = 0.13
y0 = 1.0
grad_x = jax.grad(func, argnums=0)(x0, y0) # argnums mi dice rispetto a cosa derivare --> x0 e y0 devono essere float
grad_y = jax.grad(func, argnums=1)(x0, y0)

float(grad_x), float(grad_y)

(1.0506943464279175, 0.12999999523162842)

In [38]:
# ecco come usare jax per calcolare le derivate parziali puntuali di una funzioni multivariabile