In [None]:
import os
from pathlib import Path
from utils.tensor_dict import TensorDict
project_root = os.path.join(str(Path.home()), 'diffusion_models')
os.chdir(project_root)
%pwd # should be PPGA root dir

In [None]:
import torch
import pickle
import json
import numpy as np

from diffusion.gaussian_diffusion import cosine_beta_schedule, linear_beta_schedule, GaussianDiffusion
from diffusion.latent_diffusion import LatentDiffusion
from diffusion.ddim import DDIMSampler
from autoencoders.policy.hypernet import HypernetAutoEncoder as AutoEncoder
from attrdict import AttrDict
from RL.actor_critic import Actor
from utils.normalize import ObsNormalizer
from models.cond_unet import ConditionalUNet
from envs.brax_custom.brax_env import make_vec_env_brax
from IPython.display import HTML, Image
from IPython.display import display
from brax.io import html, image
from utils.brax_utils import shared_params

In [None]:
# params to config
device = torch.device('cuda')
env_name = 'humanoid'
seed = 1111
normalize_obs = True
normalize_rewards = False
obs_shape = shared_params[env_name]['obs_dim']
action_shape = np.array([shared_params[env_name]['action_dim']])
mlp_shape = (128, 128, action_shape)

env_cfg = AttrDict({
    'env_name': env_name,
    'env_batch_size': None,
    'num_dims': 2,
    'seed': seed,
    'num_envs': 1,
    'clip_obs_rew': True,
})

In [None]:
archive_df_path = f'data/{env_name}/archive_100x100.pkl'
with open(archive_df_path, 'rb') as f:
    archive_df = pickle.load(f)

scheduler_path = f'data/{env_name}/scheduler_100x100.pkl'
with open(scheduler_path, 'rb') as f:
    scheduler = pickle.load(f)

In [None]:
# make the env
env = make_vec_env_brax(env_cfg)

In [None]:
def enjoy_brax(agent, render=True, deterministic=True):
    if normalize_obs:
        obs_mean, obs_var = agent.obs_normalizer.obs_rms.mean, agent.obs_normalizer.obs_rms.var
        print(f'Normalize Obs Enabled')

    obs = env.reset()
    rollout = [env.unwrapped._state]
    total_reward = 0
    measures = torch.zeros(env_cfg.num_dims).to(device)
    done = False
    while not done:
        with torch.no_grad():
            obs = obs.unsqueeze(dim=0).to(device)
            if normalize_obs:
                obs = (obs - obs_mean) / torch.sqrt(obs_var + 1e-8)

            if deterministic:
                act = agent.actor_mean(obs)
            else:
                act, _, _ = agent.get_action(obs)
            act = act.squeeze()
            obs, rew, done, info = env.step(act.cpu())
            measures += info['measures']
            rollout.append(env.unwrapped._state)
            total_reward += rew
    if render:
        i = HTML(html.render(env.unwrapped._env.sys, [s.qp for s in rollout]))
        display(i)
    print(f'{total_reward=}')
    print(f' Rollout length: {len(rollout)}')
    measures /= len(rollout)
    print(f'Measures: {measures.cpu().numpy()}')
    return total_reward.detach().cpu(), measures.detach().cpu()

In [None]:
# diffusion model params
latent_diffusion = True
use_ddim = True
latent_channels = 4
latent_size = 4
timesteps = 600

cfg_path = 'results/humanoid/diffusion_model/humanoid_diffusion_model_20230501-174628_0/args.json'
with open(cfg_path, 'r') as f:
    cfg = json.load(f)
    cfg = AttrDict(cfg)

scale_factor = cfg.scale_factor if latent_diffusion else None

device = 'cuda' if torch.cuda.is_available() else 'cpu'

betas = cosine_beta_schedule(timesteps)

In [None]:
# paths to VAE and diffusion model checkpoint
model_path = 'results/humanoid/diffusion_model/humanoid_diffusion_model_20230501-174628_0/model_checkpoints/humanoid_diffusion_model_20230501-174628_0.pt'
autoencoder_path = 'results/humanoid/autoencoder/humanoid_autoencoder_20230501-104217/model_checkpoints/humanoid_autoencoder_20230501-104217.pt'

In [None]:
# load the diffusion model
logvar = torch.full(fill_value=0., size=(timesteps,))
model = ConditionalUNet(
    in_channels=latent_channels,
    out_channels=latent_channels,
    channels=64,
    n_res_blocks=1,
    attention_levels=[],
    channel_multipliers=[1, 2, 4],
    n_heads=4,
    d_cond=256,
    logvar=logvar
)
autoencoder = AutoEncoder(emb_channels=4,
                          z_channels=4,
                          obs_shape=obs_shape,
                          action_shape=action_shape,
                          z_height=4,
                          enc_fc_hid=64,
                          obsnorm_hid=64,
                          ghn_hid=8)
autoencoder.load_state_dict(torch.load(autoencoder_path))
autoencoder.to(device)
autoencoder.eval()

gauss_diff = LatentDiffusion(betas, num_timesteps=timesteps, device=device)
model.load_state_dict(torch.load(model_path))
model.to(device)



In [None]:
ddim_sampler = DDIMSampler(gauss_diff, n_steps=100)

In [None]:
cond = torch.ones((64, 2)) * 0.9
cond = cond.to(device)

In [None]:
shape = [64, latent_channels, latent_size, latent_size]
samples = ddim_sampler.sample(model, shape=shape, cond=cond)
samples = samples * (1 / scale_factor)
(rec_agents, obsnorms) = autoencoder.decode(samples)
obsnorms = TensorDict(obsnorms)

In [None]:
random_idx = torch.randint(0, 64, (1,))
print(f'{random_idx=}')
print(len(obsnorms))
rec_agent = rec_agents[random_idx]

obsnorm_sd = {
    'obs_rms.mean': obsnorms[random_idx]['obs_normalizer.obs_rms.mean'].flatten(),
    'obs_rms.var': torch.exp(2 * obsnorms[random_idx]['obs_normalizer.obs_rms.logstd']).flatten(),
    'obs_rms.count': torch.zeros(1)
}
obs_normalizer = ObsNormalizer(obs_shape).to(device)
obs_normalizer.load_state_dict(obsnorm_sd)

rec_agent.obs_normalizer = obs_normalizer
enjoy_brax(rec_agent)

In [None]:
def rollout_n_times(agent: Actor, N = 20):
    rews, measures = [], []
    for _ in range(N):
        f, m = enjoy_brax(agent, render=False)
        rews.append(f)
        measures.append(m)

    rews = torch.stack(rews)
    measures = torch.stack(measures)
    print(f'{measures.shape}')
    print(f'{measures.mean(0)=}')
    m_cond = torch.ones_like(measures) * 0.9
    print(m_cond)
    mse = torch.nn.functional.mse_loss(measures[:,0], m_cond[:, 0])
    mse2 = torch.nn.functional.mse_loss(measures[:,1], m_cond[:, 1])
    print(f'{mse=}, {mse2=}')
