# Hybrid Zonotopes

### HZ-HJ idea


We have $\mathcal{Z} \subseteq \mathcal{S}_\text{CZ}$. Specifically, we have
$$\begin{split}
    \mathcal{Z} &= \{ c + G \xi : A \xi = b, |\xi|_\infty \leq 1 \}                              \\
                &= \{ c + G \xi : \xi = A^\dagger b + N_A \eta, |\xi|_\infty \leq 1 \}           \\
                &= \{ c + G (A^\dagger b + N_A \eta) : |A^\dagger b + N_A \eta|_\infty \leq 1 \} \\
                &= \{ c + G A^\dagger b + G N_A \eta : |A^\dagger b + N_A \eta|_\infty \leq 1 \} \\
                &= \{ x = c' + G' \eta : |A^\dagger b + N_A \eta|_\infty \leq 1 \}               \\
                &= \{ x : \eta = G'^\dagger(x - c'), |A^\dagger b + N_A \eta|_\infty \leq 1 \}   \\
                &= \{ x : |A^\dagger b + N_A G'^\dagger(x - c')|_\infty \leq 1 \}                \\
\end{split}$$
where $c' = c + G A^\dagger b$ and $G' = G N_A$.

If $T: \mathcal{S}_\text{HZ} \rightarrow \mathcal{S}_\text{HJ}$ 
and $T': \mathcal{S}_\text{CZ} \rightarrow \mathcal{S}_\text{HJ}$ 
then
$$\begin{split}
    T(\langle c, G_c, G_b, A_c, A_b, b \rangle) 
        &= \; \min_\delta \; T'(\langle c + G_b \delta, G_c, A_c, b - A_b \delta \rangle) \\
        &= \; \min_\delta \; \{
            x : |
                A_c^\dagger (b - A_b \delta)
                + N_{A_c} (G_c N_{A_c})^\dagger (
                    x
                    - (
                        c 
                        + G_c A_c^\dagger b 
                        + (G_b - G_c A_c^\dagger A_b) \delta
                    )
                )
            |_\infty \leq 1
        \} \\
        &= \; \min_\delta \; \{ x : |
            A_c^\dagger (b - A_b \delta) 
            + N_{A_c} N_{A_c}^\dagger G_c^\dagger (
                x
                - c 
                - G_c A_c^\dagger b
                - (G_b - G_c A_c^\dagger A_b) \delta
            )
        |_\infty \leq 1 \} \\
        &= \; \min_\delta \; \{
            x : |A_c^\dagger b - A_c^\dagger A_b \delta + H (x - m - D \delta) |_\infty \leq 1
        \} \\
        &= \; \min_\delta \; \{
            x : |A_c^\dagger b + H (x - m) - A_c^\dagger A_b \delta - H D \delta |_\infty \leq 1
        \} \\
        &= \; \min_\delta \; \{
            x : |A_c^\dagger b + H (x - m) - (A_c^\dagger A_b + H D) \delta|_\infty \leq 1
        \} \\
\end{split}$$
where $H = N_{A_c} N_{A_c}^\dagger G_c^\dagger$, $m = c + G_c A_c^\dagger b$, $D = G_b - G_c A_c^\dagger A_b$, and $\delta \in \{-1, 1\}^{n_b}$.

### Implementation

In [None]:
from math import pi
import time

from pyspect import *
from pyspect.langs.ltl import *

TLT.select(ContinuousLTL)

IMPL = ('TVHJ','TVHZ')

# TIME_HORIZON    = 8
# TIME_STEP       = 0.2
# MAX_ACCEL       = 1.0 # [mps2]

# axis_names = ['x', 'v']
# max_bounds = [250,  16] # 500m, 30 mps ~= 110 kmph
# min_bounds = [  0,   4] #   0m,  4 mps ~=  15 kmph
# grid_shape = ( 71,  71)


TIME_HORIZON    = 40
TIME_STEP       = 0.5
MAX_ACCEL       = 1.0 # [mps2]

axis_names = [ 'x', 'v']
max_bounds = [+100, +20] # 500m, 20 mps ~=  75 kmph, 30 mps ~= 110 kmph
min_bounds = [-100, -20] #   0m,  4 mps ~=  15 kmph
grid_shape = (  91,  91)

GOAL_TIMEWIDTH = 2.1 # 1.6 #.8

# axis_names = ['x', 'v']
# max_bounds = [ 50,   3] # 500m, 30 mps ~= 110 kmph
# min_bounds = [  0, 0.3] #   0m,  3 mps ~=  11 kmph
# grid_shape = ( 51,  61)

# axis_names = ['x', 'y']
# max_bounds = [+5., +5.]
# min_bounds = [-5., -5.]
# grid_shape = ( 51,  51)

In [None]:
target = TLT(BoundedSet(x=(-50,  +50)))

