In [1]:
import os
from pathlib import Path
project_root = os.path.join(str(Path.home()), 'diffusion_models')
os.chdir(project_root)
%pwd # should be PPGA root dir

'/home/sumeet/diffusion_models'

In [2]:
import torch
import pickle
import json

from diffusion.gaussian_diffusion import cosine_beta_schedule, linear_beta_schedule, GaussianDiffusion
from diffusion.latent_diffusion import LatentDiffusion
from diffusion.ddim import DDIMSampler
from autoencoders.policy.hypernet import HypernetAutoEncoder as AutoEncoder
from attrdict import AttrDict
from RL.actor_critic import Actor
from models.cond_unet import ConditionalUNet
from envs.brax_custom.brax_env import make_vec_env_brax
from IPython.display import HTML, Image
from IPython.display import display
from brax.io import html, image

In [3]:
# params to config
device = torch.device('cuda')
env_name = 'halfcheetah'
seed = 1111
normalize_obs = True
normalize_rewards = False
obs_shape = 18
action_shape = 6
mlp_shape = (128, 128, 6)

env_cfg = AttrDict({
    'env_name': env_name,
    'env_batch_size': None,
    'num_dims': 2,
    'seed': seed,
    'num_envs': 1
})

In [4]:
archive_df_path = 'data/archive_100x100_global.pkl'
with open(archive_df_path, 'rb') as f:
    archive_df = pickle.load(f)

scheduler_path = 'data/scheduler_global_obs_norm.pkl'
with open(scheduler_path, 'rb') as f:
    scheduler = pickle.load(f)



In [5]:
# make the env
env = make_vec_env_brax(env_cfg)

2023-04-01 17:53:57.666330: W external/org_tensorflow/tensorflow/compiler/xla/service/gpu/nvptx_compiler.cc:497] The NVIDIA driver's CUDA version is 11.8 which is older than the ptxas CUDA version (12.0.76). Because the driver is older than the ptxas version, XLA is disabling parallel compilation, which may slow down compilation. You should update your NVIDIA driver or use the NVIDIA-provided CUDA forward compatibility packages.


In [17]:
def enjoy_brax(agent, render=True, deterministic=True):
    if normalize_obs:
        obs_mean, obs_var = agent.obs_normalizer.obs_rms.mean, agent.obs_normalizer.obs_rms.var
        print(f'{obs_mean=}')

    obs = env.reset()
    rollout = [env.unwrapped._state]
    total_reward = 0
    measures = torch.zeros(env_cfg.num_dims).to(device)
    done = False
    while not done:
        with torch.no_grad():
            obs = obs.unsqueeze(dim=0).to(device)
            if normalize_obs:
                obs = (obs - obs_mean) / torch.sqrt(obs_var + 1e-8)

            if deterministic:
                act = agent.actor_mean(obs)
            else:
                act, _, _ = agent.get_action(obs)
            act = act.squeeze()
            obs, rew, done, info = env.step(act.cpu())
            measures += info['measures']
            rollout.append(env.unwrapped._state)
            total_reward += rew
    if render:
        i = HTML(html.render(env.unwrapped._env.sys, [s.qp for s in rollout]))
        display(i)
    print(f'{total_reward=}')
    print(f' Rollout length: {len(rollout)}')
    measures /= len(rollout)
    print(f'Measures: {measures.cpu().numpy()}')
    return total_reward.detach().cpu(), measures.detach().cpu()

In [7]:
# diffusion model params
latent_diffusion = True
use_ddim = True
latent_channels = 8
latent_size = 4
timesteps = 600

cfg_path = './checkpoints/cfg.json'
with open(cfg_path, 'r') as f:
    cfg = json.load(f)
    cfg = AttrDict(cfg)

scale_factor = cfg.scale_factor if latent_diffusion else None

device = 'cuda' if torch.cuda.is_available() else 'cpu'

betas = cosine_beta_schedule(timesteps)

In [8]:
# paths to VAE and diffusion model checkpoint
model_path = './checkpoints/model_cp.pt'
autoencoder_path = './checkpoints/autoencoder_20230401-162949_autoencoder.pt'

In [9]:
# load the diffusion model
logvar = torch.full(fill_value=0., size=(timesteps,))
model = ConditionalUNet(
    in_channels=latent_channels,
    out_channels=latent_channels,
    channels=64,
    n_res_blocks=1,
    attention_levels=[],
    channel_multipliers=[1, 2, 4],
    n_heads=4,
    d_cond=256,
    logvar=logvar
)
autoencoder = AutoEncoder(emb_channels=8, z_channels=4)
autoencoder.load_state_dict(torch.load(autoencoder_path))
autoencoder.to(device)
autoencoder.eval()

