In [2]:
import os
from pathlib import Path
project_root = os.path.join(str(Path.home()), 'diffusion_models')
os.chdir(project_root)
%pwd # should be PPGA root dir

'/home/shashank/research/qd/main/diffusion_models'

In [15]:
import pickle
import torch
import numpy as np

from autoencoders.policy.resnet3d import ResNet3DAutoEncoder
from autoencoders.policy.hypernet import HypernetAutoEncoder
from attrdict import AttrDict
from RL.actor_critic import Actor
from envs.brax_custom.brax_env import make_vec_env_brax
from IPython.display import HTML, Image
from IPython.display import display
from brax.io import html, image
from dataset.tensor_elites_dataset import preprocess_model, postprocess_model
from utils.brax_utils import shared_params
from tqdm import tqdm
import glob
from utils.normalize import ObsNormalizer


In [4]:
# params to config
device = torch.device('cuda')
env_name = 'humanoid'
seed = 1111
normalize_obs = True
normalize_rewards = True
obs_shape = shared_params[env_name]['obs_dim']
action_shape = shared_params[env_name]['action_dim']
mlp_shape = (128, 128, action_shape)

env_cfg = AttrDict({
    'env_name': env_name,
    'env_batch_size': None,
    'num_dims': 2,
    'seed': seed,
    'num_envs': 1,
    'clip_obs_rew': True,
})

In [5]:
print(glob.glob(f'data/{env_name}/*.pkl'))
archive_df_path = f'data/{env_name}/archive100x100.pkl'
with open(archive_df_path, 'rb') as f:
    archive_df = pickle.load(f)

['data/humanoid/archive100x100.pkl']




In [6]:
# make the env
env = make_vec_env_brax(env_cfg)

In [7]:
def integrate_obs_normalizer(agent: Actor):
    assert agent.obs_normalizer is not None
    w_in = agent.actor_mean[0].weight.data
    b_in = agent.actor_mean[0].bias.data
    mean, var = agent.obs_normalizer.obs_rms.mean, agent.obs_normalizer.obs_rms.var
    w_new = w_in / torch.sqrt(var + 1e-8)
    b_new = b_in - (mean / torch.sqrt(var + 1e-8)) @ w_in.T
    agent.actor_mean[0].weight.data = w_new
    agent.actor_mean[0].bias.data = b_new
    return agent

In [8]:
def enjoy_brax(agent, render=True, deterministic=True, normalize_obs=False):
    if normalize_obs:
        obs_mean, obs_var = agent.obs_normalizer.obs_rms.mean, agent.obs_normalizer.obs_rms.var
        print(f'{obs_mean=}, {obs_var=}')

    obs = env.reset()
    rollout = [env.unwrapped._state]
    total_reward = 0
    measures = torch.zeros(env_cfg.num_dims).to(device)
    done = False
    while not done:
        with torch.no_grad():
            obs = obs.unsqueeze(dim=0).to(device)
            if normalize_obs:
                obs = (obs - obs_mean) / torch.sqrt(obs_var + 1e-8)

            if deterministic:
                act = agent.actor_mean(obs)
            else:
                act, _, _ = agent.get_action(obs)
            act = act.squeeze()
            obs, rew, done, info = env.step(act.cpu())
            measures += info['measures']
            rollout.append(env.unwrapped._state)
            total_reward += rew
    if render:
        i = HTML(html.render(env.unwrapped._env.sys, [s.qp for s in rollout]))
        display(i)
    print(f'{total_reward=}')
    print(f' Rollout length: {len(rollout)}')
    measures /= len(rollout)
    print(f'Measures: {measures.cpu().numpy()}')
    return total_reward.detach().cpu().numpy(), measures.cpu().numpy()

In [10]:
elites_array = archive_df.filter(regex='solution*').to_numpy()
metadata_list = archive_df.filter(regex='metadata*').to_numpy()
inferred_labels = []
label_distances = []

In [None]:
len(elites_array)

## Rerun next two / three cells until sufficiently labelled

In [17]:
# Render one elite
if len(label_distances) > 50:
    # If we have distances, use the elite with the largest distance
    elite_index = np.argmax(label_distances)
    print("elite distance to nearest label", np.max(label_distances))
    print("current label:", inferred_labels[elite_index])
