# Hybrid Zonotopes

In [None]:
# Generic TLT imports
from pyspect import *
from pyspect.langs.ltl import *
# Hybrid Zonotope imports
from hz_reachability.hz_impl import HZImpl
from hz_reachability.systems.cars import *
from hz_reachability.shapes import HZShapes
from hz_reachability.spaces import EmptySpace

TLT.select(ContinuousLTL)

## Environment definition

In [None]:
# Option 1: Use the existing set templates or cretate your own (Not implemented for HZ yet).
# e.g., state_space = ReferredSet('state_space')

# Option 2: Use the generic Set method to import any custom shape
shapes = HZShapes()
center = Set(shapes.center())
road_west = Set(shapes.road_west())
road_east = Set(shapes.road_east())
road_north = Set(shapes.road_north())
road_south = Set(shapes.road_south())

## Definitions

### Task and Dynamics

In [None]:
reach_dynamics = CircularBicycle5DLinearized()

In [None]:
## CONSTANTS ##

LONVEL = (0.3, 0.6) # [m/s]
LATVEL = (0, 0.2)   # [m/s]

RBOUND = (0.8, 1.6) # [m]

RI = 1.0    # [m]
RM = 1.3    # [m]
RO = 1.6    # [m]

PHI_LOOKAHEAD = 2 # [rad]
TIME_HORIZON = 2 # [s]; comp time horizon

## TRAFFIC RULES ##

lonspeed = BoundedSet(v_phi=LONVEL)
latspeed = BoundedSet(v_r=LATVEL)

## ROUNDABOUT ##

lanes = And(BoundedSet(r=(RI, RO)), lonspeed, latspeed)

inner = And(BoundedSet(r=(RI, RM)), lanes)
outer = And(BoundedSet(r=(RM, RO)), lanes)

## TASKS ##

stay_inner = Until(inner, inner)
stay_outer = Until(outer, outer)
goto_inner = Until(lanes, inner)
goto_outer = Until(lanes, outer)

In [None]:
terminal = And(lanes, BoundedSet(phi=(1.7, 2.1)))

OBS1 = BoundedSet(r=(RI, RM), phi=(1.0, 1.3))

task = Until(And(Not(OBS1), goto_outer), terminal)

### Implementation

In [None]:
# Hybrid Zonotope implementation
PHI_LOOKAHEAD = 2 # [rad]
max_bounds = np.array([1.8,       +PHI_LOOKAHEAD, +0.1,  +0.7,   +np.pi/5])
min_bounds = np.array([0.8, 0.5 * -PHI_LOOKAHEAD, -0.1,  +0.3,   -np.pi/5])

grid_shape = (11, 51, 5, 5, 5)

space = EmptySpace(min_bounds, max_bounds)
space.remove_redundant = True
impl = HZImpl(dynamics=reach_dynamics, space = space, time_horizon = 2)
impl.set_axes_names('r', 'phi', 'v_r', 'v_phi', 'v_yaw')


## Solve

- `construct(task)`: Take an LTL, a set, or a lazy set, or an already constructed TLT and make sure it is a valid TLT object. Basically construct the compute graph for the given task.
- `realize(impl)`: initiates the actual computations.
- `out`: The final set in your specific set implementation. e.g., it would be a hybrid zonotope.

In [None]:
# Solve the problem - Find the states that can satisfy the task
out = TLT.construct(inner).realize(impl)
# out = space.zono_op.redundant_gc_hz(out)
# out = space.zono_op.redundant_c_hz(out)

# From string s, shift lines to the right by n spaces
def shift_lines(s, n):
    return '\n'.join([' '*n + l for l in s.split('\n')])

