# Welcome to JaxAHT!

*Please open this notebook in Colab.

In this tutorial, we will focus on introducing the core workflows for using the library.
We will:
- Demonstrate how to train teammates and ego agents, separately and as part of a single, unified workflow.
- Introduce the set of evaluation teammates and the evaluation framework
- Visualize learned policies

The project uses [Hydra](https://hydra.cc/) to manage algorithm and environment configurations, and [WandB](https://wandb.ai/) for logging.

Although the tutorial does not explicitly describe how to run the open-ended learning algorithms, these algorithm types may be run similarly to the MARL, teammate generation, and ego agent training algorithms.

Please see the project README for a full description of the project's design philosophy.
Our benchmark uses a multi-agent PPO implementation provided by [JaxMARL](https://github.com/FLAIROx/JaxMARL/tree/main).


# Install dependencies 📚

 ⚠️ Before beginning the tutorial, ensure you select a GPU or TPU from `Runtime > Change runtime type` ⚠️

 And Make sure that you select a runtime with python 3.11

 If none existant, you can use the following script to install py311

In [None]:
# !wget https://github.com/korakot/kora/releases/download/v0.11/py311.sh
# !bash ./py311.sh -b -f -p /usr/local
# !python -m ipykernel install --name "py311" --user

In [None]:
%%shell
# clone repo and install packages
git clone https://github.com/carolinewang01/jax-aht.git
# if you need authentication, one way is to use
# git clone https://<GithubId>:<ghp_xxxxGithubTokenxxxx>@github.com/carolinewang01/jax-aht.git

cd jax-aht
pip install --upgrade pip
pip install -e .
# pip install numpy==1.25.* --upgrade # forcefully downgrade numpy; necessary for colab only

Use the following script to force (downgrade the versions)

You will be prompt to restart the session after numpy 1.25.2 is installed

In [None]:
%pip uninstall -y numpy
%pip install numpy==1.25.2 scipy==1.12.0
import numpy
print(numpy.__version__)

In [None]:
# Now check to make sure that we get 1.15.2
import numpy
print(numpy.__version__) # 1.25.2

In [None]:
# change current working directory to jax-aht/ for the rest of this notebook
import os

path = os.getcwd()
if not path.endswith("jax-aht"):
    # if we are using the notebook from the repository directly
    if os.getcwd().endswith("jax-aht/docs"):
        os.chdir("..")
    # if we cloned the repository
    else:
        os.chdir("jax-aht")

print(os.getcwd()) # <path>/jax-aht

# verify that the jax installation can find the GPU/TPU
import jax
jax.devices()

## Part 1.1 Training Teammates


In Jax-AHT, teammates may be trained using either MARL algorithms, or teammate generation algorithms. Each algorithm type has its own entry point, located at `marl/run.py` and `teammate_generation/run.py` respectively.
For this tutorial, we will train teammates on Level-Based Foraging (LBF) using IPPO.


**Viewing Metrics:**    We strongly recommend using WandB's UI to view the logged metrics. You will need a WandB account, and to set the logger to online mode. For the purposes of this tutorial, you can see a preview of the loggged metrics in the console. Note that the maximum return on LBF is 0.5.

In [None]:
%%bash
# train teammates on LBF using a MARL algorithm (IPPO w/parameter sharing)
PYTHONPATH=$(pwd) python marl/run.py task=lbf algorithm=ippo/lbf logger.mode=offline
# train using BRDiv instead. Note that we use the teammate_generation/ entry point instead of the marl/ entry point.
# PYTHONPATH=$(pwd) python teammate_generation/run.py task=lbf algorithm=brdiv/lbf logger.mode=offline

## Part 1.2: Training an Ego Agent Against Pretrained Teammate

Using the IPPO teammates trained in the last section, here, we will train a PPO ego agent to collaborate with those teammates, using the ego agent training entry point at `ego_agent_training/run.py`.

Please take a moment to look over the entry point code, to understand the overall pipeline.

In [None]:
! cat ego_agent_training/run.py

### 1.2.1: Setting the Partner Config

The first step to training an ego agent is to specify what teammates the ego agent should be trained with.
The ego agent training code looks for a partner config within the algorithm block of the master config file, located at `rotate/ego_agent_training/configs/base_config_ego.yaml`.



In [None]:
! cat ego_agent_training/configs/base_config_ego.yaml

For the ppo_ego algorithm, a default partner config is specified within the algorithm-specific config directory, at `rotate/ego_agent_training/configs/algorithm/ppo_ego/_base_.yaml`. This default partner config is automatically imported and merged into the master config. The path (and other partner config values) should be set by the user, either directly within the config or via the command line. We will take the latter approach in this tutorial.


In [None]:
! cat ego_agent_training/configs/algorithm/ppo_ego/_base_.yaml

In [None]:
# First, let's find the teammate checkpoint directory
import glob

checkpoint_base_dir = 'results/lbf/ippo/default_label/'
checkpoint_dir = sorted(glob.glob(os.path.join(checkpoint_base_dir, '*')))[-1]
checkpoint_dir = os.path.join(checkpoint_dir, "saved_train_run")
assert os.path.exists(checkpoint_dir), "Error: checkpoint directory not found."
print(f"Found checkpoint directory: {checkpoint_dir}")

### 1.2.2: Train the ego agent!

Now let's train the ego agent, directly specifying the partner path as a command line argument. We reduce the training time to 1 million steps, so that the ego agent trains in a couple minutes. We also turn off the evaluation against the heldout set, which is explained in the next section.

In [None]:
# Now let's train the ego agent, directly specifying the partner path from the command line
# total training timesteps is reduced to 1 million steps for this tutorial.
! PYTHONPATH=$(pwd) python ego_agent_training/run.py task=lbf algorithm=ppo_ego/lbf algorithm.partner_agent.path={checkpoint_dir} algorithm.TOTAL_TIMESTEPS=1e6 run_heldout_eval=false logger.mode=offline

# 2. Evaluation

In AHT research, ego agents are often evaluated based on the returns achieved during collaboration with a *heldout set* of evaluation agents---i.e., agents that should not have been seen during training.
The JaxAHT benchmark provides a heldout evaluation set for LBF, and the 5 classic Overcooked tasks.
The rliable library is used to compute bootstrapped metrics across the heldout set, by treating each heldout agent as a task.
In this section of the tutorial, we will introduce the heldout agent config, demonstrate how to download the agents, and visualize them.

*Note that the ego agent training entry point evaluate the ego agent against the heldout set by default. We turned this off in the last section, but you simply need to set the `run_heldout_eval` argument to true in order to enable evaluation.

 ## 2.1 Download the Heldout Set


 The heldout evaluation set consists of both agents trained via RL, and manually programmed heuristic agents. The heuristic agents are located under the `agents/lbf/` and `agents/overcooked/` directories, while the following code will download the RL agent checkpoints to an `eval_teammates/` directory.  



In [None]:
# first, let's download the RL agents in the heldout set
! python download_eval_data.py

## 2.2. Heldout Agent Config

The evaluation teammate set is specified in the `global_heldout_settings.yaml` file, which is imported by other config files throughout the codebase as needed, to perform the heldout evaluation.

Notice that the teammate config format looks similar to the format used to specify the ego agent training teammates! This is because all teammate configs are parsed by functions within the `common/agent_loader_from_config.py` file. Please see the README for more details about the teammate config.

In [None]:
! cat evaluation/configs/global_heldout_settings.yaml

## 2.3. Visualizing Agents

We provide functions to visualize agents at `evaluation/vis_episodes.py`.
Scripts to test the heuristic agents on LBF and Overcooked are also provided at `tests/test_lbf_agents.py` and `tests/test_overcooked_agents.py`.

In [None]:
import os
import numpy as np
from typing import Dict, Tuple

from IPython.display import HTML
import time
import jax

from envs import make_env
from agents.lbf import RandomAgent, SequentialFruitAgent

def run_episode(env, agent0, agent1, key) -> Tuple[Dict[str, float], int]:
    """Run a single episode with two heuristic agents.
    """
    # Reset environment
    key, subkey = jax.random.split(key)
    obs, state = env.reset(subkey)

    # Initialize episode tracking
    done = {agent: False for agent in env.agents}
    done['__all__'] = False
    total_rewards = {agent: 0.0 for agent in env.agents}
    num_steps = 0

    # Initialize agent states
    agent0_state = agent0.init_agent_state(0)
    agent1_state = agent1.init_agent_state(1)

    # Initialize state sequence
    state_seq = []
    while not done['__all__']:
        # Get actions from both agents with their states
        key, act0_rng, act1_rng = jax.random.split(key, 3)

        action0, agent0_state = agent0.get_action(obs["agent_0"], state, agent0_state, act0_rng)
        action1, agent1_state = agent1.get_action(obs["agent_1"], state, agent1_state, act1_rng)

        actions = {"agent_0": action0, "agent_1": action1}

        # Step environment
        key, subkey = jax.random.split(key)
        obs, state, rewards, done, info = env.step(subkey, state, actions)
        state_seq.append(state)

        # Update rewards
        for agent in env.agents:
            total_rewards[agent] += rewards[agent]

        num_steps += 1

    print(f"Episode finished. Total states collected: {len(state_seq)}")
    return total_rewards, num_steps, state_seq

def main(num_episodes,
         max_steps=100,
         visualize=False,
         save_video=False):
    print("Initializing environment...")
    env = make_env(env_name="lbf", env_kwargs={"time_limit": max_steps})

    print("Initializing agents...")
    # choices: lexicographic, reverse_lexicographic, column_major, reverse_column_major, nearest_agent, farthest_agent
    agent0 = SequentialFruitAgent(grid_size=7, num_fruits=3, ordering_strategy='lexicographic') # boxed
    agent1 = SequentialFruitAgent(grid_size=7, num_fruits=3, ordering_strategy='lexicographic') # not boxed

    print("Agent 0:", agent0.get_name())
    print("Agent 1:", agent1.get_name())

    # Run multiple episodes
    key = jax.random.PRNGKey(0)

    returns = []
    state_seq_all = []
    for episode in range(num_episodes):
        print(f"\nEpisode {episode + 1}/{num_episodes}")
        key, subkey = jax.random.split(key)
        total_rewards, num_steps, ep_states = run_episode(env, agent0, agent1, subkey)
        state_seq_all.extend(ep_states)  # Changed from += to extend for better list handling
        print(f"Total states in sequence after episode: {len(state_seq_all)}")

        # Calculate episode return
        episode_return = np.mean(list(total_rewards.values()))
        returns.append(episode_return)

        print(f"\nEpisode {episode + 1} finished:")
        print(f"Total steps: {num_steps}")
        print(f"Mean episode return: {episode_return:.2f}")
        print("Episode returns per agent:")
        for agent in env.agents:
            print(f" {agent}: {total_rewards[agent]:.2f}")

    # Print statistics
    mean_return = np.mean(returns)
    std_return = np.std(returns)
    print(f"\nStatistics across {num_episodes} episodes:")
    print(f"Mean return: {mean_return:.2f} ± {std_return:.2f}")

    anim = env.animate(state_seq_all, interval=150)
    return anim

anim = main(num_episodes=5, max_steps=30)
HTML(anim.to_html5_video())