In [None]:
import os
from pathlib import Path
project_root = os.path.join(str(Path.home()), 'diffusion_models')
os.chdir(project_root)
%pwd # should be PPGA root dir

In [None]:
import pickle
import torch
import numpy as np

from autoencoders.policy.resnet3d import ResNet3DAutoEncoder
from autoencoders.policy.hypernet import HypernetAutoEncoder
from attrdict import AttrDict
from RL.actor_critic import Actor
from envs.brax_custom.brax_env import make_vec_env_brax
from IPython.display import HTML, Image
from IPython.display import display
from brax.io import html, image
from dataset.tensor_elites_dataset import preprocess_model, postprocess_model
from utils.brax_utils import shared_params
from tqdm import tqdm
import glob
from utils.normalize import ObsNormalizer
import json

In [None]:
# params to config
device = torch.device('cuda')
env_name = 'humanoid'
seed = 1111
normalize_obs = True
normalize_rewards = True
obs_shape = shared_params[env_name]['obs_dim']
action_shape = shared_params[env_name]['action_dim']
mlp_shape = (128, 128, action_shape)

env_cfg = AttrDict({
    'env_name': env_name,
    'env_batch_size': None,
    'num_dims': 2,
    'seed': seed,
    'num_envs': 1,
    'clip_obs_rew': True,
})

In [None]:
print(glob.glob(f'data/{env_name}/*.pkl'))
archive_df_path = f'data/{env_name}/archive100x100.pkl'
with open(archive_df_path, 'rb') as f:
    archive_df = pickle.load(f)

In [None]:
# make the env
env = make_vec_env_brax(env_cfg)

In [None]:
def integrate_obs_normalizer(agent: Actor):
    assert agent.obs_normalizer is not None
    w_in = agent.actor_mean[0].weight.data
    b_in = agent.actor_mean[0].bias.data
    mean, var = agent.obs_normalizer.obs_rms.mean, agent.obs_normalizer.obs_rms.var
    w_new = w_in / torch.sqrt(var + 1e-8)
    b_new = b_in - (mean / torch.sqrt(var + 1e-8)) @ w_in.T
    agent.actor_mean[0].weight.data = w_new
    agent.actor_mean[0].bias.data = b_new
    return agent

In [None]:
def enjoy_brax(agent, render=True, deterministic=True, normalize_obs=False):
    if normalize_obs:
        obs_mean, obs_var = agent.obs_normalizer.obs_rms.mean, agent.obs_normalizer.obs_rms.var
#         print(f'{obs_mean=}, {obs_var=}')

    obs = env.reset()
    rollout = [env.unwrapped._state]
    total_reward = 0
    measures = torch.zeros(env_cfg.num_dims).to(device)
    done = False
    while not done:
        with torch.no_grad():
            obs = obs.unsqueeze(dim=0).to(device)
            if normalize_obs:
                obs = (obs - obs_mean) / torch.sqrt(obs_var + 1e-8)

            if deterministic:
                act = agent.actor_mean(obs)
            else:
                act, _, _ = agent.get_action(obs)
            act = act.squeeze()
            obs, rew, done, info = env.step(act.cpu())
            measures += info['measures']
            rollout.append(env.unwrapped._state)
            total_reward += rew
    if render:
        i = HTML(html.render(env.unwrapped._env.sys, [s.qp for s in rollout]))
        display(i)
    print(f'{total_reward=}')
    print(f' Rollout length: {len(rollout)}')
    measures /= len(rollout)
    print(f'Measures: {measures.cpu().numpy()}')
    return total_reward.detach().cpu().numpy(), measures.cpu().numpy()

In [None]:
elites_array = archive_df.filter(regex='solution*').to_numpy()
metadata_list = archive_df.filter(regex='metadata*').to_numpy()
labels = []
inferred_labels = []
label_distances = []

In [None]:
len(elites_array)

## Rerun next two / three cells until sufficiently labelled

In [None]:
# Render one elite
if len(label_distances) > 10:
    # If we have distances, use the elite with the largest distance
    elite_index = np.argmax(label_distances)
    print("elite distance to nearest label", np.max(label_distances))
    print("current label:", inferred_labels[elite_index][-1])
else:
    # Otherwise, use a random elite
    elite_index = np.random.randint(len(elites_array))
print('Elite index:', elite_index)
agent = Actor(obs_shape, action_shape, True, True).deserialize(elites_array[elite_index]).to(device)
obs_normalizer = ObsNormalizer(obs_shape).to(device)
obs_normalizer.load_state_dict(metadata_list[elite_index][0]['obs_normalizer'])
agent.obs_normalizer = obs_normalizer
# make sure pre and post-processing are working correctly. This should return
# the exact same agent as the previous line
# agent = postprocess_model(agent, preprocess_model(agent, mlp_shape), mlp_shape, deterministic=False).to(device)
# if normalize_obs:
#     agent = integrate_obs_normalizer(agent)
reward, measures = enjoy_brax(agent, render=True, normalize_obs=True)

In [None]:
# Add a new label
labels.append([float(reward), measures.tolist(), int(elite_index),
              "hop forward on your left foot while lifting your right foot"]) # REPLACE LABEL HERE
with open(f"data/{env_name}/text_labels_{len(labels):05d}.json", "w") as f:
    json.dump(labels, f, indent=True)
if label_distances:
    label_distances[elite_index] = 0.0
print(len(labels))
labels[-1]

In [None]:
# Recompute inferred labels
labelled_elites = np.array([elites_array[elite_index] for (returns, measures, elite_index, label) in labels])
labelled_elites_labels = [label for (returns, measures, elite_index, label) in labels]

def nearest_labelled_elite(archive_index):
    distances = ((labelled_elites - elites_array[archive_index]) ** 2).sum(axis=-1)
    return int(np.argmin(distances)), float(distances.min())

inferred_labels = []
label_distances = []
for archive_index in tqdm(range(len(elites_array))):
    inferred_label, label_dist = nearest_labelled_elite(archive_index)
    inferred_labels.append(labels[inferred_label])
    label_distances.append(label_dist)

In [None]:
with open(f"data/{env_name}/text_labels_{len(labels):05d}.pkl", "wb") as f:
    pickle.dump(inferred_labels, f)

In [None]:
import glob
sorted(glob.glob(f"data/{env_name}/text_labels_*.pkl"))

In [None]:
labels[-1][-1]

In [341]:
[label for label in labels]

[[9686.7509765625,
  [0.08891108632087708, 0.39960038661956787],
  2382,
  'quickly slide forward on right foot'],
 [340.4988098144531,
  [0.03846153989434242, 0.9615384936332703],
  5733,
  'plant right foot and fall'],
 [9545.927734375,
  [0.8031967878341675, 0.5234764814376831],
  4753,
  'quickly walk forward while dragging right foot'],
 [9687.6181640625,
  [0.568431556224823, 0.27272725105285645],
  483,
  'quickly run forward using only your left foot'],
 [9704.373046875,
  [0.4865134656429291, 0.009990009479224682],
  1107,
  'quickly run forward while holding your right foot off the ground'],
 [9644.732421875,
  [0.4205794036388397, 0.5034964680671692],
  3551,
  'quickly walk forward while dragging your right foot'],
 [9685.0009765625,
  [0.3306693136692047, 0.29270729422569275],
  1754,
  'quickly run forward while leaning left'],
 [9362.5634765625,
  [0.8691308498382568, 0.6183816194534302],
  6133,
  'quickly shuffle forward while dragging your right foot'],
 [9643.1162109