In [None]:
import os
import re
import shutil
from IPython.display import Video
from omegaconf import OmegaConf
import torch
from gymnasium.utils.save_video import save_video
from torchrl.envs import EnvBase
from diffusion_co_design.rware.env import create_env
from diffusion_co_design.rware.model import rware_models
from diffusion_co_design.utils import omega_to_pydantic, cuda, OUTPUT_DIR
from diffusion_co_design.bin.train_rware import TrainingConfig, DesignerRegistry

# Parameters
training_dir = "/home/markhaoxiang/.diffusion_co_design/training/2025-02-17/09-40-27"
device = cuda

# Get latest policy
checkpoint_dir = os.path.join(training_dir, "checkpoints")
policy_files = [f for f in os.listdir(checkpoint_dir) if re.match(r"policy_\d+\.pt", f)]
latest_policy = max(policy_files, key=lambda x: int(re.search(r"\d+", x).group()))
latest_policy = os.path.join(checkpoint_dir, latest_policy)

# Get config
hydra_dir = os.path.join(training_dir, ".hydra")
training_config = os.path.join(hydra_dir, "config.yaml")
cfg = omega_to_pydantic(OmegaConf.load(training_config), TrainingConfig)

# Create environment
cache_dir = os.path.join(OUTPUT_DIR, ".tmp")
if os.path.exists(cache_dir):
    shutil.rmtree(cache_dir)
os.makedirs(cache_dir)
master_designer, env_designer = DesignerRegistry.get(
    cfg.designer,
    cfg.scenario,
    cache_dir,
    environment_batch_size=32,
    device=device,
)
env = create_env(cfg.scenario, env_designer, render=True, device=device)
policy, _ = rware_models(env, cfg.policy, device=device)
policy.load_state_dict(torch.load(latest_policy))


def view_video(env: EnvBase, policy):
    frames = []
    video_out = os.path.join(cache_dir, "video/rl-video-episode-0.mp4")

    def append_frames(env, td):
        return frames.append(env.render())

    env.rollout(
        max_steps=cfg.scenario.max_steps,
        policy=policy,
        callback=append_frames,
        auto_cast_to_device=True,
    )

    save_video(
        frames=frames,
        video_folder=os.path.join(cache_dir, "video"),
        fps=10,
    )

    return lambda: Video(filename=video_out, embed=True)

In [None]:
view_video(env, policy)()