In [None]:
import os
import torch
from matplotlib import pyplot as plt

from experiments.train_wfcrl.main import TrainingConfig
from conf.schema import Config
from diffusion_co_design.common import cuda as device
from diffusion_co_design.wfcrl.model.classifier import GNNCritic
from diffusion_co_design.wfcrl.model.rl import maybe_make_denormaliser
from diffusion_co_design.wfcrl.diffusion.generator import eval_to_train

from diffusion_co_design.wfcrl.schema import ScenarioConfig
from diffusion_co_design.wfcrl.diffusion.generator import (
    Generator,
    OptimizerDetails,
    soft_projection_constraint,
)
from diffusion_co_design.wfcrl.env import _create_designable_windfarm
from diffusion_co_design.common import OUTPUT_DIR, get_latest_model

from dataset import load_dataset


training_dir = "/home/markhaoxiang/.diffusion_co_design/experiments/train_wfcrl/2025-05-28/13-55-27"
training_cfg = TrainingConfig.from_file(
    os.path.join(training_dir, ".hydra", "config.yaml")
)


train_dataset, eval_dataset = load_dataset(
    scenario=training_cfg.scenario,
    training_dir=training_dir,
    dataset_size=10_000,
    num_workers=25,
    test_proportion=0.2,
    recompute=False,
    device=device,
)

In [None]:
env_returns = train_dataset.dataset.env_returns.storage

episode_reward = env_returns["episode_reward"]
expected_reward = env_returns["expected_reward"]

fig, ax = plt.subplots(1, 1)
ax.hist(episode_reward, label="Sample")
ax.hist(expected_reward, label="Critic")
ax.legend()

In [None]:
_, sorted_idxs = torch.sort(episode_reward)
fig, ax = plt.subplots(1, 1)
ax.scatter(episode_reward[sorted_idxs], expected_reward[sorted_idxs])

In [None]:
cfg = Config.from_file("conf/config.yaml")

model = torch.nn.Sequential(
    GNNCritic(
        cfg=training_cfg.scenario,
        node_emb_dim=cfg.model.node_emb_size,
        edge_emb_dim=cfg.model.edge_emb_size,
        n_layers=cfg.model.depth,
    ),
    maybe_make_denormaliser(training_cfg.normalisation),
).to(device=device)

# Sampling
# model.load_state_dict(
#     torch.load(
#         "/home/markhaoxiang/.diffusion_co_design/experiments/train_wfcrl_classifier/2025-05-28/16-55-35/checkpoints/classifier.pt"
#     )
# )

# Critic
model.load_state_dict(
    # torch.load(
    #     "/home/markhaoxiang/.diffusion_co_design/experiments/train_wfcrl_classifier/2025-05-29/02-27-38/checkpoints/classifier.pt"
    # )
    torch.load(
        "/home/markhaoxiang/.diffusion_co_design/experiments/train_wfcrl/2025-05-29/16-41-56/checkpoints/designer_100.pt",
    )
)

In [None]:
X = eval_to_train(env_returns["env"][:1024].clone(), training_cfg.scenario).to(
    device, torch.float32
)

y_pred = model(X)
actual_y = expected_reward[:1024]

fig, ax = plt.subplots(1, 1)
ax.scatter(actual_y, y_pred.numpy(force=True))
# ax.hist(actual_y)

In [None]:
pretrain_dir = os.path.join(
    OUTPUT_DIR, "wfcrl", "diffusion", training_cfg.scenario_name
)
latest_checkpoint = get_latest_model(pretrain_dir, "model")

generator = Generator(
    generator_model_path=latest_checkpoint,
    scenario=training_cfg.scenario,
    batch_size=9,
    default_guidance_wt=5,
    device=device,
)

operation = OptimizerDetails()
operation.projection_constraint = soft_projection_constraint(training_cfg.scenario)
generator.guidance_weight = 2.0
operation.num_recurrences = 4
operation.lr = 0.01
operation.backward_steps = 8
operation.use_forward = True


batch = generator.generate_batch(
    value=model, use_operation=True, operation_override=operation, batch_size=9
)
X = eval_to_train(torch.tensor(batch).to(device), cfg=training_cfg.scenario)
y = model(X)
print(y.mean())


In [None]:
for x in batch:
    env = _create_designable_windfarm(
        scenario=training_cfg.scenario,
        initial_xcoords=x[:, 0].tolist(),
        initial_ycoords=x[:, 1].tolist(),
        render=True,
    )

    env.reset()
    fig, ax = plt.subplots(figsize=(4, 4))
    ax.axis("off")
    ax.imshow(env.render(), aspect="auto")
