In [None]:
import os

os.environ['XLA_PYTHON_CLIENT_MEM_FRACTION'] = '1.0'

import jax.numpy as jnp
import numpy as np

import matplotlib.pyplot as plt

import plotly.io as pio
pio.renderers.default = "browser"

import plotly.graph_objects as go

import hj_reachability as hj

from dyn_sys.DvNonlinearCar import DvNonlinearCar


In [None]:
dynamics = DvNonlinearCar()

# limits of the grid in degrees
x1_lim = 180
x2_lim = 35
x3_lim = 0.4

radius = 5

x1_lim = x1_lim * jnp.pi / 180
x2_lim = x2_lim * jnp.pi / 180
radius = radius * jnp.pi / 180

grid_size = (300, 300, 300) # count of grid points in each dimension

# for 4GB VRAM GPU, the grid size limit is around 300^3

grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(np.array([-x1_lim, -x2_lim, -x3_lim]),
                                                                           np.array([x1_lim, x2_lim, x3_lim])),
                                                                           grid_size)
time = 0.
target_time = -0.7


In [None]:
gamma_list = [0, 1]
target_values = []

for gamma in gamma_list:
    solver_settings = hj.SolverSettings.with_accuracy("very_high",
                                                  hamiltonian_postprocessor=hj.solver.backwards_reachable_tube,)
    values = jnp.linalg.norm(grid.states[..., :3], axis=-1) - radius
    target_values.append([hj.step(solver_settings, dynamics, grid, time, values, target_time)])

In [None]:
slice = grid_size[-1] // 2
#slice = 120
plt.jet()

f, ax = plt.subplots(1, len(target_values), figsize=(13, 8))
for i, result in enumerate(target_values):
    ax[i].contourf(grid.coordinate_vectors[0], grid.coordinate_vectors[1], result[0][:, :, slice].T)
    ax[i].contour(grid.coordinate_vectors[0],
                grid.coordinate_vectors[1],
                result[0][:, :, slice].T,
                levels=0,
                colors="black",
                linewidths=3)
    ax[i].set_title(f"gamma = {gamma_list[i]}")
f.colorbar(ax[0].collections[0], ax=ax, orientation='horizontal', fraction=0.02, pad=0.1)


In [None]:
for i, results in enumerate(target_values):
    fig = go.Figure(data=go.Surface(x=grid.coordinate_vectors[0],
                             y=grid.coordinate_vectors[1],
                             z=results[0][:, :, slice].T,
                             colorscale="jet",
                             showscale=True,
                             ))
    fig.update_layout(title=f"gamma = {gamma_list[i]}",
                      scene=dict(
                          xaxis_title='x1',
                          yaxis_title='x2',
                          zaxis_title='value',
                      ))
    fig.show()

In [None]:
for i, result in enumerate(target_values):
    fig = go.Figure(data=go.Isosurface(x=grid.states[..., 0].ravel(),
                                       y=grid.states[..., 1].ravel(),
                                       z=grid.states[..., 2].ravel(),
                                       value=result[0].ravel(),
                                       isomin=0,
                                       isomax=0,
                                       surface_count=1,
                                       colorscale="jet"))
    fig.update_layout(title=f"gamma = {gamma_list[i]}",
                      scene=dict(
                          xaxis_title='x1',
                          yaxis_title='x2',
                          zaxis_title='value',
                      ))
    fig.show()