def print_hz(hz):
    dim = lambda m: "x".join(map(str, m.shape))
    print(f'Gc<{dim(hz.Gc)}>', shift_lines(str(hz.Gc), 2), sep='\n')
    print(f'Gb<{dim(hz.Gb)}>', shift_lines(str(hz.Gb), 2), sep='\n')
    print(f'C<{dim(hz.C)}>', shift_lines(str(hz.C), 2), sep='\n')
    print(f'Ac<{dim(hz.Ac)}>', shift_lines(str(hz.Ac), 2), sep='\n')
    print(f'Ab<{dim(hz.Ab)}>', shift_lines(str(hz.Ab), 2), sep='\n')
    print(f'b<{dim(hz.b)}>', shift_lines(str(hz.b), 2), sep='\n')

def save_hz(hz, directory='.'):
    from scipy.io import savemat
    from pathlib import Path
    Path(directory).mkdir(parents=True, exist_ok=True)
    savemat(f'{directory}/Gc.mat', {'Gc': hz.Gc})
    savemat(f'{directory}/Gb.mat', {'Gb': hz.Gb})
    savemat(f'{directory}/C.mat', {'C': hz.C})
    savemat(f'{directory}/Ac.mat', {'Ac': hz.Ac})
    savemat(f'{directory}/Ab.mat', {'Ab': hz.Ab})
    savemat(f'{directory}/b.mat', {'b': hz.b})

# save_hz(out, 'out-remove-redundant')
# print_hz(out)

In [None]:
import plotly.graph_objects as go

# Function to add vectors to the plot
def add_vector(fig, vector, offset, color, name):
    fig.add_trace(go.Scatter3d(
        x=[offset[0], offset[0] + vector[0]], y=[offset[1], offset[1] + vector[1]], z=[offset[2], offset[2] + vector[2]],
        mode="lines+markers",
        marker=dict(size=4),
        line=dict(width=4, color=color),
        name=name
    ))

def plot_3d_boolean_tessellation(fig, bool_grid, x_vals, y_vals, z_vals, voxel_size=1, color="blue", opacity=0.1):
    """
    Plots a 3D boolean tessellation grid using Plotly.
    
    Parameters:
        bool_grid (np.ndarray): 3D numpy array of boolean values (shape: Nx, Ny, Nz).
        x_vals (np.ndarray): 1D array of X coordinates (same length as bool_grid.shape[0]).
        y_vals (np.ndarray): 1D array of Y coordinates (same length as bool_grid.shape[1]).
        z_vals (np.ndarray): 1D array of Z coordinates (same length as bool_grid.shape[2]).
        voxel_size (float): Size of each cube (optional).
        color (str): Color of the voxels (optional).
        opacity (float): Opacity of the voxels (optional).
    """
    Nx, Ny, Nz = bool_grid.shape  # Get the grid dimensions

    # Iterate through the grid and plot only the 'True' voxels
    for i in range(Nx):
        for j in range(Ny):
            for k in range(Nz):
                if bool_grid[i, j, k]:  # If the voxel is 'True', plot it
                    x, y, z = x_vals[i], y_vals[j], z_vals[k]
                    fig.add_trace(go.Mesh3d(
                        x=[x, x+voxel_size, x+voxel_size, x, x, x+voxel_size, x+voxel_size, x],
                        y=[y, y, y+voxel_size, y+voxel_size, y, y, y+voxel_size, y+voxel_size],
                        z=[z, z, z, z, z+voxel_size, z+voxel_size, z+voxel_size, z+voxel_size],
                        i=[0, 0, 0, 1, 1, 3, 2, 4, 4, 5, 5, 6],
                        j=[1, 2, 3, 3, 5, 2, 6, 5, 6, 6, 7, 7],
                        k=[3, 3, 1, 2, 4, 6, 7, 7, 5, 4, 4, 5],
                        color=color,
                        opacity=opacity
                    ))

    # Set figure layout
    fig.update_layout(
        scene=dict(
            xaxis=dict(title="X", range=[x_vals.min(), x_vals.max()]),
            yaxis=dict(title="Y", range=[y_vals.min(), y_vals.max()]),
            zaxis=dict(title="Z", range=[z_vals.min(), z_vals.max()])
        ),
        title="3D Boolean Tessellation Grid with Custom Coordinates",
        showlegend=False
    )

    import numpy as np