phi1 = Always(target)

phi2 = Always(Eventually(target))

In [None]:
# From string s, shift lines to the right by n spaces
def shift_lines(s, n):
    return '\n'.join([' '*n + l for l in s.split('\n')])

def print_hz(hz):
    dim = lambda m: "x".join(map(str, m.shape))
    print(f'Gc<{dim(hz.Gc)}>', shift_lines(str(hz.Gc), 2), sep='\n')
    print(f'Gb<{dim(hz.Gb)}>', shift_lines(str(hz.Gb), 2), sep='\n')
    print(f'C<{dim(hz.C)}>', shift_lines(str(hz.C), 2), sep='\n')
    print(f'Ac<{dim(hz.Ac)}>', shift_lines(str(hz.Ac), 2), sep='\n')
    print(f'Ab<{dim(hz.Ab)}>', shift_lines(str(hz.Ab), 2), sep='\n')
    print(f'b<{dim(hz.b)}>', shift_lines(str(hz.b), 2), sep='\n')

def save_hz(hz, directory='.'):
    from scipy.io import savemat
    from pathlib import Path
    Path(directory).mkdir(parents=True, exist_ok=True)
    savemat(f'{directory}/Gc.mat', {'Gc': hz.Gc})
    savemat(f'{directory}/Gb.mat', {'Gb': hz.Gb})
    savemat(f'{directory}/C.mat', {'C': hz.C})
    savemat(f'{directory}/Ac.mat', {'Ac': hz.Ac})
    savemat(f'{directory}/Ab.mat', {'Ab': hz.Ab})
    savemat(f'{directory}/b.mat', {'b': hz.b})

# save_hz(out, 'out-remove-redundant')
# print_hz(out)

In [None]:
import plotly.graph_objects as go

# Function to add vectors to the plot
def add_vector(fig, vector, offset, color, name):
    fig.add_trace(go.Scatter3d(
        x=[offset[0], offset[0] + vector[0]], y=[offset[1], offset[1] + vector[1]], z=[offset[2], offset[2] + vector[2]],
        mode="lines+markers",
        marker=dict(size=4),
        line=dict(width=4, color=color),
        name=name
    ))

def plot_3d_boolean_tessellation(fig, bool_grid, x_vals, y_vals, z_vals, voxel_size=1, color="blue", opacity=0.1):
    """
    Plots a 3D boolean tessellation grid using Plotly.
    
    Parameters:
        bool_grid (np.ndarray): 3D numpy array of boolean values (shape: Nx, Ny, Nz).
        x_vals (np.ndarray): 1D array of X coordinates (same length as bool_grid.shape[0]).
        y_vals (np.ndarray): 1D array of Y coordinates (same length as bool_grid.shape[1]).
        z_vals (np.ndarray): 1D array of Z coordinates (same length as bool_grid.shape[2]).
        voxel_size (float): Size of each cube (optional).
        color (str): Color of the voxels (optional).
        opacity (float): Opacity of the voxels (optional).
    """
    Nx, Ny, Nz = bool_grid.shape  # Get the grid dimensions

    # Iterate through the grid and plot only the 'True' voxels
    for i in range(Nx):
        for j in range(Ny):
            for k in range(Nz):
                if bool_grid[i, j, k]:  # If the voxel is 'True', plot it
                    x, y, z = x_vals[i], y_vals[j], z_vals[k]
                    fig.add_trace(go.Mesh3d(
                        x=[x, x+voxel_size, x+voxel_size, x, x, x+voxel_size, x+voxel_size, x],
                        y=[y, y, y+voxel_size, y+voxel_size, y, y, y+voxel_size, y+voxel_size],
                        z=[z, z, z, z, z+voxel_size, z+voxel_size, z+voxel_size, z+voxel_size],
                        i=[0, 0, 0, 1, 1, 3, 2, 4, 4, 5, 5, 6],
                        j=[1, 2, 3, 3, 5, 2, 6, 5, 6, 6, 7, 7],
                        k=[3, 3, 1, 2, 4, 6, 7, 7, 5, 4, 4, 5],
                        color=color,
                        opacity=opacity
                    ))

    # Set figure layout
    fig.update_layout(
        scene=dict(
            xaxis=dict(title="X", range=[x_vals.min(), x_vals.max()]),
            yaxis=dict(title="Y", range=[y_vals.min(), y_vals.max()]),
            zaxis=dict(title="Z", range=[z_vals.min(), z_vals.max()])
        ),
        title="3D Boolean Tessellation Grid with Custom Coordinates",
        showlegend=False
    )

    import numpy as np

