In [None]:
import os
from pathlib import Path
project_root = os.path.join(str(Path.home()), 'diffusion_models')
os.chdir(project_root)
%pwd # should be PPGA root dir

In [None]:
import torch
import pickle
import json

from diffusion.gaussian_diffusion import cosine_beta_schedule, linear_beta_schedule, GaussianDiffusion
from diffusion.latent_diffusion import LatentDiffusion
from diffusion.ddim import DDIMSampler
from autoencoders.policy.hypernet import HypernetAutoEncoder as AutoEncoder
from attrdict import AttrDict
from RL.actor_critic import Actor
from models.unet import Unet
from envs.brax_custom.brax_env import make_vec_env_brax
from IPython.display import HTML, Image
from IPython.display import display
from brax.io import html, image

In [None]:
# params to config
device = torch.device('cuda')
env_name = 'halfcheetah'
seed = 1111
normalize_obs = False
normalize_rewards = False
obs_shape = 18
action_shape = 6
mlp_shape = (128, 128, 6)

env_cfg = AttrDict({
    'env_name': env_name,
    'env_batch_size': None,
    'num_dims': 2,
    'seed': seed,
    'num_envs': 1
})

In [None]:
archive_df_path = 'data/archive_100x100_no_obs_norm.pkl'
with open(archive_df_path, 'rb') as f:
    archive_df = pickle.load(f)

scheduler_path = 'data/scheduler_no_obs_norm.pkl'
with open(scheduler_path, 'rb') as f:
    scheduler = pickle.load(f)

In [None]:
# make the env
env = make_vec_env_brax(env_cfg)

In [None]:
def enjoy_brax(agent, render=True, deterministic=True):
    if normalize_obs:
        obs_mean, obs_var = agent.obs_normalizer.obs_rms.mean, agent.obs_normalizer.obs_rms.var
        print(f'{obs_mean=}')

    obs = env.reset()
    rollout = [env.unwrapped._state]
    total_reward = 0
    measures = torch.zeros(env_cfg.num_dims).to(device)
    done = False
    while not done:
        with torch.no_grad():
            obs = obs.unsqueeze(dim=0).to(device)
            if normalize_obs:
                obs = (obs - obs_mean) / torch.sqrt(obs_var + 1e-8)

            if deterministic:
                act = agent.actor_mean(obs)
            else:
                act, _, _ = agent.get_action(obs)
            act = act.squeeze()
            obs, rew, done, info = env.step(act.cpu())
            measures += info['measures']
            rollout.append(env.unwrapped._state)
            total_reward += rew
    if render:
        i = HTML(html.render(env.unwrapped._env.sys, [s.qp for s in rollout]))
        display(i)
    print(f'{total_reward=}')
    print(f' Rollout length: {len(rollout)}')
    measures /= len(rollout)
    print(f'Measures: {measures.cpu().numpy()}')
    return total_reward.detach().cpu().numpy()

In [None]:
# diffusion model params
latent_diffusion = True
use_ddim = True
latent_channels = 8
latent_size = 4
timesteps = 600

cfg_path = './checkpoints/cfg.json'
with open(cfg_path, 'r') as f:
    cfg = json.load(f)
    cfg = AttrDict(cfg)

scale_factor = cfg.scale_factor if latent_diffusion else None

device = 'cuda' if torch.cuda.is_available() else 'cpu'

betas = cosine_beta_schedule(timesteps)

In [None]:
# paths to VAE and diffusion model checkpoint
model_path = './checkpoints/model_cp.pt'
autoencoder_path = './checkpoints/autoencoder.pt'

In [None]:
# load the diffusion model
logvar = torch.full(fill_value=0., size=(timesteps,))
model = Unet(
    dim=64,
    channels=latent_channels,
    dim_mults=(1, 2, 4,),
    logvar=logvar
)
autoencoder = AutoEncoder(emb_channels=8, z_channels=4)
autoencoder.load_state_dict(torch.load(autoencoder_path))
autoencoder.to(device)
autoencoder.eval()

gauss_diff = LatentDiffusion(betas, num_timesteps=timesteps, device=device)
model.load_state_dict(torch.load(model_path))
model.to(device)



In [None]:
ddim_sampler = DDIMSampler(gauss_diff, n_steps=50)

random_idx = torch.randint(0, 64, (1,))
print(f'{random_idx=}')

In [None]:
samples = ddim_sampler.sample(model, shape=[64, latent_channels, latent_size, latent_size], cond=None)
samples = samples * (1. / scale_factor)
samples = autoencoder.decode(samples)

In [None]:
rec_agent = samples[random_idx]
enjoy_brax(rec_agent)