# Setup

Dependencies:
- System: python3
- Python: jupyter, numpy, scipy, jax, matplotlib

Example setup for a Ubuntu system (Mac users, maybe `brew` instead of `sudo apt`; Windows users, learn to love [WSL](https://docs.microsoft.com/en-us/windows/wsl/install-win10)):
```
/usr/bin/python3 -m pip install --upgrade pip
pip install --upgrade jupyter numpy scipy matplotlib jax jaxlib
jupyter nbextension enable --py widgetsnbextension  # necessary for interactive sliders to show up
jupyter notebook  # from the directory of this notebook
```
Alternatively, view this notebook on [Google Colab](https://colab.research.google.com/github/StanfordASL/AA203-Examples/blob/master/Lecture-5/Free%20Final%20Time%20Problem.ipynb).

In [None]:
import numpy as np
from scipy.optimize import minimize, Bounds
from scipy.integrate import solve_bvp

## [Closed-form solution](https://www.wolframcloud.com/obj/schmrlng/Published/Optimal%20Control%20Example%20%28Free%20Final%20Time%29.nb)

In [None]:
def free_final_time_analytical(a=1.0, b=1.0):
    # Closed-form solution for optimal final time.
    # See https://www.wolframcloud.com/obj/schmrlng/Published/Optimal%20Control%20Example%20%28Free%20Final%20Time%29.nb
    return (1800 * b / a)**(1 / 5)


free_final_time_analytical()

## Two-point boundary value problem solver

In [None]:
def free_final_time_2pbvp(a=1.0, b=1.0, N=20):
    # Indirect method (solving a two-point boundary value problem).

    def ode(t, x_p_tf):
        x1, x2, p1, p2, tf = x_p_tf
        return tf * np.array([x2, -p2 / b, np.zeros_like(t), -p1, np.zeros_like(t)])

    def boundary_conditions(x_p_tf_0, x_p_tf_N):
        x1_0, x2_0, p1_0, p2_0, tf_0 = x_p_tf_0
        x1_N, x2_N, p1_N, p2_N, tf_N = x_p_tf_N
        return np.array([x1_0 - 10, x2_0, x1_N, x2_N, a * tf_0 - p2_0**2 / (2 * b)])

    return solve_bvp(
        ode, boundary_conditions, np.linspace(0, 1, N + 1),
        np.array([np.linspace(10, 0, N + 1),
                  np.zeros(N + 1),
                  np.zeros(N + 1),
                  np.zeros(N + 1),
                  np.ones(N + 1)]))

In [None]:
free_final_time_2pbvp()

In [None]:
free_final_time_2pbvp().y[-1]

## Exploring indirect single shooting (see also HW1P5)

In [None]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
from jax.experimental.ode import odeint

from ipywidgets import interact

In [None]:
def indirect_single_shooting_error(initial_costate_and_final_time, a=1.0, b=1.0):
    initial_state = np.array([10., 0.])
    initial_costate = initial_costate_and_final_time[:-1]
    final_time = initial_costate_and_final_time[-1]

    def shooting_ode(state_costate, t):
        (x1, x2), (p1, p2) = state_costate
        return jnp.array([x2, -p2 / b, 0, -p1])

    states, costates = odeint(shooting_ode, (initial_state, initial_costate), jnp.array([0, final_time]))
    return jnp.linalg.norm(jnp.append(states[-1], -costates[-1, 1]**2 / (2 * b) + a * final_time))

In [None]:
vv_indirect_single_shooting_error = jax.jit(jax.vmap(jax.vmap(indirect_single_shooting_error)))
X, Y = np.meshgrid(np.linspace(-4, 8, 100), np.linspace(-4, 8, 100))


@interact
def plot_slice(final_time=(0.1, 8.0)):
    plt.figure(figsize=(12, 10))
    plt.contourf(X,
                 Y,
#                  vv_indirect_single_shooting_error(np.stack([X, Y, final_time * np.ones_like(X)], -1)))
                 np.log(1 + vv_indirect_single_shooting_error(np.stack([X, Y, final_time * np.ones_like(X)], -1))),
                 levels=40,
                 vmin=0,
                 vmax=10)
    plt.colorbar()

## Direct single shooting

In [None]:
def free_final_time_shooting(a=1.0, b=1.0, N=20):
    # Direct method (specifically, a shooting method).
    x0 = np.array([10., 0.])
    xf = np.zeros(2)

    def cost(tf_u):
        tf, u = tf_u[0], tf_u[1:]
        dt = tf / N
        return 0.5 * (a * tf**2 + b * dt * np.sum(u**2))

    def terminal_constraint(tf_u):
        tf, u = tf_u[0], tf_u[1:]
        dt = tf / N
        x = x0
        for ui in u:
            x = np.array([x[0] + x[1] * dt + 0.5 * ui * dt**2, x[1] + ui * dt])
        return x - xf

    tf_u = np.concatenate([np.ones(1), np.zeros(N)])
    return minimize(cost,
                    tf_u,
                    bounds=Bounds(np.concatenate([np.zeros(1), -np.inf * np.ones(N)]), np.inf * np.ones(N + 1)),
                    constraints={
                        'type': 'eq',
                        'fun': terminal_constraint
                    },
                    options={'maxiter': 1000})

In [None]:
free_final_time_shooting()

In [None]:
free_final_time_shooting().x[0]