gauss_diff = LatentDiffusion(betas, num_timesteps=timesteps, device=device)
model.load_state_dict(torch.load(model_path))
model.to(device)



ConditionalUNet(
  (cond_embed): Sequential(
    (0): Linear(in_features=2, out_features=256, bias=True)
    (1): SiLU()
    (2): Linear(in_features=256, out_features=256, bias=True)
  )
  (time_embed): Sequential(
    (0): SinusoidalPositionEmbeddings()
    (1): Linear(in_features=64, out_features=256, bias=True)
    (2): SiLU()
    (3): Linear(in_features=256, out_features=256, bias=True)
  )
  (input_blocks): ModuleList(
    (0): TimestepEmbedSequential(
      (0): Conv2d(8, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    )
    (1): TimestepEmbedSequential(
      (0): ResBlock(
        (in_layers): Sequential(
          (0): GroupNorm32(32, 64, eps=1e-05, affine=True)
          (1): SiLU()
          (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        )
        (emb_layers): Sequential(
          (0): SiLU()
          (1): Linear(in_features=256, out_features=64, bias=True)
        )
        (out_layers): Sequential(
          (0): GroupNorm32(32

In [34]:
ddim_sampler = DDIMSampler(gauss_diff, n_steps=100)

random_idx = torch.randint(0, 64, (1,))
print(f'{random_idx=}')

random_idx=tensor([58])


In [33]:
cond = torch.ones((64, 2)) * 0.9
cond = cond.to(device)

In [28]:
samples = ddim_sampler.sample(model, shape=[64, latent_channels, latent_size, latent_size], cond=cond)
samples = samples * (1 / scale_factor)
samples = autoencoder.decode(samples)

In [35]:
rec_agent = samples[random_idx]
obs_normalizer = scheduler.archive.sample_elites(1)[-1][0]['obs_normalizer']
rec_agent.obs_normalizer = obs_normalizer
enjoy_brax(rec_agent)

obs_mean=tensor([ 0.4914,  0.8485,  0.0244,  0.1310, -0.0145,  0.1651,  0.0988, -0.7122,
        -0.2895,  4.3650, -0.0157,  0.0596, -0.0773,  0.1568, -0.7128,  1.0650,
        -1.7570,  0.3130], device='cuda:0')


total_reward=tensor(2155.4854, device='cuda:0')
 Rollout length: 1001
Measures: [0.8321678  0.85814184]


(tensor(2155.4854), tensor([0.8322, 0.8581]))

In [40]:
def rollout_n_times(agent: Actor, N = 20):
    rews, measures = [], []
    for _ in range(N):
        f, m = enjoy_brax(agent, render=False)
        rews.append(f)
        measures.append(m)

    rews = torch.stack(rews)
    measures = torch.stack(measures)
    print(f'{measures.shape}')
    print(f'{measures.mean(0)=}')
    m_cond = torch.ones_like(measures) * 0.9
    print(m_cond)
    mse = torch.nn.functional.mse_loss(measures[:,0], m_cond[:, 0])
    mse2 = torch.nn.functional.mse_loss(measures[:,1], m_cond[:, 1])
    print(f'{mse=}, {mse2=}')


In [41]:
rollout_n_times(rec_agent, 20)

obs_mean=tensor([ 0.4914,  0.8485,  0.0244,  0.1310, -0.0145,  0.1651,  0.0988, -0.7122,
        -0.2895,  4.3650, -0.0157,  0.0596, -0.0773,  0.1568, -0.7128,  1.0650,
        -1.7570,  0.3130], device='cuda:0')
total_reward=tensor(2051.5850, device='cuda:0')
 Rollout length: 1001
Measures: [0.8151848 0.8321678]
obs_mean=tensor([ 0.4914,  0.8485,  0.0244,  0.1310, -0.0145,  0.1651,  0.0988, -0.7122,
        -0.2895,  4.3650, -0.0157,  0.0596, -0.0773,  0.1568, -0.7128,  1.0650,
        -1.7570,  0.3130], device='cuda:0')
total_reward=tensor(2209.2290, device='cuda:0')
 Rollout length: 1001
Measures: [0.80419576 0.8261738 ]
obs_mean=tensor([ 0.4914,  0.8485,  0.0244,  0.1310, -0.0145,  0.1651,  0.0988, -0.7122,
        -0.2895,  4.3650, -0.0157,  0.0596, -0.0773,  0.1568, -0.7128,  1.0650,
        -1.7570,  0.3130], device='cuda:0')
total_reward=tensor(2458.4268, device='cuda:0')
 Rollout length: 1001
Measures: [0.7492507 0.7852148]
obs_mean=tensor([ 0.4914,  0.8485,  0.0244,  0.1310, 