In [None]:
import os
from pathlib import Path
project_root = os.path.join(str(Path.home()), 'diffusion_models')
os.chdir(project_root)
%pwd # should be PPGA root dir

In [None]:
import pickle
import torch
import numpy as np

from autoencoders.policy.resnet3d import ResNet3DAutoEncoder
from autoencoders.policy.hypernet import HypernetAutoEncoder
from attrdict import AttrDict
from RL.actor_critic import Actor
from envs.brax_custom.brax_env import make_vec_env_brax
from IPython.display import HTML, Image
from IPython.display import display
from brax.io import html, image
from dataset.policy_dataset import preprocess_model, postprocess_model

In [None]:
# params to config
device = torch.device('cuda')
env_name = 'halfcheetah'
seed = 1111
normalize_obs = True
normalize_rewards = False
obs_shape = 18
action_shape = 6
mlp_shape = (128, 128, 6)

env_cfg = AttrDict({
    'env_name': env_name,
    'env_batch_size': None,
    'num_dims': 2,
    'seed': seed,
    'num_envs': 1
})

In [None]:
archive_df_path = '/home/sumeet/QDPPO/experiments/ppga_halfcheetah_adaptive_stddev/1111/checkpoints/cp_00001990/archive_df_00001990.pkl'
with open(archive_df_path, 'rb') as f:
    archive_df = pickle.load(f)

scheduler_path = '/home/sumeet/QDPPO/experiments/ppga_halfcheetah_adaptive_stddev/1111/checkpoints/cp_00001990/scheduler_00001990.pkl'
with open(scheduler_path, 'rb') as f:
    scheduler = pickle.load(f)

In [None]:
# make the env
env = make_vec_env_brax(env_cfg)

In [None]:
def get_best_elite():
    best_elite = scheduler.archive.best_elite
    agent = Actor(obs_shape, action_shape, True, True).deserialize(best_elite.solution).to(device)
    if normalize_obs:
        agent.obs_normalizer = best_elite.metadata['obs_normalizer']
    return agent

In [None]:
def get_random_elite():
    elite = scheduler.archive.sample_elites(1)
    agent = Actor(obs_shape, action_shape, True, True).deserialize(elite.solution_batch.flatten()).to(device)
    if normalize_obs:
        agent.obs_normalizer = elite.metadata_batch[0]['obs_normalizer']
    return agent

In [None]:
def integrate_obs_normalizer(agent: Actor):
    assert agent.obs_normalizer is not None
    w_in = agent.actor_mean[0].weight.data
    b_in = agent.actor_mean[0].bias.data
    mean, var = agent.obs_normalizer.obs_rms.mean, agent.obs_normalizer.obs_rms.var
    w_new = w_in / torch.sqrt(var + 1e-8)
    b_new = b_in - (mean / torch.sqrt(var + 1e-8)) @ w_in.T
    agent.actor_mean[0].weight.data = w_new
    agent.actor_mean[0].bias.data = b_new
    return agent


In [None]:
def enjoy_brax(agent, render=True, deterministic=True):
    if normalize_obs:
        obs_mean, obs_var = agent.obs_normalizer.obs_rms.mean, agent.obs_normalizer.obs_rms.var
        print(f'{obs_mean=}, {obs_var=}')

    obs = env.reset()
    rollout = [env.unwrapped._state]
    total_reward = 0
    measures = torch.zeros(env_cfg.num_dims).to(device)
    done = False
    while not done:
        with torch.no_grad():
            obs = obs.unsqueeze(dim=0).to(device)
            if normalize_obs:
                obs = (obs - obs_mean) / torch.sqrt(obs_var + 1e-8)

            if deterministic:
                act = agent.actor_mean(obs)
            else:
                act, _, _ = agent.get_action(obs)
            act = act.squeeze()
            obs, rew, done, info = env.step(act.cpu())
            measures += info['measures']
            rollout.append(env.unwrapped._state)
            total_reward += rew
    if render:
        i = HTML(html.render(env.unwrapped._env.sys, [s.qp for s in rollout]))
        display(i)
    print(f'{total_reward=}')
    print(f' Rollout length: {len(rollout)}')
    measures /= len(rollout)
    print(f'Measures: {measures.cpu().numpy()}')
    return total_reward.detach().cpu().numpy()

In [None]:
agent = get_random_elite()
# make sure pre and post-processing are working correctly. This should return
# the exact same agent as the previous line
# agent = postprocess_model(agent, preprocess_model(agent, mlp_shape), mlp_shape, deterministic=False).to(device)
# if normalize_obs:
#     agent = integrate_obs_normalizer(agent)
enjoy_brax(agent, render=False)

In [None]:
# load the VAE model
autoencoder_cp_path = 'checkpoints/autoencoder.pt'
vae_model = HypernetAutoEncoder(emb_channels=8, z_channels=4)
vae_model.load_state_dict(torch.load(autoencoder_cp_path))
vae_model.to(device)

In [None]:
# get the policy input tensor
policy_tensor = preprocess_model(agent, (128, 128, 6)).to(device).unsqueeze(dim=0)
policy_tensor.shape

In [None]:
# get the policy weights dict for shank's ghn based VAE
input_weights_dict = {}
for key, param in agent.named_parameters():
    if 'weight' in key or 'bias' in key or 'logstd' in key:
        input_weights_dict[key] = param.unsqueeze(0)

In [None]:
# get the reconstructed model
out, _ = vae_model(input_weights_dict)

# this is the 'policy as a tensor' way of doing reconstruction
# model_in = Actor(obs_shape, action_shape, True, True).to(device)
# rec_agent = postprocess_model(model_in, out, (128, 128, 6), deterministic=False)
# rec_agent.obs_normalizer = agent.obs_normalizer
# rec_agent.to(device)

# this is the 'weights dict -> Actor' method of reconstruction i.e. out is already an Actor object
rec_agent = out[0]
rec_agent.obs_normalizer = agent.obs_normalizer
rec_agent.to(device)

In [None]:
# if normalize_obs:
#     rec_agent = integrate_obs_normalizer(rec_agent)

In [None]:

enjoy_brax(rec_agent, render=True, deterministic=True)