# Automatic differentiation with JAX

## Main features

- Numpy wrapper
- Auto-vectorization
- Auto-parallelization (SPMD paradigm)
- Auto-differentiation
- XLA backend and JIT support

## How to compute gradient of your objective?

- Define it as a standard Python function
- Call ```jax.grad``` and voila!
- Do not forget to wrap these functions with ```jax.jit``` to speed up

In [1]:
import jax
import jax.numpy as jnp

- By default, JAX exploits single-precision numbers ```float32```
- You can enable double precision (```float64```) by hands.  

In [2]:
from jax.config import config
config.update("jax_enable_x64", True)

In [3]:
@jax.jit
def f(x, A, b):
    res = A @ x - b
    return res @ res

gradf = jax.grad(f, argnums=0, has_aux=False)

## Random numbers in JAX 

- JAX focuses on the reproducibility of the runs
- Analogue of random seed is **the necessary argument** of all functions that generate something random
- More details and references on the design of ```random``` submodule are [here](https://github.com/google/jax/blob/master/design_notes/prng.md)

In [4]:
n = 1000
x = jax.random.normal(jax.random.PRNGKey(0), (n, ))
A = jax.random.normal(jax.random.PRNGKey(0), (n, n))
b = jax.random.normal(jax.random.PRNGKey(0), (n, ))

In [5]:
print("Check correctness", jnp.linalg.norm(gradf(x, A, b) - 2 * A.T @ (A @ x - b)))
print("Compare speed")
print("Analytical gradient")
%timeit 2 * A.T @ (A @ x - b)
print("Grad function")
%timeit gradf(x, A, b).block_until_ready()
jit_gradf = jax.jit(gradf)
print("Jitted grad function")
%timeit jit_gradf(x, A, b).block_until_ready()

Check correctness 9.188704584401416e-11
Compare speed
Analytical gradient
1.66 ms ± 4.66 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Grad function
946 µs ± 1.34 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
Jitted grad function
316 µs ± 2.03 µs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [6]:
hess_func = jax.jit(jax.hessian(f))
print("Check correctness", jnp.linalg.norm(2 * A.T @ A - hess_func(x, A, b)))
print("Time for hessian")
%timeit hess_func(x, A, b).block_until_ready()
print("Emulate hessian and check correctness", 
      jnp.linalg.norm(jax.jit(hess_func)(x, A, b) - jax.jacfwd(jax.jacrev(f))(x, A, b)))
print("Time of emulating hessian")
hess_umul_func = jax.jit(jax.jacfwd(jax.jacrev(f)))
%timeit hess_umul_func(x, A, b).block_until_ready()

Check correctness 0.0
Time for hessian
97.6 ms ± 3.78 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Emulate hessian and check correctness 0.0
Time of emulating hessian
96.2 ms ± 7.96 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


## Forward mode vs. backward mode: $m \ll n$

In [7]:
fmode_f = jax.jit(jax.jacfwd(f))
bmode_f = jax.jit(jax.jacrev(f))
print("Check correctness", jnp.linalg.norm(fmode_f(x, A, b) - bmode_f(x, A, b)))
print("Forward mode")
%timeit fmode_f(x, A, b).block_until_ready()
print("Backward mode")
%timeit bmode_f(x, A, b).block_until_ready()

Check correctness 1.3948679012180167e-10
Forward mode
46.2 ms ± 1.13 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Backward mode
1.24 ms ± 64.8 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)


## Forward mode vs. backward mode: $m \geq n$

In [8]:
def fvec(x, A, b):
    y = A @ x + b
    return jnp.exp(y - jnp.max(y)) / jnp.sum(jnp.exp(y - jnp.max(y)))

In [9]:
grad_fvec = jax.jit(jax.grad(fvec))
jac_fvec = jax.jacobian(fvec)
fmode_fvec = jax.jit(jax.jacfwd(fvec))
bmode_fvec = jax.jit(jax.jacrev(fvec))

In [10]:
n = 1000
m = 1000
x = jax.random.normal(jax.random.PRNGKey(0), (n, ))
A = jax.random.normal(jax.random.PRNGKey(0), (m, n))
b = jax.random.normal(jax.random.PRNGKey(0), (m, ))

In [11]:
J = jac_fvec(x, A, b)
print(J.shape)
grad_fvec(x, A, b)

(1000, 1000)


TypeError: Gradient only defined for scalar-output functions. Output had shape: (1000,).

In [12]:
print("Check correctness", jnp.linalg.norm(fmode_fvec(x, A, b) - bmode_fvec(x, A, b)))
print("Check shape", fmode_fvec(x, A, b).shape, bmode_fvec(x, A, b).shape)
print("Time forward mode")
%timeit fmode_fvec(x, A, b).block_until_ready()
print("Time backward mode")
%timeit bmode_fvec(x, A, b).block_until_ready()

Check correctness 7.941016085443863e-16
Check shape (1000, 1000) (1000, 1000)
Time forward mode
55.9 ms ± 3.51 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)
Time backward mode
56.3 ms ± 3.1 ms per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [13]:
n = 10
m = 1000
x = jax.random.normal(jax.random.PRNGKey(0), (n, ))
A = jax.random.normal(jax.random.PRNGKey(0), (m, n))
b = jax.random.normal(jax.random.PRNGKey(0), (m, ))

In [14]:
print("Check correctness", jnp.linalg.norm(fmode_fvec(x, A, b) - bmode_fvec(x, A, b)))
print("Check shape", fmode_fvec(x, A, b).shape, bmode_fvec(x, A, b).shape)
print("Time forward mode")
%timeit fmode_fvec(x, A, b).block_until_ready()
print("Time backward mode")
%timeit bmode_fvec(x, A, b).block_until_ready()

Check correctness 7.297678299520367e-16
Check shape (1000, 10) (1000, 10)
Time forward mode
219 µs ± 13.4 µs per loop (mean ± std. dev. of 7 runs, 1000 loops each)
Time backward mode
10.4 ms ± 428 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Hessian-by-vector product 

In [93]:
def hvp(f, x, z, *args):
    def g(x):
        return f(x, *args)
    return jax.jvp(jax.grad(g), (x,), (z,))[1]

In [94]:
n = 3000
x = jax.random.normal(jax.random.PRNGKey(0), (n, ))
A = jax.random.normal(jax.random.PRNGKey(0), (n, n))
b = jax.random.normal(jax.random.PRNGKey(0), (n, ))
z = jax.random.normal(jax.random.PRNGKey(0), (n, ))

In [95]:
print("Check correctness", jnp.linalg.norm(2 * A.T @ (A @ z) - hvp(f, x, z, A, b)))
print("Time for hvp by hands")
%timeit (2 * A.T @ (A @ z)).block_until_ready()
print("Time for hvp via jvp, NO jit")
%timeit hvp(f, x, z, A, b).block_until_ready()
print("Time for hvp via jvp, WITH jit")
%timeit jax.jit(hvp, static_argnums=0)(f, x, z, A, b).block_until_ready()

Check correctness 8.374868878600283e-10
Time for hvp by hands
43.1 ms ± 923 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Time for hvp via jvp, NO jit
31 ms ± 352 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
Time for hvp via jvp, WITH jit
8.61 ms ± 318 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Summary

- JAX is a simple and extensible tool in the problem where autodiff is crucial
- JIT is a key to fast Python code
- Input/output dimensions are important
- Hessian matvec is faster than explicit hessian matrix by vector product