In [None]:
%matplotlib notebook
%load_ext autoreload
%autoreload 2

In [None]:
import numpy as np
import cvxpy as cp
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle

from gcspy.graph_of_convex_sets import GraphOfConvexSets
from gcspy.shortest_path_problem import ShortestPathProblem

In [None]:
np.random.seed(1)
grid_size = 10
box_side = .1
boxes = []
for i in np.linspace(0, 1, 10):
    for j in np.linspace(0, 1, 10):
        center = np.array([i, j])
        sides = np.random.rand(2) * box_side
        boxes.append((center - sides, center + sides))

In [None]:
def plot_boxes(boxes):
    for box in boxes:
        plt.gca().add_patch(Rectangle(box[0], *(box[1] - box[0]), fc='lightcyan', ec='k'))

plt.figure()
plot_boxes(boxes)
plt.xlim([-.2, 1.2])
plt.ylim([-.2, 1.2])

In [None]:
def intersect(box1, box2):
    bl = np.maximum(box1[0], box2[0])
    tr = np.minimum(box1[1], box2[1])
    return all(bl <= tr)

gcs = GraphOfConvexSets()
for i, box in enumerate(boxes):
    v = gcs.add_vertex(i)
    x1 = v.add_variable(2)
    v.add_constraint(x1 >= box[0])
    v.add_constraint(x1 <= box[1])
    x2 = v.add_variable(2)
    v.add_constraint(x2 >= box[0])
    v.add_constraint(x2 <= box[1])
    
for i, box1 in enumerate(boxes):
    for j, box2 in enumerate(boxes):
        if i != j:
            if intersect(box1, box2):
                u = gcs.get_vertex(i)
                v = gcs.get_vertex(j)
                e = gcs.add_edge(u, v)
                e.add_length(cp.norm(u.variables[1] - u.variables[0], 2))
                if j == len(boxes) - 1:
                     e.add_length(cp.norm(v.variables[1] - v.variables[0], 2))
                e.add_constraint(u.variables[1] == v.variables[0])

In [None]:
spp = ShortestPathProblem(gcs)
sol = spp.solve(gcs.get_vertex(0), gcs.get_vertex(99), relaxation=0)

In [None]:
traj = []
for e in gcs.edges:
    if np.isclose(sol.y[e], 1):
        traj.extend(sol.x[e.u])
        traj.extend(sol.x[e.v])
traj = np.array(traj)

plt.figure()
plot_boxes(boxes)
plt.plot(*traj.T, c='b')
plt.xlim([-.2, 1.2])
plt.ylim([-.2, 1.2])