def plot_3d_boolean_isosurface(fig, bool_grid, x_vals, y_vals, z_vals, color="blue", opacity=0.5):
    """
    Plots a 3D boolean isosurface representation using Plotly.
    
    Parameters:
        fig (go.Figure): The Plotly figure object to add the isosurface to.
        bool_grid (np.ndarray): 3D numpy array of boolean values (shape: Nx, Ny, Nz).
        x_vals (np.ndarray): 1D array of X coordinates (same length as bool_grid.shape[0]).
        y_vals (np.ndarray): 1D array of Y coordinates (same length as bool_grid.shape[1]).
        z_vals (np.ndarray): 1D array of Z coordinates (same length as bool_grid.shape[2]).
        color (str): Color of the isosurface (optional).
        opacity (float): Opacity of the isosurface (optional).
    """
    # Convert boolean grid to numerical values (1 for True, 0 for False)
    numerical_grid = bool_grid.astype(float)
    
    # Create the isosurface plot
    fig.add_trace(go.Isosurface(
        x=np.repeat(x_vals, len(y_vals) * len(z_vals)),
        y=np.tile(np.repeat(y_vals, len(z_vals)), len(x_vals)),
        z=np.tile(z_vals, len(x_vals) * len(y_vals)),
        value=numerical_grid.flatten(),
        isomin=0.5,  # Threshold for rendering the isosurface
        isomax=1.0,
        surface_count=1,  # Single surface rendering
        colorscale=[[0, color], [1, color]],
        opacity=opacity
    ))
    
    # Set figure layout
    fig.update_layout(
        scene=dict(
            xaxis=dict(title="X", range=[x_vals.min(), x_vals.max()]),
            yaxis=dict(title="Y", range=[y_vals.min(), y_vals.max()]),
            zaxis=dict(title="Z", range=[z_vals.min(), z_vals.max()])
        ),
        title="3D Boolean Isosurface Representation",
        showlegend=False
    )

def bin2f(N, L=None):
    """
    Convert the non-negative integer N to a sequence of +1.0 / -1.0,
    corresponding to each bit in N's binary representation.
    
    Bit '0' -> -1.0
    Bit '1' -> +1.0
    
    Returns a list of floats in most-significant-bit-first order.
    """
    signs_reversed = []

    while N > 0:
        bit = N & 1   # extract the least-significant bit
        if bit == 1:
            signs_reversed.append(1.0)
        else:
            signs_reversed.append(-1.0)
        N >>= 1       # shift N to process the next bit
    
    if L is not None:
        L -= len(signs_reversed)
        signs_reversed += [-1.0] * L

    # signs_reversed is in LSB -> MSB order; reverse it to get MSB -> LSB
    return signs_reversed[::-1]

# stuff

We have $\mathcal{Z} \subseteq \mathcal{S}_\text{CZ}$. Specifically, we have
$$\begin{split}
    \mathcal{Z} &= \{ c + G \xi : A \xi = b, |\xi|_\infty \leq 1 \}                              \\
                &= \{ c + G \xi : \xi = A^\dagger b + N_A \eta, |\xi|_\infty \leq 1 \}           \\
                &= \{ c + G (A^\dagger b + N_A \eta) : |A^\dagger b + N_A \eta|_\infty \leq 1 \} \\
                &= \{ c + G A^\dagger b + G N_A \eta : |A^\dagger b + N_A \eta|_\infty \leq 1 \} \\
                &= \{ x = c' + G' \eta : |A^\dagger b + N_A \eta|_\infty \leq 1 \}               \\
                &= \{ x : \eta = G'^\dagger(x - c'), |A^\dagger b + N_A \eta|_\infty \leq 1 \}   \\
                &= \{ x : |A^\dagger b + N_A G'^\dagger(x - c')|_\infty \leq 1 \}                \\
\end{split}$$
where $c' = c + G A^\dagger b$ and $G' = G N_A$.

