In [None]:
import os
from pathlib import Path
project_root = os.path.join(str(Path.home()), 'diffusion_models')
os.chdir(project_root)
%pwd # should be PPGA root dir

In [None]:
import torch
import pickle
import json
import numpy as np
import matplotlib.pyplot as plt

from diffusion.gaussian_diffusion import cosine_beta_schedule, linear_beta_schedule, GaussianDiffusion
from diffusion.latent_diffusion import LatentDiffusion
from diffusion.ddim import DDIMSampler
from autoencoders.policy.hypernet import HypernetAutoEncoder as AutoEncoder
from dataset.shaped_elites_dataset import WeightNormalizer
from attrdict import AttrDict
from utils.tensor_dict import TensorDict, cat_tensordicts
from RL.actor_critic import Actor
from utils.normalize import ObsNormalizer
from models.cond_unet import ConditionalUNet
from envs.brax_custom.brax_env import make_vec_env_brax
from IPython.display import HTML, Image
from IPython.display import display
from brax.io import html, image
from utils.brax_utils import shared_params, rollout_many_agents

In [None]:
# params to config
device = torch.device('cuda')
env_name = 'humanoid'
seed = 1111
normalize_obs = True
normalize_rewards = False
obs_shape = shared_params[env_name]['obs_dim']
action_shape = np.array([shared_params[env_name]['action_dim']])
mlp_shape = (128, 128, action_shape)

env_cfg = AttrDict({
    'env_name': env_name,
    'env_batch_size': None,
    'num_dims': 2,
    'seed': seed,
    'num_envs': 1,
    'clip_obs_rew': True,
})

In [None]:
archive_df_path = f'data/{env_name}/archive_100x100.pkl'
with open(archive_df_path, 'rb') as f:
    archive_df = pickle.load(f)

scheduler_path = f'data/{env_name}/scheduler_100x100.pkl'
with open(scheduler_path, 'rb') as f:
    scheduler = pickle.load(f)

In [None]:
# make the env
env = make_vec_env_brax(env_cfg)

In [None]:
def enjoy_brax(agent, render=True, deterministic=True):
    if normalize_obs:
        obs_mean, obs_var = agent.obs_normalizer.obs_rms.mean, agent.obs_normalizer.obs_rms.var
        print(f'Normalize Obs Enabled')

    obs = env.reset()
    rollout = [env.unwrapped._state]
    total_reward = 0
    measures = torch.zeros(env_cfg.num_dims).to(device)
    done = False
    while not done:
        with torch.no_grad():
            obs = obs.unsqueeze(dim=0).to(device)
            if normalize_obs:
                obs = (obs - obs_mean) / torch.sqrt(obs_var + 1e-8)

            if deterministic:
                act = agent.actor_mean(obs)
            else:
                act, _, _ = agent.get_action(obs)
            act = act.squeeze()
            obs, rew, done, info = env.step(act.cpu())
            measures += info['measures']
            rollout.append(env.unwrapped._state)
            total_reward += rew
    if render:
        i = HTML(html.render(env.unwrapped._env.sys, [s.qp for s in rollout]))
        display(i)
    print(f'{total_reward=}')
    print(f' Rollout length: {len(rollout)}')
    measures /= len(rollout)
    print(f'Measures: {measures.cpu().numpy()}')
    return total_reward.detach().cpu(), measures.detach().cpu()

In [None]:
# diffusion model params
latent_diffusion = True
use_ddim = True
center_data = True
latent_channels = 4
latent_size = 4
timesteps = 600

cfg_path = 'results/humanoid/diffusion_model/humanoid_diffusion_model_paper_111/humanoid_diffusion_model_20230504-041644_111/args.json'
with open(cfg_path, 'r') as f:
    cfg = json.load(f)
    cfg = AttrDict(cfg)

scale_factor = cfg.scale_factor if latent_diffusion else None

device = 'cuda' if torch.cuda.is_available() else 'cpu'

betas = cosine_beta_schedule(timesteps)

