# Example using IQL and Flashbax in Matrax

This guide demonstrates how to use the Item Buffer for experience replay in reinforcement learning tasks. Specifically, we implement independent Q-learning ([original IQL paper](https://arxiv.org/pdf/1511.08779.pdf)), using a single Q-network to solve two-player matrix games hosted in [Matrax](https://github.com/instadeepai/matrax?tab=readme-ov-file). The Item Buffer operates by saving all experience data in a first-in-first-out (FIFO) queue and returns batches of uniformly sampled experience from it. The point of the item buffer is for a simplified buffer experience where each data item stored is completely independent of each other. An example of this would be full state-action-reward-nextstate transitions for Q Learning. Additionally, as other buffers generally maintain a temporal relation with their data, the buffers have restrictions on their size depending on whether data is added in batches or not; but, since there is no temporal relation to the item buffer, the exact size of a desired buffer is achievable regardless of whether data is added in batches, sequences or both batches of sequences.

For this example we briefly explain the Matrax environment and how different games work:

### The Matrax Environment

*Two-player matrix games in JAX*

A matrix game is a two-player game where each player has a set of actions and a payoff matrix. The payoff matrix is a two-dimensional array where the rows correspond to the actions of Player 1 and the columns correspond to the actions of Player 2. The entry at row $i$ and column $j$ for Player 1 is the reward given to Player 1 when playing action $i$ and Player 2 plays action $j$. Similarly, the entry at row $i$ and column $j$ for Player 2 is the reward given to Player 2 when playing action $j$ and Player 1 plays action $i$.

Maximum number of steps defaults to 500.

#### 🔻 Penalty Game
- **Shape (action space):** 3 $\times$ 3
- **Registered versions:** Penalty-{k}-{state}-v0
- **Valid arguments:** $k \in \{0, 25, 50, 75, 100\}$
- **Payoff matrix (for each agent):**
\begin{equation*}
\begin{bmatrix} 
-k & 0 & 10 \\
0 & 2 & 0 \\
10 & 0 & -k \\
\end{bmatrix}
\end{equation*}

#### 🧗‍♀️ Climbing Game
- **Shape (action space):** 3 $\times$ 3
- **Registered versions:** Climbing-{state}-v0
- **Payoff matrix (for each agent):**
\begin{equation*}
\begin{bmatrix} 
11 & -30 & 0 \\
-30 & 7 & 0 \\
0 & 6 & 5 \\
\end{bmatrix}
\end{equation*}

#### 🤝 No-Conflict Games
- **Shape (action space):** 2 $\times$ 2
- **Registered versions:** NoConflict-{id}-{state}-v0
- **Valid arguments:** $\texttt{id} \in \{0, 1, 2, ..., 20\}$
- **Payoff matrix:** controlled by $\texttt{id}$ as dictated [here](https://github.com/instadeepai/matrax/blob/main/matrax/games/conflict.py).

#### 💣 Conflict Games
- **Shape (action space):** 2 $\times$ 2
- **Registered versions:** Conflict-{id}-{state}-v0
- **Valid arguments:** $\texttt{id} \in \{0, 1, 2, ..., 56\}$
- **Payoff matrix:** controlled by $\texttt{id}$ as dictated [here](https://github.com/instadeepai/matrax/blob/main/matrax/games/no_conflict.py).

## Prerequisites

In [1]:
%%capture
!pip install flashbax matrax jumanji

In [3]:
import jax
import matrax
from jumanji.wrappers import AutoResetWrapper
import collections
import jax.numpy as jnp
import optax
import flax.linen as nn
import flashbax as fbx

For this example, we have chosen a No Conflict game.

## Environment

In [4]:
# Instantiate a matrix game environment
ENV_NAME = "NoConflict-0-stateless-v0"

# Enable auto-resetting of the training environment
env = AutoResetWrapper(matrax.make(ENV_NAME))

# Leave evaluation environment without auto-resetting
eval_env = matrax.make(ENV_NAME)

NUM_ACTIONS = env.num_actions
NUM_AGENTS = env.num_agents
NUM_OBS = NUM_AGENTS  # in matrax, observations have shape (num_agents, num_agents)

This specific matrix game has 2 agents and 2 actions, so each agents payoff matrix will be 2x2 in size.

In [5]:
print("Number of agents:", env.num_agents)
print("Number of actions:", env.num_actions)

Number of agents: 2
Number of actions: 2


The following are the agents' payoff matrices:

In [6]:
print("Payoff Matrices:\n")
for agent in range(NUM_AGENTS):
    print(f"Agent {agent + 1} Payoff Matrix:")
    for row in env.payoff_matrix[agent]:
        print(f"[{row[0]:3.0f} {row[1]:3.0f}{' ' * 2}]")
    print("\n")

Payoff Matrices:

Agent 1 Payoff Matrix:
[  4   3  ]
[  2   1  ]


Agent 2 Payoff Matrix:
[  4   3  ]
[  2   1  ]




## Setup IQL Algorithm

### Q-Network

In [7]:
class Flatten(nn.Module):
    @nn.compact
    def __call__(self, x):
        batch_size = x.shape[0]  # assuming batch size is leading dim

        return x.reshape(batch_size, -1)


class QNetwork(nn.Module):
    num_actions: int

    @nn.compact
    def __call__(self, state):
        return nn.Sequential(
            [
                Flatten(),
                nn.Dense(20),
                nn.relu,
                nn.Dense(20),
                nn.relu,
                nn.Dense(self.num_actions),
            ]
        )(state)

In [8]:
SEED = 42

# Instantiate Q-network
q_network = QNetwork(NUM_ACTIONS)

# Create a single dummy observation (i.e., batch size is 1)
# We add num_agents to num_obs to account for the addition of one-hot-encoded agent IDs
dummy_obs = jnp.zeros((1, NUM_OBS + NUM_AGENTS), jnp.float32)

# Generate random key for initialising params
key = jax.random.PRNGKey(SEED)
key, subkey = jax.random.split(key)

# Initialise Q-network
q_network_params = q_network.init(subkey, dummy_obs)

In [9]:
# Store online and target parameters
QLearnParams = collections.namedtuple("Params", ["online", "target"])

# Q-learn-state
QLearnState = collections.namedtuple("LearnerState", ["count", "optim_state"])

### Action Selection

In [10]:
def select_greedy_action(q_values):
    """A function to select the action corresponding to the largest inputted Q-value for each agent."""

    action = jnp.argmax(q_values)

    return action

In [11]:
def select_random_action(subkey):
    """A function to select an action randomly for each agent."""

    action = jax.random.randint(subkey, shape=(), minval=0, maxval=NUM_ACTIONS)

    return action

In [12]:
EPSILON_DECAY_TIMESTEPS = 50_000  # decay epsilon over 50,000 timesteps
EPSILON_MIN = 0.1  # 10% exploration

In [13]:
def get_epsilon(num_timesteps):
    """A function to retrieve the value of epsilon, based on a (linear) decay rate and a minimum possible value."""

    epsilon = 1.0 - num_timesteps / EPSILON_DECAY_TIMESTEPS

    epsilon = jax.lax.select(epsilon < EPSILON_MIN, EPSILON_MIN, epsilon)

    return epsilon

In [14]:
def select_epsilon_greedy_action(key, q_values, num_timesteps):
    """A function to perform epsilon-greedy action selection."""

    epsilon = get_epsilon(num_timesteps)

    key, subkey = jax.random.split(key)

    should_explore = jax.random.uniform(subkey, (1,))[0] < epsilon

    key, subkey = jax.random.split(key)

    action = jax.lax.select(
        should_explore, select_random_action(subkey), select_greedy_action(q_values)
    )

    return action

In [15]:
def q_learning_select_action(key, params, num_timesteps, obs, evaluation=False):
    """A function to perform greedy or epsilon-greedy action selection,
    based on whether we are evaluating or training a policy.
    """

    obs = jnp.expand_dims(obs, axis=0)  # add a batch dim as the leading axis

    q_values = q_network.apply(params.online, obs)[0]  # remove batch dim

    action = select_epsilon_greedy_action(key, q_values, num_timesteps)
    greedy_action = select_greedy_action(q_values)

    action = jax.lax.select(evaluation, greedy_action, action)

    return action

### Create the Item Buffer

Here we create the item buffer and instantiate it with a full environment transition consisting of all agent observations, actions, rewards, next state observations and a global done flag. The buffer is used to sample batches of these transitions to perform Q-learning for each agent.

In [16]:
# Instantiate the replay buffer
BATCH_SIZE = 64
q_learning_memory = fbx.make_item_buffer(
    max_length=50_000, min_length=64, sample_batch_size=BATCH_SIZE, add_batches=True
)

# Make a dummy observation and initialise the replay buffer
transition = {
    "obs": jnp.zeros(
        (NUM_AGENTS, NUM_AGENTS), dtype="float32"
    ),  # second dim won't always be num_agents
    "action": jnp.zeros(NUM_AGENTS, dtype="int32"),
    "reward": jnp.zeros(NUM_AGENTS, dtype="float32"),
    "next_obs": jnp.zeros((NUM_AGENTS, NUM_AGENTS), dtype="float32"),
    "done": 0.0,
}  # store in dictionary

q_learning_memory_state = q_learning_memory.init(transition)

### Loss Function

In [17]:
def compute_squared_error(pred, target):
    """A function to compute the mean-squared error between a prediction and a target value."""

    squared_error = jax.numpy.square(pred - target)

    return squared_error

In [18]:
def compute_bellman_target(reward, done, next_q_values, gamma=0.99):
    """A function to compute the bellman target."""

    bellman_target = reward + gamma * (1.0 - done) * jax.numpy.max(next_q_values)

    return bellman_target

In [19]:
def q_learning_loss(q_values, action, reward, done, next_q_values):
    """Implementation of the Q-learning loss."""

    chosen_action_q_value = q_values[action]
    bellman_target = compute_bellman_target(reward, done, next_q_values)
    squared_error = compute_squared_error(chosen_action_q_value, bellman_target)

    return squared_error

In [20]:
def compute_agent_mse(
    online_params, target_params, encoded_obs, actions, rewards, encoded_next_obs, dones
):
    """A function to compute a single agent's mean-squared error."""

    q_values = q_network.apply(online_params, encoded_obs)  # use the online parameters
    next_q_values = q_network.apply(
        target_params, encoded_next_obs
    )  # use the target parameters

    # vmap the loss calculation over the batch
    q_learning_loss_vmap = jax.vmap(q_learning_loss, in_axes=(0, 0, 0, 0, 0))
    squared_error = q_learning_loss_vmap(
        q_values, actions, rewards, dones, next_q_values
    )

    # Take the mean of the batch losses
    mean_squared_error = jnp.mean(squared_error)

    return mean_squared_error

In [21]:
def batched_q_learning_loss(
    online_params, target_params, obs, actions, rewards, next_obs, dones
):
    """A function to compute the current and next Q-values and the squared loss over a batch, for both agents."""

    # Add one-hot encoding with agent IDs for each batch
    agent_ids = jnp.repeat(
        jnp.expand_dims(jnp.identity(NUM_AGENTS), axis=0), repeats=BATCH_SIZE, axis=0
    )
    encoded_obs = jnp.concatenate((obs, agent_ids), axis=2, dtype="float32")
    encoded_next_obs = jnp.concatenate((next_obs, agent_ids), axis=2, dtype="float32")

    # vmap the loss computation over both agents
    compute_agent_mse_vmap = jax.vmap(
        compute_agent_mse, in_axes=(None, None, 1, 1, 1, 1, None)
    )
    agent_mean_squared_errors = compute_agent_mse_vmap(
        online_params,
        target_params,
        encoded_obs,
        actions,
        rewards,
        encoded_next_obs,
        dones,
    )

    # Take the mean between all agent MSEs are the loss value
    loss_value = jnp.mean(agent_mean_squared_errors)

    return loss_value  # returns a single value

### Set up the Optimiser

In [22]:
LEARNING_RATE = 3e-4

# Initialise Q-network optimiser
OPTIMISER = optax.adam(learning_rate=LEARNING_RATE)

Q_LEARN_OPTIM_STATE = OPTIMISER.init(q_network_params)  # initial optim state

# Create Learn State
q_learning_learn_state = QLearnState(
    0, Q_LEARN_OPTIM_STATE
)  # count set to zero initially

# Add initial Q-network weights to QLearnParams object
q_learning_params = QLearnParams(
    online=q_network_params, target=q_network_params
)  # target equal to online

### Parameter Updates

In [23]:
UPDATE_TARGET_PERIOD = 100  # how often to update the target network with the current online network parameter values

In [24]:
def update_target_params(learn_state, online_weights, target_weights):
    """A function to update target params periodically."""

    target = jax.lax.cond(
        jax.numpy.mod(learn_state.count, UPDATE_TARGET_PERIOD) == 0,
        lambda x, y: x,
        lambda x, y: y,
        online_weights,
        target_weights,
    )

    params = QLearnParams(online_weights, target)

    return params

In [25]:
def q_learn(params, learner_state, batch):
    """A function to perform Q-learning updates to the main network parameters, and maybe the to target network parameters."""

    # Compute gradients
    grad_loss = jax.grad(batched_q_learning_loss)(
        params.online,
        params.target,
        batch.experience["obs"].astype("float32"),
        batch.experience["action"].astype("int32"),
        batch.experience["reward"].astype("float32"),
        batch.experience["next_obs"].astype("float32"),
        batch.experience["done"].astype("float32"),
    )

    # Get updates
    updates, opt_state = OPTIMISER.update(grad_loss, learner_state.optim_state)

    # Apply them
    new_weights = optax.apply_updates(params.online, updates)

    # Maybe update target network
    params = update_target_params(learner_state, new_weights, params.target)

    # Increment learner step counter
    learner_state = QLearnState(learner_state.count + 1, opt_state)

    return params, learner_state

### Set up the Training Loop

In [26]:
# Set the remaining training hyperparameters
NUM_TIMESTEPS = 1001  # the total number of timesteps to take during training
TRAINING_PERIOD = 4  # how often to train
EVALUATION_EPISODES = 32  # how many evaluation episodes to run
# Set NUM_ENVIRONMENTS = BATCH_SIZE to allow buffer to fill to the batch size on the first timestep
NUM_ENVIRONMENTS = BATCH_SIZE  # the number of training environments to step through in parallel for experience accumulation
EVALUATOR_PERIOD = 100  # how often to run evaluation

In [27]:
# Calculate the scanning values for the training loop
inner = jnp.array(TRAINING_PERIOD, dtype="int32")
middle = jnp.array(EVALUATOR_PERIOD / inner, dtype="int32")
outer = jnp.array((NUM_TIMESTEPS - 1) / (inner * middle), dtype="int32")

In [28]:
# vmap the environments
env_step_vmap = jax.vmap(env.step, in_axes=(0, 0))
env_reset_vmap = jax.vmap(env.reset, in_axes=0)
eval_env_step_vmap = jax.vmap(eval_env.step, in_axes=(0, 0))
eval_env_reset_vmap = jax.vmap(eval_env.reset, in_axes=0)

# vmap action selection across all agents and all environments
select_all_agent_actions = jax.vmap(
    q_learning_select_action, in_axes=(0, None, None, 0, None)
)
select_all_env_actions = jax.vmap(
    select_all_agent_actions, in_axes=(0, None, None, 0, None)
)

In [29]:
def action_step_store(obs, env_state, agent_memory_state, key, agent_params, ts):
    """A function to select an action, take a step and store the observation during agent training."""

    # Add a one-hot encoding with the agent IDs for each environment
    agent_ids = jnp.repeat(
        jnp.expand_dims(jnp.identity(NUM_AGENTS), axis=0),
        repeats=NUM_ENVIRONMENTS,
        axis=0,
    )
    encoded_obs = jnp.concatenate((obs, agent_ids), axis=2, dtype="float32")

    # Select actions for each agent
    key, subkey = jax.random.split(key)
    subkeys = jax.random.split(subkey, num=(NUM_ENVIRONMENTS, NUM_AGENTS))
    action = select_all_env_actions(subkeys, agent_params, ts, encoded_obs, False)

    # Step the environment(s)
    next_env_state, timestep = env_step_vmap(env_state, action)
    next_obs = timestep.observation.agent_obs

    # Add observations for each environment to the replay buffer
    transition = {
        "obs": obs * 1.0,
        "action": action,
        "reward": timestep.reward * 1.0,
        "next_obs": next_obs * 1.0,
        "done": 1 - timestep.discount,  # if terminated, done = True
    }
    # Add to the buffer
    agent_memory_state = q_learning_memory.add(agent_memory_state, transition)

    # Update obs and env state before next step
    obs = next_obs
    env_state = next_env_state

    # Increment the timestep
    ts = ts + 1

    return obs, env_state, agent_memory_state, key, ts


def action_step_store_scan(obs_env_mem_key_ts, _):
    """A scan-compatible version of action_step_store."""

    # Unpack the initial state
    obs, env_state, agent_memory_state, key, agent_params, ts = obs_env_mem_key_ts

    # Perform action, step, and store
    obs, env_state, agent_memory_state, key, ts = action_step_store(
        obs, env_state, agent_memory_state, key, agent_params, ts
    )

    # Re-pack the updated state
    obs_env_mem_key_ts = (obs, env_state, agent_memory_state, key, agent_params, ts)

    return obs_env_mem_key_ts, None

In [30]:
def learn(agent_params, agent_learner_state, key, agent_memory_state):
    """A function perform a learning step during agent training."""

    # Generate a new key
    key, subkey = jax.random.split(key)
    # First sample memory and then pass the result to the learn function
    batch = q_learning_memory.sample(agent_memory_state, subkey)
    agent_params, agent_learner_state = q_learn(
        agent_params, agent_learner_state, batch
    )

    return agent_params, agent_learner_state, key


def learn_scan(learn_state, _):
    """A nested, scan-compatible version of learn and action_step_store."""

    # Unpack the initial state for learn
    (
        agent_params,
        agent_learner_state,
        key,
        agent_memory_state,
        obs,
        env_state,
        ts,
    ) = learn_state

    # Perform parameter updates
    agent_params, agent_learner_state, key = learn(
        agent_params, agent_learner_state, key, agent_memory_state
    )

    # Define the initial state for action_step_store_scan
    action_step_store_state = (
        obs,
        env_state,
        agent_memory_state,
        key,
        agent_params,
        ts,
    )
    # Perform scan for taking an action, stepping and storing in the buffer
    action_step_store_state, _ = jax.lax.scan(
        action_step_store_scan, action_step_store_state, xs=None, length=inner
    )
    # Unpack arguments from updated action_step_store_scan state
    obs, env_state, agent_memory_state, key, agent_params, ts = action_step_store_state

    # Repack the updated state for learn_scan
    learn_state = (
        agent_params,
        agent_learner_state,
        key,
        agent_memory_state,
        obs,
        env_state,
        ts,
    )

    return learn_state, None

In [31]:
def evaluate(ts, key, agent_params):
    """A function to perform evaluation on the learned policy during training."""

    # Since an episode is one step in this environment, we can evaluate the policy using a single step!

    # Reset the environments
    key, subkey = jax.random.split(key)
    subkeys = jax.random.split(subkey, num=EVALUATION_EPISODES)
    eval_env_state, eval_timestep = eval_env_reset_vmap(subkeys)
    eval_obs = eval_timestep.observation.agent_obs

    # Add a one-hot encoding with the agent IDs for the NN
    agent_ids = jnp.repeat(
        jnp.expand_dims(jnp.identity(NUM_AGENTS), axis=0),
        repeats=EVALUATION_EPISODES,
        axis=0,
    )
    encoded_eval_obs = jnp.concatenate((eval_obs, agent_ids), axis=2, dtype="float32")

    # Select actions for each agent
    key, subkey = jax.random.split(key)
    subkeys = jax.random.split(subkey, num=(EVALUATION_EPISODES, NUM_AGENTS))
    eval_action = select_all_env_actions(
        subkeys, agent_params, ts, encoded_eval_obs, True
    )

    # Step the environment(s)
    eval_env_state, eval_timestep = eval_env_step_vmap(eval_env_state, eval_action)
    eval_obs = eval_timestep.observation.agent_obs

    # Record the average return over all (single step) evaluation episodes
    evaluator_return = jnp.mean(eval_timestep.reward * 1.0, axis=0)

    return evaluator_return, ts, key


def evaluate_scan(evaluate_state, _):
    """A nested, scan-compatible version of evaluate, learn and action_step_store."""

    # Unpack the initial state for evaluate
    (
        evaluator_return,
        ts,
        key,
        agent_params,
        agent_learner_state,
        agent_memory_state,
        obs,
        env_state,
        ts,
    ) = evaluate_state

    # Evaluate the learned policy
    evaluator_return, ts, key = evaluate(ts, key, agent_params)

    # Define the shared initial state for learn_scan
    learn_state = (
        agent_params,
        agent_learner_state,
        key,
        agent_memory_state,
        obs,
        env_state,
        ts,
    )
    # Perform scan for learning and taking an action, stepping and storing in the buffer
    learn_state, _ = jax.lax.scan(learn_scan, learn_state, xs=None, length=middle)
    # Unpack arguments from updated learn_scan state
    (
        agent_params,
        agent_learner_state,
        key,
        agent_memory_state,
        obs,
        env_state,
        ts,
    ) = learn_state

    # Repack the updated state for evaluate_scan
    evaluate_state = (
        evaluator_return,
        ts,
        key,
        agent_params,
        agent_learner_state,
        agent_memory_state,
        obs,
        env_state,
        ts,
    )

    return evaluate_state, evaluator_return

In [32]:
def run_training_loop(agent_params, agent_learner_state, agent_memory_state, seed=42):
    """
    This function runs several episodes in an environment and periodically
    does some agent learning and evaluation.

    Args:
        agent_params: an object to store parameters that the agent uses.
        agent_learner_state: an object that stores the internal state
            of the agent learn function.
        agent_memory_state: an object that stores the internal state of the
            agent memory.
        seed: PRNG seed for reproducibility.

    Returns:
        evaluator_episode_returns: list of all the evaluator episode returns
        for each agent.
    """

    # JAX random number generator
    key = jax.random.PRNGKey(seed)

    # Reset the environment(s)
    key, subkey = jax.random.split(key)
    subkeys = jax.random.split(subkey, num=NUM_ENVIRONMENTS)
    env_state, timestep = env_reset_vmap(subkeys)
    obs = timestep.observation.agent_obs

    # Set the timestep counter to zero
    ts = 0

    # Initialise the average return variable for evaluation in evaluate_scan
    evaluator_return = jnp.array([0.0, 0.0])

    # Pack arguments into a state for scanning over evaluate
    evaluate_state = (
        evaluator_return,
        ts,
        key,
        agent_params,
        agent_learner_state,
        agent_memory_state,
        obs,
        env_state,
        ts,
    )
    # Perform scan over evaluate, learn and action_step_store
    evaluate_state, evaluator_returns = jax.lax.scan(
        evaluate_scan, evaluate_state, xs=None, length=outer
    )
    # Unpack arguments from the updated state
    (
        evaluator_return,
        ts,
        key,
        agent_params,
        agent_learner_state,
        agent_memory_state,
        obs,
        env_state,
        ts,
    ) = evaluate_state

    # Evaluate the final learned policy at the end of training
    final_evaluator_return, ts, key = evaluate(ts, key, agent_params)

    evaluator_returns = jnp.append(
        evaluator_returns, jnp.expand_dims(final_evaluator_return, axis=0), axis=0
    )

    return evaluator_returns

In [33]:
run_training_loop_jitted = jax.jit(run_training_loop)

## Train the agent and evaluate

Below, we print the results for the evaluation rounds for each of the two agents as training progresses.

In [34]:
# Run environment loop
evaluator_returns = run_training_loop_jitted(
    q_learning_params, q_learning_learn_state, q_learning_memory_state, seed=SEED
)

# Generate timesteps from 0 to NUM_TIMESTEPS with a step of EVALUATOR_PERIOD
timesteps = jnp.arange(0, NUM_TIMESTEPS, EVALUATOR_PERIOD)

# Create logs for each timestep and its corresponding average return
logs = [
    f"Timestep: {str(timestep).ljust(6, ' ')}\tAgent 1 Return: {avg_return[0]:.2f}\tAgent 2 Return: {avg_return[1]:.2f}"
    for timestep, avg_return in zip(timesteps, evaluator_returns)
]

# Print the logs
print(*logs, sep="\n")

Timestep: 0     	Agent 1 Return: 2.00	Agent 2 Return: 2.00
Timestep: 100   	Agent 1 Return: 2.00	Agent 2 Return: 2.00
Timestep: 200   	Agent 1 Return: 4.00	Agent 2 Return: 4.00
Timestep: 300   	Agent 1 Return: 4.00	Agent 2 Return: 4.00
Timestep: 400   	Agent 1 Return: 4.00	Agent 2 Return: 4.00
Timestep: 500   	Agent 1 Return: 4.00	Agent 2 Return: 4.00
Timestep: 600   	Agent 1 Return: 4.00	Agent 2 Return: 4.00
Timestep: 700   	Agent 1 Return: 4.00	Agent 2 Return: 4.00
Timestep: 800   	Agent 1 Return: 4.00	Agent 2 Return: 4.00
Timestep: 900   	Agent 1 Return: 4.00	Agent 2 Return: 4.00
Timestep: 1000  	Agent 1 Return: 4.00	Agent 2 Return: 4.00


We see that both agents have converged to their optimal return in their payoff matrices.

### A Note on IQL Performance in Matrix Games

In some Matrax environments, IQL often falls into local optima. One such example of this can be seen using Climbing Game. Agents are expected to cooperately reach the highest reward of 11 by taking a joint action that risks a large punishment if either agent deviates from the optimal joint action. However, each agent is acting independently to optimise its own returns, creating a scenario where the agents are reluctant to frequently take suboptimal actions that may yield a higher joint reward.

See this paper for more details:

[Contrasting Centralized and Decentralized Critics in Multi-Agent Reinforcement Learning](https://arxiv.org/pdf/2102.04402.pdf#:~:text=Thus%2C%20for%20decentralized%20policy%20learning,less%20bias%20and%20more%20variance)