If $T: \mathcal{S}_\text{HZ} \rightarrow \mathcal{S}_\text{HJ}$ 
and $T': \mathcal{S}_\text{CZ} \rightarrow \mathcal{S}_\text{HJ}$ 
then
$$\begin{split}
    T(\langle c, G_c, G_b, A_c, A_b, b \rangle) 
        &= \; \min_\delta \; T'(\langle c + G_b \delta, G_c, A_c, b - A_b \delta \rangle) \\
        &= \; \min_\delta \; \{
            x : |
                A_c^\dagger (b - A_b \delta)
                + N_{A_c} (G_c N_{A_c})^\dagger (
                    x
                    - (
                        c 
                        + G_c A_c^\dagger b 
                        + (G_b - G_c A_c^\dagger A_b) \delta
                    )
                )
            |_\infty \leq 1
        \} \\
        &= \; \min_\delta \; \{ x : |
            A_c^\dagger (b - A_b \delta) 
            + N_{A_c} N_{A_c}^\dagger G_c^\dagger (
                x
                - c 
                - G_c A_c^\dagger b
                - (G_b - G_c A_c^\dagger A_b) \delta
            )
        |_\infty \leq 1 \} \\
        &= \; \min_\delta \; \{
            x : |A_c^\dagger b - A_c^\dagger A_b \delta + H (x - m - D \delta) |_\infty \leq 1
        \} \\
        &= \; \min_\delta \; \{
            x : |A_c^\dagger b + H (x - m) - A_c^\dagger A_b \delta - H D \delta |_\infty \leq 1
        \} \\
        &= \; \min_\delta \; \{
            x : |A_c^\dagger b + H (x - m) - (A_c^\dagger A_b + H D) \delta|_\infty \leq 1
        \} \\
\end{split}$$
where $H = N_{A_c} N_{A_c}^\dagger G_c^\dagger$, $m = c + G_c A_c^\dagger b$, $D = G_b - G_c A_c^\dagger A_b$, and $\delta \in \{-1, 1\}^{n_b}$.

In [None]:
debug = False

if not debug:

    nz = out.dim
    ng = out.ng
    nb = out.nb
    nc = out.nc

    out = impl.plane_cut(normal=[0., 1., 0, 0, 0],
                         offset=[0.0, 5.3, 0, 0, 0])

    c, Gc, Gb, Ac, Ab, b = out.astuple()

else:

    min_bounds = [-10, -10]
    max_bounds = [+10, +10]
    grid_shape = [101, 101]

    nz = len(grid_shape)
    ng = 3
    nb = ng
    nc = 1

    c = np.array([
        [0],
        [0],
        # [0],
    ]).reshape(nz, 1)
    Gc = np.array([
        # [+1, +0, +0, 0],
        # [+0, +1, +0, 0],
        # [+0, +0, +1, 0],
        [1.5, -1.5,  0.5],
        [1.0,  0.5, -1.0],
    ]).reshape(nz, ng)
    Gb = 2 * Gc.copy()
    Ac = np.array([
        [1, 1, 1],
    ]).reshape(nc, ng)
    Ab = Ac.copy()
    b = np.array([
        [1],
    ]).reshape(nc, 1)

print(
    f'{nz = }, {ng = }, {nb = }, {nc = }',
    '',
    f' { c.shape = }',
    f'{ Gc.shape = }',
    f'{ Gb.shape = }',
    f'{ Ac.shape = }',
    f'{ Ab.shape = }',
    f' { b.shape = }',
    sep='\n',
)

In [None]:
import itertools
import numpy as np
import hj_reachability.shapes as shp
from numpy.linalg import matrix_rank
from scipy.linalg import null_space, svd, pinv
from pyspect.plotting.levelsets import *

NA = null_space(Ac)
NAi = pinv(NA)
Gci = pinv(Gc)
Aci = pinv(Ac)

# print(f'{ NA.shape = }')
# print(f'{ NAi.shape = }')
# print(f'{ Gci.shape = }')

GNi = pinv(Gc @ NA)
# print(f'{ GNi.shape = }')

