In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch_geometric.nn import global_add_pool
from tqdm import tqdm
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
from diffusion_co_design.pretrain.rware.transform import (
    graph_projection_constraint,
    storage_to_layout,
)
from experiments.train_rware.main import TrainingConfig, ScenarioConfig
from diffusion_co_design.common import (
    OUTPUT_DIR,
    omega_to_pydantic,
    get_latest_model,
    cuda,
)
from diffusion_co_design.pretrain.rware.graph import WarehouseGNNBase, E3GNNLayer
from guided_diffusion.script_util import create_classifier, classifier_defaults
from diffusion_co_design.pretrain.rware.generator import (
    Generator,
    OptimizerDetails,
)
from rware.warehouse import Warehouse

from dataset import (
    load_dataset,
    make_dataloader,
    CollateFn,
    ImageCollateFn,
    working_dir,
)
from diffusion_co_design.rware.model.classifier import GNNCNN

device = cuda
training_dir = "/home/markhaoxiang/.diffusion_co_design/training/2025-04-05/04-00-12"  # Four corners

# Load latest model and config
checkpoint_dir = os.path.join(training_dir, "checkpoints")
latest_policy = get_latest_model(checkpoint_dir, "policy_")
# Get config
hydra_dir = os.path.join(training_dir, ".hydra")
cfg = omega_to_pydantic(
    OmegaConf.load(os.path.join(hydra_dir, "config.yaml")), TrainingConfig
)

diffusion_dir = pretrain_dir = os.path.join(
    OUTPUT_DIR, "diffusion_pretrain", "graph", cfg.scenario.name
)
latest_checkpoint = get_latest_model(diffusion_dir, "model")

# Make Dataset
train_dataset, eval_dataset = load_dataset(
    scenario=cfg.scenario,
    training_dir=training_dir,
    dataset_size=10_000,
    num_workers=25,
    test_proportion=0.2,
    recompute=False,
    device=device,
)


In [None]:
BATCH_SIZE = 128

train_dataloader = make_dataloader(
    train_dataset,
    scenario=cfg.scenario,
    batch_size=128,
    representation="graph",
    device=device,
)

eval_dataloader = make_dataloader(
    eval_dataset,
    scenario=cfg.scenario,
    batch_size=128,
    representation="graph",
    device=device,
)

pass

In [None]:
from diffusion_co_design.rware.model.classifier import make_model

cfg.scenario.representation = "graph"
model = make_model("gnn-cnn", cfg.scenario, model_kwargs={}, device=device)
model.load_state_dict(
    torch.load(
        "/home/markhaoxiang/.diffusion_co_design/training/2025-04-09/18-44-16/checkpoints/designer_1900.pt"
    )
)

# model = torch.load(
#     "/home/markhaoxiang/.diffusion_co_design/experiments/diffusion_playground/gnn-cnn_graph/2025-04-06 20-41-44/checkpoints/classifier.pt",
#     weights_only=False,
# )

In [None]:
# FIGURE_SIZE_CNST = 2.5

# test_layout = [next(iter(train_dataset))]
# test_layout, _ = collate_fn(test_layout)
# pos, color = test_layout

# pos.requires_grad = True
# pos_optim = torch.optim.Adam([pos], lr=0.01)

# constraint = graph_projection_constraint(cfg.scenario)

# n_iterations = 1000
# for iteration in range(n_iterations):
#     pos.requires_grad = True
#     pos_optim.zero_grad()
#     y_pred = model.predict((pos, color))
#     loss = -y_pred.mean()
#     loss.backward()
#     pos_optim.step()

#     if iteration % (n_iterations // 10) == 0:
#         print(f"Iteration {iteration} Loss: {loss.item()}")
#         # pos = constraint(pos.detach())

#         fig, ax = plt.subplots(figsize=(FIGURE_SIZE_CNST, FIGURE_SIZE_CNST))

#         show_pos = (pos.squeeze() + 1) / 2
#         show_pos = show_pos * cfg.scenario.size
#         layout = storage_to_layout(
#             features=show_pos.numpy(force=True),
#             config=cfg.scenario,
#             representation_override="graph",
#         )
#         warehouse = Warehouse(layout=layout, render_mode="rgb_array")
#         print(len(warehouse.shelves))
#         im = warehouse.render()
#         ax.imshow(im)
#         ax.axis("off")
#         plt.show()
#         warehouse.close()