def plot_3d_boolean_isosurface(fig, bool_grid, x_vals, y_vals, z_vals, color="blue", opacity=0.5):
    """
    Plots a 3D boolean isosurface representation using Plotly.
    
    Parameters:
        fig (go.Figure): The Plotly figure object to add the isosurface to.
        bool_grid (np.ndarray): 3D numpy array of boolean values (shape: Nx, Ny, Nz).
        x_vals (np.ndarray): 1D array of X coordinates (same length as bool_grid.shape[0]).
        y_vals (np.ndarray): 1D array of Y coordinates (same length as bool_grid.shape[1]).
        z_vals (np.ndarray): 1D array of Z coordinates (same length as bool_grid.shape[2]).
        color (str): Color of the isosurface (optional).
        opacity (float): Opacity of the isosurface (optional).
    """
    # Convert boolean grid to numerical values (1 for True, 0 for False)
    numerical_grid = bool_grid.astype(float)
    
    # Create the isosurface plot
    fig.add_trace(go.Isosurface(
        x=np.repeat(x_vals, len(y_vals) * len(z_vals)),
        y=np.tile(np.repeat(y_vals, len(z_vals)), len(x_vals)),
        z=np.tile(z_vals, len(x_vals) * len(y_vals)),
        value=numerical_grid.flatten(),
        isomin=0.5,  # Threshold for rendering the isosurface
        isomax=1.0,
        surface_count=1,  # Single surface rendering
        colorscale=[[0, color], [1, color]],
        opacity=opacity
    ))
    
    # Set figure layout
    fig.update_layout(
        scene=dict(
            xaxis=dict(title="X", range=[x_vals.min(), x_vals.max()]),
            yaxis=dict(title="Y", range=[y_vals.min(), y_vals.max()]),
            zaxis=dict(title="Z", range=[z_vals.min(), z_vals.max()])
        ),
        title="3D Boolean Isosurface Representation",
        showlegend=False
    )

def bin2f(N, L=None):
    """
    Convert the non-negative integer N to a sequence of +1.0 / -1.0,
    corresponding to each bit in N's binary representation.
    
    Bit '0' -> -1.0
    Bit '1' -> +1.0
    
    Returns a list of floats in most-significant-bit-first order.
    """
    signs_reversed = []

    while N > 0:
        bit = N & 1   # extract the least-significant bit
        if bit == 1:
            signs_reversed.append(1.0)
        else:
            signs_reversed.append(-1.0)
        N >>= 1       # shift N to process the next bit
    
    if L is not None:
        L -= len(signs_reversed)
        signs_reversed += [-1.0] * L

    # signs_reversed is in LSB -> MSB order; reverse it to get MSB -> LSB
    return signs_reversed[::-1]


def IfThenElse(C, A, B):
    return And(Implies(C, A), Implies(Not(C), B))

def SwitchCase(D, *args):
    if not args:
        return D
    else:
        cases = [Implies(C, A) for C, A in zip(args[0::2], args[1::2])]
        default = Implies(Not(Or(*args[0::2])), D)
        return And(*cases, default)


In [None]:
## Instantiate Impls ##

if 'HZ' == IMPL or (isinstance(IMPL, tuple) and 'HZ' in IMPL):

    from hz_reachability.hz_impl import *
    from hz_reachability.systems.cars import *
    from hz_reachability.systems.integrators import *
    from hz_reachability.spaces import EmptySpace
    
    space = EmptySpace(min_bounds, max_bounds)

    dynamics = DoubleIntegrator(max_accel=MAX_ACCEL, dt=TIME_STEP)

    impl_hz = HZImpl(dynamics, space, axis_names, time_horizon=TIME_HORIZON, time_step=TIME_STEP)
    # impl.enable_reduce = True

    impl = impl_hz

if 'TVHZ' == IMPL or (isinstance(IMPL, tuple) and 'TVHZ' in IMPL):

    from hz_reachability.hz_impl import *
    from hz_reachability.systems.cars import *
    from hz_reachability.systems.integrators import *
    from hz_reachability.spaces import EmptySpace
    
    space = EmptySpace(min_bounds, max_bounds)

    dynamics = DoubleIntegrator(max_accel=MAX_ACCEL, dt=TIME_STEP)

    impl_hz = TVHZImpl(dynamics, space, axis_names, time_horizon=TIME_HORIZON, time_step=TIME_STEP)
    # impl.enable_reduce = True

    impl = impl_hz

if 'TVHJ' == IMPL or (isinstance(IMPL, tuple) and 'TVHJ' in IMPL):

    import hj_reachability as hj
    from pyspect.impls.hj_reachability import TVHJImpl
    from hj_reachability.systems import *

    dynamics = dict(cls=DoubleIntegrator,
                    min_accel=-MAX_ACCEL,
                    max_accel=+MAX_ACCEL)

    impl_hj = TVHJImpl(dynamics, axis_names, min_bounds, max_bounds, grid_shape, TIME_HORIZON, time_step=TIME_STEP)

    impl = impl_hj

