In [5]:
import torch

# Humandoid MARL
from Humanoid_MARL import envs
from Humanoid_MARL.envs.base_env import GymWrapper, VectorGymWrapper
from Humanoid_MARL.utils.torch_utils import save_models, load_models
from Humanoid_MARL.agent.ppo.train_torch import Agent, eval_unroll, get_agent_actions
from Humanoid_MARL.envs.torch_wrapper import TorchWrapper
from IPython.display import HTML, clear_output
from brax.io import html
import jax
from Humanoid_MARL import envs
import mediapy as media

In [6]:
config = {
        'num_timesteps': 100_000_000,
        'eval_frequency': 10,
        'episode_length': 1000,
        'unroll_length': 10,
        'num_minibatches': 32,
        'num_update_epochs': 8,
        'discounting': 0.97,
        'learning_rate': 3e-4,
        'entropy_cost': 1e-3,
        'num_envs': 2048,
        'batch_size': 512,
        'env_name': "humanoids",
        'render' : True,
        'device' : 'cuda',
        'model_path' : "../models/20240223_124515_ppo_humanoid_copy_2.pt",
        'video_length' : 300,
    }

env = envs.create(
        config['env_name'],
        auto_reset=False,
    )

env = GymWrapper(env, get_jax_state=True)
env = TorchWrapper(env, device=config['device'], get_jax_state=True)

# env warmup
observation = env.reset()
action = torch.zeros(env.action_space.shape[0] * env.num_agents).to(config['device'])
env.step(action)
agents = load_models(config['model_path'], Agent, device=config['device'])
jax_states = []
num_steps = 1000

eval_reward = 0.0
episodes = torch.zeros((), device = config['device'])
for i in range(num_steps):
    print(f"{i} / {num_steps}")
    logits, action = get_agent_actions(agents, observation, env.obs_dims)
    jax_state, observation, reward, done, info = env.step(Agent.dist_postprocess(action[0]))
    episodes += torch.sum(done)
    jax_states.append(jax_state)
    print(f"{i} | {info} | DONE [{done}] | Reward [{reward}]")
    eval_reward += reward
print(f"Total Reward | {eval_reward / episodes}")

Models loaded from ../models/20240223_124515_ppo_humanoid_copy_2.pt
0 / 1000
0 | {'distance_from_origin': tensor([0.9494, 1.7068], device='cuda:0'), 'forward_reward': tensor([ 0.0016, -0.0186], device='cuda:0'), 'reward_alive': tensor([5., 5.], device='cuda:0'), 'reward_chase': tensor([0., 0.], device='cuda:0'), 'reward_linvel': tensor([ 0.0016, -0.0186], device='cuda:0'), 'reward_quadctrl': tensor([-0.4830, -0.6752], device='cuda:0'), 'standup_reward': tensor([0., 0.], device='cuda:0'), 'steps': tensor(2., device='cuda:0'), 'truncation': tensor(0., device='cuda:0'), 'x_position': tensor([0.0197, 1.0185], device='cuda:0'), 'x_velocity': tensor([ 0.0011, -0.0124], device='cuda:0'), 'y_position': tensor([2.7883e-04, 1.0026e+00], device='cuda:0'), 'y_velocity': tensor([ 0.0108, -0.0013], device='cuda:0'), 'z_position': tensor([1.4036, 1.3875], device='cuda:0')} | DONE [0.0] | Reward [tensor([4.5186, 4.3062], device='cuda:0')]
1 / 1000
1 | {'distance_from_origin': tensor([0.9438, 1.7036], 

In [7]:
HTML(html.render(env.sys, [jax_state.pipeline_state for jax_state in jax_states])) 

In [8]:
HTML(html.render(env.sys, [jax_state.pipeline_state for jax_state in jax_states])) 