In [1]:
import jax
import jax.numpy as jnp
import numpy as np
import hj_reachability as hj

In [2]:
from hj_reachability import dynamics
from hj_reachability import sets

class Drone2D(dynamics.ControlAndDisturbanceAffineDynamics):

  def __init__(self,
               k,
               g=9.8,
               u_bar=1.):
    self.k = k
    self.g = g
    control_mode = 'max'
    disturbance_mode = 'min'
    control_space = sets.Box(jnp.array([-u_bar]), jnp.array([u_bar]))
    disturbance_space = sets.Box(jnp.array([0.]), jnp.array([0.]))
    super().__init__(control_mode, disturbance_mode, control_space, disturbance_space)

  def open_loop_dynamics(self, state, time):
    _, v = state
    return jnp.array([v, -self.g])

  def control_jacobian(self, state, time):
    return jnp.array([[0.], [self.k]])

  def disturbance_jacobian(self, state, time):
    return jnp.array([[0.], [0.]])

In [4]:
# Different values of constant K that we will compute the value function for
ks = np.linspace(6, 12, 11, endpoint=True)

# Define the computation grid for numerically solving the PDE
grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(
    hj.sets.Box(np.array([-2., -4.]),
                np.array([+2., +4.])),
    (51, 51))

# Define the implicit function l(x) for the failure set
failure_values = 1.5 - jnp.abs(grid.states[..., 0])

# Solver settings
times = np.linspace(0, -5, 11, endpoint=True)
solver_settings = hj.SolverSettings.with_accuracy('very_high',
                                                  hamiltonian_postprocessor=hj.solver.backwards_reachable_tube)



In [5]:
# Compute the value function by solving the PDE for each K
values = np.full((len(ks), len(times), 51, 51), fill_value=np.nan)

from tqdm import tqdm
for i, k in tqdm(enumerate(ks)):
  dynamics = Drone2D(k)
  values[i] = hj.solve(solver_settings, dynamics, grid, times, failure_values)

0it [00:00, ?it/s]

100%|##########|  5.0000/5.0 [00:00<00:00,  8.81sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00,  8.78sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00,  8.58sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00,  8.37sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00,  8.09sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00,  8.02sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00,  7.78sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00,  7.64sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00,  7.51sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00,  7.06sim_s/s]
100%|##########|  5.0000/5.0 [00:00<00:00,  7.17sim_s/s]
11it [00:12,  1.12s/it]


In [None]:
import matplotlib.pyplot as plt
from ipywidgets import interact, FloatSlider

vbar = 1.5
def plot_value_function(k, t):
  ki = np.argwhere(np.isclose(ks, k)).item()
  ti = np.argwhere(np.isclose(times, t)).item()

  plt.figure()
  plt.title(f'$V(x, {t})$ for k={k}')
  plt.xlabel('$v_z$ (m/s)')
  plt.ylabel('$z$ (m)')

  plt.pcolormesh(
      grid.coordinate_vectors[1],
      grid.coordinate_vectors[0],
      values[ki, ti],
      cmap='RdBu',
      vmin=-vbar, vmax=vbar
  )
  plt.colorbar()
  plt.contour(
      grid.coordinate_vectors[1],
      grid.coordinate_vectors[0],
      values[ki, ti],
      levels=0,
      colors='k'
  )
  plt.contour(
      grid.coordinate_vectors[1],
      grid.coordinate_vectors[0],
      failure_values,
      levels=0,
      colors='r'
  )
  plt.show()

interact(
    plot_value_function,
    k=FloatSlider(value=12., min=6., max=12., step=0.6),
    t=FloatSlider(value=0., min=-5., max=0., step=0.5)
)