In [None]:
if False:

    t0 = time.time()
    Phi1HJ = TLT(phi1).realize(impl_hj)
    tf = time.time()

    print('HJ Time:', tf - t0)
    
    t0 = time.time()
    Phi1HZ = TLT(phi1).realize(impl_hz)
    tf = time.time()

    print('HZ Time:', tf - t0)

    if isinstance(Phi1HZ, list):
        nz, ng, nb, nc = \
            np.array([[_out.dim, _out.ng, _out.nb, _out.nc]
                    for _out in Phi1HZ]).max(axis=0)
    else:
        nz,ng,nb,nc = Phi1HZ.dim, Phi1HZ.ng, Phi1HZ.nb, Phi1HZ.nc

    print(
        f'{nz = !s}, {ng = !s}, {nb = !s}, {nc = !s}',
        # '',
        # 'c = ',  c,
        # 'Gc = ', Gc,
        # 'Gb = ', Gb,
        # 'Ac = ', Ac,
        # 'Ab = ', Ab,
        # 'b = ',  b,
        sep='\n',
    )

elif True:

    t0 = time.time()
    Phi2HJ = TLT(phi2).realize(impl_hj)
    tf = time.time()

    print('HJ Time:', tf - t0)

    t0 = time.time()
    Phi2HZ = TLT(phi2).realize(impl_hz)
    tf = time.time()

    print('HZ Time:', tf - t0)

    if isinstance(Phi2HZ, list):
        nz, ng, nb, nc = \
            np.array([[_out.dim, _out.ng, _out.nb, _out.nc]
                    for _out in Phi2HZ]).max(axis=0)
    else:
        nz,ng,nb,nc = Phi2HZ.dim, Phi2HZ.ng, Phi2HZ.nb, Phi2HZ.nc

    print(
        f'{nz = !s}, {ng = !s}, {nb = !s}, {nc = !s}',
        # '',
        # 'c = ',  c,
        # 'Gc = ', Gc,
        # 'Gb = ', Gb,
        # 'Ac = ', Ac,
        # 'Ab = ', Ab,
        # 'b = ',  b,
        sep='\n',
    )

elif False:


    t0 = time.time()
    Phi2HJ = TLT(phi2).realize(impl_hj)
    tf = time.time()

    print('HJ Time:', tf - t0)


elif False:

    GOAL_S_LB = max_bounds[0] - GOAL_TIMEWIDTH * max_bounds[1]
    GOAL_S_UB = max_bounds[0] - 1e-1
    print(f'{GOAL_S_LB=}, {GOAL_S_UB=}')
    # NOTE: Don't put the target region on the space boundary, close is fine.

    city = Not(EMPTY)
    # city = BoundedSet(x=(200,300), v=(10, 20))
    
    goal = Or(
        BoundedSet(x=(GOAL_S_LB, GOAL_S_UB), v=(-5, +5)),
        BoundedSet(x=(GOAL_S_LB, GOAL_S_UB), v=(-5, +5)),
    )

    if IMPL == 'TVHJ':
        # Do BRS, not BRT (TVHJ trick)
        goal = And(goal, BoundedSet(t=(TIME_HORIZON - GOAL_TIMEWIDTH,
                                       TIME_HORIZON - 1e-1)))

    task = Until(city, goal)
    out = TLT(task).realize(impl)

elif False:
    task = HalfSpaceSet(normal=[+1], offset=[-2], axes=[0])

    task = And(
        And(HalfSpaceSet(normal=[+1], offset=[-2], axes=['x']),
            HalfSpaceSet(normal=[-1], offset=[-1], axes=['x'])),
        And(HalfSpaceSet(normal=[+1], offset=[-3], axes=['y']),
            HalfSpaceSet(normal=[-1], offset=[+1], axes=['y'])),
    )

    # task = BoundedSet(x=(-2, -1), y=(-3, 1))

    out = TLT(task).realize(impl)

elif False:

    sl1 = impl.plane_cut(normal=[+1, +1],
                         offset=[0., 0.])
    sl2 = impl.plane_cut(normal=[-1, +2.4],
                         offset=[0., 0.])

    out = sl1
    out = sl2
    out = impl.intersect(sl1, sl2)
    
elif False:

    nz = len(grid_shape)
    ng = 3
    nb = ng
    nc = 1

    c = np.array([
        [0],
        [0],
    ]).reshape(nz, 1)
    Gc = np.array([
        # [2.0, -2.0, +0.5], # Used g_3
        # [1.0,  0.5, -1.0],

        # [2.0, -2.0, +0.0], # Unused g_3
        # [1.0,  0.5, +0.0],

        [1.5, -1.5,  0.5], # Example from paper 
        [1.0,  0.5, -1.0],
    ]).reshape(nz, ng)
    Gb = 2 * Gc.copy()
    Ac = np.array([
        [1, 1, 1],
    ]).reshape(nc, ng)
    Ab = Ac.copy()
    b = np.array([
        [1],
    ]).reshape(nc, 1)

    out = HybridZonotope(Gc, Gb, c, Ac, Ab, b)

