In [1]:
import torch
import yaml
import os

# Humandoid MARL
from Humanoid_MARL.envs.base_env import GymWrapper
from Humanoid_MARL.utils.torch_utils import save_models, load_model_central_critic_agent
from Humanoid_MARL.algorithms.ant_mappo import AgentMAPPO
from Humanoid_MARL.training.train_ant_mappo import get_agent_actions
from Humanoid_MARL.envs.torch_wrapper import TorchWrapper
from IPython.display import HTML, clear_output
from brax.io import html
import jax
from Humanoid_MARL import envs
from Humanoid_MARL.utils.utils import load_reward_config, load_config

central_agent_config = ["mappo"]
independent_agent_config = ["ippo"]

In [2]:
env_name = "ants"
algo = "mappo"
config = load_config(env_name, algo)
env = envs.create(config['env_name'], auto_reset=False, **config['env_config'])
env = GymWrapper(env, get_jax_state=True)
env = TorchWrapper(env, device=config['train_config']['device'], get_jax_state=True)

model = "20240507_171850_ppo_ants_44236800.pt"
model_path = os.path.join("../models/", model)
# env warmup
observation = env.reset()
action = torch.zeros(env.action_space.shape[0] * env.num_agents).to(config['train_config']['device'])
env.step(action)
agents = load_model_central_critic_agent(model_path, AgentMAPPO, device=config['train_config']['device'])
jax_states = []
num_steps = 1000

eval_reward = 0.0
episodes = torch.zeros((), device = config['train_config']['device'])
for i in range(num_steps):
    print(f"{i} / {num_steps}")
    observation = torch.unsqueeze(observation, 0).to(config['train_config']['device'])
    logits, action = get_agent_actions(agents, observation, env.obs_dims)
    if config['agent_config'].get("freeze_idx"):
        action[:,config['agent_config'].get("freeze_idx") * 8:(config['agent_config'].get("freeze_idx") * 8) + 8] = torch.ones_like(action[:,config['agent_config'].get("freeze_idx") * 8:(config['agent_config'].get("freeze_idx") * 8) + 8]) * 0
    jax_state, observation, reward, done, info = env.step(AgentMAPPO.dist_postprocess(action[0]))
    episodes += torch.sum(done)
    jax_states.append(jax_state)
    print(f"{i} | {info} | DONE [{done}] | Reward [{reward}]")
    print(f"{i} | Action {action}")
    eval_reward += reward
    if done:
        observation = env.reset()
        print(f"Episode Done")
        print(f"Total Reward | {eval_reward / episodes}")
print(f"Total Reward | {eval_reward / episodes}")

Model loaded from ../models/20240507_171850_ppo_ants_44236800.pt
0 / 1000
0 | {'angle_penalty': tensor([-0.0153, -0.0256], device='cuda:0'), 'distance_from_origin': tensor([0.1117, 2.0287], device='cuda:0'), 'forward_reward': tensor([0., 0.], device='cuda:0'), 'reward_chase': tensor([-0.9695,  0.9695], device='cuda:0'), 'reward_ctrl': tensor([-0., -0.], device='cuda:0'), 'reward_forward': tensor([0., -0.], device='cuda:0'), 'reward_survive': tensor([0., 0.], device='cuda:0'), 'reward_tag': tensor([0., -0.], device='cuda:0'), 'stand_up_reward': tensor([0., 0.], device='cuda:0'), 'steps': tensor(2., device='cuda:0'), 'truncation': tensor(0., device='cuda:0'), 'wall_penalty': tensor([-7.0545e-28, -1.0351e-19], device='cuda:0'), 'x_position': tensor([0.0899, 2.0285], device='cuda:0'), 'x_velocity': tensor([ 0.3063, -0.2816], device='cuda:0'), 'y_position': tensor([-0.0663, -0.0249], device='cuda:0'), 'y_velocity': tensor([-0.0173,  0.0141], device='cuda:0'), 'z_position': tensor([0.8362, 0

In [5]:
HTML(html.render(env.sys, [jax_state.pipeline_state for jax_state in jax_states])) 

In [6]:
HTML(html.render(env.sys, [jax_state.pipeline_state for jax_state in jax_states])) 