In [None]:
# generator = Generator(
#     batch_size=10,
#     generator_model_path=latest_checkpoint,
#     scenario=cfg.scenario,
#     guidance_wt=4,
#     representation="graph",
# )
# # guidance_model = model
# guidance_model = model
# guidance_model.eval()
# operation = OptimizerDetails()
# operation.lr = 0.003
# operation.num_recurrences = 32
# operation.backward_steps = 80
# operation.projection_constraint = graph_projection_constraint(cfg.scenario)
# # operation.print = True
# # operation.print_every = 5
# # operation.folder = "test_diffusion"


# def show_batch(environment_batch, n: int = 8):
#     layouts = []
#     for image in environment_batch:
#         layout = storage_to_layout(image, cfg.scenario)
#         warehouse = Warehouse(layout=layout, render_mode="rgb_array")
#         layouts.append(warehouse.render())
#         warehouse.close()

#     fig, axs = plt.subplots(3, 3, figsize=(12, 12))
#     axs = axs.ravel()
#     for ax in axs:
#         ax.axis("off")
#     for i in range(n):
#         axs[i].imshow(layouts[i])
#     return fig, axs


# environment_batch = generator.generate_batch(
#     value=guidance_model,
#     use_operation=True,
#     operation_override=operation,
# )


# cfg.scenario.representation = "graph"
# for env in environment_batch:
#     layout = storage_to_layout(env, cfg.scenario)
#     print(len(layout.reset_shelves()))
# fig, axs = show_batch(environment_batch)
# fig.suptitle("Guided Generation")
# fig.tight_layout()

# X_batch = (
#     torch.from_numpy(environment_batch).to(device=device, dtype=torch.float32)
#     # .moveaxis((0, 1, 2, 3), (0, 2, 3, 1))
# )
# # X_batch = torch.cat([X_batch, goal_map.unsqueeze(0).expand(8, -1, -1, -1)], dim=1)
# X_batch = (X_batch / (cfg.scenario.size - 1)) * 2 - 1
# print(X_batch.shape)
# print(guidance_model(X_batch))

In [None]:
from diffusion_co_design.rware.model.classifier import GNNClassifier

model = GNNClassifier(cfg=cfg.scenario).to(device=device)

# Test
(pos, color), y = next(iter(train_dataloader))

number_parameters = sum([p.numel() for p in model.parameters()])
print(f"Number of parameters: {number_parameters}")
assert model.predict((pos, color)).shape == y.shape
pass

In [None]:
from diffusion_co_design.pretrain.rware.graph import visualize_warehouse_graph

fig, ax = plt.subplots()
data = model.gnn.make_graph_batch_from_data(pos, color=color)[0].to_data_list()[6]
visualize_warehouse_graph(data=data, ax=ax)

In [None]:
TRAIN_NUM_EPOCHS = 50
RECOMPUTE = True

model = GNNClassifier(
    cfg=cfg.scenario, node_embedding_dim=512, edge_embedding_dim=32, num_layers=4
).to(device=device)
model_dir = os.path.join(working_dir, "classifier_gnn.pt")

number_parameters = sum([p.numel() for p in model.parameters()])
print(f"Number of parameters: {number_parameters}")


if RECOMPUTE or not os.path.exists(model_dir):
    optim = torch.optim.Adam(model.parameters(), lr=1e-3)
    criterion = torch.nn.MSELoss()

    train_losses = []
    eval_losses = []
    with tqdm(range(TRAIN_NUM_EPOCHS)) as pbar:
        for epoch in range(TRAIN_NUM_EPOCHS):
            running_train_loss = 0
            model.train()
            for (pos, colors), y in train_dataloader:
                optim.zero_grad()

                batch_size = pos.shape[0]
                y_pred = model.predict((pos, colors))
                loss = criterion(y_pred.view(batch_size, -1), y.view(batch_size, -1))

                loss.backward()
                optim.step()

                running_train_loss += loss.item()
            running_train_loss = running_train_loss / len(train_dataloader)

            # Evaluate
            model.eval()
            running_eval_loss = 0
            with torch.no_grad():
                for (pos, colors), y in eval_dataloader:
                    y_pred = model.predict((pos, colors)).squeeze()
                    batch_size = pos.shape[0]
                    loss = criterion(
                        y_pred.view(batch_size, -1), y.view(batch_size, -1)
                    )

                    running_eval_loss += loss.item()
            running_eval_loss = running_eval_loss / len(eval_dataloader)

            train_losses.append(running_train_loss)
            eval_losses.append(running_eval_loss)
            pbar.set_description(
                f" Train Loss {running_train_loss} Eval Loss {running_eval_loss}"
            )
            pbar.update()

    torch.save(model.state_dict(), model_dir)
else:
    model.load_state_dict(torch.load(model_dir))