In [None]:
import torch
from matplotlib import pyplot as plt

from diffusion_co_design.wfcrl.schema import (
    ScenarioConfig,
    TrainingConfig,
    Diffusion,
    ClassifierConfig,
)
from diffusion_co_design.wfcrl.diffusion.generator import (
    OptimizerDetails,
    soft_projection_constraint,
    eval_to_train,
)
from diffusion_co_design.wfcrl.design import FixedDesigner, DesignerRegistry
from diffusion_co_design.wfcrl.env import create_env, _create_designable_windfarm
from diffusion_co_design.wfcrl.model.rl import wfcrl_models

In [None]:
train_cfg = TrainingConfig.from_file(
    "/home/markhaoxiang/.diffusion_co_design/experiments/train_wfcrl/2025-05-29/12-39-25/.hydra/config.yaml"
)
scenario = train_cfg.scenario

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
designer = FixedDesigner(scenario=scenario)
env = create_env(mode="reference", scenario=scenario, designer=designer, device=device)
env.check_env_specs()
policy, critic = wfcrl_models(
    env,
    train_cfg.policy,
    normalisation=train_cfg.normalisation,
    device=device,
)

diffusion, _ = DesignerRegistry.get(
    designer=Diffusion(
        type="diffusion",
        model=ClassifierConfig(
            node_emb_size=64,
            edge_emb_size=32,
            depth=2,
        ),
    ),
    artifact_dir=".",
    normalisation_statistics=train_cfg.normalisation,
    scenario=scenario,
    ppo_cfg=train_cfg.ppo,
    device=device,
)

# critic.load_state_dict(
#     torch.load(
#         "/home/markhaoxiang/.diffusion_co_design/experiments/train_wfcrl/2025-05-29/12-39-25/checkpoints/critic_200.pt"
#     )
# )

In [None]:
diffusion.master_designer.model.load_state_dict(
    torch.load(
        "/home/markhaoxiang/.diffusion_co_design/experiments/train_wfcrl/2025-05-29/13-24-21/checkpoints/designer_150.pt",
    )
)

model = diffusion.master_designer.model

generator = diffusion.master_designer.generator
operation = OptimizerDetails()
generator.guidance_weight = 2.0
operation.num_recurrences = 4
operation.lr = 0.01
operation.backward_steps = 8
operation.use_forward = True
operation.projection_constraint = soft_projection_constraint(scenario)

batch = generator.generate_batch(
    value=model, use_operation=True, operation_override=operation, batch_size=9
)
X = eval_to_train(torch.tensor(batch).to(device), cfg=scenario)
y = model(X)
print(y.mean())

In [None]:
for x in batch:
    env = _create_designable_windfarm(
        scenario=scenario,
        initial_xcoords=x[:, 0].tolist(),
        initial_ycoords=x[:, 1].tolist(),
        render=True,
    )

    env.reset()
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.axis("off")
    ax.imshow(env.render(), aspect="auto")
