In [1]:
import os
from pathlib import Path
project_root = os.path.join(str(Path.home()), 'diffusion_models')
os.chdir(project_root)
%pwd # should be PPGA root dir

'/home/sumeet/diffusion_models'

In [2]:
import pickle
import torch
import numpy as np

from autoencoders.policy.resnet3d import ResNet3DAutoEncoder
from autoencoders.policy.hypernet import HypernetAutoEncoder
from attrdict import AttrDict
from RL.actor_critic import Actor
from envs.brax_custom.brax_env import make_vec_env_brax
from IPython.display import HTML, Image
from IPython.display import display
from brax.io import html, image
from dataset.tensor_elites_dataset import preprocess_model, postprocess_model
from utils.brax_utils import shared_params

In [3]:
# params to config
device = torch.device('cuda')
env_name = 'walker2d'
seed = 1111
normalize_obs = True
normalize_rewards = True
obs_shape = shared_params[env_name]['obs_dim']
action_shape = shared_params[env_name]['action_dim']
mlp_shape = (128, 128, action_shape)

env_cfg = AttrDict({
    'env_name': env_name,
    'env_batch_size': None,
    'num_dims': 2,
    'seed': seed,
    'num_envs': 1
})

In [4]:
archive_df_path = 'data/walker2d/archive_100x100_global_obs_norm.pkl'
with open(archive_df_path, 'rb') as f:
    archive_df = pickle.load(f)

scheduler_path = 'data/walker2d/scheduler_100x100_global_obs_norm.pkl'
with open(scheduler_path, 'rb') as f:
    scheduler = pickle.load(f)



In [5]:
# make the env
env = make_vec_env_brax(env_cfg)

2023-04-06 14:37:34.467606: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:497] The NVIDIA driver's CUDA version is 11.8 which is older than the ptxas CUDA version (12.0.76). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [6]:
def get_best_elite():
    best_elite = scheduler.archive.best_elite
    agent = Actor(obs_shape, action_shape, True, True).deserialize(best_elite.solution).to(device)
    if normalize_obs:
        agent.obs_normalizer = best_elite.metadata['obs_normalizer']
    return agent

In [7]:
def get_random_elite():
    elite = scheduler.archive.sample_elites(1)
    agent = Actor(obs_shape, action_shape, True, True).deserialize(elite.solution_batch.flatten()).to(device)
    if normalize_obs:
        agent.obs_normalizer = elite.metadata_batch[0]['obs_normalizer']
    return agent

In [8]:
def integrate_obs_normalizer(agent: Actor):
    assert agent.obs_normalizer is not None
    w_in = agent.actor_mean[0].weight.data
    b_in = agent.actor_mean[0].bias.data
    mean, var = agent.obs_normalizer.obs_rms.mean, agent.obs_normalizer.obs_rms.var
    w_new = w_in / torch.sqrt(var + 1e-8)
    b_new = b_in - (mean / torch.sqrt(var + 1e-8)) @ w_in.T
    agent.actor_mean[0].weight.data = w_new
    agent.actor_mean[0].bias.data = b_new
    return agent


In [9]:
def enjoy_brax(agent, render=True, deterministic=True):
    if normalize_obs:
        obs_mean, obs_var = agent.obs_normalizer.obs_rms.mean, agent.obs_normalizer.obs_rms.var
        print(f'{obs_mean=}, {obs_var=}')

    obs = env.reset()
    rollout = [env.unwrapped._state]
    total_reward = 0
    measures = torch.zeros(env_cfg.num_dims).to(device)
    done = False
    while not done:
        with torch.no_grad():
            obs = obs.unsqueeze(dim=0).to(device)
            if normalize_obs:
                obs = (obs - obs_mean) / torch.sqrt(obs_var + 1e-8)

            if deterministic:
                act = agent.actor_mean(obs)
            else:
                act, _, _ = agent.get_action(obs)
            act = act.squeeze()
            obs, rew, done, info = env.step(act.cpu())
            measures += info['measures']
            rollout.append(env.unwrapped._state)
            total_reward += rew
    if render:
        i = HTML(html.render(env.unwrapped._env.sys, [s.qp for s in rollout]))
        display(i)
    print(f'{total_reward=}')
    print(f' Rollout length: {len(rollout)}')
    measures /= len(rollout)
    print(f'Measures: {measures.cpu().numpy()}')
    return total_reward.detach().cpu().numpy()

In [24]:
agent = get_random_elite()
# make sure pre and post-processing are working correctly. This should return
# the exact same agent as the previous line
# agent = postprocess_model(agent, preprocess_model(agent, mlp_shape), mlp_shape, deterministic=False).to(device)
# if normalize_obs:
#     agent = integrate_obs_normalizer(agent)
enjoy_brax(agent, render=False)

