In [None]:
import pandas as pd
import numpy as np
import functools

import jax
import numpy as np
from etils import epath
from tqdm import tqdm

from brax.training.agents.ppo import train as ppo
from brax.training.agents.ppo import networks as ppo_networks
from mujoco_playground.config import locomotion_params

from mujoco_playground import registry
from mujoco_playground import wrapper, wrapper_torch

In [3]:
!export XLA_PYTHON_CLIENT_PREALLOCATE=false 

In [None]:
path_model = [
    {'env': 'H1JoystickGaitTracking', "model": "PPO", "checkpoint_path": "expert_checkpoints/H1JoystickGaitTracking", "policy_hidden_layer_sizes": (256, 256)},
    {'env': 'H1InplaceGaitTracking', "model": "PPO", "checkpoint_path": "expert_checkpoints/H1InplaceGaitTracking", "policy_hidden_layer_sizes": (256, 256)},
    {'env': 'Go1JoystickRoughTerrain', "model": "PPO", "checkpoint_path": "expert_checkpoints/Go1JoystickRoughTerrain", "policy_hidden_layer_sizes": (256, 256)}, 
    {'env': 'Go1JoystickFlatTerrain', "model": "PPO", "checkpoint_path": "expert_checkpoints/Go1JoystickFlatTerrain", "policy_hidden_layer_sizes": (256, 256)},
    {'env': 'Go1Handstand', "model": "PPO", "checkpoint_path": "expert_checkpoints/Go1Handstand", "policy_hidden_layer_sizes": (256, 256)}, 
    {'env': 'Go1Getup', "model": "PPO", "checkpoint_path": "expert_checkpoints/Go1Getup", "policy_hidden_layer_sizes": (256, 256)}, 
    {'env': 'Go1Footstand', "model": "PPO", "checkpoint_path": "expert_checkpoints/Go1Footstand", "policy_hidden_layer_sizes": (256, 256)},
    {'env': 'G1JoystickRoughTerrain', "model": "PPO", "checkpoint_path": "expert_checkpoints/G1JoystickRoughTerrain", "policy_hidden_layer_sizes": (256, 256)}, 
    {'env': 'G1JoystickFlatTerrain', "model": "PPO", "checkpoint_path": "expert_checkpoints/G1JoystickFlatTerrain", "policy_hidden_layer_sizes": (256, 256)},
]

In [None]:
for p in path_model:
    print("-"*100)
    print(f"ENV: {p['env']}")
    print("-"*100)
    print()

    env = registry.load(p["env"])
    env_cfg = registry.get_default_config(p["env"])
    randomizer = registry.get_domain_randomizer(p["env"])


    # ------------- EXPERT EVALUATION
    ckpt_path = str(epath.Path(p["expert_checkpoint_path"]).resolve())
    FINETUNE_PATH = epath.Path(ckpt_path)
    latest_ckpts = list(FINETUNE_PATH.glob("*"))
    latest_ckpts = [ckpt for ckpt in latest_ckpts if ckpt.is_dir()]
    latest_ckpts.sort(key=lambda x: int(x.name))
    latest_ckpt = latest_ckpts[-1]
    restore_checkpoint_path = latest_ckpt

    ppo_params = locomotion_params.brax_ppo_config(p["env"])
    ppo_training_params = dict(ppo_params)
    ppo_training_params["num_timesteps"] = 0
    ppo_params["network_factory"]["policy_hidden_layer_sizes"] = p["policy_hidden_layer_sizes"]

    network_factory = ppo_networks.make_ppo_networks
    if "network_factory" in ppo_params:
        del ppo_training_params["network_factory"]
        network_factory = functools.partial(
            ppo_networks.make_ppo_networks, **ppo_params.network_factory
        )

    train_fn = functools.partial(
        ppo.train,
        **dict(ppo_training_params),
        network_factory=network_factory,
        randomization_fn=randomizer,
    )

    make_inference_fn, params, metrics = train_fn(
        environment=registry.load(p["env"]),
        eval_env=registry.load(p["env"]),
        wrap_env_fn=wrapper.wrap_for_brax_training,
        restore_checkpoint_path=restore_checkpoint_path,
        seed=1,
    )

    jit_inference_fn = jax.jit(make_inference_fn(params, deterministic=True))
    
    def eval_expert(env, n_episodes, jit_inference_fn):
        jit_reset = jax.jit(env.reset)
        jit_step = jax.jit(env.step)
        rng = jax.random.PRNGKey(12345)
        rng, reset_rng = jax.random.split(rng)
        episode_rewards = []
        for _ in tqdm(range(n_episodes)):
            state = jit_reset(reset_rng)
            done = False
            episode_reward = 0.0
            for i in range(env_cfg.episode_length):
                act_rng, rng = jax.random.split(rng)
                action, _ = jit_inference_fn(state.obs, act_rng)
                state = jit_step(state, action)
                episode_reward += wrapper_torch._jax_to_torch(state.reward).cpu().numpy()
                done = bool(wrapper_torch._jax_to_torch(state.done).cpu().numpy().item())
                if done:
                    break
            episode_rewards.append(episode_reward)

        return np.asarray(episode_rewards)
    
    episode_rewards = eval_expert(env, 10, jit_inference_fn, p["env"])
    p["episode_rewards_mean"] = episode_rewards.mean()
    p["episode_rewards_std"] = episode_rewards.std()


In [None]:
pd.DataFrame.from_dict(path_model)

In [12]:
pd.DataFrame.from_dict(path_model).to_csv("results_expert.csv", index=False)