# Jax

Jax is package for "composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more".
* Core features
    * [Automatic differentiation](#Automatic-differentiation)
    * [Vectorization with vmap](#Vectorization-with-jax.vmap)
    * [Just-in-time compilation](#Just-in-time-compilation-with-jax.jit)
* Submodules
    * [Random-number-generation](#Random-number-generation)
* Use cases
    * [Gradient-based optimization](#Use-case:-Gradient-based-optimization)


In [1]:
import jax
import jax.numpy as np
import numpy as onp

## Automatic differentiation

jax.grad produces a function handle for the derivative of a function.

In [2]:
jax.grad(np.tanh)(2.0)

DeviceArray(0.07065082, dtype=float32)

Higher order derivatives can be simply chained.

In [3]:
jax.grad(jax.grad(np.tanh))(2.0)

DeviceArray(-0.13621867, dtype=float32)

The keyword argument `argnums` specifies with respect to which arguments the function is to be differentiated.
`argnums` can also take a tuple to calculated multiple gradients at the same time.

In [4]:
def foo(x, a, b):
    return a * x + b

dfoo_dx = jax.grad(foo, argnums=0)
dfoo_dx(0.1, 2, 3)

DeviceArray(2., dtype=float32)

Function value and gradient can be calculated simultaneously.

In [5]:
val_grad_foo = jax.value_and_grad(foo, argnums=0)
val_grad_foo(0.1, 2, 3)

(DeviceArray(3.2, dtype=float32), DeviceArray(2., dtype=float32))

### Jacobian
In order to calculate the gradient of vector-valued functions we need to use `jacfwd` or `jacrev` instead of `grad`.

In [6]:
def f(x):
    """Scalar-valued function with 3 inputs."""
    x1, x2, x3 = x
    return x1 + x2 * np.sin(x3)

x = np.ones(3)
print(jax.grad(f)(x))
print(jax.jacrev(f)(x))
print(jax.jacfwd(f)(x))


[1.         0.84147096 0.5403023 ]
[1.         0.84147096 0.5403023 ]
[1.         0.84147096 0.5403023 ]


In [7]:
def f(x):
    """Vector-valued function with 3 inputs and 2 outputs"""
    x1, x2, x3 = x
    return np.array([
        x1 * np.sin(x2),
        x1 * x2 * x3**2
    ])

x = np.ones(3)
print('f(x):\n', f(x))
print('df/dx(x):')
print(jax.jacfwd(f)(x))

f(x):
 [0.84147096 1.        ]
df/dx(x):
[[0.84147096 0.5403023  0.        ]
 [1.         1.         2.        ]]


### Hessian
For the Hessian we simply call the jacobian twice.

In [8]:
def hessian(f):
    return jax.jacfwd(jax.jacrev(f))

H = hessian(f)(x)
print('H, shape', H.shape)
print(H)

H, shape (2, 3, 3)
[[[ 0.          0.5403023   0.        ]
  [ 0.5403023  -0.84147096  0.        ]
  [ 0.          0.          0.        ]]

 [[ 0.          1.          2.        ]
  [ 1.          0.          2.        ]
  [ 2.          2.          2.        ]]]


## Vectorization with jax.vmap

In [9]:
def f(x):
    return x[:, 0] * np.sin(x[:, 1])

x = np.ones((3, 2))
jax.jacrev(f)(x)

DeviceArray([[[0.84147096, 0.5403023 ],
              [0.        , 0.        ],
              [0.        , 0.        ]],

             [[0.        , 0.        ],
              [0.84147096, 0.5403023 ],
              [0.        , 0.        ]],

             [[0.        , 0.        ],
              [0.        , 0.        ],
              [0.84147096, 0.5403023 ]]], dtype=float32)

In [10]:
def f(x):
    return x[0] * np.sin(x[1])

x = np.ones((5, 2))
y = jax.vmap(f)(x)
print('f(x):', y)

f(x): [0.84147096 0.84147096 0.84147096 0.84147096 0.84147096]


## Just-in-time compilation with jax.jit

In [11]:
J = jax.vmap(jax.jacfwd(f))(x)
print(J)

[[0.84147096 0.5403023 ]
 [0.84147096 0.5403023 ]
 [0.84147096 0.5403023 ]
 [0.84147096 0.5403023 ]
 [0.84147096 0.5403023 ]]


In [12]:
H = jax.vmap(jax.jacfwd(jax.jacrev(f)))(x)
print(H)

[[[ 0.          0.5403023 ]
  [ 0.5403023  -0.84147096]]

 [[ 0.          0.5403023 ]
  [ 0.5403023  -0.84147096]]

 [[ 0.          0.5403023 ]
  [ 0.5403023  -0.84147096]]

 [[ 0.          0.5403023 ]
  [ 0.5403023  -0.84147096]]

 [[ 0.          0.5403023 ]
  [ 0.5403023  -0.84147096]]]


## Random number generation
Jax has its own pseudo random number generation system that focusses on parallel usage.
A pecularity compared to e.g. `numpy` is that calls to the random number generator don't modify the random key, thus multiple calls with the same key yield the same random numbers.

In [13]:
key = jax.random.PRNGKey(42)

for i in range(2):
    x = jax.random.uniform(key, shape=(3,))
    print('key =', key, 'x =', x)


key = [ 0 42] x = [0.57414436 0.10015821 0.05946112]
key = [ 0 42] x = [0.57414436 0.10015821 0.05946112]


We need to explicitly create new keys by splitting an existing one.

In [14]:
for i in range(2):
    key, = jax.random.split(key, num=1)  # num new keys are created
    x = jax.random.uniform(key, shape=(3,))
    print('key =', key, 'x =', x)

key = [  64467757 2916123636] x = [0.05988741 0.19778168 0.13219142]
key = [2350016172 1168365246] x = [0.853312   0.68688035 0.85908866]


## Use case: Gradient-based optimization

Here we use jax to provide gradients to the SLSQP optimizer in `scipy.optimize` in order to M
minimize the bivariate Rosenbrock, subject to some non-linear (in)equality constraints and box bounds.

In [15]:
import scipy.optimize

def rosen(x):
    return np.sum(100 * (x[1] - x[0]**2)**2 + (1 - x[0])**2)

def ineq(x):
    return np.array([
        1 - x[0] - 2 * x[1],
        1 - x[0]**2 - x[1],
        1 - x[0]**2 + x[1]
    ])

def eq(x):
    return np.array([2 * x[0] + x[1] - 1])

res = scipy.optimize.minimize(
    rosen,
    jac=jax.grad(rosen),
    x0=np.array([0.5, 0]),
    bounds=[(0, 1), (-0.5, 2.0)],
    constraints=[
        dict(type='eq', fun=eq, jac=jax.jacfwd(eq)),
        dict(type='ineq', fun=ineq, jac=jax.jacfwd(ineq)) 
    ], 
    options={'ftol': 1e-9, 'disp': True},
    method='SLSQP',
)
res

Optimization terminated successfully.    (Exit mode 0)
            Current function value: 0.34271758794784546
            Iterations: 4
            Function evaluations: 5
            Gradient evaluations: 4


     fun: 0.34271758794784546
     jac: array([-0.82676458, -0.41372478])
 message: 'Optimization terminated successfully.'
    nfev: 5
     nit: 4
    njev: 4
  status: 0
 success: True
       x: array([0.41494475, 0.1701105 ])