H = NA @ NAi @ Gci      # (ng, nz)
m = c + Gc @ Aci @ b    # (nz, 1)
D = Gb - Gc @ Aci @ Ab  # (nz, nb)

# print(f'{ H.shape = } ?=', (ng, nz))
# print(f'{ m.shape = } ?=', (nz, 1))
# print(f'{ D.shape = } ?=', (nz, nb))
# if H.size: print(' Rank(H) =', matrix_rank(H))
# if D.size: print(' Rank(D) =', matrix_rank(D))

coords = [
    np.linspace(a, b, n)
    for n, a, b in zip(grid_shape, min_bounds, max_bounds)
]
assert len(coords) == nz

x = np.array(np.meshgrid(*coords, indexing='ij'))

# A_c^\dagger b + H (x - m) - (A_c^\dagger A_b + H D) \delta
# M                         - K \delta

M = (+ (Aci @ b).reshape(-1, *[1] * nz) 
     + shp.tmul(H, x - m.reshape(-1, *[1] * nz))) # (ng, ..X)

K = Aci @ Ab + H @ D # (ng, nb)

ONE = np.ones((ng, *[1] * nz))              # (ng, ..X)
fixed = np.nan * np.ones((nb, *grid_shape)) # (nb, ..X)

# xi, xj = 53, 53
# print((coords[0][xi], coords[1][xj]))

for s, i in itertools.product([1, -1], range(nb)):

    # unit vector
    e = np.zeros((nb, 1)) # (nb, 1)
    e[i] = s

    print(e)

    Ml = (M - ONE) # (ng, ..X)
    Mu = (M + ONE) # (ng, ..X)

    for j in range(ng):

        k   = K[j].reshape(-1, 1) # (nb, 1)
        ml  = Ml[j] # (..X)
        mu  = Mu[j] # (..X)

        # print('k =', k.flatten())
        # print('ml =', ml[xi, xj])
        # print('mu =', mu[xi, xj])

        _K = np.block([
            [k.T], # (1, nb)
            [e.T], # (1, nb)
        ]) # (2, nb)
        _Ki = pinv(_K) # (nb, 2)

        _Ml = np.array([
            ml, 
            np.ones(grid_shape),
        ]) # (2, ..X)
        _Mu = np.array([
            mu, 
            np.ones(grid_shape),
        ]) # (2, ..X)

        pl = shp.tmul(_Ki, _Ml) # (nb, ..X)
        pu = shp.tmul(_Ki, _Mu) # (nb, ..X)

        # print('pl =', pl[:, xi, xj])
        # print('pu =', pu[:, xi, xj])

        # reminder: lb and ub normal is the same direction

        # NOTE: missing conditions on where plane cuts 
        # FIXME: some incomplete conditions

        # not parallel enough
        if np.sqrt(nb) < np.abs(k.T @ e):
            continue

        # we're looking for the negative case
        lb = s < ml/k[i] # (..X)
        ub = mu/k[i] < s # (..X)

        lb = np.logical_and(lb, np.sqrt(nb) < np.linalg.norm(pl - e.reshape(-1, *[1] * nz), axis=0)) # (..X)
        ub = np.logical_and(ub, np.sqrt(nb) < np.linalg.norm(pu - e.reshape(-1, *[1] * nz), axis=0)) # (..X)

        # print('lb =', lb[xi, xj])
        # print('ub =', ub[xi, xj])
        # print()

        cond = np.logical_or(lb, ub)

        fixed[i] = np.where(cond, -s, fixed[i])

# Even when nb = 0, this loop will run once, and use E = M to produce the right output