if not isinstance(IMPL, tuple):
    print('Approx:', TLT(task)._approx)

if IMPL == 'HZ':

    # out = space.zono_op.redundant_gc_hz(out)
    # out = space.zono_op.redundant_c_hz(out)
    
    pass

if IMPL in ('HZ', 'TVHZ'):

    if isinstance(out, list):
        nz, ng, nb, nc = \
            np.array([[_out.dim, _out.ng, _out.nb, _out.nc]
                    for _out in out]).max(axis=0)
    else:
        nz,ng,nb,nc = out.dim, out.ng, out.nb, out.nc

    print(
        f'{nz = !s}, {ng = !s}, {nb = !s}, {nc = !s}',
        # '',
        # 'c = ',  c,
        # 'Gc = ', Gc,
        # 'Gb = ', Gb,
        # 'Ac = ', Ac,
        # 'Ab = ', Ab,
        # 'b = ',  b,
        sep='\n',
    )

In [None]:
import itertools
import numpy as np
from tqdm import tqdm
import hj_reachability.shapes as shp
from numpy.linalg import matrix_rank
from scipy.linalg import null_space, svd, pinv
from scipy.ndimage import distance_transform_edt
from pyspect.plotting.levelsets import *

eye = EYE_LO_SE
traditional_method = False

