In [None]:
import os
import torch
import matplotlib.pyplot as plt
from diffusion_co_design.common import get_latest_model, cuda

from diffusion_co_design.common.design import DesignerParams
from diffusion_co_design.vmas.model.rl import create_critic
from diffusion_co_design.vmas.schema import (
    TrainingConfig,
    EnvCriticConfig,
    DiffusionOperation,
)
from diffusion_co_design.vmas.design import DicodeDesigner, RandomDesigner
from diffusion_co_design.vmas.scenario.env import create_env, render_layout

device = cuda
training_dir = (
    "/home/markhaoxiang/.diffusion_co_design/experiments/train_vmas/2025-08-28/23-03-28"
)

cfg = TrainingConfig.from_file(os.path.join(training_dir, ".hydra", "config.yaml"))

baseline = RandomDesigner(DesignerParams.placeholder(cfg.scenario))

designer = DicodeDesigner(
    designer_setting=DesignerParams.placeholder(cfg.scenario),
    classifier=EnvCriticConfig(
        depth=cfg.policy.critic.depth,
        hidden_size=cfg.policy.critic.hidden_size,
        k=cfg.policy.critic.k,
    ),
    diffusion=DiffusionOperation(
        num_recurrences=8,
        backward_lr=0.01,
        backward_steps=6,
        forward_guidance_wt=50,
        forward_guidance_annealing=False,
    ),
    device=device,
)

ref_env = create_env(
    mode="reference",
    scenario=cfg.scenario,
    designer=RandomDesigner(DesignerParams.placeholder(cfg.scenario)).get_placeholder(),
    device=device,
)

critic = create_critic(
    env=ref_env, scenario=cfg.scenario, cfg=cfg.policy.critic, device=device
)

state_dict = torch.load(
    get_latest_model(dir=os.path.join(training_dir, "checkpoints"), prefix="critic_")
)

designer.model.load_state_dict(critic.module.state_dict())

baseline_envs = torch.stack(baseline.generate_layout_batch(batch_size=9)).to(device)
envs = torch.stack(designer.generate_layout_batch(batch_size=9))

print(designer.model(baseline_envs).mean())
print(designer.model(envs).mean())

fig, axes = plt.subplots(2, 9, figsize=(18, 4))
for i in range(9):
    axes[0, i].imshow(render_layout(x=baseline_envs[i], scenario=cfg.scenario))
    axes[0, i].axis("off")
    axes[1, i].imshow(render_layout(x=envs[i], scenario=cfg.scenario))
    axes[1, i].axis("off")