In [None]:
# paths to VAE and diffusion model checkpoint
autoencoder_path = 'results/humanoid/autoencoder/humanoid_autoencoder_paper_111/model_checkpoints/humanoid_autoencoder_20230502-081156_111.pt'
model_path = 'results/humanoid/diffusion_model/humanoid_diffusion_model_paper_111/humanoid_diffusion_model_20230504-041644_111/model_checkpoints/humanoid_diffusion_model_20230504-041644_111.pt'
weight_normalizer_path = 'results/humanoid/weight_normalizer.pkl'

In [None]:
# load the diffusion model
logvar = torch.full(fill_value=0., size=(timesteps,))
model = ConditionalUNet(
    in_channels=latent_channels,
    out_channels=latent_channels,
    channels=64,
    n_res_blocks=1,
    attention_levels=[],
    channel_multipliers=[1, 2, 4],
    n_heads=4,
    d_cond=256,
    logvar=logvar
)
autoencoder = AutoEncoder(emb_channels=4,
                          z_channels=4,
                          obs_shape=obs_shape,
                          action_shape=action_shape,
                          z_height=4,
                          enc_fc_hid=64,
                          obsnorm_hid=64,
                          ghn_hid=8)
autoencoder.load_state_dict(torch.load(autoencoder_path))
autoencoder.to(device)
autoencoder.eval()

gauss_diff = LatentDiffusion(betas, num_timesteps=timesteps, device=device)
model.load_state_dict(torch.load(model_path))
model.to(device)

weight_normalizer = None
if center_data:
    weight_normalizer = WeightNormalizer(TensorDict({}), TensorDict({}))
    weight_normalizer.load(weight_normalizer_path)


In [None]:
def postprocess_agents(rec_agents: list[Actor], obsnorms: list[dict]):
    '''Denormalize outputs of the decoder and return a list of Actors that can be rolled out'''
    batch_size = len(rec_agents)
    TensorDict(obsnorms)
    rec_agents_params = [TensorDict(p.state_dict()) for p in rec_agents]
    rec_agents_params = cat_tensordicts(rec_agents_params)
    rec_agents_params.update(obsnorms)
    # decoder doesn't fill in the logstd param, so we manually set it to default values
    actor_logstd = torch.zeros(batch_size, 1, action_shape[0])
    actor_logstd = actor_logstd.to(device)
    rec_agents_params['actor_logstd'] = actor_logstd
    # if data centering was used during training, we need to denormalize the weights
    if center_data:
        rec_agents_params = weight_normalizer.denormalize(rec_agents_params)

    if normalize_obs:
        rec_agents_params['obs_normalizer.obs_rms.var'] = torch.exp(rec_agents_params['obs_normalizer.obs_rms.logstd'] * 2)
        rec_agents_params['obs_normalizer.obs_rms.count'] = torch.zeros(batch_size, 1).to(device)
        del rec_agents_params['obs_normalizer.obs_rms.logstd']

    rec_agents = [Actor(obs_shape, action_shape, normalize_obs=normalize_obs).to(device) for _ in range(len(rec_agents_params))]
    for i in range(len(rec_agents_params)):
        rec_agents[i].load_state_dict(rec_agents_params[i])

    return rec_agents

In [None]:
ddim_sampler = DDIMSampler(gauss_diff, n_steps=100)

In [None]:
def get_agent_with_measure(m: list[float]):
    batch_size = 1
    cond = torch.Tensor(m).view(1, -1).to(device)
    shape = [batch_size, latent_channels, latent_size, latent_size]
    samples = ddim_sampler.sample(model, shape=shape, cond=cond, classifier_free_guidance=True, classifier_scale=1.0)
    samples = samples * (1 / scale_factor)
    (rec_agents, obsnorms) = autoencoder.decode(samples)
    rec_agents = postprocess_agents(rec_agents, obsnorms)
    return rec_agents[0]

In [None]:
batch_size = 1
cond = torch.ones((batch_size, 2)) * 0.2
cond = cond.to(device)

In [None]:
shape = [batch_size, latent_channels, latent_size, latent_size]
samples = ddim_sampler.sample(model, shape=shape, cond=cond, classifier_free_guidance=True, classifier_scale=2.0)
samples = samples * (1 / scale_factor)
(rec_agents, obsnorms) = autoencoder.decode(samples)
rec_agents = postprocess_agents(rec_agents, obsnorms)

