In [43]:
import pickle

import numpy as np
from attrdict import AttrDict
from RL.ppo import *
from utils.utils import log
from envs.cpu.env import make_env
from envs.brax_custom.gpu_env import make_vec_env_brax
from models.actor_critic import Actor

from IPython.display import HTML, Image
from IPython.display import display
from brax.io import html, image
from brax import envs
from jax import numpy as jnp

In [44]:
%pwd
import os
os.chdir('/home/sbatra/QDPPO')

In [45]:
cfg = {'env_name': 'ant', 'env_batch_size': None, 'normalize_obs': False, 'normalize_rewards': True,
       'num_dims': 4, 'envs_per_model': 1, 'seed': 0, 'obs_shape': (87,), 'num_envs': 1}
cfg = AttrDict(cfg)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
env = make_vec_env_brax(cfg)

obs_shape, action_shape = env.observation_space.shape, env.action_space.shape

In [46]:
def load_agent_from_archive():
    archive_path = '/home/sbatra/QDPPO/logs/debug/cma_mega_adam/trial_0/checkpoints/cp_00000610/archive_00000610.pkl'
    with open(archive_path, 'rb') as f:
        archive_df = pickle.load(f)
    # elites = archive_df.query("objective > 6000").sort_values("objective", ascending=False)
    agent_params = archive_df.query('323').to_numpy()[6:]
    agent = Actor(cfg, obs_shape=obs_shape, action_shape=action_shape).deserialize(agent_params).to(device)
    return agent

In [47]:
def enjoy_brax(agent=None):
    if agent is None:
        agent = Actor(cfg, obs_shape, action_shape).to(device)
        cp_path = "checkpoints/brax_model_0_checkpoint"
        model_state_dict = torch.load(cp_path)['model_state_dict']
        model_state_dict['actor_logstd'] = model_state_dict['actor_logstd'].reshape(1, -1)
        agent.load_state_dict(model_state_dict)

    if cfg.normalize_obs:
        obs_mean, obs_var = agent.obs_normalizer.obs_rms.mean, agent.obs_normalizer.obs_rms.var


    obs = env.reset()
    rollout = [env.unwrapped._state]
    total_reward = 0
    measures = torch.zeros(cfg.num_dims).to(device)
    done = False
    while not done:
        with torch.no_grad():
            obs = obs.unsqueeze(dim=0).to(device)
            if cfg.normalize_obs:
                obs = (obs - obs_mean) / torch.sqrt(obs_var + 1e-8)
            act, _, _ = agent.get_action(obs)
            act = act.squeeze()
            obs, rew, done, info = env.step(act.cpu())
            measures += info['measures']
            rollout.append(env.unwrapped._state)
            total_reward += rew

    i = HTML(html.render(env.unwrapped._env.sys, [s.qp for s in rollout]))
    display(i) 
    print(f'{total_reward=}')
    print(f' Rollout length: {len(rollout)}')
    measures /= len(rollout)
    print(f'Measures: {measures.cpu().numpy()}')


In [48]:
agent = load_agent_from_archive()
enjoy_brax(agent)

total_reward=tensor(934.4839, device='cuda:0')
 Rollout length: 1001
Measures: [0.85814184 0.8261738  0.8371628  0.87212783]
