In [None]:
import os
import torch
from matplotlib import pyplot as plt
from diffusion_co_design.wfcrl.schema import ScenarioConfig
from diffusion_co_design.wfcrl.diffusion.generator import (
    Generator,
    OptimizerDetails,
    soft_projection_constraint,
)
from diffusion_co_design.wfcrl.env import _create_designable_windfarm
from diffusion_co_design.common import OUTPUT_DIR, get_latest_model

SCENARIO = "wfcrl_10"
scenario = ScenarioConfig.from_file(f"conf/{SCENARIO}.yaml")

pretrain_dir = os.path.join(OUTPUT_DIR, "wfcrl", "diffusion", SCENARIO)
latest_checkpoint = get_latest_model(pretrain_dir, "model")

generator = Generator(
    generator_model_path=latest_checkpoint,
    scenario=scenario,
    batch_size=9,
    guidance_wt=5,
)

operation = OptimizerDetails()
operation.projection_constraint = soft_projection_constraint(scenario)
operation.num_recurrences = 2
operation.backward_steps = 0

In [None]:
class PlaceholderValueFn(torch.nn.Module):
    def forward(self, x):
        return x.flatten(start_dim=1).sum(dim=1)


placeholder_value_fn = PlaceholderValueFn()

batch = generator.generate_batch(
    batch_size=9,
    value=placeholder_value_fn,
    use_operation=True,
    operation_override=operation,
)

In [None]:
x = batch[0]
env = _create_designable_windfarm(
    scenario=scenario,
    initial_xcoords=x[:, 0].tolist(),
    initial_ycoords=x[:, 1].tolist(),
    render=True,
)

env.reset()
plt.imshow(env.render())