# Static Environment

In [None]:
%load_ext autoreload
%autoreload 2

## Load Instance 

In [None]:
from st_gaussian_prm.utils import load_instance, load_config

instance_config, solver_config, experiment_config, apf_config = load_config("./configs/static_config.yaml")
solver = load_instance(instance_config, solver_config, **experiment_config)

gaussian_prm = solver.gaussian_prm
agent_radius = instance_config["agent_radius"]
starts_idx = solver.starts_idx
goals_idx = solver.goals_idx
num_agents = solver.num_agents


In [None]:
print(num_agents)
print(starts_idx)
print(goals_idx)
print(solver.starts_agent_count)
print(solver.goals_agent_count)

### Visualize instance

In [None]:
import matplotlib.pyplot as plt
fig, ax = gaussian_prm.visualize_roadmap()

for start in starts_idx:
    gaussian_prm.gaussian_nodes[start].visualize_gaussian(ax, cmap="Reds")

for goal in goals_idx :
    gaussian_prm.gaussian_nodes[goal].visualize_gaussian(ax)

plt.show()

tested with DRRT. For 12 agents and 5 goals, it did not find a solution within 2 hours.

## Run Solver

In [None]:
solution = solver.solve()
assert solution["success"], "solver failed."
timestep = solution["timestep"]
paths = solution["paths"]
g_nodes = solution["g_nodes"]
candid_starts_idx = solution["starts_idx"]
candid_goals_idx = solution["goals_idx"]

num_violation, max_violation_percentage = solver.eval_capacity(solution["paths"])
print("num violation: ", num_violation)
print("max violation percentage: ", max_violation_percentage)

## Plot Gaussian trajectory

In [None]:
import matplotlib.pyplot as plt
from matplotlib.cm import ScalarMappable
import matplotlib.colors as colors

cmap = plt.colormaps['coolwarm']
norm = colors.Normalize(vmin=0, vmax=solution["timestep"]-1) 

fig, ax = gaussian_prm.visualize_g_nodes()

for timestep in range(solution["timestep"]):
    color = cmap(norm(timestep))
    positions = [path[timestep] for path in solution["paths"]]
    g_nodes = [gaussian_prm.gaussian_nodes[pos] for pos in positions]
    for g_node in g_nodes:
        ellipse = g_node.get_confidence_ellipse()
        x, y = ellipse.exterior.xy
        ax.fill(x, y, facecolor=color)

ax.set_xticks([])
ax.set_yticks([])

sm = ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])  # Required for plt.colorbar()
cbar = plt.colorbar(sm, ax=ax, orientation='horizontal', pad=0.01 )
cbar.set_label('Timestep')

plt.show()
fig.savefig("./solutions/gaussian_trajectory.pdf", bbox_inches="tight")


## Plot Maximum overflow percentage

In [None]:
from matplotlib.cm import ScalarMappable
import matplotlib.colors as colors
import numpy as np
from collections import defaultdict

# Choose a red-based colormap for violation visualization
cmap = plt.colormaps['Reds']
norm = colors.Normalize(vmin=0, vmax=3)  # Adjust vmax depending on expected violation ratio

fig, ax = gaussian_prm.visualize_g_nodes()

# Step 1: Accumulate agent counts per node per timestep
violation_by_node = defaultdict(int)

for timestep in range(solution["timestep"]):
    positions = [path[timestep] for path in solution["paths"]]
    for pos in positions:
        g_node = gaussian_prm.gaussian_nodes[pos]
        count = violation_by_node[pos] + 1  # increment current count for this timestep
        capacity = g_node.get_capacity(agent_radius)
        violation_ratio = max(0, count) / capacity if capacity > 0 else 0
        violation_by_node[pos] = max(violation_by_node[pos], violation_ratio)  # take max over time

# Step 2: Visualize each node by its max violation
for pos, max_violation in violation_by_node.items():
    g_node = gaussian_prm.gaussian_nodes[pos]
    color = cmap(norm(max_violation))
    ellipse = g_node.get_confidence_ellipse()
    x, y = ellipse.exterior.xy
    ax.fill(x, y, facecolor=color, edgecolor='black', linewidth=0.5)

ax.set_xticks([])
ax.set_yticks([])

# Step 3: Add colorbar
sm = ScalarMappable(norm=norm, cmap=cmap)
sm.set_array([])
cbar = plt.colorbar(sm, ax=ax, orientation='horizontal', pad=0.01)
cbar.set_label('Max Capacity Violation Ratio')

plt.show()
fig.savefig("./solutions/gaussian_violation_max.pdf", bbox_inches="tight")


## APF

Parameters should be changed according to the planning instance

In [None]:
assert solution["success"]

from st_gaussian_prm.solvers.micro.apf import APF

trajectories = APF(solution["paths"], solution["g_nodes"], **apf_config).solve()

## Visualize Per-agent Paths

In [None]:
# Static path
from matplotlib import pyplot as plt
fig, ax = gaussian_prm.visualize_map()

cmap = plt.get_cmap("rainbow")
colors = [cmap(i / num_agents) for i in range(num_agents)]

for i, path in enumerate(trajectories):
    x_coords = [loc[0] for loc in path]
    y_coords = [loc[1] for loc in path]
    ax.plot(x_coords, y_coords, '-', label='Path', color=colors[i], linewidth=0.8, alpha=0.5)

for start in starts_idx:
    gaussian_prm.gaussian_nodes[start].visualize(ax)

for goal in goals_idx:
    gaussian_prm.gaussian_nodes[goal].visualize(ax, edgecolor="b")

plt.show()

In [None]:
# Animate path
from matplotlib.animation import FuncAnimation
from matplotlib.patches import Circle

from IPython.display import Video

speed = 20

def animate_solution(agent_radius, paths, fig, ax, fig_path="."):
    """
        Visualize solution trajectory provided instance
    """
    agents = []
    cmap = plt.get_cmap("tab10")
    
    for i in range(len(paths)):
        loc = paths[i][0]
        circle = Circle((loc[0], loc[1]), radius=agent_radius, color=cmap(i % 10))
        agents.append(circle)
        ax.add_patch(circle)

    def init():
        return agents

    def update(frame):
        frame = frame * speed
        for agent, traj in zip(agents, paths):
            agent.set_center(traj[frame])
        return agents

    anim = FuncAnimation(fig, update, frames=len(paths[0]) // speed, 
                         init_func=init, blit=True, interval=100)
    anim.save(f"{fig_path}/apf_solution.mp4", writer='ffmpeg', fps=24)
    plt.close()

fig_path = "solutions"
fig, ax = gaussian_prm.visualize_map()
animate_solution(agent_radius, trajectories, fig, ax, fig_path=fig_path)
Video(filename="solutions/apf_solution.mp4", embed=True)