In [None]:
import os
import shutil
from IPython.display import Video
from omegaconf import OmegaConf
from tqdm import tqdm
from gymnasium.utils.save_video import save_video
import torch
from tensordict import TensorDict
from torchrl.envs import EnvBase
from torchrl.data import ReplayBuffer, LazyTensorStorage, SamplerWithoutReplacement
from matplotlib import pyplot as plt
from guided_diffusion.script_util import create_classifier, classifier_defaults
from rware.warehouse import Warehouse
import pickle as pkl
from diffusion_co_design.rware.env import create_env, create_batched_env
from diffusion_co_design.rware.model import rware_models
from diffusion_co_design.utils import (
    omega_to_pydantic,
    get_latest_model,
    cuda,
    OUTPUT_DIR,
)
from diffusion_co_design.pretrain.rware.transform import storage_to_layout
from diffusion_co_design.bin.train_rware import TrainingConfig, DesignerRegistry
from diffusion_co_design.pretrain.rware.generator import Generator, GeneratorConfig

FIGURE_SIZE_CNST = 3

# Parameters
training_dir = "/home/markhaoxiang/.diffusion_co_design/training/2025-02-17/09-40-27"
device = cuda

# Get latest policy
checkpoint_dir = os.path.join(training_dir, "checkpoints")
latest_policy = get_latest_model(checkpoint_dir, "policy_")

# Get config
hydra_dir = os.path.join(training_dir, ".hydra")
training_config = os.path.join(hydra_dir, "config.yaml")
cfg = omega_to_pydantic(OmegaConf.load(training_config), TrainingConfig)

# Create environment
cache_dir = os.path.join(OUTPUT_DIR, ".tmp")
if os.path.exists(cache_dir):
    shutil.rmtree(cache_dir)
os.makedirs(cache_dir)
master_designer, env_designer = DesignerRegistry.get(
    "random",
    cfg.scenario,
    cache_dir,
    environment_batch_size=32,
    device=device,
)
env = create_env(cfg.scenario, env_designer, render=True, device=device)
policy, _ = rware_models(env, cfg.policy, device=device)
policy.load_state_dict(torch.load(latest_policy))


def view_video(env: EnvBase, policy):
    frames = []
    video_out = os.path.join(cache_dir, "video/rl-video-episode-0.mp4")

    def append_frames(env, td):
        return frames.append(env.render())

    env.rollout(
        max_steps=cfg.scenario.max_steps,
        policy=policy,
        callback=append_frames,
        auto_cast_to_device=True,
    )

    save_video(
        frames=frames,
        video_folder=os.path.join(cache_dir, "video"),
        fps=10,
    )

    return lambda: Video(filename=video_out, embed=True)

In [2]:
# view_video(env, policy)()

In [None]:
# Test if the value function is able to discern good environments!

NUM_PARALLEL_COLLECTION = 25
DATASET_SIZE = 10_000
# DATASET_SIZE = 4
BATCH_SIZE = 128

collection_env = create_batched_env(
    num_environments=NUM_PARALLEL_COLLECTION,
    scenario=cfg.scenario,
    designer=env_designer,
    is_eval=False,
    device="cpu",
)

env_returns = ReplayBuffer(
    storage=LazyTensorStorage(max_size=DATASET_SIZE),
    sampler=SamplerWithoutReplacement(),
    batch_size=BATCH_SIZE,
)

for _ in tqdm(range(DATASET_SIZE // NUM_PARALLEL_COLLECTION)):
    rollout = collection_env.rollout(
        max_steps=cfg.scenario.max_steps, policy=policy, auto_cast_to_device=True
    )
    done = rollout.get(("next", "done"))
    X = rollout.get("state")[done.squeeze()]
    y = rollout.get(("next", "agents", "episode_reward")).mean(-2)[done]
    data = TensorDict({"env": X, "episode_reward": y}, batch_size=len(y))
    env_returns.extend(data)
del rollout, done

In [None]:
GUIDANCE_WT = 50
TRAIN_NUM_ITERATIONS = 10000
VALUE_LR = 3e-4
VALUE_WEIGHT_DECAY = 0.05
OUTPUT_CHANNELS = 1

# Create value model
pretrain_dir = os.path.join(OUTPUT_DIR, "diffusion_pretrain", cfg.scenario.name)
latest_checkpoint = get_latest_model(pretrain_dir, "model")

gen_cfg = GeneratorConfig(
    batch_size=8,
    generator_model_path=latest_checkpoint,
    size=cfg.scenario.size,
    num_channels=OUTPUT_CHANNELS,
)
generator = Generator(gen_cfg, guidance_wt=GUIDANCE_WT)

model_dict = classifier_defaults()
model_dict["image_size"] = cfg.scenario.size
model_dict["image_channels"] = OUTPUT_CHANNELS
model_dict["classifier_width"] = 256
model_dict["classifier_depth"] = 2
model_dict["classifier_attention_resolutions"] = "16, 8, 4"
model_dict["output_dim"] = 1

model = create_classifier(**model_dict).to(device)

# Train
optim = torch.optim.Adam(
    model.parameters(), lr=VALUE_LR, weight_decay=VALUE_WEIGHT_DECAY
)
criterion = torch.nn.MSELoss()

model.train()
losses = []
with tqdm(range(TRAIN_NUM_ITERATIONS)) as pbar:
    for _ in range(TRAIN_NUM_ITERATIONS):
        optim.zero_grad()
        sample = env_returns.sample(batch_size=BATCH_SIZE)
        X_batch = sample.get("env").to(dtype=torch.float32, device=device)
        y_batch = sample.get("episode_reward").to(dtype=torch.float32, device=device)
        t, _ = generator.schedule_sampler.sample(len(X_batch), device)
        X_batch = generator.diffusion.q_sample(X_batch, t)
        y_pred = model(X_batch, t).squeeze()
        loss = criterion(y_pred, y_batch)
        loss.backward()
        optim.step()

        pbar.set_description(f"Loss {loss.item()}")
        pbar.update()

In [None]:
torch.cuda.empty_cache()

env_returns_sorted_index = torch.argsort(
    env_returns.storage["episode_reward"], descending=True
)
best_5 = env_returns_sorted_index[:5]
worst_5 = env_returns_sorted_index[-5:]

fig, axs = plt.subplots(2, 5)
fig.set_size_inches(5 * FIGURE_SIZE_CNST, 2 * FIGURE_SIZE_CNST)

for i, idx in enumerate(best_5):
    ax = axs[0, i]
    layout = storage_to_layout(
        env_returns.storage["env"][idx], cfg.scenario.agent_idxs, cfg.scenario.goal_idxs
    )
    print(env_returns.storage["episode_reward"][idx])
    warehouse = Warehouse(layout=layout, render_mode="rgb_array")
    im = warehouse.render()
    ax.imshow(im)
    warehouse.close()
    ax.axis("off")

print("===")

for i, idx in enumerate(worst_5):
    ax = axs[1, i]
    layout = storage_to_layout(
        env_returns.storage["env"][idx], cfg.scenario.agent_idxs, cfg.scenario.goal_idxs
    )
    print(env_returns.storage["episode_reward"][idx])
    warehouse = Warehouse(layout=layout, render_mode="rgb_array")
    im = warehouse.render()
    ax.imshow(im)
    warehouse.close()
    ax.axis("off")

In [None]:
best_5_env = env_returns.storage["env"][best_5].to(device=device, dtype=torch.float32)
model(best_5_env)

In [None]:
model(env_returns.storage["env"][worst_5].to(device=device, dtype=torch.float32))