else:
    # Otherwise, use a random elite
    elite_index = np.random.randint(len(elites_array))
print('Elit index:', elite_index)
agent = Actor(obs_shape, action_shape, True, True).deserialize(elites_array[elite_index]).to(device)
obs_normalizer = ObsNormalizer(obs_shape).to(device)
obs_normalizer.load_state_dict(metadata_list[elite_index][0]['obs_normalizer'])
agent.obs_normalizer = obs_normalizer
# make sure pre and post-processing are working correctly. This should return
# the exact same agent as the previous line
# agent = postprocess_model(agent, preprocess_model(agent, mlp_shape), mlp_shape, deterministic=False).to(device)
# if normalize_obs:
#     agent = integrate_obs_normalizer(agent)
reward, measures = enjoy_brax(agent, render=True, normalize_obs=True)

Elit index: 6381
obs_mean=tensor([ 1.1337e+00,  9.7395e-01, -5.5708e-02,  3.6482e-02, -4.6043e-02,
        -3.8665e-04, -1.2335e+00, -9.7330e-01, -1.8264e-01,  2.6335e-01,
         1.8258e-01, -5.4048e-01, -1.9653e-01,  2.4558e-02, -8.8508e-03,
        -4.9567e-01, -4.8779e-02,  3.2579e-01,  1.3397e-02, -2.1375e-02,
         1.2111e-01, -1.4854e-02,  3.9019e+00,  3.9537e-01, -6.6402e-02,
         3.4984e-02,  1.7214e-02, -1.0813e-02,  1.2334e-02,  2.0938e-01,
        -1.3632e-02, -3.4590e-02, -1.0492e-01, -9.1260e-02, -1.5966e-01,
        -3.7884e-02, -5.0611e-02, -6.3386e-03,  1.3548e-01, -2.9977e-02,
         7.2164e-02, -1.8925e-01, -1.2265e-01,  8.1530e-02,  1.7283e-01,
         1.9596e+00,  1.6927e-03,  2.3400e-02,  1.6927e-03,  1.9657e+00,
        -8.3290e-03,  2.3400e-02, -8.3290e-03,  1.8684e+00,  1.0288e+00,
         4.0590e-04,  5.5268e-03,  4.0590e-04,  1.0360e+00, -3.6920e-04,
         5.5268e-03, -3.6920e-04,  1.0317e+00,  1.0341e+00,  8.5313e-05,
        -1.4814e-04,  8.5

total_reward=tensor(9393.7598, device='cuda:0')
 Rollout length: 1001
Measures: [0.6953047  0.75524473]


In [None]:
# Add a new label
labels_archive_idx.append([float(reward), measures.tolist(), int(elite_index),
                          "fall forwards while keeping both feet planted"]) # REPLACE LABEL HERE
with open(f"data/{env_name}/text_labels_{len(labels_archive_idx):05d}.json", "w") as f:
    json.dump(labels_archive_idx, f, indent=True)
label_distances[elite_index] = 0.0
labels_archive_idx[-1]

In [None]:
# Recompute inferred labels
labelled_elites = np.array([elites_array[elite_index] for (returns, measures, elite_index, label) in labels_archive_idx])
labelled_elites_labels = [label for (returns, measures, elite_index, label) in labels_archive_idx]

def nearest_labelled_elite(archive_index):
    distances = ((labelled_elites - elites_array[archive_index]) ** 2).sum(axis=-1)
    return int(np.argmin(distances)), float(distances.min())

inferred_labels = []
label_distances = []
for archive_index in tqdm(range(len(elites_array))):
    inferred_label, label_dist = nearest_labelled_elite(archive_index)
    inferred_labels.append(labelled_elites_labels[inferred_label])
    label_distances.append(label_dist)

In [None]:
with open(f"data/{env_name}/text_labels_{len(labels_archive_idx):05d}.pkl", "wb") as f:
    pickle.dump(inferred_labels, f)

In [None]:
import glob
sorted(glob.glob(f"data/{env_name}/text_labels_*.pkl"))

In [None]:
labels_archive_idx[-1]