In [4]:
import argparse
import json
import numpy as np

from ray.rllib.algorithms import Algorithm
from ray.rllib.utils.typing import AgentID
from train import algorithm_config, get_checkpoint_dir, get_policy_mapping_fn
from typing import Any, Callable, Iterable



def get_actions(
    agent_ids: Iterable[AgentID],
    algorithm: Algorithm,
    policy_mapping_fn: Callable[[AgentID], str],
    observations: dict[AgentID, Any],
    states: dict[AgentID, Any]) -> tuple[dict[AgentID, Any], dict[AgentID, Any]]:
    """
    Get actions for the given agents.

    Parameters
    ----------
    agent_ids : Iterable[AgentID]
        Agent IDs for which to get actions
    algorithm : Algorithm
        RLlib algorithm instance with trained policies
    policy_mapping_fn : Callable(AgentID) -> str
        Function mapping agent IDs to policy IDs
    observations : dict[AgentID, Any]
        Observations for each agent
    states : dict[AgentID, Any]
        States for each agent

    Returns
    -------
    actions : dict[AgentID, Any]
        Actions for each agent
    states : dict[AgentID, Any]
        Updated states for each agent
    """
    actions = {}
    for agent_id in agent_ids:
        if states[agent_id]:
            actions[agent_id], states[agent_id], _ = algorithm.compute_single_action(
                observations[agent_id],
                states[agent_id],
                policy_id=policy_mapping_fn(agent_id)
            )
        else:
            actions[agent_id] = algorithm.compute_single_action(
                observations[agent_id],
                policy_id=policy_mapping_fn(agent_id)
            )

    return actions, states

def visualize(
    algorithm: Algorithm,
    policy_mapping_fn: Callable[[AgentID], str],
    num_episodes: int = 10) -> list[np.ndarray]:
    """
    Visualize trajectories from trained agents.

    Parameters
    ----------
    algorithm : Algorithm
        RLlib algorithm instance with trained policies
    policy_mapping_fn : Callable(AgentID) -> str
        Function mapping agent IDs to policy IDs
    num_episodes : int, default=10
        Number of episodes to visualize
    """
    frames = []
    env = algorithm.env_creator(algorithm.config.env_config)

    for episode in range(num_episodes):
        print('\n', '-' * 32, '\n', 'Episode', episode, '\n', '-' * 32)

        episode_rewards = {agent_id: 0.0 for agent_id in env.get_agent_ids()}
        terminations, truncations = {'__all__': False}, {'__all__': False}
        observations, infos = env.reset()
        states = {
            agent_id: algorithm.get_policy(policy_mapping_fn(agent_id)).get_initial_state()
            for agent_id in env.get_agent_ids()
        }
        while not terminations['__all__'] and not truncations['__all__']:
            frames.append(env.get_frame())
            actions, states = get_actions(
                env.get_agent_ids(), algorithm, policy_mapping_fn, observations, states)
            observations, rewards, terminations, truncations, infos = env.step(actions)
            for agent_id in rewards:
                episode_rewards[agent_id] += rewards[agent_id]

        frames.append(env.get_frame())
        print('Rewards:', episode_rewards)

    env.close()
    return frames




In [5]:


# Define your parameters here instead of using argparse
algo = 'PPO'
framework = 'torch'
lstm = False  # Use True if you want to enable LSTM
env = 'MultiGrid-Empty-8x8-v0'
env_config = {}  # You can define a custom config dictionary here
num_agents = 2
num_episodes = 10
max_steps = 20
load_dir = None  # Provide the directory if you have a pre-trained model
gif = None  # Provide the path if you want to save the output as a GIF

# # Rest of your functions (get_actions, visualize) remain the same

# Equivalent of your main block
args_env_config = {'render_mode': 'human'}
config = algorithm_config(
    algo=algo,
    framework=framework,
    lstm=lstm,
    env=env,
    env_config={**args_env_config, **env_config},
    num_agents=num_agents,
    num_episodes=num_episodes,
    max_steps=max_steps,
    load_dir=load_dir,
    gif=gif,
    num_workers=0,
    num_gpus=0,
)

algorithm = config.build()
checkpoint = get_checkpoint_dir(load_dir)
policy_mapping_fn = lambda agent_id, *args, **kwargs: f'policy_{agent_id}'
if checkpoint:
    print(f"Loading checkpoint from {checkpoint}")
    algorithm.restore(checkpoint)
    policy_mapping_fn = get_policy_mapping_fn(checkpoint, num_agents)

frames = visualize(algorithm, policy_mapping_fn, num_episodes=num_episodes)
if gif:
    from array2gif import write_gif
    filename = gif if gif.endswith('.gif') else f'{gif}.gif'
    print(f"Saving GIF to {filename}")
    write_gif(np.array(frames), filename, fps=10)




  logger.warn(
  logger.warn(f"{pre} is not within the observation space.")


`UnifiedLogger` will be removed in Ray 2.7.
  return UnifiedLogger(config, logdir, loggers=None)
The `JsonLogger interface is deprecated in favor of the `ray.tune.json.JsonLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `CSVLogger interface is deprecated in favor of the `ray.tune.csv.CSVLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
The `TBXLogger interface is deprecated in favor of the `ray.tune.tensorboardx.TBXLoggerCallback` interface and will be removed in Ray 2.7.
  self._loggers.append(cls(self.config, self.logdir, self.trial))
2023-09-20 16:40:09,623	INFO tensorboardx.py:48 -- pip install "ray[tune]" to see TensorBoard files.
  logger.warn(
  logger.warn(
  logger.warn(


: 

: 