# Using JAXopt

In this exercise you will use JAXopt to solve a *batched* version of the first and second exercise.

## Resources

- [JAX documentation](https://jax.readthedocs.io/en/latest/notebooks/quickstart.html)
- [jax.numpy documentation](https://jax.readthedocs.io/en/latest/jax.numpy.html)
- [JAXopt documentation](https://jaxopt.github.io/stable/unconstrained.html)

---

Lets start by defining the criterion function from the solutions of the second exercise, but this time we use jax.numpy instead of numpy, and we parametrize the function.

In [1]:
import jax
import jax.numpy as jnp

jax.config.update("jax_enable_x64", True)

In [2]:
x0 = {"a": jnp.array(0.0), "b": jnp.zeros(3), "c": jnp.zeros((2, 2))}
b0 = jnp.arange(3, dtype="float64")


def f(x, b0):
    value = (
        (x["a"] - jnp.pi) ** 2
        + jnp.sum((x["b"] - b0) ** 2)
        + jnp.sum((x["c"] - jnp.eye(2)) ** 2)
    )
    return value



## Task 1: Optimize using `JAXopt`

- Create a solver instance of the class `LBFGS` for your problem. You need to make sure to
    - pass the function `f`,
    - set `tol=1e-6` for increased accuracy.
- Run the optimization using `solver.run`. You need to make sure to
    - pass the initial parameters,
    - pass additional arguments of `f`.
- Look at the output of the results. How do you access the parameters?

In [3]:
from jaxopt import LBFGS

solver = ...

## Task 2: Batched optimization

Now you will optimize `f` not only for a single value of `b`, but for many.

- Write a wrapper for `solver.run` that takes starting values for `x` and a single vector-valued parameter `b0`.
- Use `vmap` and `jit` to create a vectorized and jitted version of this wrapper that allows for array-valued `b0`.
- Execute your vectorized function on `b_arr` (this should perform 500 optimizations.)

In [4]:
from jax import jit
from jax import vmap


b_arr = jnp.arange(1500, dtype="float64").reshape(500, 3)

## Optional: Speed comparison

Lets compare the speed of `jaxopt`'s batch optimization to a loop with scipy's `minimize`.

- Finish the loop in `batch_solve_scipy` using `method="L-BFGS-B"` in scipy's `minimize`.
- Time the functions (Use `%timeit func()` in a notebook cell.)

In [5]:
import numpy as np
from scipy.optimize import minimize


x0_numpy = np.zeros(8)
b_arr_numpy = np.array(b_arr)


def f_numpy(x, b0):
    a = x[0]
    b = x[1:4]
    c = x[4:].reshape(2, 2)

    value = (a - np.pi) ** 2 + np.sum((b - b0) ** 2) + np.sum((c - np.eye(2)) ** 2)
    return value

In [6]:
def batch_solve_scipy(x0, b_arr):

    results = []
    for b in b_arr:

        res = ...

        b = res.x[1:4]
        results.append(b)

    return np.stack(results)

### Timing

In [7]:
%timeit ...

7.83 ns ± 0.163 ns per loop (mean ± std. dev. of 7 runs, 100,000,000 loops each)