vf = np.inf * np.ones(tuple(map(len, coords)))
for i in range(2**nb):

    delta = np.array(bin2f(i, nb)).reshape(nb, 1)

    E = M - shp.tmul(K, delta.reshape(-1, *[1] * nz))

    _vf = np.where(np.max(np.abs(E), axis=0) <= 1, -1.0, +1.0)

    vf = np.minimum(vf, _vf)

    # 1 + _c >= Dp delta <= _c - 1

    # _c = M[:, xi, xj].reshape(-1, 1)

    # print( '------------------------------')
    # print(f'=> For delta^{i} = {delta.flatten()}\n')
    # print(f'Empty: {0 < _vf.min()!s}')
    # print(f'x = ({coords[0][xi := 53]:.2f}, {coords[1][xj := 53]:.2f})')
    # print('V(x) = ', _vf[xi, xj])
    # print('|_c - Dp delta|_inf = ', np.abs(_c - Dp @ delta).max(), '<= 1')
    # print()


    # print('Hyperplanes:')
    # for j in range(ng):
    #     ml = ((_c - ONE))[j, 0] # / S[j]
    #     mu = ((_c + ONE))[j, 0] # / S[j]
    #     k = Dp[j]
    #     print('  m:', ml, '<=')
    #     print('  k:', k)
    #     print('  m:', '<=', mu)
    #     print('  <=>', np.logical_and(ml <= k, k <= mu))
    #     print('  <=>', f'{ml:.3f} =', ' + '.join(f'{_k:.3f}{_v}' for _k,_v in zip(k, 'xyz')), f'= {mu:.3f}')
    #     print('\n')

print(vf.min(), vf.max())

plot3D_levelset(shp.project_onto(vf, 0, 1, 2), 
                axes=(0,1,2),
                min_bounds=min_bounds[:3], 
                max_bounds=max_bounds[:3], 
                xtitle='x', ytitle='y', ztitle='z',
                fig_enabled=False) or \
plot2D_levelset(shp.project_onto(vf, 0, 1), 
                min_bounds=min_bounds[:2], 
                max_bounds=max_bounds[:2],
                xtitle='x', ytitle='y',
                fig_enabled=True)

In [None]:
e = np.array([[0, -1, 0]]).T
k = K[1].reshape(-1, 1)
_c = M[:, xi, xj].reshape(-1, 1)
m = ((_c - ONE))[1, 0]
_K = np.block([
    [k.T],
    [e.T],
])
_M = np.block([[m, 1]]).T
print(_K)
print(_M)
p = pinv(_K) @ _M
print() 
print(np.linalg.norm(p - e))
','.join(f'{_:.3f}' for _ in (pinv(_K) @ _M).flatten())

print()
print()

e = np.array([[-1, 0, 0]]).T
k = K[0].reshape(-1, 1)
_c = M[:, xi, xj].reshape(-1, 1)
m = ((_c - ONE))[0, 0]
_K = np.block([
    [k.T],
    [e.T],
])
_M = np.block([[m, 1]]).T
print(_K)
print(_M)
p = pinv(_K) @ _M
print() 
print(np.linalg.norm(p - e))
','.join(f'{_:.3f}' for _ in (pinv(_K) @ _M).flatten())

print('dot:', -k.T @ e)



print()
print()

e = np.array([[0, 1, 0]]).T
k = K[1].reshape(-1, 1)
_c = M[:, xi, xj].reshape(-1, 1)
m = ((_c + ONE))[1, 0]
_K = np.block([
    [k.T],
    [e.T],
])
_M = np.block([[m, 1]]).T
print(_K)
print(_M)
p = pinv(_K) @ _M
print() 
print(np.linalg.norm(p - e))
','.join(f'{_:.3f}' for _ in (pinv(_K) @ _M).flatten())

$$A = U \Sigma V^T$$
$$\begin{split} 
    | Ax - c |_\infty \leq 1 \\
    | U \Sigma V^T x - c |_\infty \leq 1 \\
    c_i - 1 \leq (U \Sigma V^T x)_i \leq c_i + 1 \\
    c - \vec{1} \leq U \Sigma V^T x \leq c + \vec{1} \\
\end{split}$$

In [None]:
ONE = np.ones((3, 1))
i = 1
print('m:', ((T1.reshape(3, 1) - ONE))[i] / S[i], '<=')
print('k:', Dp[i])
print('m:', '<=', ((T1.reshape(3, 1) + ONE))[i] / S[i])

