# Jax

## Define a problem in Jax

This example does not intend to cover all the features of Jax.
For more details and tutorials on Jax, please refer to **[Jax's documentation](https://jax.readthedocs.io/)**.
In this example, we solve a constrained problem given by

$$
\underset{x_1, x_2 \in \mathbb{R}}{\text{minimize}} \quad x_1^2 + x_2^2

\newline
\text{subject to} \quad x_1 \geq 0
\newline
\quad \quad \quad \quad x_1 + x_2 = 1
\newline
\quad \quad \quad \quad x_1 - x_2 \geq 1
$$

We know the solution of this problem is $x_1=1$, and $x_2=0$.
However, we start from an intial guess of $x_1=0$, and $x_2=0.0$ for the purposes of this tutorial.

The problem functions are written using Jax functions as follows:

In [1]:
import jax
import jax.numpy as jnp 
jax.config.update("jax_enable_x64", True)

# minimize x^2 + y^2 subject to x>=0, x+y=1, x-y>=1.

jax_obj = lambda x: jnp.sum(x ** 2)
jax_con = lambda x: jnp.array([x[0] + x[1], x[0] - x[1]])

ModOpt will auto-generate and jit-compile the gradient, Jacobian, 
and objective Hessian, as well as the Lagrangian, its gradient, and Hessian.
Users do not need to manually generate/jit-compile these functions 
or their derivatives using Jax and then wrap them.
Once the problem functions are defined as Jax functions, 
create a `JaxProblem` object for modOpt by passing the above functions
along with other problem constants, such as initial guesses, 
variable bounds, and constraint bounds.

In [2]:
import numpy as np
import modopt as mo

prob = mo.JaxProblem(x0=np.array([500., 5.]), nc=2, jax_obj=jax_obj, jax_con=jax_con,
                     xl=np.array([0., -np.inf]), xu=np.array([np.inf, np.inf]),
                     cl=np.array([1., 1.]), cu=np.array([1., np.inf]), 
                     name='quadratic_jax', order=1)



Once your problem model is wrapped for modOpt, import your preferred optimizer
from modOpt and solve it, following the standard procedure.
Here we will use the `SLSQP` optimizer from the SciPy library.

In [3]:
# Setup your preferred optimizer (SLSQP) with the Problem object 
# Pass in the options for your chosen optimizer
optimizer = mo.SLSQP(prob, solver_options={'maxiter':20})

# Check first derivatives at the initial guess, if needed
optimizer.check_first_derivatives(prob.x0)

# Solve your optimization problem
optimizer.solve()

# Print results of optimization
optimizer.print_results()


----------------------------------------------------------------------------
Derivative type | Calc norm  | FD norm    | Abs error norm | Rel error norm 
----------------------------------------------------------------------------

Gradient        | 1.0000e+03 | 1.0000e+03 | 1.5473e-05     | 1.5472e-08    
Jacobian        | 2.0000e+00 | 2.0000e+00 | 5.0495e-09     | 2.5248e-09    
----------------------------------------------------------------------------


	Solution from Scipy SLSQP:
	----------------------------------------------------------------------------------------------------
	Problem                  : quadratic_jax
	Solver                   : scipy-slsqp
	Success                  : True
	Message                  : Optimization terminated successfully
	Status                   : 0
	Total time               : 0.0031862258911132812
	Objective                : 1.0000000068019972
	Gradient norm            : 2.000000006801997
	Total function evals     : 2
	Total gradient evals  