In [None]:
import os
import torch
import shutil
from matplotlib import pyplot as plt
from diffusion_co_design.wfcrl.schema import ScenarioConfig
from diffusion_co_design.wfcrl.diffusion.generator import (
    Generator,
    OptimizerDetails,
    soft_projection_constraint,
)
from diffusion_co_design.wfcrl.env import render_layout
from diffusion_co_design.common import OUTPUT_DIR, get_latest_model

SCENARIO = "wfcrl_10"
scenario = ScenarioConfig.from_file(f"conf/{SCENARIO}.yaml")
shutil.rmtree("__simul__")

In [None]:
pretrain_dir = os.path.join(OUTPUT_DIR, "wfcrl", "diffusion", SCENARIO)
latest_checkpoint = get_latest_model(pretrain_dir, "model")

generator = Generator(
    generator_model_path=latest_checkpoint, scenario=scenario, batch_size=9
)

operation = OptimizerDetails()
operation.projection_constraint = soft_projection_constraint(scenario)
operation.num_recurrences = 2
operation.backward_steps = 0


class PlaceholderValueFn(torch.nn.Module):
    def forward(self, x):
        return x.flatten(start_dim=1).sum(dim=1)


placeholder_value_fn = PlaceholderValueFn()

batch = generator.generate_batch(
    batch_size=9,
    value=placeholder_value_fn,
    use_operation=True,
    operation_override=operation,
)

In [None]:
x = batch[1]
fig, ax = plt.subplots(1, 1)
fig.set_size_inches(5, 10)
ax.imshow(render_layout(x, scenario))

In [None]:
x