def hz2hj(c, Gc, Gb, Ac, Ab, b):

    nz = c.shape[0]
    ng = Gc.shape[1]
    nb = Gb.shape[1]
    nc = b.shape[0]

    # print(f'Info: {nz=}, {ng=}, {nb=}, {nc=}')

    assert  c.shape == (nz,  1), '(hz2hj) Wrong shape: c'
    assert Gc.shape == (nz, ng), '(hz2hj) Wrong shape: c'
    assert Gb.shape == (nz, nb), '(hz2hj) Wrong shape: c'
    assert Ac.shape == (nc, ng), '(hz2hj) Wrong shape: c'
    assert Ab.shape == (nc, nb), '(hz2hj) Wrong shape: c'
    assert  b.shape == (nc,  1), '(hz2hj) Wrong shape: c'
    
    NA = null_space(Ac) # ;   print('null_space(Ac)')
    Aci = pinv(Ac)      # ;   print('pinv(Ac)')

    GNi = pinv(Gc @ NA) # ;   print('pinv(Gc @ NA)')

    H = NA @ GNi            # (ng, nz)
    m = c + Gc @ Aci @ b    # (nz, 1)
    D = Gb - Gc @ Aci @ Ab  # (nz, nb)

    # print(f'{ H.shape = } ?=', (ng, nz))
    # print(f'{ m.shape = } ?=', (nz, 1))
    # print(f'{ D.shape = } ?=', (nz, nb))
    # if H.size: print(' Rank(H) =', matrix_rank(H))
    # if D.size: print(' Rank(D) =', matrix_rank(D))

    coords = [
        np.linspace(a, b, n)
        for n, a, b in zip(grid_shape, min_bounds, max_bounds)
    ]
    assert len(coords) == nz

    x = np.array(np.meshgrid(*coords, indexing='ij'))

    #       A_c^\dagger b + H (x - m) - (A_c^\dagger A_b + H D) \delta
    # (g =) M                         - K \delta

    M = (+ (Aci @ b).reshape(-1, *[1] * nz) 
         + shp.tmul(H, x - m.reshape(-1, *[1] * nz))) # (ng, ..X)

    K = Aci @ Ab + H @ D # (ng, nb)

    fixed = np.zeros((nb, *grid_shape)) # (nb, ..X)

    # fixed array has following values:
    # *   0: both axis directions (+1/-1) are free
    # *  +1: binary generators only exist on positive side of axis
    # *  -1: binary generators only exist on negative side of axis
    # * nan: there is no solution; binary generators degenerate zonotopic conditions

    # print('Branch analysis...')

    for j in range(ng):

        # *) Select g_j = _m - _k.T delta with unknown delta
        # *) Inf-norm asserts -1 <= _m - _k xib <= +1
        # *) Goal: Study upper- and lower-bounds _m - _k xib <= +1 and -1 <= _m - _K xib, respectively.
        # *) Upper- and lower-bounds form a half-space constraints
        # *) Iterate over axes (i) to find axis-aligned conditions wrt. half-space constraints
        #   *) Need to know where half-space hyperplane cuts the selected axis
        #   *) Need to know where half-space hyperplane cuts the null space (?) 

        k  = K[j]       # (nb,)
        mu = (M[j] + 1) # (..X)
        ml = (M[j] - 1) # (..X)

        ## We'll use these later

        _Ml = np.array([
            ml, 
            np.ones(grid_shape),
        ]) # (2, ..X)
        _Mu = np.array([
            mu, 
            np.ones(grid_shape),
        ]) # (2, ..X)

        for s, i in (itertools.product([+1], range(nb))):

            # unit vector
            e = np.zeros((nb,)) # (nb,)
            e[i] = 1

            _K = np.array([k, s*e]) # (2, nb)
            _Ki = pinv(_K) # (nb, 2)

            ## Compute closes point p on line directed by e

            # we seek to know if |p.T e| < 1: within +/- 1

            ##
            # *) Consider signed basis vector e
            # *) Consider hyperplanes Pl: (k, ml) and Pu: (k, mu)
            # *) Consider axis-aligned hyperplane E: e x = 1
            # *) We seek closest point p in L = intersect(P, E) to e
            # *) We seek value along axis which cuts P, i.e. v = m / (k e)
            # *) If lowerbound: k *= -1
            # *) Check cases:
            #   0) If k e = 0: skip
            #   1) If k e > 0 and v < -1                 : degenerate if sqrt(n) < |p-e| else skip
            #   2) If k e > 0 and     -1 <= v < +1       : fix delta e < 0 if sqrt(n) < |p-e| else skip
            #   3) If k e > 0 and               +1 < v   : skip
            #   4) If k e < 0 and v <= -1                : skip
            #   5) If k e < 0 and      -1 < v <= +1      : fix delta e < 0 if sqrt(n) < |p-e| else skip
            #   6) If k e < 0 and                +1 < v  : degenerate if sqrt(n) < |p-e| else skip

            d = k @ e # (1,)
            if d == 0: continue # Case 0)

            # k = k.reshape(-1, *[1]*nz) # (nb, ..X)
            e = e.reshape(-1, *[1]*nz) # (nb, ..X)

            ## Upper bound

            ub = np.zeros(grid_shape)

            v = mu / (k.flatten() @ e.flatten()) # (..X)
            p = shp.tmul(_Ki, _Mu) # (nb, ..X)

            # is_free = fixed[i] == 0.0
            is_outside = np.sqrt(nb) < np.linalg.norm(p-e, axis=0) # (..X)

            if d > 0:
                # assert np.all(is_free[_sel := is_outside & (v < -1)                ]), (
                #     f'{fixed[i].flatten()[_idx := np.argmax(_sel & ~is_free)]}'
                #     +' at (%d, %d)' % np.unravel_index(_idx, grid_shape)
                # )
                # assert np.all(is_free[_sel := is_outside &     (-1 <= v) & (v < +1)]), (
                #     f'{fixed[i].flatten()[_idx := np.argmax(_sel & ~is_free)]}'
                #     +' at (%d, %d)' % np.unravel_index(_idx, grid_shape)
                # )
                ub[is_outside & (v < -1)                ]     = np.nan    # Case 1)
                ub[is_outside &     (-1 <= v) & (v < +1)]     = -1        # Case 2)
                # skip                                                    # Case 3)
                
            if d < 0:
                # assert np.all(is_free[_sel := is_outside & (-1 < v) & (v <= +1)    ]), (
                #     f'{fixed[i].flatten()[_idx := np.argmax(_sel & ~is_free)]}'
                #     +' at (%d, %d)' % np.unravel_index(_idx, grid_shape)
                # )
                # assert np.all(is_free[_sel := is_outside &                 (+1 < v)]), (
                #     f'{fixed[i].flatten()[_idx := np.argmax(_sel & ~is_free)]}'
                #     +' at (%d, %d)' % np.unravel_index(_idx, grid_shape)
                # )
                # skip                                                     # Case 4)
                ub[is_outside & (-1 < v) & (v <= +1)    ]      = +1        # Case 5)
                ub[is_outside &                 (+1 < v)]      = np.nan    # Case 6)

            ## Lower bound
            
            lb = np.zeros(grid_shape)

            v = ml / (k.flatten() @ e.flatten()) # (..X)
            p = shp.tmul(_Ki, _Ml) # (nb, ..X)

            # is_free = fixed[i] == 0.0
            is_outside = np.sqrt(nb) < np.linalg.norm(p-e, axis=0) # (..X)

            if d > 0:
                # assert np.all(is_free[_sel := is_outside & (-1 <= v) & (v < +1)    ]), (
                #     f'{fixed[i].flatten()[_idx := np.argmax(_sel & ~is_free)]}'
                #     +' at (%d, %d)' % np.unravel_index(_idx, grid_shape)
                # )
                # assert np.all(is_free[_sel := is_outside &                 (+1 < v)]), (
                #     f'{fixed[i].flatten()[_idx := np.argmax(_sel & ~is_free)]}'
                #     +' at (%d, %d)' % np.unravel_index(_idx, grid_shape)
                # )
                # skip                                                     # Case 1)
                lb[is_outside & (-1 <= v) & (v < +1)    ]      = +1        # Case 2)
                lb[is_outside &                 (+1 < v)]      = np.nan    # Case 3)
                
            if d < 0:
                # assert np.all(is_free[_sel := is_outside & (v < -1)                ]), (
                #     f'{fixed[i].flatten()[_idx := np.argmax(_sel & ~is_free)]}'
                #     +' at (%d, %d)' % np.unravel_index(_idx, grid_shape)
                # )
                # assert np.all(is_free[_sel := is_outside &     (-1 < v) & (v <= +1)]), (
                #     f'{fixed[i].flatten()[_idx := np.argmax(_sel & ~is_free)]}'
                #     +' at (%d, %d)' % np.unravel_index(_idx, grid_shape)
                # )
                lb[is_outside & (v < -1)                ]      = np.nan    # Case 4)
                lb[is_outside &     (-1 < v) & (v <= +1)]      = -1        # Case 5)
                # skip                                                     # Case 6)
 
            ## Combining bound conditions

            fixed[i, np.isnan(ub)] = np.nan
            fixed[i, np.isnan(lb)] = np.nan
            
            # Disagreement between ub and lb, effectively doing XOR
            fixed[i, (ub * lb) == -1] = np.nan

            _good = np.abs(ub + lb) == 1
            fixed[i, _good] = ub[_good] + lb[_good]
            fixed[i, _good] = ub[_good] + lb[_good]

    # print('Branch analysis done')

    if not traditional_method:

        # Now we go through fixed list to create the sub-zero level set
        vf = np.inf * np.ones(tuple(map(len, coords)))
        
        # 1) Degenerate states
        mask = np.isnan(fixed).any(axis=0) # (..X)
        # print(f'Info: There are {mask.sum()} degenerate/ill-conditioned states')

        # 2) Iterate over the non-degenerate states
        mask = np.logical_not(mask) # (..X)
        while (n := mask.sum()) > 0:
            # print(f'{n} states left to check!')

            # only care about X idx
            idx = np.unravel_index(np.argmax(mask), mask.shape)
            
            # binary generator
            xib = fixed[(..., *idx)].reshape(nb, *[1]*nz) # (nb, ..X)

            submask = np.all(fixed == xib, axis=0) # (..X)
            # print(f'  Info: Found 1 binary generator covering {submask.sum()} states')

            # Remove collected states from mask
            mask[submask] = False

            nfree = (xib == 0).sum()
            # print('Number of free:', nfree)

            zeroes = np.where(xib == 0)[0]
            for zs in itertools.product([-1, 1], repeat=len(zeroes)):
                
                _xib = xib.copy()

                if nfree > 0:
                    _xib[zeroes] = np.array(zs).reshape(-1, *[1]*nz)

                # M: (ng, ..X)
                # K: (ng, nb)

                E = M[:, submask] - shp.tmul(K, _xib.reshape(-1, 1)) # (ng, msk)

                # narrow down submask with final condition on binary generators
                # _mask = np.zeros_like(mask)
                # _mask[submask] = np.max(np.abs(E), axis=0) <= 1
                submask[submask] = np.max(np.abs(E), axis=0) <= 1
                
                # print(f'    ==> ... of which {submask.sum()} are well-conditioned')

                # # Not necessary
                # # # Step 1: Compute distance transform
                # # mask = mask.astype(int)
                # # _vf = distance_transform_edt(mask == 0) - distance_transform_edt(mask == 1)
                # _vf = np.where(mask, -1, +1)


                vf = np.minimum(vf, np.where(submask, -1, +1))

    else:

        # Even when nb = 0, this loop will run once, and use E = M to produce the right output

        vf = np.inf * np.ones(tuple(map(len, coords)))
        for i in range(2**nb):

            delta = np.array(bin2f(i, nb)).reshape(nb, 1)

            E = M - shp.tmul(K, delta.reshape(-1, *[1] * nz))

            mask = np.max(np.abs(E), axis=0) <= 1

            # Not necessary
            # # Step 1: Compute distance transform
            # mask = mask.astype(int)
            # _vf = distance_transform_edt(mask == 0) - distance_transform_edt(mask == 1)

            vf = np.minimum(vf, np.where(mask, -1, +1))

            ninvalid = np.isnan(fixed[:, mask]).any(axis=0).sum()

            if ninvalid > 0:
                # print('Looking at:', delta.flatten())
                # print(f'INFO: {ninvalid} states falsely reported degenerate!')

                Ml = (M - 1) # (ng, ..X)
                Mu = (M + 1) # (ng, ..X)

                # # Temporary debugging
                idx = np.unravel_index(np.argmax(mask), mask.shape)
                # print('One example:', fixed[(..., *idx)], 'at (%d, %d)' % idx)
                for _i in range(ng):
                    namel = 'C_{%s}' % f'{_i+1}L'
                    nameu = 'C_{%s}' % f'{_i+1}U'
                    rhs = " + ".join(f'{k:.2f}{x}' for k, x in zip(K[_i], 'xyz'))
                    # print(f'{namel}: {Ml[(_i, *idx)]:.2f} = {rhs}')
                    # print(f'{nameu}: {Mu[(_i, *idx)]:.2f} = {rhs}')

    return vf