# d0 at i=1 is outside when pos

In [None]:
y = x.copy()
print('Gci', y := shp.tmul(Gci, y), 
      '---------------', sep='\n', end='\n\n')
print('NAi', y := shp.tmul(NAi, y), 
      '---------------', sep='\n', end='\n\n')
print('NA', y := shp.tmul(NA, y), 
      '---------------', sep='\n', end='\n\n')
print('abs', y := np.abs(y), 
      '---------------', sep='\n', end='\n\n')
print('max', y := np.max(y, axis=0), 
      '---------------', sep='\n', end='\n\n')
print('mask', y := y <= 1, 
      '---------------', sep='\n', end='\n\n')

## Conversion

### HJ Setup

In [None]:
import hj_reachability as hj
import hj_reachability.shapes as shp

from pyspect.impls.hj_reachability import TVHJImpl
from hj_reachability.systems import Bicycle4D
from pyspect.plotting.levelsets import *

from math import pi

# Define origin and size of area, makes it easier to scale up/down later on 
X0, XN = -1.2, 2.4
Y0, YN = -1.2, 2.4
Z0, ZN = -1.2, 2.4

min_bounds = np.array([   X0,    Y0])
max_bounds = np.array([XN+X0, YN+Y0])
grid_space = (51, 51)

# min_bounds = np.array([   X0,    Y0,    Z0])
# max_bounds = np.array([XN+X0, YN+Y0, ZN+Z0])
# grid_space = (51,51,51)

# min_bounds = np.array([   X0,    Y0, -pi, 1.0])
# max_bounds = np.array([XN+X0, YN+Y0, +pi, 0.0])
# grid_space = (31, 31, 21, 11)

grid = hj.Grid.from_lattice_parameters_and_boundary_conditions(hj.sets.Box(min_bounds, max_bounds),
                                                               grid_space)

dynamics = dict(cls=None)

hj_impl = TVHJImpl(dynamics, grid, 3)
hj_impl.set_axes_names('t', 'x', 'y')

### Method

In [None]:
import jax
import jax.numpy as jnp

@jax.jit
def infmax_conv(VA, VB):

    # We'll flatten all (i,j) into an array of shape [Nx*Ny, 2].
    IJ_x = jnp.meshgrid(*[jnp.arange(n) for n in VA.shape],
                      indexing='ij')
    indices = jnp.stack(IJ_x, axis=-1).reshape(-1, 2)

    def infmax_conv_cell(carry, ij_x):
        """
        Compute VC(x) = min_y max{ VA(y), VB(x-y) }
        for a particular (ij..).
        """
        def body(k, current_min):
            # k goes from 0 to Nx*Ny
            # Decompose k into (a, b)
            a = k // Ny
            b = k % Ny
            
            i2 = i - a
            j2 = j - b
            
            # Check out-of-bounds for (i2, j2)
            valid = (0 <= i2) & (i2 < Nx) & (0 <= j2) & (j2 < Ny)
            val = jnp.where(
                valid,
                jnp.maximum(VA[a, b], VB[i2, j2]),
                jnp.inf
            )
            return jnp.minimum(current_min, val)
        
        ## Loop over y

        return jax.lax.fori_loop(
            lower=0,
            upper=Nx*Ny,
            body_fun=body_y,
            init_val=jnp.inf
        )

    ## Loop over x
    return jax.lax.scan(infmax_conv_cell, None, indices)[1].reshape(VA.shape)

In [None]:
from scipy.ndimage import convolve
import numba as nb

def crop_mask(mask):
    """Reduce a binary mask to the smallest bounding box that includes all True values.

    Args:
        mask (ndarray): N-D boolean mask.

    Returns:
        cropped_mask (ndarray): Cropped version of the mask.
        slices (tuple): Tuple of slices that define the cropped region.
    """
    if not np.any(mask):  # Check if all False
        return mask, tuple(slice(0, 0) for _ in range(mask.ndim))  # Empty region
    
    # Find min/max indices for each axis
    slices = tuple(
        slice(np.min(indices), np.max(indices) + 1)
        for indices in (np.where(mask) if mask.ndim > 1 else (np.where(mask)[0],))
    )
    
    return mask[slices], slices

