In [None]:
from omegaconf import OmegaConf
import torch
import numpy as np

from diffusion_co_design.pretrain.rware.generate import (
    generate,
    WarehouseRandomGeneratorConfig,
)
from diffusion_co_design.pretrain.rware.generator import (
    create_model_and_diffusion_rware,
)
from diffusion_co_design.pretrain.rware.graph import WarehouseDiffusionModel
from diffusion_co_design.utils import omega_to_pydantic

config = omega_to_pydantic(
    OmegaConf.load(
        "../diffusion_co_design/bin/conf/scenario/rware_16_50_5_4_corners.yaml"
    ),
    WarehouseRandomGeneratorConfig,
)

out = generate(
    n=2,
    size=config.size,
    n_shelves=config.n_shelves,
    agent_idxs=config.agent_idxs,
    goal_idxs=config.goal_idxs,
    n_colors=config.n_colors,
    training_dataset=True,
    representation="graph",
)

out = torch.tensor(np.array(out))

model = WarehouseDiffusionModel(config)
model(out).shape

In [None]:
import torch
from torch_geometric.nn import radius_graph

x = torch.tensor([[-1.0, -1.0], [-1.0, 1.0], [1.0, -1.0], [1.0, 1.0]])
batch = torch.tensor([0, 0, 0, 0])
edge_index = radius_graph(x, r=2.2, batch=batch, loop=False)
edge_index.shape


In [None]:
for representation in ["image", "flat", "graph"]:
    model = create_model_and_diffusion_rware(
        scenario=config,
        representation=representation,
    )[0]
    # Print model parameters
    print(representation)
    print(sum([p.numel() for p in model.parameters()]))