# Automatic differentiation

In this exercise you will use automatic differentiation in JAX and estimagic to solve the previous problem.

## Resources

- https://jax.readthedocs.io/en/latest/jax.numpy.html
- https://jax.readthedocs.io/en/latest/notebooks/autodiff_cookbook.html

In [1]:
import jax 
import jax.numpy as jnp
import estimagic as em

jax.config.update("jax_enable_x64", True)

## Task 1:  Switch to JAX

- Use the code from exercise 2, task 2, and convert the criterion function and the parameters to JAX. Hint: look at the `jax.numpy` documentation.

In [2]:
def criterion(x):
    first = (x["a"] - jnp.pi) ** 4 
    second =  jnp.sum((x["b"] - jnp.arange(3)) ** 2)
    third = jnp.sum((x["c"] - jnp.eye(2)) ** 2)
    return first + second + third
    
    
start_params = {
    "a": 1.,
    "b": jnp.ones(3).astype(float),
    "c": jnp.ones((2, 2)).astype(float)
}



In [3]:
criterion(start_params)

DeviceArray(25.0352401, dtype=float64)

## Task 2: Gradient

- Compute the gradient of the criterion (the whole function). Hint: look at the `autodiff_cookbook` documentation 

In [4]:
gradient = jax.grad(criterion)
gradient(start_params)

{'a': DeviceArray(-39.28896575, dtype=float64, weak_type=True),
 'b': DeviceArray([ 2.,  0., -2.], dtype=float64),
 'c': DeviceArray([[0., 2.],
              [2., 0.]], dtype=float64)}

## Task 3: Minimize

- Use estimagic to minimize the criterion
    - pass the gradient function you computed above to the minimize call
    - use the `"scipy_lbfgsb"` algorithm

In [5]:
res = em.minimize(
    criterion=criterion,
    derivative=gradient,
    params=start_params,
    algorithm="scipy_lbfgsb",
)

res.params

{'a': 3.129255066950777,
 'b': DeviceArray([-4.86427302e-06,  1.00000000e+00,  1.99999782e+00], dtype=float64),
 'c': DeviceArray([[ 1.00000000e+00, -4.86427302e-06],
              [-4.86427302e-06,  1.00000000e+00]], dtype=float64)}