def generator(hz, i):
    Gc, Gb, C, Ac, Ab, b = hz
    nc = Ac.shape[0]
    ng = Ac.shape[1]

    C = C.reshape(-1)
    b = b.reshape(-1)
    
    # --- Continuous Constraints ---

    data = -np.inf * np.ones(grid.shape)
    for i in range(nc):
        data = np.max(
            data,
            shp.hyperplane(grid, normal=Ac[i], offset=[0]*ng, const=b[i]),
        )

    # --- Continuous Generators ---

    g = Gc[:, i:i+1]
    data = np.max(
        data,
        shp.intersection(
            shp.cylinder(grid, r=GEN_WIDTH, c=C, axis=g),
            shp.hyperplane(grid, normal=+g, offset=C+g),
            shp.hyperplane(grid, normal=-g, offset=C-g),
        ),
    )

def hz2hj(hz):

    Gc, Gb, C, Ac, Ab, b = hz
    ng = Gc.shape[1]

    ## Generators

    I = (generator(hz, 0) <= 0).astype(int)

    for i in np.arange(1, ng):
        K = (generator(hz, i) <= 0).astype(int)
        K = crop_mask(K)[0]

        I = convolve(I, K, mode='constant', cval=0) > 0

    vf = 0.5 - (I > 0)
    return vf


In [None]:
# from hz_reachability.visualizer import ZonoVisualizer
# from hz_reachability.auxiliary_operations import ZonoOperations

# op = ZonoOperations()
# viz = ZonoVisualizer(op)

# out = op.redundant_c_hz(out)
# viz.vis_hz([out])

In [None]:
from scipy.signal import correlate
from scipy import ndimage as ndi
from tqdm import trange

GEN_WIDTH = np.linalg.norm(grid.spacings)

## ##

fig_select = 2
fig_kwds = dict(fig_theme='Light')

sq = shp.rectangle(grid, [-.5, -.5], [+.5, +.5])

plot2D_bitmap(
    # M,
    **fig_kwds, 
    fig_enabled=fig_select==0,
) or \
plot3D_valuefun(
    # vf,
    min_bounds=min_bounds,
    max_bounds=max_bounds,
    **fig_kwds,
    fig_enabled=fig_select==1,
) or \
plot_levelsets(
    # shp.project_onto(vf, 1, 2),
    ndi.geometric_transform(sq, lambda coord: (coord[0] - 1,coord[1] + 1), cval=1.0),
    

    # (hz2hj(shapes.road_west()), dict(colorscale='blues')),
    # (hz2hj(shapes.road_east()), dict(colorscale='blues')),
    # (hz2hj(shapes.road_south()), dict(colorscale='blues')),
    # (hz2hj(shapes.road_north()), dict(colorscale='blues')),
    # (hz2hj(shapes.center().astuple()), dict(colorscale='greens')),

    # (hz2hj(out.astuple()), dict(colorscale='greens')),

    # axes = (0, 1, 2),
    min_bounds=min_bounds,
    max_bounds=max_bounds,
    # plot_func=plot3D_levelset,
    **fig_kwds,
    fig_enabled=fig_select==2,
    fig_width=500, fig_height=500,
)

In [None]:
def foo():
    eps = 0.1
    Nx = 1
    Nd = 30

    # (Nd, Nx, ..., Nx)
    x = np.array(np.meshgrid(*[np.linspace(-1, 1, Nx)] * Nd))

    # 
    y = np.tensordot(a, x, ([0], [0])) - b
    
    mask = y <= eps

    for i, xi in enumerate(x):
        y[mask] 


    for i, k, m in zip(axes, normal, offset):
        data += k*x(i) - k*m
    return data