In [1]:
import torch

# Humandoid MARL
from Humanoid_MARL import envs
from Humanoid_MARL.envs.base_env import GymWrapper, VectorGymWrapper
from Humanoid_MARL.utils.visual import save_video, save_rgb_image
from Humanoid_MARL.utils.torch_utils import save_models, load_models
from Humanoid_MARL.agent.ppo.train_torch import eval_unroll, get_agent_actions
from Humanoid_MARL.agent.ppo.agent import Agent
from Humanoid_MARL.envs.torch_wrapper import TorchWrapper
from IPython.display import HTML
from brax.io import html
import jax
from Humanoid_MARL import envs

In [2]:

config = {
        'num_timesteps': 100_000_000,
        'eval_frequency': 10,
        'episode_length': 1000,
        'unroll_length': 10,
        'num_minibatches': 32,
        'num_update_epochs': 8,
        'discounting': 0.97,
        'learning_rate': 3e-4,
        'entropy_cost': 1e-3,
        'num_envs': 2048,
        'batch_size': 512,
        'env_name': "humanoid",
        'render' : False,
        'device' : 'cuda',
        'model_path' : "../models/20240213_160524_ppo_humanoid.pt",
        'video_length' : 300,
    }
env = envs.create(
        config['env_name'], batch_size=None, episode_length=config['episode_length'], backend="generalized"
    )
env = GymWrapper(env, get_jax_state=True)
env = TorchWrapper(env, device=config['device'], get_jax_state=True)

# env warmup
observation = env.reset()
action = torch.zeros(env.action_space.shape[0] * env.num_agents).to(config['device'])
env.step(action)
jax_states = []
num_steps = 1000
# agents = load_models(config['model_path'], Agent, device=config['device'])

state_dicts = torch.load(config['model_path'], map_location=config['device'])
network_arch = state_dicts["network_arch"]
agents = []

for i in range(len(state_dicts) - 1):  # Subtract 1 to exclude the network_arch entry
    agent = Agent(**network_arch).to(config['device'])
    agent.load_state_dict(state_dicts[f'agent_{i}'])
    agent.eval()
    agents.append(agent)

In [3]:
eval_reward = 0.0
episodes = torch.zeros((), device = config['device'])
for i in range(num_steps):
    print(f"{i} / {num_steps}")
    logits, action = get_agent_actions(agents, observation, env.obs_dims)
    jax_state, observation, reward, done, info = env.step(Agent.dist_postprocess(action[0]))
    episodes += torch.sum(done)
    jax_states.append(jax_state)
    print(f"{i} | {info} | DONE [{done}] | Reward [{reward}]")
    eval_reward += reward
print(f"Total Reward | {eval_reward / episodes}")

0 / 1000
0 | {'distance_from_origin': tensor(0.9374, device='cuda:0'), 'first_obs': tensor([ 1.3964e+00,  9.9022e-01, -5.0478e-03, -4.1752e-03,  8.2115e-03,
        -3.0075e-03, -2.3696e-04,  5.0182e-03, -2.1919e-03,  8.7368e-03,
        -7.9229e-03,  3.3246e-03, -6.0711e-03,  8.0665e-03,  7.3675e-03,
        -9.5078e-03, -1.2503e-03,  8.0097e-03,  5.7708e-03, -1.8930e-03,
         2.9378e-03, -5.4378e-03, -1.0107e-04,  4.8649e-03, -4.7188e-03,
         2.7836e-03,  2.4877e-03,  1.9663e-03, -1.1979e-03, -1.6890e-03,
        -4.4962e-03, -8.0390e-03,  8.2597e-03,  4.6945e-03, -8.8630e-03,
         6.0610e-03, -1.7541e-04, -5.4629e-03, -2.6865e-03,  8.5640e-03,
        -9.5905e-03,  5.7331e-03, -1.2208e-03,  7.3980e-04,  8.2583e-03,
         2.2379e+00,  8.4696e-04,  8.9562e-02,  8.4696e-04,  2.2228e+00,
        -1.3639e-02,  8.9562e-02, -1.3639e-02,  4.5343e-02,  8.9075e+00,
         9.3039e-02,  7.3657e-05,  1.1370e-02,  7.3657e-05,  8.8489e-02,
         1.2817e-04,  1.1370e-02,  1.281

In [4]:
HTML(html.render(env.sys, [jax_state.pipeline_state for jax_state in jax_states])) 

In [5]:
HTML(html.render(env.sys, [jax_state.pipeline_state for jax_state in jax_states])) 