# Jax

Jax is package for "composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more".
* Core features
    * [Automatic differentiation](#Automatic-differentiation)
    * [Vectorization with vmap](#Vectorization-with-jax.vmap)
    * [Just-in-time compilation](#Just-in-time-compilation-with-jax.jit)
    * [Gotchas](#Gotchas)
* Submodules
    * [Random-number-generation](#Random-number-generation)
* Use cases
    * [Gradient-based optimization](#Use-case:-Gradient-based-optimization)
    * [Differentiable polynomial regression model](#Use-case:-Differentiable-polynomial-regression-model)

In [1]:
import jax
import jax.numpy as np
import numpy as onp

## Automatic differentiation

jax.grad produces a function handle for the derivative of a function.

In [2]:
jax.grad(np.tanh)(2.0)



DeviceArray(0.07065082, dtype=float32)

Higher order derivatives can be simply chained.

In [3]:
jax.grad(jax.grad(np.tanh))(2.0)

DeviceArray(-0.13621867, dtype=float32)

The keyword argument `argnums` specifies with respect to which arguments the function is to be differentiated.
`argnums` can also take a tuple to calculated multiple gradients at the same time.

In [4]:
def f(x, a, b):
    return a * x + b

df_dx = jax.grad(f, argnums=0)  # df/dx
df_dx(0.1, 2, 3)

DeviceArray(2., dtype=float32)

Function value and gradient can be calculated simultaneously.

In [5]:
jax.value_and_grad(f, argnums=0)(0.1, 2, 3)

(DeviceArray(3.2, dtype=float32), DeviceArray(2., dtype=float32))

Derivatives can be functions involving python control flow operators.

In [6]:
def f(x):
    if x < 1:
        return x
    else:
        return x**2

jax.grad(f)(1.1)

DeviceArray(2.2, dtype=float32)

### Jacobian
In order to calculate the gradient of vector-valued functions we need to use `jacfwd` or `jacrev` instead of `grad`.

In [7]:
def f(x):
    """Scalar-valued function with 3 inputs."""
    x1, x2, x3 = x
    return x1 + x2 * np.sin(x3)

x = np.ones(3)
print(jax.grad(f)(x))
print(jax.jacrev(f)(x))
print(jax.jacfwd(f)(x))


[1.         0.84147096 0.5403023 ]
[1.         0.84147096 0.5403023 ]
[1.         0.84147096 0.5403023 ]


In [8]:
def f(x):
    """Vector-valued function with 3 inputs and 2 outputs"""
    x1, x2, x3 = x
    return np.array([
        x1 * np.sin(x2),
        x1 * x2 * x3**2
    ])

x = np.ones(3)
print('f(x):\n', f(x))
print('df/dx(x):')
print(jax.jacfwd(f)(x))

f(x):
 [0.84147096 1.        ]
df/dx(x):
[[0.84147096 0.5403023  0.        ]
 [1.         1.         2.        ]]


### Hessian
For the Hessian we simply call the jacobian twice.

In [9]:
def hessian(f):
    return jax.jacfwd(jax.jacrev(f))

H = hessian(f)(x)
print('H, shape', H.shape)
print(H)

H, shape (2, 3, 3)
[[[ 0.          0.5403023   0.        ]
  [ 0.5403023  -0.84147096  0.        ]
  [ 0.          0.          0.        ]]

 [[ 0.          1.          2.        ]
  [ 1.          0.          2.        ]
  [ 2.          2.          2.        ]]]


## Vectorization with jax.vmap

In [10]:
def f(x):
    return x[:, 0] * np.sin(x[:, 1])

x = np.ones((3, 2))
jax.jacrev(f)(x)

DeviceArray([[[0.84147096, 0.5403023 ],
              [0.        , 0.        ],
              [0.        , 0.        ]],

             [[0.        , 0.        ],
              [0.84147096, 0.5403023 ],
              [0.        , 0.        ]],

             [[0.        , 0.        ],
              [0.        , 0.        ],
              [0.84147096, 0.5403023 ]]], dtype=float32)

In [11]:
def f(x):
    return x[0] * np.sin(x[1])

x = np.ones((5, 2))
y = jax.vmap(f)(x)
print('f(x):', y)

f(x): [0.84147096 0.84147096 0.84147096 0.84147096 0.84147096]


## Just-in-time compilation with jax.jit

In [12]:
J = jax.vmap(jax.jacfwd(f))(x)
print(J)

[[0.84147096 0.5403023 ]
 [0.84147096 0.5403023 ]
 [0.84147096 0.5403023 ]
 [0.84147096 0.5403023 ]
 [0.84147096 0.5403023 ]]


In [13]:
H = jax.vmap(jax.jacfwd(jax.jacrev(f)))(x)
print(H)

[[[ 0.          0.5403023 ]
  [ 0.5403023  -0.84147096]]

 [[ 0.          0.5403023 ]
  [ 0.5403023  -0.84147096]]

 [[ 0.          0.5403023 ]
  [ 0.5403023  -0.84147096]]

 [[ 0.          0.5403023 ]
  [ 0.5403023  -0.84147096]]

 [[ 0.          0.5403023 ]
  [ 0.5403023  -0.84147096]]]


## Gotchas
Index-based operations have to be modified using `index_add`, `index_update` and `index` so that autograd can handle them.  

In [14]:
from jax.ops import index, index_add, index_update

In [15]:
try:
    x = np.ones((3, 5))
    x[::2, 3:] += 4
except TypeError as err:
    print(err)

'<class 'jax.interpreters.xla.DeviceArray'>' object does not support item assignment. JAX arrays are immutable; perhaps you want jax.ops.index_update or jax.ops.index_add instead?


In [16]:
x = np.ones((3, 5))
x = index_add(x, index[::2, 3:], 4.)
x

DeviceArray([[1., 1., 1., 5., 5.],
             [1., 1., 1., 1., 1.],
             [1., 1., 1., 5., 5.]], dtype=float32)

In [17]:
x = np.ones((3, 5))
x = index_update(x, index[0, :], 4.)
x

DeviceArray([[4., 4., 4., 4., 4.],
             [1., 1., 1., 1., 1.],
             [1., 1., 1., 1., 1.]], dtype=float32)

## Random number generation
Jax has its own pseudo random number generation system that focusses on parallel usage.
A pecularity compared to e.g. `numpy` is that calls to the random number generator don't modify the random key, thus multiple calls with the same key yield the same random numbers.

In [18]:
key = jax.random.PRNGKey(42)

for i in range(2):
    x = jax.random.uniform(key, shape=(3,))
    print('key =', key, 'x =', x)


key = [ 0 42] x = [0.57414436 0.10015821 0.05946112]
key = [ 0 42] x = [0.57414436 0.10015821 0.05946112]


We need to explicitly create new keys by splitting an existing one.

In [19]:
for i in range(2):
    key, = jax.random.split(key, num=1)  # num new keys are created
    x = jax.random.uniform(key, shape=(3,))
    print('key =', key, 'x =', x)

key = [  64467757 2916123636] x = [0.05988741 0.19778168 0.13219142]
key = [2350016172 1168365246] x = [0.853312   0.68688035 0.85908866]


## Use case: Gradient-based optimization

Here we use jax to provide gradients to the SLSQP optimizer in `scipy.optimize` in order to M
minimize the bivariate Rosenbrock, subject to some non-linear (in)equality constraints and box bounds.

In [20]:
import scipy.optimize

def rosen(x):
    return np.sum(100 * (x[1] - x[0]**2)**2 + (1 - x[0])**2)

def ineq(x):
    return np.array([
        1 - x[0] - 2 * x[1],
        1 - x[0]**2 - x[1],
        1 - x[0]**2 + x[1]
    ])

def eq(x):
    return np.array([2 * x[0] + x[1] - 1])

res = scipy.optimize.minimize(
    rosen,
    jac=jax.grad(rosen),
    x0=np.array([0.5, 0]),
    bounds=[(0, 1), (-0.5, 2.0)],
    constraints=[
        dict(type='eq', fun=eq, jac=jax.jacfwd(eq)),
        dict(type='ineq', fun=ineq, jac=jax.jacfwd(ineq)) 
    ], 
    options={'ftol': 1e-9, 'disp': True},
    method='SLSQP',
)
res

Optimization terminated successfully.    (Exit mode 0)
            Current function value: 0.34271758794784546
            Iterations: 4
            Function evaluations: 5
            Gradient evaluations: 4


     fun: 0.34271758794784546
     jac: array([-0.82676458, -0.41372478])
 message: 'Optimization terminated successfully.'
    nfev: 5
     nit: 4
    njev: 4
  status: 0
 success: True
       x: array([0.41494475, 0.1701105 ])

## Use case: Differentiable polynomial regression model
Sklearn models generally don't provide gradients $df(x)/dx$. Here we set up a multivariate polynomial regression model with jax, in which the fitted coefficients (linreg, PLS, ...) from sklearn can be plugged.

In [21]:
import sklearn.preprocessing
import sklearn.linear_model
import sklearn.pipeline
import itertools as it

In [22]:
# toy data, x in R^3 -> y in R^1
onp.random.seed(42)
N = 20
X_data = onp.random.rand(N, 3)
x1, x2, x3 = X_data.T
Y_data = 1 + x1 + 0.2 * x1 * x2 + x1**2 - 0.2 * x2 * x3 - 0.1 * x3**2 + 0.1 * onp.random.randn(N)

In [23]:
model = sklearn.pipeline.Pipeline([
    ("polyfeat", sklearn.preprocessing.PolynomialFeatures(degree=2, include_bias=True, interaction_only=False)),
    ("linreg", sklearn.linear_model.LinearRegression(fit_intercept=False)),
])
model.fit(X_data, Y_data)
p = model["linreg"].coef_
print(f"model coefficients\n{p}")

model coefficients
[ 0.98178781  0.84156272  0.53594876 -0.40367642  1.24280767 -0.25484535
  0.78677898 -0.38437492 -0.17086278 -0.14827962]


In [24]:
def combinations(n_features, degree, interaction_only, include_bias):
    n_features = len(x)
    comb = (it.combinations if interaction_only else it.combinations_with_replacement)
    start = int(not include_bias)
    return it.chain.from_iterable(comb(range(n_features), i) for i in range(start, degree + 1))

def polyn(x, p, degree=2, interaction_only=False, include_bias=True):
    """Multivariate polynomial - Numpy version"""
    combs = [c for c in combinations(len(x), degree, interaction_only, include_bias)]
    out = onp.array(p)
    for (i, c) in enumerate(combs):
        for j in c:
            out[i] *= x[j]
    return out.sum()

@jax.jit
def jpolyn(x, p, degree=2, interaction_only=False, include_bias=True):
    """Multivariate polynomial - JAX version"""
    combs = [c for c in combinations(len(x), degree, interaction_only, include_bias)]
    out = np.array(p)
    for (i, c) in enumerate(combs):
        for j in c:
            out = jax.ops.index_update(out, jax.ops.index[i], out[i] * x[j])
    return out.sum()

In [25]:
x = np.array([1, 2, 3], dtype=np.float32)
print("original     ", model.predict(x.reshape(1, -1))[0])
print("polyn (numpy)", polyn(x, p))
print("polyn (jax)  ", jpolyn(x, p))

original      0.8804797755714371
polyn (numpy) 0.8804796997954933
polyn (jax)   0.8804792


And now we have gradients for our polynomials.

In [26]:
jax.grad(jpolyn, argnums=0)(x, p)

DeviceArray([ 5.1778245, -1.7689846, -0.8483008], dtype=float32)