---

This Notebook was developed by [Haimin Hu](https://haiminhu.org/) for the RSS'24 paper [_Who Plays First? Optimizing the Order of Play in Stackelberg Games with Many Robots_](https://saferobotics.princeton.edu/research/who-plays-first).

Instructions:
* Run the cells to initiate closed-loop simulation for each method.
* The simulation results are automatically displayed as an animation within the Notebook.
* The demo was created for $N=4$ agents, but can be straightforwardly extended to more agents by changing the initial conditions and config files.

---

##### Who Plays First

In [1]:
# region: Imports and initialization
import os, sys, jax
import numpy as np
from IPython.display import display, HTML
from copy import deepcopy
if len(os.getcwd()) > 0:
  sys.path.insert(1, "../../")
  import wpf
  from wpf.STP.utils import *
  from wpf.STP import *

jax.config.update('jax_platform_name', 'cpu')
jax.config.update('jax_enable_x64', True)
np.set_printoptions(suppress=True)

# Loads the config and specifies the folder to save figures.
config = load_config("config/default.yaml")
config_target = load_config("config/target_reach.yaml")
dsolver = iLQR(load_config("config/default.yaml"))  # default solver
fig_prog_folder = os.path.join(config.OUT_FOLDER, "progress_wpf")
os.makedirs(fig_prog_folder, exist_ok=True)
quad_img = plt.imread('asset/aircraft.png', format="png")

# Sets the initial and target states.
#                     x     y   v   psi
x_init_R1 = np.array([-2.0, 2.5, 0., -0.79])  # Initial state of R1.
x_init_R2 = np.array([2.5, 2.0, 0., 3.93])  # Initial state of R2.
x_init_R3 = np.array([2.8, -2.5, 0., 2.36])  # Initial state of R3.
x_init_R4 = np.array([-2.0, -2.8, 0., 0.79])  # Initial state of R4.
x_init = [x_init_R1, x_init_R2, x_init_R3, x_init_R4]

#                     x    y      v             psi
target_R1 = np.array([1.5, -1.5, config.V_REF, 2.36])  # Target state of R1.
target_R2 = np.array([-1.5, -1.5, config.V_REF, 0.79])  # Target state of R2.
target_R3 = np.array([-1.5, 1.5, config.V_REF, -0.79])  # Target state of R3.
target_R4 = np.array([1.5, 1.5, config.V_REF, 3.93])  # Target state of R4.
targets = [target_R1, target_R2, target_R3, target_R4]

# Initializes stats.
num_agent = config.N_AGENT
zero_control = [np.zeros((2, config.N)) for _ in range(num_agent)]
init_control = deepcopy(zero_control)
sim_steps = config.SIM_STEPS
state_hist = [np.zeros((4, sim_steps)) for _ in range(num_agent)]
ctrl_hist = [np.zeros((2, sim_steps)) for _ in range(num_agent)]
order_hist = []
zonef_hist = []

# Initializes Branch-and-play.
x_cur = deepcopy(x_init)
STP_instance = (x_cur, init_control, targets, config)

if config.BRANCHING == 'depthfirst':
  branching_strategy = wpf.depthfirst_cb
elif config.BRANCHING == 'bestfirst':
  branching_strategy = wpf.bestfirst_cb

bnb = wpf.BranchAndPlay(
    num_agent, STP_instance, branching_strategy, stp_solve_cb, lambda *args: None, stp_branching,
    wpf.Settings(
        config.MAX_NODES, config.FEAS_TOL, config.MIN_GAP, config.MAX_ITER_BNP,
        verbose=config.VERBOSE_BNP
    ), stp_initializer
)

# Initializes the ATC Zone.
zone = ATCZone(config, targets, STP(config), iLQR(config_target))
# endregion

In [None]:
# region: Receding horizon planning
dh = display(HTML('<pre>WPF starts.</pre>'), display_id=True)
k = 0
while not all(zone.is_reach_target(x_cur)) and k < sim_steps:

  # region: Solves BnP.
  if k > 0:
    Js_prev = _results.custom_statistics["Js_prev"]
    bnb.update_problem(instance_data=(x_cur, init_control, targets, Js_prev, config), mode="update")

  bnb.solve()
  _results = bnb.results
  _stime = sum(_results.custom_statistics["jax_compile_time"])
  # endregion

  # region: Locks the ordering when collision check fails.
  _bnb_state = jnp.stack(_results.incumbent[0], axis=2)
  if (zone.oz_planner.pairwise_collision_check(_bnb_state[:, :config.COL_CHECK_N, :])).any():
    _states, _controls = zone.plan_stp(x_cur, init_control, targets, order_hist[-1])
    x_new = [_states[ii][:, 1] for ii in range(num_agent)]
    order_hist.append(order_hist[-1])
    us_cur = _controls
  else:
    x_new = [_results.incumbent[0][ii][:, 1] for ii in range(num_agent)]
    order_hist.append(list(bnb.results.incumbent_permutation))
    us_cur = _results.incumbent[1]
  for ii in range(num_agent):
    init_control[ii][:, :-1] = us_cur[ii][:, 1:]
  # endregion

  # region: Checks ATC zone.
  _, us_new, zone_flags = zone.check_zone(x_cur, x_new, us_cur, zero_control)
  zonef_hist.append(zone_flags)
  col_flag = zone.is_collision(x_cur)
  # endregion

  # region: Updates and reports stats.
  for ii in range(num_agent):
    _xii, _ = dsolver.dynamics.integrate_forward(x_cur[ii], us_new[ii][:, 0])
    x_cur[ii] = _xii

  for ii in range(num_agent):
    state_hist[ii][:, k], ctrl_hist[ii][:, k] = x_cur[ii], us_new[ii][:, 0]
  k += 1

  _info = [
      'step: ', k, ' | stime: ', '{:04.2f}'.format(_stime), ' | Objective: ',
      '{:03.1f}'.format(_results.global_ub), ' | Permutation: ', _results.incumbent_permutation,
      ' | Col: ', col_flag
  ]
  _info = [str(_item) for _item in _info]
  dh.update(''.join(_info))
  # endregion
# endregion

# region: Wraps up.
state_hist = [state_hist[ii][:, :k] for ii in range(num_agent)]
ctrl_hist = [ctrl_hist[ii][:, :k] for ii in range(num_agent)]

# Plots the optimal trajectory.
if config.PLOT_RES:
  plot_trajectory(
      state_hist, config, fig_prog_folder, orders=order_hist, targets=targets,
      colors=['r', 'g', 'b', 'm'], xlim=(-3, 3), ylim=(-3, 3), figsize=(20, 20), fontsize=35,
      image=quad_img, plot_arrow=False, zone=zone, zone_flags=zonef_hist
  )
  img = make_animation(state_hist[0].shape[1], config, fig_prog_folder)
  display(img)
# endregion

##### First-come-first-served

In [None]:
# region: Initialization
# Loads the config and specifies the folder to save figures.
fig_prog_folder = os.path.join(config.OUT_FOLDER, "progress_fcfs")
os.makedirs(fig_prog_folder, exist_ok=True)

# Initializes stats.
x_cur = deepcopy(x_init)
num_agent = config.N_AGENT
zero_control = [np.zeros((2, config.N)) for _ in range(num_agent)]
init_control = deepcopy(zero_control)
sim_steps = config.SIM_STEPS
state_hist = [np.zeros((4, sim_steps)) for _ in range(num_agent)]
ctrl_hist = [np.zeros((2, sim_steps)) for _ in range(num_agent)]
order_hist = []
zonef_hist = []
dh = display(HTML('<pre>FCFS starts.</pre>'), display_id=True)

perm = [None] * num_agent
assigned_flags = [False] * num_agent
order = 0

# Initializes the ATC Zone.
zone = ATCZone(config, targets, STP(config), iLQR(config_target))
# endregion

# region: Receding horizon planning
k = 0
while not all(zone.is_reach_target(x_cur)) and k < sim_steps:

  # region: Determines the order on a first-come-first-serve basis.
  if not all(assigned_flags):
    zone_flags = zone.is_in_zone(x_cur)
    for ii in range(num_agent):
      if not assigned_flags[ii] and zone_flags[ii]:
        perm[order] = ii
        assigned_flags[ii] = True
        order += 1
  # endregion

  # region: Plans STP.
  _states, _controls = zone.plan_stp(x_cur, init_control, targets, perm)
  x_new = [_states[ii][:, 1] for ii in range(num_agent)]
  order_hist.append(perm)
  us_cur = _controls

  for ii in range(num_agent):
    init_control[ii][:, :-1] = us_cur[ii][:, 1:]
  # endregion

  # region: Checks ATC zone.
  _, us_new, zone_flags = zone.check_zone(x_cur, x_new, us_cur, zero_control)
  zonef_hist.append(zone_flags)
  col_flag = zone.is_collision(x_cur)
  # endregion

  # region: Updates and reports stats.
  for ii in range(num_agent):
    _xii, _ = dsolver.dynamics.integrate_forward(x_cur[ii], us_new[ii][:, 0])
    x_cur[ii] = _xii

  for ii in range(num_agent):
    state_hist[ii][:, k], ctrl_hist[ii][:, k] = x_cur[ii], us_new[ii][:, 0]
  k += 1

  _info = ['step: ', k, ' | Permutation: ', perm, ' | Col: ', col_flag]
  _info = [str(_item) for _item in _info]
  dh.update(''.join(_info))
  # endregion
# endregion

# region: Wraps up.
state_hist = [state_hist[ii][:, :k] for ii in range(num_agent)]
ctrl_hist = [ctrl_hist[ii][:, :k] for ii in range(num_agent)]

# Plots the optimal trajectory.
if config.PLOT_RES:
  plot_trajectory(
      state_hist, config, fig_prog_folder, orders=order_hist, targets=targets,
      colors=['r', 'g', 'b', 'm'], xlim=(-3, 3), ylim=(-3, 3), figsize=(20, 20), fontsize=35,
      image=quad_img, plot_arrow=False, zone=zone, zone_flags=zonef_hist
  )
  img = make_animation(state_hist[0].shape[1], config, fig_prog_folder, name="rollout_fcfs.gif")
  display(img)
# endregion

##### Nash ILQ Game

In [None]:
# region: Initialization
# Loads the config and specifies the folder to save figures.
config = load_config("config/ilqgame.yaml")
config_ilqr = load_config("config/default.yaml")
config_ilqr.N, config_ilqr.T = config.N, config.T
config_target = load_config("config/target_reach.yaml")
config_target.N, config_target.T = config.N, config.T
fig_prog_folder = os.path.join(config.OUT_FOLDER, "progress_wpf")
os.makedirs(fig_prog_folder, exist_ok=True)
quad_img = plt.imread('asset/aircraft.png', format="png")

# Loads the config and specifies the folder to save figures.
fig_prog_folder = os.path.join(config.OUT_FOLDER, "progress_ilq")
os.makedirs(fig_prog_folder, exist_ok=True)

# Initializes stats.
x_cur = deepcopy(x_init)
num_agent = config.N_AGENT
zero_control = [np.zeros((2, config.N)) for _ in range(num_agent)]
init_control = deepcopy(zero_control)
sim_steps = config.SIM_STEPS
state_hist = [np.zeros((4, sim_steps)) for _ in range(num_agent)]
ctrl_hist = [np.zeros((2, sim_steps)) for _ in range(num_agent)]
order_hist = []
zonef_hist = []
dh = display(HTML('<pre>ILQGame starts.</pre>'), display_id=True)

perm = [None] * num_agent
assigned_flags = [False] * num_agent
order = 0

# Initializes the ATC Zone.
zone = ATCZone(config, targets, STP(config_ilqr), iLQR(config_target))

# Sets up the iLQGame solver.
dummy_ilqr = iLQR(config_ilqr)
jnt_sys = ProductMultiPlayerDynamicalSystem([dummy_ilqr.dynamics] * num_agent)
solver = ILQGame(config, jnt_sys, verbose=False)
# endregion

# region: Receding horizon planning
k = 0
while not all(zone.is_reach_target(x_cur)) and k < sim_steps:

  # region: Plans ILQGame.
  _states, _controls, _, _ = solver.solve(x_cur, init_control, targets)
  x_new = [_states[ii][:, 1] for ii in range(num_agent)]
  order_hist.append(perm)
  us_cur = _controls

  for ii in range(num_agent):
    init_control[ii][:, :-1] = us_cur[ii][:, 1:]
  # endregion

  # region: Shielding
  _states_cc = jnp.stack(_states, axis=2)
  if (zone.oz_planner.pairwise_collision_check(_states_cc[:, :config.COL_CHECK_N, :])).any():
    us_cur = deepcopy(us_cur)
    _xs_bk, _us_bk = zone.plan_stp(x_cur, init_control, targets, list(range(num_agent)))
    _xs_bk_cc = jnp.stack(_xs_bk, axis=2)
    if (zone.oz_planner.pairwise_collision_check(_xs_bk_cc[:, :config.COL_CHECK_N, :])).any():
      for ii in range(num_agent):
        if ii > 0:
          us_cur[ii][:, 0] = -x_cur[ii][2] / dsolver.dynamics.dt
    else:
      for ii in range(num_agent):
        us_cur[ii][:, 0] = _us_bk[ii][:, 0]
  # endregion

  # region: Checks ATC zone.
  _, us_new, zone_flags = zone.check_zone(x_cur, x_new, us_cur, zero_control)
  zonef_hist.append(zone_flags)
  col_flag = zone.is_collision(x_cur)
  # endregion

  # region: Updates and reports stats.
  for ii in range(num_agent):
    _xii, _ = solver.dynamics._subsystem.integrate_forward_norev(x_cur[ii], us_new[ii][:, 0])
    x_cur[ii] = _xii

  for ii in range(num_agent):
    state_hist[ii][:, k], ctrl_hist[ii][:, k] = x_cur[ii], us_new[ii][:, 0]
  k += 1

  print('step: ', k, ' | Permutation: ', perm, ' | Col: ', col_flag)
  _info = ['step: ', k, ' | Permutation: ', perm, ' | Col: ', col_flag]
  _info = [str(_item) for _item in _info]
  dh.update(''.join(_info))
  # endregion
# endregion

# region: Wraps up.
state_hist = [state_hist[ii][:, :k] for ii in range(num_agent)]
ctrl_hist = [ctrl_hist[ii][:, :k] for ii in range(num_agent)]

# Plots the optimal trajectory.
if config.PLOT_RES:
  plot_trajectory(
      state_hist, config, fig_prog_folder, orders=order_hist, targets=targets,
      colors=['r', 'g', 'b', 'm'], xlim=(-3, 3), ylim=(-3, 3), figsize=(20, 20), fontsize=35,
      image=quad_img, plot_arrow=False, zone=zone, zone_flags=zonef_hist
  )
  img = make_animation(state_hist[0].shape[1], config, fig_prog_folder, name="rollout_ilq.gif")
  display(img)
# endregion