## Imports

In [None]:
# !pip install dynamaxsys==0.0.5
# !git clone https://github.gatech.edu/ACDS/MPPI-Generic.git
# !pip install tqdm
# !pip install hj-reachability

In [None]:
# using the dynamaxsys library to import dynamical systems implemented in JAX: https://github.com/UW-CTRL/dynamaxsys
from dynamaxsys.unicycle import Unicycle
from dynamaxsys.base import get_discrete_time_dynamics
from dynamaxsys.utils import linearize

import matplotlib.pyplot as plt
from matplotlib.patches import Ellipse, Circle, FancyArrow
from scipy.stats import chi2
import jax, jax.numpy as jnp

from typing import Callable, List, Literal, Tuple

## Visualize Trajectories

In [None]:
dt          = 0.1
sim_steps   = 150

start_state = np.array([-2.0, -1.0, 0.0])
goal        = np.array([5.0,  0.0])

obstacles         = [np.array([1.0, 0.0]), np.array([3.0, -0.5])]
obstacle_radius   = 0.5
v_max             = 0.5               # speed bound for HJ tube

horizon           = 40
n_samples         = 100
lam               = 2.0
noise_scale       = 0.2

collision_penalty = 5e3
kappa_cbf         = 1.0               # CBF gain
kappa_scbf        = 5.0               # Smooth‑CBF gain

np.random.seed(42)                    # overall reproducibility

# ============================================================
# ------------------ UNICYCLE DYNAMICS -----------------------
# ============================================================
def unicycle_dynamics(state, control, dt):
    x, y, th = state
    v, w     = control
    return np.array([
        x  + v*np.cos(th)*dt,
        y  + v*np.sin(th)*dt,
        th + w*dt
    ])

# ============================================================
# ------------------ SAFETY FILTERS --------------------------
# ============================================================
class SafetyFilter:
    def penalty(self, state: np.ndarray, t: float) -> float:
        return 0.0

class ApproxHJSafety(SafetyFilter):
    """Inflated obstacle tube: r + v_max (T‑t)."""
    def __init__(self, obstacles, radius, v_max, horizon_T, penalty=1e6):
        self.obs, self.r, self.vmax, self.T, self.pen = obstacles, radius, v_max, horizon_T, penalty
    def penalty(self, state, t):
        x, y, _ = state
        tube = self.r + self.vmax * (self.T - t)
        return self.pen if any(np.hypot(x-ox, y-oy) < tube for ox,oy in self.obs) else 0.0

class CBFSafety(SafetyFilter):
    """1/h soft penalty (standard reciprocal CBF)."""
    def __init__(self, obstacles, radius, kappa=1.0, penalty=1e6):
        self.obs, self.r2, self.kappa, self.pen = obstacles, radius**2, kappa, penalty
    def penalty(self, state, t):
        x, y, _ = state
        for ox, oy in self.obs:
            h = (x-ox)**2 + (y-oy)**2 - self.r2
            if h <= 0:
                return self.pen
            return self.kappa / h
        return 0.0

class SCBFSafety(SafetyFilter):
    """Smooth‑CBF: quadratic penalty inside r‑tube, softer outside."""
    def __init__(self, obstacles, radius, kappa=5.0, penalty=1e6):
        self.obs, self.r, self.kappa, self.pen = obstacles, radius, kappa, penalty
    def penalty(self, state, t):
        x, y, _ = state
        for ox, oy in self.obs:
            dist = np.hypot(x-ox, y-oy)
            if dist < self.r:
                return self.pen
            return self.kappa * (self.r / dist)**2
        return 0.0

class OptHJSafety(SafetyFilter):
    """Placeholder optimal HJ: signed distance to nearest obstacle."""
    def __init__(self, obstacles, radius, penalty=1e6):
        self.obs, self.r, self.pen = obstacles, radius, penalty
    def value(self, x, y):
        return min(np.hypot(x-ox, y-oy) - self.r for ox, oy in self.obs)
    def penalty(self, state, t):
        x, y, _ = state
        return 0.0 if self.value(x, y) > 0 else self.pen

class CombinedSafety(SafetyFilter):
    """Sum penalties from multiple filters."""
    def __init__(self, filters):
        self.filters = filters
    def penalty(self, state, t):
        return sum(f.penalty(state, t) for f in self.filters)

