# Introduction to JAX and SciPy

- [JAX documentation](https://jax.readthedocs.io/en/latest/)
- [SciPy documentation](https://docs.scipy.org/doc/scipy/)

### Imports for this lesson

In [4]:
import jax
import jax.numpy as jnp
import numpy as np
import time
from scipy.optimize import minimize, Bounds

## JAX

- [JAX tutorial](https://jax.readthedocs.io/en/latest/jax-101/index.html)

`JAX` is a wrapper of `numpy` and can be thought as *differentiable numpy*. As soon as we create a simple `JAX` array, we immediately noticed some differences w.r.t. `numpy`:
- the type, which is `DeviceArray`;
- the default `dtype` is `float32` and not `float64`.

In [3]:
# create a simple jax array
x = jnp.array([1,2,3,5,7.5])
x

Array([1. , 2. , 3. , 5. , 7.5], dtype=float32)

Another important difference with `numpy` stands in the modification of an array.

In [6]:
# numpy
x = np.array([1,2,3])
print(f"Before:{x}")
x[0] = 4
print(f"After:{x}")

Before:[1 2 3]
After:[4 2 3]


In [9]:
# JAX 
x = jnp.array([1,2,3])
print(f"Before:{x}")
# raises an error: not the correct syntax
x[0] = 4
print(f"After:{x}")

Before:[1 2 3]


TypeError: '<class 'jaxlib.xla_extension.ArrayImpl'>' object does not support item assignment. JAX arrays are immutable. Instead of ``x[idx] = y``, use ``x = x.at[idx].set(y)`` or another .at[] method: https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.ndarray.at.html

In [14]:
x = jnp.array([1,2,3])
print(f"Before:{x}")
# correct syntax
x = x.at[0].set(4)
print(f"After:{x}")

Before:[1 2 3]
After:[4 2 3]


### Differentiate with JAX

#### 1. Gradient of a function of one argument

In [15]:
def square_norm(x):
  return jnp.sum(x**2)

In [23]:
# type is important: must be float!
x = jnp.array([1.,2.,3.,4.])
square_norm(x)

Array(30., dtype=float32)

In [22]:
# the gradient of norm(x)^2 is 2x
grad_square_norm = jax.grad(square_norm)
grad_square_norm(x)

Array([2., 4., 6., 8.], dtype=float32)

#### 2. Gradient of a function of two arguments

In [24]:
def square_distance(x, y):
  return jnp.sum((x-y)**2)

In [31]:
# gradient w.r.t. the first argument (default)
grad_x_square_distance = jax.grad(square_distance)
x = jnp.array([1.,2.,3.,4.])
y = jnp.array([1.5,2.5,3.5,4.5])
grad_x_square_distance(x,y)

Array([-1., -1., -1., -1.], dtype=float32)

In [33]:
# gradient w.r.t. the second argument
grad_y_square_distance = jax.grad(square_distance, argnums = 1)
x = jnp.array([1.,2.,3.,4.])
y = jnp.array([1.5,2.5,3.5,4.5])
grad_y_square_distance(x,y)

Array([1., 1., 1., 1.], dtype=float32)

In [38]:
# gradient w.r.t. both arguments
grad_x_square_distance = jax.grad(square_distance, argnums = (0,1))
x = jnp.array([1.,2.,3.,4.])
y = jnp.array([1.5,2.5,3.5,4.5])
grad_x_y_ = grad_x_square_distance(x,y)
print(grad_x_y_)
# the output is a tuple!
print(type(grad_x_y_))

(Array([-1., -1., -1., -1.], dtype=float32), Array([1., 1., 1., 1.], dtype=float32))
<class 'tuple'>


### Just In Time (JIT) with JAX

In [39]:
# hyperbolic tangent definition
def tanh(x):
    return (jnp.exp(x) - jnp.exp(-x)) / (jnp.exp(x) + jnp.exp(-x))

In [64]:
# random seed
key = jax.random.PRNGKey(42)
x = jax.random.uniform(key, (1000000, ))
x

Array([0.92596364, 0.12574375, 0.14114332, ..., 0.37422693, 0.46216547,
       0.4711089 ], dtype=float32)

In [72]:
print("Non-jit tanh: ")
%time tanh(x).block_until_ready()

Non-jit tanh: 
CPU times: user 18.9 ms, sys: 0 ns, total: 18.9 ms
Wall time: 7.38 ms


Array([0.72870654, 0.12508513, 0.14021347, ..., 0.3576834 , 0.43184748,
       0.43909487], dtype=float32)

In [79]:
# First call
print("Jit tanh, first call:")
%time tanh_jit(x).block_until_ready()

# Second call
print("Jit tanh, second call:")
%time tanh_jit(x).block_until_ready()

Jit tanh, first call:
CPU times: user 7.86 ms, sys: 0 ns, total: 7.86 ms
Wall time: 2.51 ms
Jit tanh, second call:
CPU times: user 3.3 ms, sys: 770 µs, total: 4.07 ms
Wall time: 1.97 ms


Array([0.72870654, 0.12508513, 0.14021347, ..., 0.3576834 , 0.43184748,
       0.43909487], dtype=float32)

## SciPy

[Optimization in SciPy](https://docs.scipy.org/doc/scipy/tutorial/optimize.html)

### Unconstrained optimization

To learn how to minimize a function using `scipy.optimize.minimize`, we consider the problem of minimizing the *Rosenbrock function*:
$$f(x) = \sum_{i = 1}^{n-1} 100 (x_{i+1} - x_{i}^2) + (1 - x_i^2),$$
where $x = [x_1, \dots, x_n] \in \mathbb{R}^n$. This function has a global minimum in $\tilde{x} = [1, \dots, 1]$ where $f(\tilde{x}) = 0$.

In [5]:
def rosenbrock(x):
    """The Rosenbrock function"""
    return jnp.sum(100.0*(x[1:]-x[:-1]**2.0)**2.0 + (1.-x[:-1])**2.0)

In [7]:
# unconstrained minimization without gradient using Nelder-Mead
x0 = jnp.array([1.8, 0.7, 0.8, 1.9, 1.2])
res = minimize(rosenbrock, x0, method='nelder-mead', options={'xatol': 1e-8, 'disp': True})
res.x

Optimization terminated successfully.
         Current function value: 0.000000
         Iterations: 395
         Function evaluations: 650


array([1., 1., 1., 1., 1.])

In [8]:
# unconstrained minimization with jitted gradient using BFGS
x0 = jnp.array([1.8, 0.7, 0.8, 1.9, 1.2])
rosenbrock_grad = jax.jit(jax.grad(rosenbrock))
res = minimize(rosenbrock, x0, method='BFGS', 
               jac=rosenbrock_grad, options={'gtol': 1e-8, 'disp': True})
res.x

Optimization terminated successfully.
         Current function value: 0.000000
         Iterations: 28
         Function evaluations: 34
         Gradient evaluations: 34


array([0.99999998, 0.99999998, 1.00000001, 1.00000002, 1.00000004])

### Constrained optimization

Here we solve the following 2-dimensional constrained optimization problem:

\begin{equation}
\begin{aligned}
\min_{x_0,x_1} \quad & 100(x_1 - x_0)^2 + (1 - x_0)^2,\\
\textrm{s.t.} \quad & x_0 + 2 x_1 \leq 1,\\
& x_0^2 + x_1 \leq 1, \\
& x_0^2 - x_1 \leq 1, \\
& 2x_0 + x_1 = 1, \\
& 0 \leq x_0 \leq 1, \quad -0.5 \leq x_1 \leq 2.\\
\end{aligned}
\end{equation}

In [17]:
x0 = np.array([0.5, 0])

# inequality constraints
ineq_cons = {'type': 'ineq',
             'fun' : lambda x: np.array([1 - x[0] - 2*x[1],
                                         1 - x[0]**2 - x[1],
                                         1 - x[0]**2 + x[1]]),
             'jac' : lambda x: np.array([[-1.0, -2.0],
                                         [-2*x[0], -1.0],
                                         [-2*x[0], 1.0]])}
# equality constraints
eq_cons = {'type': 'eq',
           'fun' : lambda x: np.array([2*x[0] + x[1] - 1]),
           'jac' : lambda x: np.array([2.0, 1.0])}

# 0 <= x_0 <= 1 and -0.5 <= x_1 <= 2
bounds = Bounds([0, -0.5], [1.0, 2.0])

res = minimize(rosenbrock, x0, method='SLSQP', jac=rosenbrock_grad,
               constraints=[eq_cons, ineq_cons], options={'ftol': 1e-9, 'disp': True},
               bounds=bounds)
res.x

Optimization terminated successfully    (Exit mode 0)
            Current function value: 0.34271758794784546
            Iterations: 4
            Function evaluations: 5
            Gradient evaluations: 4


array([0.41494474, 0.17011051])