In [1]:
import jax
import jax.numpy as jnp
import numpy as np

import matplotlib.pyplot as plt
from ipywidgets import interact

from cbfax.dynamics import *
from cbfax.cbf import *
from cbfax.plotting import plot_halfspace


In [2]:
def plot_cbf(barrier_func, th, xlim=(-10, 10), ylim=(-10, 10)):
    # Create a grid of points
    x = np.linspace(xlim[0], xlim[1], 400)
    y = np.linspace(ylim[0], ylim[1], 400)
    X, Y = np.meshgrid(x, y)

    # Calculate the CBF values
    H = jax.vmap(barrier_func)(jnp.stack([X.reshape([-1]), Y.reshape([-1]), jnp.ones_like(X.reshape([-1]))*th], 1)).reshape(400,400)


    # Plot the CBF
    plt.contourf(X, Y, H >= 0, alpha=0.2, colors=['#ff9999', '#99ff99'])
    plt.contourf(X, Y, H, alpha=0.1)
    plt.contour(X, Y, H, levels=[0], colors='black')

    # Set the limits and labels
    plt.xlim(xlim)
    plt.ylim(ylim)
    plt.xlabel('x')
    plt.ylabel('y')
    # plt.title(f'Control Barrier Function: {a}x^2 + {b}y^2 + {c}')

    plt.grid(True)
    plt.axhline(0, color='black', linewidth=0.5)
    plt.axvline(0, color='black', linewidth=0.5)

In [None]:
dynamics = SimpleCar(wheelbase=2)
state = jnp.ones(dynamics.state_dim)
time = 0.

alpha = lambda x: 1 * x
radius = 1.
cbfs = [lambda x: jnp.linalg.norm(x[:2], axis=0)**2 - radius**2, lambda x: jnp.dot(x[:2], jnp.array([jnp.cos(x[2]), jnp.sin(x[2])])) / jnp.linalg.norm(x[:2], axis=0) + 0.95]
@interact
def interactive_plot(x=(-5, 5, 0.2), y=(-5, 5, 0.2), th=(-np.pi, np.pi, 0.1), v=(0.1, 1., 0.1), delta=(-1.,1., 0.1)):
    state = jnp.array([x, y, th])
    control = jnp.array([v, delta])
    lin_dyn = dynamics.linearized_dynamics(state, control, time)
    

    plt.figure(figsize=(10,5))
    plt.subplot(1,2,1)
    for cbf in cbfs:
        plot_cbf(cbf, th, xlim=(-5, 5), ylim=(-5, 5))
    plt.quiver(x, y, v * np.cos(th), v * np.sin(th))
    plt.axis("equal")

    plt.subplot(1,2,2)
    for cbf in cbfs:
        linear, constant = get_cbf_constraint_rd1(state, time, cbf, alpha, lin_dyn)
        plot_halfspace(linear, constant, ">=", xlim=(-5, 5), ylim=(-5, 5))
    plt.scatter(control[:1], control[1:], color="black", label="Control")

interactive(children=(FloatSlider(value=0.0, description='x', max=5.0, min=-5.0, step=0.2), FloatSlider(value=…