vfs_2d = []

if isinstance(IMPL, tuple):

    _PhiHJ = Phi1HJ
    _PhiHZ = Phi1HZ

    vfhj = _PhiHJ

    vfs_2d += [
        (vfhj[0], dict(colorscale='greens')),
    ]

    if not isinstance(_PhiHZ, list):
        print('TVHZ: Result is constant')
        vfhz = np.array([hz2hj(*_PhiHZ.astuple())] * impl.N)
    
    else:
        vfhz = np.array([hz2hj(*_out.astuple())
                       for _out in tqdm(_PhiHZ)])

    vfs_2d += [
        (vfhz[0], dict(colorscale='blues')),
    ]

    plot_levelsets(
        (shp.project_onto(vfhj, 0, 1, 2), dict(colorscale='greens', opacity=1.0)),
        (shp.project_onto(vfhz, 0, 1, 2), dict(colorscale='blues', opacity=1.0)),
        plot_func=plot3D_levelset,
        min_bounds=[           0, *min_bounds[:2]],
        max_bounds=[TIME_HORIZON, *max_bounds[:2]],
        xtitle='x [m]', ytitle='v [m/s]',
        eye=eye,
    ).show()

if IMPL == 'HZ':

    vfs_2d += [
        (hz2hj(*out.astuple()), dict(colorscale='blues')),
    ]