obs_mean=tensor([ 1.0464, -0.6017, -0.1448, -0.8962,  0.3852, -1.2207, -0.1963, -0.0661,
         1.0486, -0.1671, -0.3600,  0.1230,  0.5467, -0.1813,  0.7596,  0.3892,
        -0.0556], device='cuda:0'), obs_var=tensor([0.0167, 0.1110, 0.0378, 0.2373, 0.1877, 0.3711, 0.1633, 0.3658, 1.7961,
        0.9004, 2.3453, 2.3420, 6.7307, 6.5571, 4.6513, 2.7656, 4.9342],
       device='cuda:0')
total_reward=tensor(67.2580, device='cuda:0')
 Rollout length: 53
Measures: [0.20754717 0.9433962 ]


array(67.25801, dtype=float32)

In [25]:
# load the VAE model
autoencoder_cp_path = 'checkpoints/autoencoder_walker2d.pt'
vae_model = HypernetAutoEncoder(emb_channels=4, z_channels=4, obs_shape=obs_shape, action_shape=np.array([action_shape]))
vae_model.load_state_dict(torch.load(autoencoder_cp_path))
vae_model.to(device)

Total size of z is: 64


HypernetAutoEncoder(
  (quant_conv): Conv2d(8, 8, kernel_size=(1, 1), stride=(1, 1))
  (post_quant_conv): Conv2d(4, 4, kernel_size=(1, 1), stride=(1, 1))
  (encoder): ModelEncoder(
    (cnns): ModuleDict(
      (actor_logstd): Sequential(
        (fc1): Linear(in_features=6, out_features=256, bias=True)
        (relu1): ReLU(inplace=True)
        (fc2): Linear(in_features=256, out_features=256, bias=True)
        (relu2): ReLU(inplace=True)
        (fc3): Linear(in_features=256, out_features=256, bias=True)
        (relu3): ReLU(inplace=True)
      )
      (actor_mean_0_weight): Sequential(
        (cnn_block_0): Sequential(
          (conv0): Conv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (batchnorm0): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
          (relu0): ReLU(inplace=True)
          (maxpool_0): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        )
        (cnn_block_1): Sequential

In [26]:
# get the policy weights dict for shank's ghn based VAE
input_weights_dict = {}
for key, param in agent.named_parameters():
    if 'weight' in key or 'bias' in key or 'logstd' in key:
        input_weights_dict[key] = param.unsqueeze(0)

In [27]:
# get the reconstructed model
out, _ = vae_model(input_weights_dict)

# this is the 'policy as a tensor' way of doing reconstruction
# model_in = Actor(obs_shape, action_shape, True, True).to(device)
# rec_agent = postprocess_model(model_in, out, (128, 128, 6), deterministic=False)
# rec_agent.obs_normalizer = agent.obs_normalizer
# rec_agent.to(device)

# this is the 'weights dict -> Actor' method of reconstruction i.e. out is already an Actor object
rec_agent = out[0]

rec_agent.obs_normalizer = agent.obs_normalizer
rec_agent.to(device)

  if param.grad is not None:


Actor(
  (actor_mean): Sequential(
    (0): Linear(in_features=17, out_features=128, bias=True)
    (1): Tanh()
    (2): Linear(in_features=128, out_features=128, bias=True)
    (3): Tanh()
    (4): Linear(in_features=128, out_features=6, bias=True)
  )
  (obs_normalizer): ObsNormalizer(
    (obs_rms): RunningMeanStd()
  )
)

In [28]:
# if normalize_obs:
#     rec_agent = integrate_obs_normalizer(rec_agent)

In [31]:
normalize_obs = True
enjoy_brax(rec_agent, render=True, deterministic=True)

obs_mean=tensor([ 1.0464, -0.6017, -0.1448, -0.8962,  0.3852, -1.2207, -0.1963, -0.0661,
         1.0486, -0.1671, -0.3600,  0.1230,  0.5467, -0.1813,  0.7596,  0.3892,
        -0.0556], device='cuda:0'), obs_var=tensor([0.0167, 0.1110, 0.0378, 0.2373, 0.1877, 0.3711, 0.1633, 0.3658, 1.7961,
        0.9004, 2.3453, 2.3420, 6.7307, 6.5571, 4.6513, 2.7656, 4.9342],
       device='cuda:0')


total_reward=tensor(199.7835, device='cuda:0')
 Rollout length: 96
Measures: [0.125     0.9166667]


array(199.78348, dtype=float32)