# Setup

Dependencies:
- System: python3, ffmpeg (for rendering animations)
- Python: jupyter, jax, numpy, matplotlib, plotly, tqdm, hj-reachability

Example setup for a Ubuntu system (Mac users, maybe `brew` instead of `sudo apt`; Windows users, learn to love [WSL](https://docs.microsoft.com/en-us/windows/wsl/install-win10)):
```
sudo apt install ffmpeg
/usr/bin/python3 -m pip install --upgrade pip
pip install --upgrade jupyter jax numpy matplotlib plotly tqdm hj-reachability
jupyter notebook  # from the directory of this notebook
```
Alternatively, view this notebook on [Google Colab](https://colab.research.google.com/github/StanfordASL/AA203-Examples/blob/master/Lecture-12/HJ%20Reachability%20--%20Pursuit%20Evasion.ipynb) and run a cell containing this command:
```
!pip install --upgrade hj-reachability
```

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

from IPython.display import HTML
import matplotlib.animation as anim
import matplotlib.pyplot as plt
from matplotlib import transforms
import plotly.graph_objects as go

import hj_reachability as hj

In [None]:
dynamics = hj.systems.Air3d()
grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(lo=np.array([-6., -10., 0.]),
                                                                           hi=np.array([20., 10., 2 * np.pi])),
                                                               (53, 41, 50),
                                                               periodic_dims=2)
terminal_values = jnp.linalg.norm(grid.states[..., :2], axis=-1) - 5

solver_settings = hj.SolverSettings.with_accuracy("very_high",
                                                  hamiltonian_postprocessor=hj.solver.backwards_reachable_tube)

In [None]:
times = np.linspace(0, -3.5, 71)
all_values = hj.solve(solver_settings, dynamics, grid, times, terminal_values)

In [None]:
go.Figure(data=go.Isosurface(x=grid.states[..., 0].ravel(),
                             y=grid.states[..., 1].ravel(),
                             z=grid.states[..., 2].ravel(),
                             value=all_values[-1].ravel(),
                             colorscale="jet",
                             isomin=0,
                             surface_count=1,
                             isomax=0))

In [None]:
vmin, vmax = all_values.min(), all_values.max()
levels = np.linspace(round(vmin), round(vmax), round(vmax) - round(vmin) + 1)
fig = plt.figure(figsize=(13, 8))
plt.jet()


def render_frame(i, colorbar=False):
    plt.contourf(grid.coordinate_vectors[0],
                 grid.coordinate_vectors[1],
                 all_values[i, :, :, 30].T,
                 vmin=vmin,
                 vmax=vmax,
                 levels=levels)
    if colorbar:
        plt.colorbar()
        plt.title(f"Slice at θ_rel = {float(grid.coordinate_vectors[2][30]):4.3f}", fontsize=20)
    plt.contour(grid.coordinate_vectors[0],
                grid.coordinate_vectors[1],
                all_values[-1, :, :, 30].T,
                levels=0,
                colors="black",
                linewidths=3)


render_frame(0, True)
animation = HTML(anim.FuncAnimation(fig, render_frame, all_values.shape[0], interval=50).to_html5_video())
plt.close()
animation

In [None]:
# Re-solve, this time just BRS (i.e., not BRT) computation.
# This approach for pursuit-evasion policy computation is a bit hacky/memory-inefficient for now.
solver_settings = hj.SolverSettings.with_accuracy("very_high")
dynamics = hj.systems.Air3d()
grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(lo=np.array([-10., -10., 0.]),
                                                                           hi=np.array([20., 10., 2 * np.pi])),
                                                               (61, 41, 50),
                                                               periodic_dims=2)
times = np.linspace(0, -5, 51)
terminal_values = jnp.linalg.norm(grid.states[..., :2], axis=-1) - 5

all_values = hj.solve(solver_settings, dynamics, grid, times, terminal_values)
all_grad_values = jax.vmap(grid.grad_values)(all_values)

In [None]:
def relative_state(x):
    xa, ya, qa, xb, yb, qb = x
    rot_matrix = jnp.array([[jnp.cos(qa), jnp.sin(qa)], [-jnp.sin(qa), jnp.cos(qa)]])
    return jnp.array([*(rot_matrix @ jnp.array([xb - xa, yb - ya])), jnp.mod(qb - qa, 2 * jnp.pi)])


def joint_dynamics(x, u, relative_dynamics=dynamics):
    return jnp.array([
        relative_dynamics.evader_speed * jnp.cos(x[2]),
        relative_dynamics.evader_speed * jnp.sin(x[2]),
        u[0],
        relative_dynamics.pursuer_speed * jnp.cos(x[5]),
        relative_dynamics.pursuer_speed * jnp.sin(x[5]),
        u[1],
    ])


@jax.jit
def joint_step(joint_state, dt, t):
    state = relative_state(joint_state)
    all_state_values = jax.vmap(grid.interpolate, (0, None))(all_values, state)
    # Find the time horizon at which the value function is minimized (excluding any immediate period
    # where it's increasing, in which case we should be reasoning about the "next" min).
    i_increasing = jnp.concatenate([np.array([True]), all_state_values[1:] - all_state_values[:-1] > 0])
    i = jnp.argmin(jnp.where(i_increasing, np.inf, all_state_values))
    i = jnp.where(i == 0, -1, i)  # Default to max horizon policy if min value is the current time step.
    value = grid.interpolate(all_values[i], state)
    grad_value = grid.interpolate(all_grad_values[i], state)
    a_opt, b_opt = dynamics.optimal_control_and_disturbance(state, t * dt, grad_value)
    return joint_state + joint_dynamics(joint_state, jnp.concatenate([a_opt, b_opt])) * dt, value


def joint_trajectory(evader_state, pursuer_state, dt=1 / 30, T=5):
    joint_states = [np.concatenate([evader_state, pursuer_state])]
    values = []
    for t in range(int(T / dt)):
        joint_state, value = joint_step(joint_states[-1], dt, t)
        joint_states.append(joint_state)
        values.append(value)
    return np.array(joint_states[:-1]), np.array(values)


def animate_joint_trajectory(evader_state, pursuer_state, dt=1 / 30, T=5, animation_time_scale_factor=2):
    joint_states, values = joint_trajectory(evader_state, pursuer_state, dt, T)
    xmin, xmax = np.min(joint_states[:, [0, 3]]), np.max(joint_states[:, [0, 3]])
    ymin, ymax = np.min(joint_states[:, [1, 4]]), np.max(joint_states[:, [1, 4]])

    fig = plt.figure(figsize=(10, 8))
    ax = fig.gca()
    ax.set_xlim(xmin - 3, xmax + 3)
    ax.set_ylim(ymin - 3, ymax + 3)
    ax.set_aspect("equal", adjustable="box")
    triangle_pts = np.array([[-.2, -.2], [1., 0], [-.2, .2]])
    evader = ax.add_patch(plt.Polygon(triangle_pts))
    pursuer = ax.add_patch(plt.Polygon(triangle_pts, color="orange"))
    evader_radius = ax.add_patch(plt.Circle([0, 0], 5, alpha=0.5))

    def render_frame(i):
        evader.set_transform(
            transforms.Affine2D().rotate(joint_states[i, 2]).translate(joint_states[i, 0], joint_states[i, 1]) +
            ax.transData)
        pursuer.set_transform(
            transforms.Affine2D().rotate(joint_states[i, 5]).translate(joint_states[i, 3], joint_states[i, 4]) +
            ax.transData)
        evader_radius.set_center([joint_states[i, 0], joint_states[i, 1]])
        return [evader, pursuer, evader_radius]

    animation = HTML(
        anim.FuncAnimation(fig,
                           render_frame,
                           joint_states.shape[0],
                           interval=1000 * dt / animation_time_scale_factor,
                           blit=True).to_html5_video())
    plt.close()
    return animation

In [None]:
animate_joint_trajectory(np.zeros(3), np.array([15., -6., 3*np.pi/4]))

In [None]:
animate_joint_trajectory(np.zeros(3), np.array([5., -6., 3*np.pi/4]))

In [None]:
animate_joint_trajectory(np.zeros(3), np.array([10, 0, 3.77]), T=7)