if IMPL == 'TVHZ':

    if not isinstance(out, list):
        print('TVHZ: Result is constant')
        vf = np.array([hz2hj(*out.astuple())] * impl.N)
    
    else:
        vf = np.array([hz2hj(*_out.astuple())
                       for _out in tqdm(out)])

    plot3D_levelset(
        vf,
        min_bounds=[           0, *min_bounds[:2]],
        max_bounds=[TIME_HORIZON, *max_bounds[:2]],
        xtitle='x [m]', ytitle='v [m/s]',
        colorscale='blues',
        eye=eye,
    ).show()

    vfs_2d += [
        (hz2hj(*trg.astuple()), dict(colorscale='rdylgn')),
        (vf[0], dict(colorscale='blues')),
    ]

if IMPL == 'TVHJ':

    plot_levelsets(
        (shp.project_onto(out, 0, 1, 2), dict(colorscale='greens')),
        # (shp.project_onto(_goal, 0, 1, 2), dict(colorscale='blues', opacity=0.6)),
        # (shp.project_onto(_city, 0, 1, 2), dict(colorscale='reds', opacity=0.6)),
        plot_func=plot3D_levelset,
        min_bounds=[           0, *min_bounds[:2]],
        max_bounds=[TIME_HORIZON, *max_bounds[:2]],
        xtitle='x [m]', ytitle='v [m/s]',
        eye=eye,
    ).show()

    vfs_2d += [
        # (trg[0], dict(colorscale='rdylgn')),
        (out[0], dict(colorscale='greens')),
    ]


In [None]:
plot_levelsets(
    (vfhj[0], dict(colorscale='greens', name="HJ")),
    (vfhz[0], dict(colorscale='blues', name="HZ")),
    min_bounds=min_bounds[:2],
    max_bounds=max_bounds[:2],
    xtitle='x [m]', ytitle='v [m/s]',
    # showlegend=True,
)

In [None]:
if IMPL in ('HZ', '-TVHZ'):
    from hz_reachability.visualizer import ZonoVisualizer
    from hz_reachability.auxiliary_operations import ZonoOperations
    import matplotlib.pyplot as plt

    op = ZonoOperations()
    viz = ZonoVisualizer(op)

    # out = op.redundant_c_hz(out)
    # viz.new_fig()
    viz.vis_hz([out] if IMPL == 'HZ' else
               out[:1] if isinstance(out, list) else
               [out])

    plt.xlim(-5, +5)
    plt.ylim(-5, +5)
    plt.grid(False)
    plt.xticks(range(-5, 6))
    plt.yticks(range(-5, 6))

    plt.show()