# Hybrid Zonotopes

In [1]:
# Generic TLT imports
from pyspect import *
from pyspect.langs.ltl import *
# Hybrid Zonotope imports
from hz_reachability.hz_impl import HZImpl
from hz_reachability.systems.cars import CarLinearModel2D
from hz_reachability.shapes import HZShapes
from hz_reachability.spaces import ParkingSpace

TLT.select(ContinuousLTL)

## Environment definition

In [2]:
# Option 1: Use the existing set templates or cretate your own (Not implemented for HZ yet).
# e.g., state_space = ReferredSet('state_space')

# Option 2: Use the generic Set method to import any custom shape
shapes = HZShapes()
center = Set(shapes.center())
road_west = Set(shapes.road_west())
road_east = Set(shapes.road_east())
road_north = Set(shapes.road_north())
road_south = Set(shapes.road_south())

## Definitions

### Task

In [3]:
# Example task: Stay in road_e, or road_n UNTIL you REACH exit_n.
task = Until(Or(road_east, road_north), center)

### Dynamics

In [4]:
reach_dynamics = CarLinearModel2D()

### Implementation

In [5]:
# Hybrid Zonotope implementation
impl = HZImpl(dynamics=reach_dynamics, space = ParkingSpace(), time_horizon = 5)

## Solve

- `construct(task)`: Take an LTL, a set, or a lazy set, or an already constructed TLT and make sure it is a valid TLT object. Basically construct the compute graph for the given task.
- `realize(impl)`: initiates the actual computations.
- `out`: The final set in your specific set implementation. e.g., it would be a hybrid zonotope.

In [6]:
# Solve the problem - Find the states that can satisfy the task
out = TLT.construct(task).realize(impl)

print(f'out = {out}')
print(f'Gc = \n{out.Gc}')

out = <hz_reachability.sets.HybridZonotope object at 0x7f5374007fe0>
Gc = 
[[0.05 0.   0.   ... 0.   0.   0.  ]
 [0.   0.95 0.   ... 0.   0.   0.  ]]


## Conversion

### HJ Setup

In [7]:
import numpy as np
import hj_reachability as hj

class HJImpl(ContinuousLTL.Impl):

    solver_settings = hj.SolverSettings.with_accuracy("low")

    def __init__(self, dynamics, grid, time_horizon):
        self.grid = grid
        self.ndim = grid.ndim
        self.dynamics = dynamics
        self.timeline = self.new_timeline(time_horizon)
           
    def new_timeline(self, target_time, start_time=0, time_step=0.2):
        assert time_step > 0
        is_forward = target_time >= start_time
        target_time += 1e-5 if is_forward else -1e-5
        time_step *= 1 if is_forward else -1
        return np.arange(start_time, target_time, time_step)

    def set_axes_names(self, *args):
        assert len(args) == self.ndim
        self._axes_names = tuple(args)

    def axis(self, name: str) -> int:
        assert name in self._axes_names, f'Axis ({name=}) does not exist.'
        return self._axes_names.index(name)

    def axis_name(self, i: int) -> str:
        assert i < len(self._axes_names), f'Axis ({i=}) does not exist.'
        return self._axes_names[i]

    def axis_is_periodic(self, i: int) -> bool:
        assert i < len(self._axes_names), f'Axis ({i=}) does not exist.'
        return bool(self.grid._is_periodic_dim[i])

    def plane_cut(self, normal, offset, axes=None):
        data = np.zeros(self.grid.shape)
        axes = axes or list(range(self.grid.ndim))
        x = lambda i: self.grid.states[..., i]
        for i, k, m in zip(axes, normal, offset):
            data -= k*x(i) - k*m
        return data

    def empty(self):
        return np.ones(self.grid.shape)*np.inf
    
    def complement(self, vf):
        return np.asarray(-vf)
    
    def intersect(self, vf1, vf2):
        return np.maximum(vf1, vf2)

    def union(self, vf1, vf2):
        return np.minimum(vf1, vf2)
    
    def reach(self, target, constraints=None):
        self.dynamics.with_mode('reach')
        if not self.is_invariant(target):
            target = np.flip(target, axis=0)
        if not self.is_invariant(constraints):
            constraints = np.flip(constraints, axis=0)
        vf = hj.solve(self.solver_settings,
                      self.dynamics,
                      self.grid,
                      -self.timeline,
                      target,
                      constraints)
        return np.flip(np.asarray(vf), axis=0)
    
    def avoid(self, target, constraints=None):
        self.dynamics.with_mode('avoid')
        if not self.is_invariant(target):
            target = np.flip(target, axis=0)
        if not self.is_invariant(constraints):
            constraints = np.flip(constraints, axis=0)
        vf = hj.solve(self.solver_settings,
                      self.dynamics,
                      self.grid,
                      -self.timeline,
                      target,
                      constraints)
        return np.flip(np.asarray(vf), axis=0)

    def project_onto(self, vf, *idxs, keepdims=False, union=True):
        idxs = [len(vf.shape) + i if i < 0 else i for i in idxs]
        dims = [i for i in range(len(vf.shape)) if i not in idxs]
        if union:
            return vf.min(axis=tuple(dims), keepdims=keepdims)
        else:
            return vf.max(axis=tuple(dims), keepdims=keepdims)

    def is_invariant(self, vf):
        return (True if vf is None else
                len(vf.shape) != len(self.timeline.shape + self.grid.shape))

    def make_tube(self, vf):
        return (vf if not self.is_invariant(vf) else
                np.concatenate([vf[np.newaxis, ...]] * len(self.timeline)))

In [8]:
from hj_reachability.systems import Bicycle4D
from pyspect.plotting.levelsets import *

from math import pi

# Define origin and size of area, makes it easier to scale up/down later on 
X0, XN = -1.2, 2.4
Y0, YN = -1.2, 2.4

min_bounds = np.array([   X0,    Y0, +0.3, +0.3])
max_bounds = np.array([XN+X0, YN+Y0, +0.8, +0.8])
grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(min_bounds, max_bounds),
                                                               (31, 31, 15, 15),
                                                               periodic_dims=2)

impl = HJImpl(reach_dynamics, grid, 3)
impl.set_axes_names('x', 'y', 'vx', 'vy')

### Method

In [None]:
import hj_reachability.shapes as shp

vf = shp.point(grid, out.C)