# ============================================================
# ------------------ MPPI CONTROLLER -------------------------
# ============================================================
class MPPI:
    def __init__(self, dynamics, dt, goal, horizon, n_samples,
                 lam, noise_scale, safety: SafetyFilter|None):
        self.dyn    = dynamics
        self.dt     = dt
        self.goal   = goal
        self.H      = horizon
        self.N      = n_samples
        self.lam    = lam
        self.noise  = noise_scale
        self.safety = safety or SafetyFilter()
        self.u_nom  = np.zeros((horizon, 2))

    def _rollout_cost(self, x0, u_seq):
        x = x0.copy()
        cost = 0.0
        for k, u in enumerate(u_seq):
            x = self.dyn(x, u, self.dt)
            cost += np.linalg.norm(x[:2] - self.goal)
            cost += self.safety.penalty(x, k*self.dt)
        return cost

    def step(self, x0):
        noise   = self.noise * np.random.randn(self.N, self.H, 2)
        rollouts= self.u_nom[None,:,:] + noise
        costs   = np.array([self._rollout_cost(x0, u) for u in rollouts])

        w = np.exp(-costs / self.lam)
        w_sum = w.sum()
        w = w / w_sum if w_sum > 0 and not np.isnan(w_sum) else np.ones_like(w) / len(w)

        u0 = np.tensordot(w, rollouts[:,0], axes=1)
        self.u_nom = np.vstack([self.u_nom[1:], u0])
        return u0

def simulate(controller, seed=0):
    np.random.seed(seed)
    x = start_state.copy()
    traj = [x]
    for _ in range(sim_steps):
        u = controller.step(x)
        x = unicycle_dynamics(x, u, dt)
        traj.append(x)
    return np.array(traj)

# ============================================================
# ------------------ BUILD CONTROLLERS -----------------------
# ============================================================
T_horizon_sec = horizon * dt

hj_filter    = ApproxHJSafety(obstacles, obstacle_radius, v_max, T_horizon_sec, collision_penalty)
cbf_filter   = CBFSafety(obstacles, obstacle_radius, kappa_cbf, collision_penalty)
scbf_filter  = SCBFSafety(obstacles, obstacle_radius, kappa_scbf, collision_penalty)
opt_hj_filter= OptHJSafety(obstacles, obstacle_radius, collision_penalty)

controllers = {
    "Vanilla":    MPPI(unicycle_dynamics, dt, goal, horizon, n_samples, lam, noise_scale, None),
    "HJ":         MPPI(unicycle_dynamics, dt, goal, horizon, n_samples, lam, noise_scale, hj_filter),
    "HJ+CBF":     MPPI(unicycle_dynamics, dt, goal, horizon, n_samples, lam, noise_scale,
                       CombinedSafety([hj_filter, cbf_filter])),
    "HJ+OptHJ":   MPPI(unicycle_dynamics, dt, goal, horizon, n_samples, lam, noise_scale,
                       CombinedSafety([hj_filter, opt_hj_filter])),
    "HJ+SCBF":    MPPI(unicycle_dynamics, dt, goal, horizon, n_samples, lam, noise_scale,
                       CombinedSafety([hj_filter, scbf_filter]))
}


trajectories = {label: simulate(ctrl, seed=42) for label, ctrl in controllers.items()}

fig, ax = plt.subplots(figsize=(8,4))

for label, traj in trajectories.items():
    ax.plot(traj[:,0], traj[:,1], label=label)

ax.plot(start_state[0], start_state[1], marker='o', linestyle='', markersize=8, label='Start')
ax.plot(goal[0], goal[1], marker='x', linestyle='', markersize=8, label='Goal')

ax.scatter([p[0] for p in obstacles], [p[1] for p in obstacles],
           s=(obstacle_radius*400)**2/100.0, alpha=0.3, label='Obstacles')

ax.set_title("Unicycle MPPI: Safety Filter Comparison")
ax.set_xlabel("x position (m)")
ax.set_ylabel("y position (m)")
ax.axis('equal')
ax.grid(True)
ax.legend()
plt.show()

## Compare Results

### Success Rate

### Collision Rate