---

This Notebook was developed by [Haimin Hu](https://haiminhu.org/) for the RSS'24 paper [_Blending Data-Driven Priors in Dynamic Games_](https://kl-games.github.io).

Instructions:
* Run the cells to initiate closed-loop simulation for each method.
* The simulation results are automatically displayed as an animation within the Notebook.
* The reference policy uses a pre-trained neural game solver. If you want to train your own model, please refer to [Neural NOD](https://arxiv.org/pdf/2406.09810) for the training code.

---

##### KLGame (overtaking ref. policy only)

In [None]:
import os
import pickle
import numpy as np
from copy import deepcopy
from IPython.display import display, HTML

import jax
from flax import serialization
import imageio.v2 as imageio
from matplotlib import cm
from matplotlib import pyplot as plt
from matplotlib.transforms import Affine2D
from IPython.display import Image

from iLQR.utils import *
from iLQGame.utils import *
from iLQR import Dynamics
from iLQGame import ILQSolver, KLGameSolver, ExplicitMLP, ProductMultiPlayerDynamicalSystem

# region: Problem setup
ado_mode = "ilq"  # "orig", "ilq"
with open('model/mlp_params.pkl', 'rb') as f:
  model_params_tmp = pickle.load(f)

EGO_REF_VEL_SCALE, ADO_REF_VEL_SCALE = 0.7, 0.7

jax.config.update('jax_platform_name', 'cpu')
jax.config.update('jax_enable_x64', True)
plt.rcParams['font.family'] = 'serif'
plt.rcParams['font.serif'] = ['Times New Roman'] + plt.rcParams['font.serif']
plt.rcParams['axes.titlepad'] = 40

# Loads the config and track file.
config = load_config("config/default.yaml")
config_cautious = load_config("config/cautious.yaml")
config_nn = get_nn_config(config)
track_name = "thunderhill_track"
track = load_track_variable_width(track_name)

itr_receding = config.MAX_ITER_RECEDING

# Ego ellipsoidal footprint
ego_a = config.LENGTH / 2.0
ego_b = config.WIDTH / 2.0
ego_q = np.array([config.WHEELBASE / 2., 0])[:, np.newaxis]
ego_Q = np.diag([ego_a * ego_a, ego_b * ego_b])

# Specifies the folder to save figures.
fig_prog_folder_zoom_in = os.path.join(config.OUT_FOLDER, "progress_zoom_in_MLP")
os.makedirs(fig_prog_folder_zoom_in, exist_ok=True)
supra = plt.imread('tracks/supra.png', format="png")
benz = plt.imread('tracks/benz.png', format="png")

# Creates subsystems and the joint system.
ego_idx = 0
ado_idx = 1
ego_subsys = Dynamics(ego_idx, config, ref_vel_scaling=EGO_REF_VEL_SCALE)
ado_subsys = Dynamics(ado_idx, config, ref_vel_scaling=ADO_REF_VEL_SCALE)
jnt_sys = ProductMultiPlayerDynamicalSystem([ego_subsys, ado_subsys])

# Loads the dataset.
with open("dataset/dataset.pkl", 'rb') as handle:
  ds = pickle.load(handle)

# Sets up the iLQGame solver.
solver_klg = KLGameSolver(track, config, jnt_sys, verbose=False)
solver = ILQSolver(track, config, jnt_sys, verbose=False)
solver_cautious = ILQSolver(track, config_cautious, jnt_sys, verbose=False)
solver.normalize_nn_input(ds['data'].T)
config_eval = deepcopy(config)
config_eval.W_BLOCK = 20.
config_eval.W_INNER_ADO = 0.
config_eval.W_OUTER_ADO = 0.5
solver_ado = ILQSolver(track, config_eval, jnt_sys, verbose=False)
solver_neutral = ILQSolver(track, deepcopy(config), jnt_sys, verbose=False)
solver_neutral.set_neutral_weights()

# Loads NN params
solver.cost_param_model = ExplicitMLP(
    features=config_nn.network_features, cutoff=True, config=config, is_softmax=True,
    px_rel_min=solver._data_min[0, 0], px_rel_range=solver._data_max[0, 0] - solver._data_min[0, 0],
    py_rel_min=solver._data_min[1, 0], py_rel_range=solver._data_max[1, 0] - solver._data_min[1, 0]
)
key = jax.random.PRNGKey(config_nn.random_seed)
model_params = solver.cost_param_model.init(key, jnp.zeros((config_nn.network_dim_in,)))
model_params = serialization.from_state_dict(model_params, model_params_tmp)
# endregion

# region: Specifies the initial state.
pos0_ego, psi0_ego = track.interp([2280], mode='center')  # track.length = 2673.5
pos0_ado, psi0_ado = track.interp([2325], mode='center')
x_init = np.array([
    pos0_ego[0], pos0_ego[1], 0., psi0_ego[0], pos0_ado[0], pos0_ado[1], 0., psi0_ado[0]
])
# endregion

# region: Main simulation loop.
# Initializes the simulation.
x_cur = deepcopy(x_init)
state_hist = np.zeros((solver.dim_x, itr_receding))
control_hist = np.zeros((solver.dim_u_ss, itr_receding, solver.num_players))
_identity_matrix = np.eye(config.DIM_U)
rSigmas = np.tile(_identity_matrix[:, :, np.newaxis, np.newaxis], (1, 1, config.N, 2))
theta_ego_hist, theta_ado_hist = [], []
t_total = 0.
_ot_flag = False
dh = display(HTML('<pre>KLGame planning starts.</pre>'), display_id=True)

for i in range(itr_receding):

  # region: Ego Planning
  # Updates the nominal control signal for warmstart of the next planning cycle.
  controls_ws = np.zeros((solver.dim_u_ss, config.N, solver.num_players))
  if i > 0:
    controls_ws[:, :-1, :] = controls[:, 1:, :]

  # Plans mode trajectories using learned game policy.
  # -> Overtake
  states_ov, controls_ov, _, _, _, _, _ = solver.solve_parametric(model_params, x_cur, controls_ws)

  # Solves KLGame.
  ref_mus = [controls_ov]
  states, controls, t_process, status, thetas, mode = solver_klg.solve(
      x_cur, controls_ws, ref_mus, [rSigmas], [config.REG_EGO, config.REG_ADO]
  )
  # endregion

  # region: Ado Planning
  progress = np.asarray(thetas[0, 0, :]) / track.length
  _ot_flag = progress[ego_idx] > progress[ado_idx] + 0.001
  if _ot_flag:
    solver_ado = solver_neutral

  if ado_mode == "ilq":
    _, controls_ado, _, _, _ = solver_ado.solve(x_cur, controls_ws)
    controls = deepcopy(controls)
    controls[:, 0, ado_idx] = controls_ado[:, 0, ado_idx]
  # endregion

  # region: Executes the control, records states and controls, computes the progress.
  x_cur, _ = solver.dynamics.integrate_forward(x_cur, controls[:, 0, :])
  x_cur = np.asarray(x_cur)

  state_hist[:, i] = x_cur
  control_hist[:, i, :] = controls[:, 0, :]
  theta_ego_hist.append(thetas[0, 0, ego_idx])
  theta_ado_hist.append(thetas[0, 0, ado_idx])

  if i > 0:  # Updates computation time: excludes JAX compilation time at the first time step.
    t_total += t_process
  # endregion

  # region: Plots the current progress.
  plt.figure(figsize=(15, 15))
  track.plot_track(N_pts=1000, plot_raceline=False)
  plt.plot(states[0, 1:], states[1, 1:], linewidth=6, c='red', alpha=1)  # ego plan
  # plt.plot(states[4, 1:], states[5, 1:], linewidth=2, c='b')  # ado plan
  plt.plot(states_ov[0, 1:], states_ov[1, 1:], linewidth=3, linestyle='--', c='orange', alpha=1)

  transform_data = Affine2D().rotate_deg_around(*(x_cur[0], x_cur[1]),
                                                x_cur[3] / np.pi * 180) + plt.gca().transData
  plt.imshow(
      supra, transform=transform_data, interpolation='none', origin='lower',
      extent=[x_cur[0] - 1., x_cur[0] + 4., x_cur[1] - 1.,
              x_cur[1] + 1.], alpha=1.0, zorder=10.0, clip_on=True
  )  # plot ego car figure
  transform_data = Affine2D().rotate_deg_around(*(x_cur[4], x_cur[5]),
                                                x_cur[7] / np.pi * 180) + plt.gca().transData
  plt.imshow(
      benz, transform=transform_data, interpolation='none', origin='lower',
      extent=[x_cur[4] - 1.2, x_cur[4] + 4., x_cur[5] - 1.3,
              x_cur[5] + 1.3], alpha=1.0, zorder=10.0, clip_on=True
  )  # plot ado car figure
  sc = plt.scatter(
      state_hist[0, :i + 1], state_hist[1, :i + 1], s=400, c=state_hist[2, :i + 1], cmap=cm.jet,
      vmin=0, vmax=config.V_MAX, edgecolor='none', marker='o'
  )  # trajectory history
  plt.axis('equal')
  plt.xlim([states[0, 0] - 30., states[0, 0] + 30.])
  plt.ylim([states[1, 0] - 30., states[1, 0] + 30.])
  plt.title("step: " + str(i) + " | mode: " + str(mode))
  plt.xticks([])
  plt.yticks([])
  plt.savefig(os.path.join(fig_prog_folder_zoom_in, str(i) + ".png"), dpi=50)
  plt.rcParams.update({'font.size': 25})
  plt.close()
  # endregion

  # region: Reports simulation status.
  _info = [
      'step: ', i, ' | ego prog: ', '{:04.2f}'.format(progress[ego_idx]), ' | ado prog: ',
      '{:04.2f}'.format(progress[ado_idx]), ' | stime: ', '{:04.2f}'.format(t_process), ' | mode: ',
      mode
  ]
  _info = [str(_item) for _item in _info]
  dh.update(''.join(_info))
  # endregion

# region: Wraps up the simulation.
plt.close('All')
print("Avg. computation time is {:.3f} s per planning cycle.".format(t_total / (itr_receding-1)))
print("Completed a race on", track_name, "in ", round((i+1) * ego_subsys.dt, 2), 's')

# Makes animations.
gif_path_zoom_in = os.path.join(config.OUT_FOLDER, 'rollout_zoom_in_KLG.gif')
with imageio.get_writer(gif_path_zoom_in, mode='I', loop=0) as writer_zoom_in:
  for j in range(i):
    filename = os.path.join(fig_prog_folder_zoom_in, str(j) + ".png")
    image = imageio.imread(filename)
    writer_zoom_in.append_data(image)
img = Image(open(gif_path_zoom_in, 'rb').read())
display(img)
# endregion
# endregion

##### Multi-modal KLGame (overtaking and car-following)

In [None]:
# region: Main simulation loop.
# Initializes the simulation.
x_cur = deepcopy(x_init)
state_hist = np.zeros((solver.dim_x, itr_receding))
control_hist = np.zeros((solver.dim_u_ss, itr_receding, solver.num_players))
_identity_matrix = np.eye(config.DIM_U)
rSigmas = np.tile(_identity_matrix[:, :, np.newaxis, np.newaxis], (1, 1, config.N, 2))
theta_ego_hist, theta_ado_hist = [], []
t_total = 0.
_ot_flag = False
dh = display(HTML('<pre>KLGame planning starts.</pre>'), display_id=True)

for i in range(itr_receding):

  # region: Ego Planning
  # Updates the nominal control signal for warmstart of the next planning cycle.
  controls_ws = np.zeros((solver.dim_u_ss, config.N, solver.num_players))
  if i > 0:
    controls_ws[:, :-1, :] = controls[:, 1:, :]

  # Plans modal trajectories using reference game policies.
  # -> Time-optimal
  if not _ot_flag:
    states_to, controls_to, _, _, _ = solver_cautious.solve(x_cur, controls_ws)
  else:
    states_to, controls_to, _, _, _ = solver_neutral.solve(x_cur, controls_ws)
  # -> Overtake
  states_ov, controls_ov, _, _, _, _, _ = solver.solve_parametric(model_params, x_cur, controls_ws)

  # Solves KLGame.
  ref_mus = [controls_ov, controls_to]
  states, controls, t_process, status, thetas, mode = solver_klg.solve(
      x_cur, controls_ws, ref_mus, [rSigmas, rSigmas], [config.REG_EGO, config.REG_ADO]
  )
  # endregion

  # region: Ado Planning
  progress = np.asarray(thetas[0, 0, :]) / track.length
  _ot_flag = progress[ego_idx] > progress[ado_idx] + 0.001
  if _ot_flag:
    solver_ado = solver_neutral

  if ado_mode == "ilq":
    _, controls_ado, _, _, _ = solver_ado.solve(x_cur, controls_ws)
    controls = deepcopy(controls)
    controls[:, 0, ado_idx] = controls_ado[:, 0, ado_idx]
  # endregion

  # region: Executes the control, records states and controls, computes the progress.
  x_cur, _ = solver.dynamics.integrate_forward(x_cur, controls[:, 0, :])
  x_cur = np.asarray(x_cur)

  state_hist[:, i] = x_cur
  control_hist[:, i, :] = controls[:, 0, :]
  theta_ego_hist.append(thetas[0, 0, ego_idx])
  theta_ado_hist.append(thetas[0, 0, ado_idx])

  if i > 0:  # Updates computation time: excludes JAX compilation time at the first time step.
    t_total += t_process
  # endregion

  # region: Plots the current progress.
  plt.figure(figsize=(15, 15))
  track.plot_track(N_pts=1000, plot_raceline=False)
  plt.plot(states[0, 1:], states[1, 1:], linewidth=6, c='red', alpha=1)  # ego plan
  # plt.plot(states[4, 1:], states[5, 1:], linewidth=2, c='b')  # ado plan
  plt.plot(states_ov[0, 1:], states_ov[1, 1:], linewidth=3, linestyle='--', c='orange', alpha=1)
  plt.plot(states_to[0, 1:], states_to[1, 1:], linewidth=3, linestyle='--', c='blue', alpha=1)

  transform_data = Affine2D().rotate_deg_around(*(x_cur[0], x_cur[1]),
                                                x_cur[3] / np.pi * 180) + plt.gca().transData
  plt.imshow(
      supra, transform=transform_data, interpolation='none', origin='lower',
      extent=[x_cur[0] - 1., x_cur[0] + 4., x_cur[1] - 1.,
              x_cur[1] + 1.], alpha=1.0, zorder=10.0, clip_on=True
  )  # plot ego car figure
  transform_data = Affine2D().rotate_deg_around(*(x_cur[4], x_cur[5]),
                                                x_cur[7] / np.pi * 180) + plt.gca().transData
  plt.imshow(
      benz, transform=transform_data, interpolation='none', origin='lower',
      extent=[x_cur[4] - 1.2, x_cur[4] + 4., x_cur[5] - 1.3,
              x_cur[5] + 1.3], alpha=1.0, zorder=10.0, clip_on=True
  )  # plot ado car figure
  sc = plt.scatter(
      state_hist[0, :i + 1], state_hist[1, :i + 1], s=400, c=state_hist[2, :i + 1], cmap=cm.jet,
      vmin=0, vmax=config.V_MAX, edgecolor='none', marker='o'
  )  # trajectory history
  plt.axis('equal')
  plt.xlim([states[0, 0] - 30., states[0, 0] + 30.])
  plt.ylim([states[1, 0] - 30., states[1, 0] + 30.])
  plt.title("step: " + str(i) + " | mode: " + str(mode))
  plt.xticks([])
  plt.yticks([])
  plt.savefig(os.path.join(fig_prog_folder_zoom_in, str(i) + ".png"), dpi=50)
  plt.rcParams.update({'font.size': 25})
  plt.close()
  # endregion

  # region: Reports simulation status.
  _info = [
      'step: ', i, ' | ego prog: ', '{:04.2f}'.format(progress[ego_idx]), ' | ado prog: ',
      '{:04.2f}'.format(progress[ado_idx]), ' | stime: ', '{:04.2f}'.format(t_process), ' | mode: ',
      mode
  ]
  _info = [str(_item) for _item in _info]
  dh.update(''.join(_info))
  # endregion

# region: Wraps up the simulation.
plt.close('All')
print("Avg. computation time is {:.3f} s per planning cycle.".format(t_total / (itr_receding-1)))
print("Completed a race on", track_name, "in ", round((i+1) * ego_subsys.dt, 2), 's')

# Makes animations.
gif_path_zoom_in = os.path.join(config.OUT_FOLDER, 'rollout_zoom_in_KLG.gif')
with imageio.get_writer(gif_path_zoom_in, mode='I', loop=0) as writer_zoom_in:
  for j in range(i):
    filename = os.path.join(fig_prog_folder_zoom_in, str(j) + ".png")
    image = imageio.imread(filename)
    writer_zoom_in.append_data(image)
img = Image(open(gif_path_zoom_in, 'rb').read())
display(img)
# endregion
# endregion

##### Multi-modal reference policy (baseline)

In [None]:
# region: Main simulation loop.
# Initializes the simulation.
x_cur = deepcopy(x_init)
state_hist = np.zeros((solver.dim_x, itr_receding))
control_hist = np.zeros((solver.dim_u_ss, itr_receding, solver.num_players))
_identity_matrix = np.eye(config.DIM_U)
rSigmas = np.tile(_identity_matrix[:, :, np.newaxis, np.newaxis], (1, 1, config.N, 2))
theta_ego_hist, theta_ado_hist = [], []
t_total = 0.
_ot_flag = False
dh = display(HTML('<pre>KLGame planning starts.</pre>'), display_id=True)

for i in range(itr_receding):

  # region: Ego Planning
  # Updates the nominal control signal for warmstart of the next planning cycle.
  controls_ws = np.zeros((solver.dim_u_ss, config.N, solver.num_players))
  if i > 0:
    controls_ws[:, :-1, :] = controls[:, 1:, :]

  # Plans modal trajectories using reference game policies.
  # -> Time-optimal
  if not _ot_flag:
    states_to, controls_to, _, _, _ = solver_cautious.solve(x_cur, controls_ws)
  else:
    states_to, controls_to, _, _, _ = solver_neutral.solve(x_cur, controls_ws)
  # -> Overtake
  states_ov, controls_ov, _, _, _, _, _ = solver.solve_parametric(model_params, x_cur, controls_ws)

  # Switches policies based on collision checking.
  controls = controls_ov
  mode = 'ov'
  for k in range(solver.horizon):
    obs_tuple_cur = get_perfect_obs_two_player(states_ov[:, k], controls_ov[:, k, :], solver)
    if not is_safe(config, obs_tuple_cur):
      controls = controls_to
      mode = 'to'
      break
  controls = np.asarray(controls)
  # endregion

  # region: Ado Planning
  if ado_mode == "ilq":
    _, controls_ado, _, _, thetas = solver_ado.solve(x_cur, controls_ws)
    controls = deepcopy(controls)
    controls[:, 0, ado_idx] = controls_ado[:, 0, ado_idx]

  progress = np.asarray(thetas[0, 0, :]) / track.length
  _ot_flag = progress[ego_idx] > progress[ado_idx] + 0.001
  if _ot_flag:
    solver_ado = solver_neutral
  # endregion

  # region: Executes the control, records states and controls, computes the progress.
  x_cur, _ = solver.dynamics.integrate_forward(x_cur, controls[:, 0, :])
  x_cur = np.asarray(x_cur)

  state_hist[:, i] = x_cur
  control_hist[:, i, :] = controls[:, 0, :]
  theta_ego_hist.append(thetas[0, 0, ego_idx])
  theta_ado_hist.append(thetas[0, 0, ado_idx])
  # endregion

  # region: Plots the current progress.
  plt.figure(figsize=(15, 15))
  track.plot_track(N_pts=1000, plot_raceline=False)
  # plt.plot(states[0, 1:], states[1, 1:], linewidth=6, c='red', alpha=1)  # ego plan
  # plt.plot(states[4, 1:], states[5, 1:], linewidth=2, c='b')  # ado plan
  if mode == 'ov':
    plt.plot(states_ov[0, 1:], states_ov[1, 1:], linewidth=3, linestyle='--', c='orange', alpha=1)
    plt.plot(states_to[0, 1:], states_to[1, 1:], linewidth=3, linestyle='--', c='blue', alpha=0.3)
  else:
    plt.plot(states_ov[0, 1:], states_ov[1, 1:], linewidth=3, linestyle='--', c='orange', alpha=0.3)
    plt.plot(states_to[0, 1:], states_to[1, 1:], linewidth=3, linestyle='--', c='blue', alpha=1)

  transform_data = Affine2D().rotate_deg_around(*(x_cur[0], x_cur[1]),
                                                x_cur[3] / np.pi * 180) + plt.gca().transData
  plt.imshow(
      supra, transform=transform_data, interpolation='none', origin='lower',
      extent=[x_cur[0] - 1., x_cur[0] + 4., x_cur[1] - 1.,
              x_cur[1] + 1.], alpha=1.0, zorder=10.0, clip_on=True
  )  # plot ego car figure
  transform_data = Affine2D().rotate_deg_around(*(x_cur[4], x_cur[5]),
                                                x_cur[7] / np.pi * 180) + plt.gca().transData
  plt.imshow(
      benz, transform=transform_data, interpolation='none', origin='lower',
      extent=[x_cur[4] - 1.2, x_cur[4] + 4., x_cur[5] - 1.3,
              x_cur[5] + 1.3], alpha=1.0, zorder=10.0, clip_on=True
  )  # plot ado car figure
  sc = plt.scatter(
      state_hist[0, :i + 1], state_hist[1, :i + 1], s=400, c=state_hist[2, :i + 1], cmap=cm.jet,
      vmin=0, vmax=config.V_MAX, edgecolor='none', marker='o'
  )  # trajectory history
  plt.axis('equal')
  plt.xlim([states_ov[0, 0] - 30., states_ov[0, 0] + 30.])
  plt.ylim([states_ov[1, 0] - 30., states_ov[1, 0] + 30.])
  plt.title("step: " + str(i) + " | mode: " + str(mode))
  plt.xticks([])
  plt.yticks([])
  plt.savefig(os.path.join(fig_prog_folder_zoom_in, str(i) + ".png"), dpi=50)
  plt.rcParams.update({'font.size': 25})
  plt.close()
  # endregion

  # region: Reports simulation status.
  _info = [
      'step: ', i, ' | ego prog: ', '{:04.2f}'.format(progress[ego_idx]), ' | ado prog: ',
      '{:04.2f}'.format(progress[ado_idx]), ' | mode: ', mode
  ]
  _info = [str(_item) for _item in _info]
  dh.update(''.join(_info))
  # endregion

# region: Wraps up the simulation.
plt.close('All')
print("Avg. computation time is {:.3f} s per planning cycle.".format(t_total / (itr_receding-1)))
print("Completed a race on", track_name, "in ", round((i+1) * ego_subsys.dt, 2), 's')

# Makes animations.
gif_path_zoom_in = os.path.join(config.OUT_FOLDER, 'rollout_zoom_in_KLG.gif')
with imageio.get_writer(gif_path_zoom_in, mode='I', loop=0) as writer_zoom_in:
  for j in range(i):
    filename = os.path.join(fig_prog_folder_zoom_in, str(j) + ".png")
    image = imageio.imread(filename)
    writer_zoom_in.append_data(image)
img = Image(open(gif_path_zoom_in, 'rb').read())
display(img)
# endregion
# endregion