In [12]:
import pickle

import numpy as np
from attrdict import AttrDict
from RL.ppo import *
from utils.utils import log
from envs.cpu.env import make_env
from envs.brax_custom.gpu_env import make_vec_env_brax
from models.actor_critic import Actor

from IPython.display import HTML, Image
from IPython.display import display
from brax.io import html, image
from brax import envs
from jax import numpy as jnp

In [13]:
%pwd
import os
os.chdir('/home/sbatra/QDPPO')

In [14]:
cfg = {'env_name': 'humanoid', 'env_batch_size': None, 'normalize_obs': False, 'normalize_rewards': True,
       'num_dims': 2, 'envs_per_model': 1, 'seed': 0, 'obs_shape': (87,), 'num_envs': 1}
cfg = AttrDict(cfg)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
env = make_vec_env_brax(cfg)

obs_shape, action_shape = env.observation_space.shape, env.action_space.shape

In [15]:
def load_agent_from_archive():
    archive_path = '/home/sbatra/QDPPO/logs/method3_humanoid_pycma/cma_maega/trial_0/checkpoints/cp_00001000/archive_00001000.pkl'
    with open(archive_path, 'rb') as f:
        archive_df = pickle.load(f)
    # elites = archive_df.query("objective > 6000").sort_values("objective", ascending=False)
    agent_params = archive_df.query('2805').to_numpy()[4:]
    agent = Actor(cfg, obs_shape=obs_shape, action_shape=action_shape).deserialize(agent_params).to(device)
    return agent

In [16]:
def enjoy_brax(agent=None, render=False):
    if agent is None:
        agent = Actor(cfg, obs_shape, action_shape).to(device)
        cp_path = "/home/sbatra/QDPPO/checkpoints/humanoid_brax_model_0_checkpoint"
        model_state_dict = torch.load(cp_path)['model_state_dict']
        model_state_dict['actor_logstd'] = model_state_dict['actor_logstd'].reshape(1, -1)
        agent.load_state_dict(model_state_dict)

    if cfg.normalize_obs:
        obs_mean, obs_var = agent.obs_normalizer.obs_rms.mean, agent.obs_normalizer.obs_rms.var


    obs = env.reset()
    rollout = [env.unwrapped._state]
    total_reward = 0
    measures = torch.zeros(cfg.num_dims).to(device)
    done = False
    while not done:
        with torch.no_grad():
            obs = obs.unsqueeze(dim=0).to(device)
            if cfg.normalize_obs:
                obs = (obs - obs_mean) / torch.sqrt(obs_var + 1e-8)
            act, _, _ = agent.get_action(obs)
            act = act.squeeze()
            obs, rew, done, info = env.step(act.cpu())
            measures += info['measures']
            rollout.append(env.unwrapped._state)
            total_reward += rew

    if render:
        i = HTML(html.render(env.unwrapped._env.sys, [s.qp for s in rollout]))
        display(i)
        print(f'{total_reward=}')
        print(f' Rollout length: {len(rollout)}')
        measures /= len(rollout)
        print(f'Measures: {measures.cpu().numpy()}')
    return total_reward.detach().cpu().numpy()


In [17]:
agent = load_agent_from_archive()
enjoy_brax(agent, render=True)

total_reward=tensor(441.3353, device='cuda:0')
 Rollout length: 90
Measures: [0.72222227 0.17777778]


array(441.3353, dtype=float32)

In [18]:
# trials = 20
# scores = []
# for t in range(trials):
#     scores.append(enjoy_brax(agent, render=False))
# scores = np.array(scores)
# min_rew, max_rew, mean_rew = np.min(scores), np.max(scores), np.mean(scores)
# print(f'{min_rew=}, {max_rew=}, {mean_rew=}')