In [None]:
# random_idx = torch.randint(0, batch_size, (1,))
# print(f'{random_idx=}')
# print(len(obsnorms))
# rec_agent = rec_agents[random_idx]
rec_agent = get_agent_with_measure([0.5, 0.5])
enjoy_brax(rec_agent)

In [None]:
# evaluate an agent on many envs in parallel
N = 50
multi_env_cfg = AttrDict({
    'env_name': env_name,
    'env_batch_size': N,
    'num_envs': N,
    'num_dims': 2,
    'seed': seed,
    'clip_obs_rew': True,
})
multi_vec_env = make_vec_env_brax(multi_env_cfg)

In [None]:
rollout_many_agents([rec_agent], multi_env_cfg, multi_vec_env, device, verbose=True, normalize_obs=True)

In [None]:
def compose_behaviors(measures, env, env_cfg, device, deterministic: bool = True, render: bool = True):
    num_chunks = len(measures)
    agents = []
    for m in measures:
        agent = get_agent_with_measure(m)
        agents.append(agent)

    # https://stackoverflow.com/questions/2130016/splitting-a-list-into-n-parts-of-approximately-equal-length
    def split(a, n):
        k, m = divmod(len(a), n)
        return (a[i*k+min(i, m):(i+1)*k+min(i+1, m)] for i in range(n))

    time_intervals = list(split(np.arange(0, 1000), num_chunks))

    obs = env.reset()
    rollout = [env.unwrapped._state]
    total_reward = 0
    done = False
    # get the per-chunk measures independent of other chunks
    measure_data = [[] for _ in range(num_chunks)]

    t = 0
    while not done:
        interval_idx = next((i for i, interval in enumerate(time_intervals) if t in interval), None)
        agent = agents[interval_idx]
        obs_mean, obs_var = agent.obs_normalizer.obs_rms.mean, agent.obs_normalizer.obs_rms.var

        obs = (obs - obs_mean) / torch.sqrt(obs_var + 1e-8)

        if deterministic:
                act = agent.actor_mean(obs)
        else:
            act, _, _ = agent.get_action(obs)
        act = act.squeeze()
        obs, rew, done, info = env.step(act.cpu())
        measure_t = info['measures']
        measure_data[interval_idx].append(measure_t.detach().cpu().numpy())
        rollout.append(env.unwrapped._state)
        total_reward += rew
        t += 1
    if render:
        i = HTML(html.render(env.unwrapped._env.sys, [s.qp for s in rollout]))
        display(i)
    print(f'{total_reward=}, Trajectory Length: {t}')
    return measure_data

In [None]:
measures = [
    [0.9, 0.9],
    [0.2, 0.2],
    [0.5, 0.0],
    [0.0, 0.5]
]
measure_data = compose_behaviors(measures, env, env_cfg, device, render=False)

In [None]:
# get the avg measure for each time interval independent of the other ones. Sanity check
interval_measures = []
for ms in measure_data:
    ms = np.mean(np.array(ms), axis=0)
    interval_measures.append(ms)
print(f'{interval_measures=}')

# get the moving average measures
window_size = 50
moving_averages = []
all_measure_data = np.concatenate(measure_data)
t = 0
while t < len(all_measure_data) - window_size + 1:
    window_average = np.sum(all_measure_data[t: t + window_size], axis=0) / window_size
    moving_averages.append(window_average)
    t += 1

In [None]:
plt.plot(np.arange(0, len(moving_averages)), [moving_averages[i][0] for i in range(len(moving_averages))])

In [None]:
fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(10, 6))
ax1.plot(np.arange(0, len(moving_averages)), [moving_averages[i][0] for i in range(len(moving_averages))])
ax2.plot(np.arange(0, len(moving_averages)), [moving_averages[i][1] for i in range(len(moving_averages))])
ax1.set_ylabel('Measure 0')
ax2.set_ylabel('Measure 1')

In [None]:
measures = [
    [0.9, 0.9],
    [0.2, 0.2],
    [0.5, 0.0],
    [0.0, 0.5]
]