<a href="https://colab.research.google.com/github/carolinewang01/jax-aht/blob/main/JaxAHT_Tutorial.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Welcome to JaxAHT!

In this tutorial, we will focus on introducing the core workflows for using the library.
We will:
- Demonstrate how to train teammates and ego agents, separately and as part of a single, unified workflow.
- Introduce the set of evaluation teammates and the evaluation framework
- Visualize learned policies

The project uses [Hydra](https://hydra.cc/) to manage algorithm and environment configurations, and [WandB](https://wandb.ai/) for logging.

Although the tutorial does not explicitly describe how to run the open-ended learning algorithms, these algorithm types may be run similarly to the MARL, teammate generation, and ego agent training algorithms.

Please see the project README for a full description of the project's design philosophy.
Our benchmark uses a multi-agent PPO implementation provided by [JaxMARL](https://github.com/FLAIROx/JaxMARL/tree/main).


# Install dependencies 📚

 ⚠️ Before beginning the tutorial, ensure you select a GPU or TPU from `Runtime > Change runtime type` ⚠️

In [11]:
%%shell
# clone repo and install packages
git clone https://github.com/carolinewang01/jax-aht.git
cd jax-aht
pip install --upgrade pip
pip install -r requirements.txt

Cloning into 'jax-aht'...
remote: Enumerating objects: 5722, done.[K
remote: Counting objects: 100% (191/191), done.[K
remote: Compressing objects: 100% (98/98), done.[K
remote: Total 5722 (delta 113), reused 145 (delta 93), pack-reused 5531 (from 1)[K
Receiving objects: 100% (5722/5722), 2.58 MiB | 6.05 MiB/s, done.
Resolving deltas: 100% (4037/4037), done.
Collecting scipy>=1.11.1 (from jax==0.5.3->jax[cuda12]==0.5.3->-r requirements.txt (line 2))
  Using cached scipy-1.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (60 kB)
Using cached scipy-1.12.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (38.4 MB)
Installing collected packages: scipy
  Attempting uninstall: scipy
    Found existing installation: scipy 1.16.0
    Uninstalling scipy-1.16.0:
      Successfully uninstalled scipy-1.16.0
[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the followi



In [1]:
# change current working directory to rotate/ for the rest of this notebook
import os
os.chdir('jax-aht/')
os.getcwd()

# verify that the jax installation can find the GPU/TPU
import jax
jax.devices()

[CudaDevice(id=0)]

## Part 1.1 Training Teammates


In Jax-AHT, teammates may be trained using either MARL algorithms, or teammate generation algorithms. Each algorithm type has its own entry point, located at `marl/run.py` and `teammate_generation/run.py` respectively.
For this tutorial, we will train teammates on Level-Based Foraging (LBF) using IPPO.


**Viewing Metrics:**    We strongly recommend using WandB's UI to view the logged metrics. You will need a WandB account, and to set the logger to online mode. For the purposes of this tutorial, you can see a preview of the loggged metrics in the console. Note that the maximum return on LBF is 0.5.

In [5]:
%%shell
# train teammates on LBF using a MARL algorithm (IPPO w/parameter sharing)
PYTHONPATH=$(pwd) python marl/run.py task=lbf algorithm=ippo/lbf logger.mode=offline
# train using BRDiv instead. Note that we use the teammate_generation/ entry point instead of the marl/ entry point.
# PYTHONPATH=$(pwd) python teammate_generation/run.py task=lbf algorithm=brdiv/lbf logger.mode=offline

2025-07-08 00:38:39.127718: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751935119.147399    2689 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751935119.153491    2689 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
task:
  ENV_NAME: lbf
  ROLLOUT_LENGTH: 128
  ENV_KWARGS: {}
  TASK_NAME: lbf
algorithm:
  ALG: ippo
  TOTAL_TIMESTEPS: 500000.0
  ACTOR_TYPE: mlp
  NUM_CHECKPOINTS: 5
  GAMMA: 0.99
  GAE_LAMBDA: 0.95
  VF_COEF: 0.5
  MAX_GRAD_NORM: 1.0
  ANNEAL_LR: true
  SEED: 12345
  LR: 0.0001
  NUM_ENVS: 8
  UPDATE_EPOCHS: 15
  NUM_MINIBATCHES: 4
  CLIP_EPS: 0.03
  ENT_COEF: 0.01
  NUM_SEEDS: 3
  TRAIN_SEED: 20374
  ENV_NAME: lbf
  ENV_KWARGS: {



## Part 1.2: Training an Ego Agent Against Pretrained Teammate

Using the IPPO teammates trained in the last section, here, we will train a PPO ego agent to collaborate with those teammates, using the ego agent training entry point at `ego_agent_training/run.py`.

Please take a moment to look over the entry point code, to understand the overall pipeline.

In [10]:
! cat ego_agent_training/run.py

'''Main entry point for running ego agent training algorithms against a fixed partner population.'''
import hydra
from omegaconf import OmegaConf

from common.plot_utils import get_metric_names
from common.wandb_visualizations import Logger
from evaluation.heldout_eval import run_heldout_evaluation, log_heldout_metrics
from ppo_ego import run_ego_training
from ego_agent_training.ppo_br import run_br_training


@hydra.main(version_base=None, config_path="configs", config_name="base_config_ego")
def run_training(cfg):
    '''Runs the ego agent training against a fixed partner population.'''
    print(OmegaConf.to_yaml(cfg, resolve=True))
    wandb_logger = Logger(cfg)

    if cfg["algorithm"]["ALG"] == "ppo_ego":
        ego_params, ego_policy, init_ego_params = run_ego_training(cfg, wandb_logger)
    elif cfg["algorithm"]["ALG"] == "ppo_br":
        ego_params, ego_policy, init_ego_params = run_br_training(cfg, wandb_logger)

    if cfg["run_heldout_eval"]:
        metric_names = get_me

### 1.2.1: Setting the Partner Config

The first step to training an ego agent is to specify what teammates the ego agent should be trained with.
The ego agent training code looks for a partner config within the algorithm block of the master config file, located at `rotate/ego_agent_training/configs/base_config_ego.yaml`.



In [5]:
! cat ego_agent_training/configs/base_config_ego.yaml

defaults:
  - task: overcooked-v1/cramped_room # task configs
  - algorithm@algorithm: ppo_ego/${task} # task-specific algorithm configs
  - hydra: hydra_simple
  - ../../evaluation/configs/global_heldout_settings
  - _self_         # Ensures that values in this file override imported ones if needed

ENV_NAME: ${task.ENV_NAME}
ENV_KWARGS: ${task.ENV_KWARGS}
ROLLOUT_LENGTH: ${task.ROLLOUT_LENGTH}
TASK_NAME: ${task.TASK_NAME}

run_heldout_eval: true

# partner configs are specified in each algorithm _base_ config file
# note that ppo_ego and ppo_br support different types of partner configs
algorithm:
  TRAIN_SEED: 12345
  NUM_EGO_TRAIN_SEEDS: 3
  NUM_EVAL_EPISODES: 20
  ENV_NAME: ${ENV_NAME}
  ENV_KWARGS: ${ENV_KWARGS}
  ROLLOUT_LENGTH: ${ROLLOUT_LENGTH}

label: "default_label"
name: ${TASK_NAME}/${algorithm.ALG}_${algorithm.EGO_ACTOR_TYPE}/${label}

# WandB Params
logger: 
  project: aht-benchmark
  entity: aht-project
  tags: 
    - ${algorithm.ALG}
    - ${TASK_NAME}
    - ${label}
 

For the ppo_ego algorithm, a default partner config is specified within the algorithm-specific config directory, at `rotate/ego_agent_training/configs/algorithm/ppo_ego/_base_.yaml`. This default partner config is automatically imported and merged into the master config. The path (and other partner config values) should be set by the user, either directly within the config or via the command line. We will take the latter approach in this tutorial.


In [6]:
! cat ego_agent_training/configs/algorithm/ppo_ego/_base_.yaml

# @package algorithm
# ^ tells hydra to place these value directly under algorithm key
ALG: ppo_ego
EGO_ACTOR_TYPE: s5
NUM_EGO_TRAIN_SEEDS: 1
TOTAL_TIMESTEPS: 1e7
NUM_CHECKPOINTS: 5
NUM_ENVS: 8
LR: 1.e-4 
UPDATE_EPOCHS: 15
NUM_MINIBATCHES: 4
GAMMA: 0.99
GAE_LAMBDA: 0.95
CLIP_EPS: 0.05
ENT_COEF: 0.01
VF_COEF: 0.5
MAX_GRAD_NORM: 1.0
ANNEAL_LR: true
partner_agent: # partner config for ppo_ego
  name: ippo
  path: null # Please set the path to a partner agent checkpoint. You will also need to set the actor_type, ckpt_key, and specify the partner idxs to load. 
  actor_type: mlp
  ckpt_key: final_params
  idx_list: [0]

In [7]:
# First, let's find the teammate checkpoint directory
import glob

checkpoint_base_dir = 'results/lbf/ippo/default_label/'
checkpoint_dir = sorted(glob.glob(os.path.join(checkpoint_base_dir, '*')))[-1]
checkpoint_dir = os.path.join(checkpoint_dir, "saved_train_run")
assert os.path.exists(checkpoint_dir), "Error: checkpoint directory not found."
print(f"Found checkpoint directory: {checkpoint_dir}")

Found checkpoint directory: results/lbf/ippo/default_label/2025-07-07_02-34-21/saved_train_run


### 1.2.2: Train the ego agent!

Now let's train the ego agent, directly specifying the partner path as a command line argument. We reduce the training time to 1 million steps, so that the ego agent trains in a couple minutes. We also turn off the evaluation against the heldout set, which is explained in the next section.

In [12]:
# Now let's train the ego agent, directly specifying the partner path from the command line
# total training timesteps is reduced to 1 million steps for this tutorial.
! PYTHONPATH=$(pwd) python ego_agent_training/run.py task=lbf algorithm=ppo_ego/lbf algorithm.partner_agent.path={checkpoint_dir} algorithm.TOTAL_TIMESTEPS=1e6 run_heldout_eval=false logger.mode=offline

2025-07-07 02:56:01.721169: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1751856961.743478   10427 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1751856961.750090   10427 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
task:
  ENV_NAME: lbf
  ROLLOUT_LENGTH: 128
  ENV_KWARGS: {}
  TASK_NAME: lbf
algorithm:
  ALG: ppo_ego
  EGO_ACTOR_TYPE: s5
  NUM_EGO_TRAIN_SEEDS: 3
  TOTAL_TIMESTEPS: 1000000.0
  NUM_CHECKPOINTS: 5
  NUM_ENVS: 8
  LR: 5.0e-05
  UPDATE_EPOCHS: 10
  NUM_MINIBATCHES: 4
  GAMMA: 0.99
  GAE_LAMBDA: 0.95
  CLIP_EPS: 0.1
  ENT_COEF: 0.0001
  VF_COEF: 0.5
  MAX_GRAD_NORM: 1.0
  ANNEAL_LR: false
  partner_agent:
    name: ippo
    path: res

# 2. Evaluation

In AHT research, ego agents are often evaluated based on the returns achieved during collaboration with a *heldout set* of evaluation agents---i.e., agents that should not have been seen during training.
The JaxAHT benchmark provides a heldout evaluation set for LBF, and the 5 classic Overcooked tasks.
The rliable library is used to compute bootstrapped metrics across the heldout set, by treating each heldout agent as a task.
In this section of the tutorial, we will introduce the heldout agent config, demonstrate how to download the agents, and visualize them.

*Note that the ego agent training entry point evaluate the ego agent against the heldout set by default. We turned this off in the last section, but you simply need to set the `run_heldout_eval` argument to true in order to enable evaluation.

 ## 2.1 Download the Heldout Set


 The heldout evaluation set consists of both agents trained via RL, and manually programmed heuristic agents. The heuristic agents are located under the `agents/lbf/` and `agents/overcooked/` directories, while the following code will download the RL agent checkpoints to an `eval_teammates/` directory.  



In [6]:
# first, let's download the RL agents in the heldout set
! python download_eval_data.py

Starting download & extraction: https://drive.google.com/file/d/1pS0wvJDzOZa954RADF_j9ETR74THzh8I/view?usp=sharing -> results/
Downloading...
From: https://drive.google.com/uc?id=1pS0wvJDzOZa954RADF_j9ETR74THzh8I
To: /content/jax-aht/results/downloaded_gdrive_file.zip
100% 5.94k/5.94k [00:00<00:00, 27.4MB/s]
Downloaded results/downloaded_gdrive_file.zip (5942 bytes).
Unzipping results/downloaded_gdrive_file.zip to temporary directory /tmp/tmpbf001ng1...
Successfully unzipped results/downloaded_gdrive_file.zip to /tmp/tmpbf001ng1.
Processing and moving files from '/tmp/tmpbf001ng1/best_heldout_returns' to 'results/'...
Successfully moved 6 file(s) to results/.
Cleaning up temporary extraction directory: /tmp/tmpbf001ng1
Download completed successfully for best_returns_teammates.
Starting download & extraction: https://drive.google.com/file/d/1KjBV2GekKdRBiK6QSGe2vYx2ThXlG7X7/view?usp=sharing -> eval_teammates/
Downloading...
From (original): https://drive.google.com/uc?id=1KjBV2GekKdRBi

## 2.2. Heldout Agent Config

The evaluation teammate set is specified in the `global_heldout_settings.yaml` file, which is imported by other config files throughout the codebase as needed, to perform the heldout evaluation.

Notice that the teammate config format looks similar to the format used to specify the ego agent training teammates! This is because all teammate configs are parsed by functions within the `common/agent_loader_from_config.py` file. Please see the README for more details about the teammate config.

In [7]:
! cat evaluation/configs/global_heldout_settings.yaml

# @package _global_
# THIS SCRIPT CONTAINS THE PATHS FOR THE HELDOUT EVALUATION. DO NOT MODIFY THIS FILE.
global_heldout_settings:
  EVAL_SEED: 34957
  MAX_EPISODE_STEPS: ${task.ROLLOUT_LENGTH}
  NUM_EVAL_EPISODES: 64
  AGGREGATE_STAT: mean # choices: mean, iqm. both are computed with the bootstrapped CI using rliable.
  NORMALIZE_RETURNS: true # if true, use performance bounds to normalize returns

heldout_set:
  lbf: # DONE - DO NOT MODIFY. In LBF, results are much better with test mode = false.
    ippo_mlp: 
      path: "eval_teammates/lbf/ippo/2025-04-21_23-41-17/saved_train_run"
      actor_type: mlp
      ckpt_key: final_params
      idx_list: [0] # load in seed 0 only
      test_mode: false
      performance_bounds:
        percent_eaten: [[0.0, 100.0]]
        returned_episode_returns: [[0.0, 0.5]]
    ippo_mlp_s2c0:
      path: "eval_teammates/lbf/ippo/2025-04-21_23-41-17/saved_train_run"
      actor_type: mlp
      idx_list: [[2, 0]] # load in seed 2 ckpt 0
      test_mode: 

## 2.3. Visualizing Agents

We provide functions to visualize agents at `evaluation/vis_episodes.py`.
Scripts to test the heuristic agents on LBF and Overcooked are also provided at `tests/test_lbf_agents.py` and `tests/test_overcooked_agents.py`.

In [5]:
import os
import numpy as np
from typing import Dict, Tuple

from IPython.display import HTML
import time
import jax

from envs import make_env
from agents.lbf import RandomAgent, SequentialFruitAgent

def run_episode(env, agent0, agent1, key) -> Tuple[Dict[str, float], int]:
    """Run a single episode with two heuristic agents.
    """
    # Reset environment
    key, subkey = jax.random.split(key)
    obs, state = env.reset(subkey)

    # Initialize episode tracking
    done = {agent: False for agent in env.agents}
    done['__all__'] = False
    total_rewards = {agent: 0.0 for agent in env.agents}
    num_steps = 0

    # Initialize agent states
    agent0_state = agent0.init_agent_state(0)
    agent1_state = agent1.init_agent_state(1)

    # Initialize state sequence
    state_seq = []
    while not done['__all__']:
        # Get actions from both agents with their states
        key, act0_rng, act1_rng = jax.random.split(key, 3)

        action0, agent0_state = agent0.get_action(obs["agent_0"], state, agent0_state, act0_rng)
        action1, agent1_state = agent1.get_action(obs["agent_1"], state, agent1_state, act1_rng)

        actions = {"agent_0": action0, "agent_1": action1}

        # Step environment
        key, subkey = jax.random.split(key)
        obs, state, rewards, done, info = env.step(subkey, state, actions)
        state_seq.append(state)

        # Update rewards
        for agent in env.agents:
            total_rewards[agent] += rewards[agent]

        num_steps += 1

    print(f"Episode finished. Total states collected: {len(state_seq)}")
    return total_rewards, num_steps, state_seq

def main(num_episodes,
         max_steps=100,
         visualize=False,
         save_video=False):
    print("Initializing environment...")
    env = make_env(env_name="lbf", env_kwargs={"time_limit": max_steps})

    print("Initializing agents...")
    # choices: lexicographic, reverse_lexicographic, column_major, reverse_column_major, nearest_agent, farthest_agent
    agent0 = SequentialFruitAgent(grid_size=7, num_fruits=3, ordering_strategy='lexicographic') # boxed
    agent1 = SequentialFruitAgent(grid_size=7, num_fruits=3, ordering_strategy='lexicographic') # not boxed

    print("Agent 0:", agent0.get_name())
    print("Agent 1:", agent1.get_name())

    # Run multiple episodes
    key = jax.random.PRNGKey(0)

    returns = []
    state_seq_all = []
    for episode in range(num_episodes):
        print(f"\nEpisode {episode + 1}/{num_episodes}")
        key, subkey = jax.random.split(key)
        total_rewards, num_steps, ep_states = run_episode(env, agent0, agent1, subkey)
        state_seq_all.extend(ep_states)  # Changed from += to extend for better list handling
        print(f"Total states in sequence after episode: {len(state_seq_all)}")

        # Calculate episode return
        episode_return = np.mean(list(total_rewards.values()))
        returns.append(episode_return)

        print(f"\nEpisode {episode + 1} finished:")
        print(f"Total steps: {num_steps}")
        print(f"Mean episode return: {episode_return:.2f}")
        print("Episode returns per agent:")
        for agent in env.agents:
            print(f" {agent}: {total_rewards[agent]:.2f}")

    # Print statistics
    mean_return = np.mean(returns)
    std_return = np.std(returns)
    print(f"\nStatistics across {num_episodes} episodes:")
    print(f"Mean return: {mean_return:.2f} ± {std_return:.2f}")

    anim = env.animate(state_seq_all, interval=150)
    return anim

anim = main(num_episodes=5, max_steps=30)
HTML(anim.to_html5_video())

Initializing environment...
Environment initialized
Initializing agents...
Agents initialized
Agent 0: SequentialFruitAgent
Agent 1: SequentialFruitAgent

Episode 1/1
Episode finished. Total states collected: 13
Total states in sequence after episode: 13

Episode 1 finished:
Total steps: 13
Mean episode return: 0.50
Episode returns per agent:
 agent_0: 0.50
 agent_1: 0.50

Statistics across 1 episodes:
Mean return: 0.50 ± 0.00
