In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib as mpl

import os

from navground import core, sim
from navground.sim.ui.video import display_video_from_run

import perdiver.perdiver as perdiver
from perdiver.navground_io import parser, run_navground
from perdiver.distances import *

plots_dir = os.path.join("plots", "stability")
os.makedirs(plots_dir, exist_ok=True)

In [None]:
args = parser.parse_args([
        '--scenario', 'Cross',
        '--side', '10',
        '--num_steps', '500',
        '--time_step', '0.1',
        '--num_agents', '12',
        '--max_speed', '1.66',
        '--optimal_speed_min', '0.1',
        '--optimal_speed_min', '0.15',
        '--radius', '0.4',
        '--safety_margin', '0.1',
        '--epsilon', '30',
        '--time_delay', '5',
])
behavior_list = ["ORCA"]
runs = {}
for behavior in behavior_list:
    args.behavior = behavior
    runs[behavior] = run_navground(args)

In [None]:
from perdiver.perdiver import get_matching_diagram, plot_matching_diagram, plot_timesteps_cross

args.weight = 1

run = runs["ORCA"][0]
ps = np.array(run.poses)
twists = np.array(run.twists)

fig, ax = plt.subplots(figsize=(15,5), ncols=3)
match_diagram_list = []
for i, initial_step in enumerate([150, 153]):
    args.start_step = initial_step
    X = ps[args.start_step]
    Y = ps[args.start_step + args.epsilon]
    vel_X = twists[args.start_step]
    vel_Y = twists[args.start_step + args.epsilon]
    X_len = X.shape[0]-1
    # Plot two timesteps
    plot_timesteps_cross(run, [args.start_step, args.start_step + args.epsilon], args.side, ax[i]) 
    # Save figure
    ax[i].set_xticks([])
    ax[i].set_yticks([])
    ax[i].set_title(f"Start timestep:{initial_step}", fontsize=20)
    # Plot matching diagram
    Dist_X = distances_2Dtorus_weighted_velocities(X, vel_X, args.weight, args.side)
    Dist_Y = distances_2Dtorus_weighted_velocities(Y, vel_Y, args.weight, args.side)
    match_diagram_list.append(get_matching_diagram(Dist_X, Dist_Y))
# end for    
plot_matching_diagram(match_diagram_list[0], ax[2], color="blue")
match_diagram = match_diagram_list[1]
ax[2].scatter(match_diagram[:,0], match_diagram[:,1], color="red", marker="X")
ax[2].set_title("Matching diagrams", fontsize=20)
# Save figure
plt.savefig(os.path.join(plots_dir, "stability.png"))