In [None]:
import torch
import matplotlib.pyplot as plt
from matplotlib.patches import Circle

from diffusion_co_design.wfcrl.schema import ScenarioConfig, RLConfig
from diffusion_co_design.wfcrl.design import FixedDesigner, RandomDesigner
from diffusion_co_design.wfcrl.env import create_env
from diffusion_co_design.wfcrl.model import wfcrl_models


In [None]:
scenario = ScenarioConfig(
    n_turbines=10,
    max_steps=150,
    map_x_length=1000,
    map_y_length=1000,
    min_distance_between_turbines=100,
)
designer = FixedDesigner(scenario, seed=0)
env = create_env(mode="reference", scenario=scenario, designer=designer)
env.check_env_specs()
policy, critic = wfcrl_models(
    env,
    RLConfig(
        backbone_depth=3,
        edge_hidden_size=32,
        mlp_hidden_size=128,
        node_hidden_size=128,
    ),
    "cpu",
)

base = "/home/markhaoxiang/.diffusion_co_design/experiments/train_wfcrl/2025-04-30/02-21-47/checkpoints"


policy.load_state_dict(torch.load(base + "/policy_1900.pt"))

critic.load_state_dict(torch.load(base + "/critic_1980.pt"))

coords = designer.layout_image

fig, ax = plt.subplots(figsize=(6, 6))

# Scatter points
ax.scatter(coords[:, 0], coords[:, 1], alpha=0.7, edgecolors="k", zorder=2)

# Draw circles of radius 50
for x, y in coords:
    circle = Circle(
        (x, y), radius=50, edgecolor="r", facecolor="none", linewidth=1.2, zorder=1
    )
    ax.add_patch(circle)

ax.set_title("2D Coordinate Scatter Plot with Circles")
ax.set_xlabel("X")
ax.set_ylabel("Y")
ax.axis("equal")
ax.grid(True)
ax.set_xlim(0, scenario.map_x_length)
ax.set_ylim(0, scenario.map_y_length)
plt.show()

In [None]:
td = env.rollout(max_steps=150, policy=policy)

In [None]:
td["terminated"]