In [None]:
import os
import torch
import matplotlib.pyplot as plt
from experiments.train_rware.main import TrainingConfig
from diffusion_co_design.common import OUTPUT_DIR, get_latest_model, cuda

from diffusion_co_design.rware.model.classifier import make_model
from diffusion_co_design.rware.diffusion.transform import (
    storage_to_layout,
    graph_projection_constraint,
    image_projection_constraint,
)
from diffusion_co_design.rware.diffusion.generator import Generator, OptimizerDetails
from rware.warehouse import Warehouse

from dataset import load_dataset, make_dataloader


device = cuda
training_dir = "/home/markhaoxiang/.diffusion_co_design/experiments/train_rware/2025-04-18/00-36-30"  # Four corners
representation = "image"

# Load latest model and config
cfg = TrainingConfig.from_file(os.path.join(training_dir, ".hydra", "config.yaml"))
diffusion_dir = pretrain_dir = os.path.join(
    OUTPUT_DIR, "rware", "diffusion", representation, cfg.scenario.name
)
latest_diffusion_checkpoint = get_latest_model(diffusion_dir, "model")

# Make Dataset
train_dataset, eval_dataset = load_dataset(
    scenario=cfg.scenario,
    training_dir=training_dir,
    dataset_size=10_000,
    num_workers=25,
    test_proportion=0.2,
    recompute=False,
    device=device,
)

BATCH_SIZE = 128

train_dataloader = make_dataloader(
    train_dataset,
    scenario=cfg.scenario,
    batch_size=128,
    representation=representation,
    device=device,
)

eval_dataloader = make_dataloader(
    eval_dataset,
    scenario=cfg.scenario,
    batch_size=128,
    representation=representation,
    device=device,
)

match representation:
    case "graph":
        model = make_model(
            "gnn-cnn",
            cfg.scenario,
            model_kwargs={"add_goal_positions": False},
            device=device,
        )

        model.load_state_dict(
            torch.load(
                "/home/markhaoxiang/.diffusion_co_design/experiments/train_rware_classifier/2025-04-22/18-45-03/checkpoints/classifier.pt"
            )
        )
    case "image":
        model = make_model(
            "cnn",
            cfg.scenario,
            device=device,
        )

        model.load_state_dict(
            torch.load(
                "/home/markhaoxiang/.diffusion_co_design/experiments/train_rware_classifier/2025-04-23/09-58-41/checkpoints/classifier.pt"
            )
        )

In [None]:
from torchvision.models.vision_transformer import VisionTransformer

vt = VisionTransformer(
    image_size=16,
    patch_size=8,
    num_classes=1,
    num_heads=8,
    hidden_dim=128,
    mlp_dim=128,
    num_layers=4,
)
print(sum(p.numel() for p in vt.parameters()))

In [None]:
import torch.nn as nn


model = make_model(
    "cnn",
    cfg.scenario,
    device=device,
    model_kwargs={
        "model_channels": 64,
        "channel_mult": (1, 2, 2, 2),
        "num_attention_head_channels": 64,
        "resblock_updown": True,
        "attention_resolutions": (16, 8, 4),
        "depthwise_separable": True,
    },
)


def summarize_parameters(module: nn.Module, name: str = "", indent: int = 0):
    """Recursively summarize parameters in a PyTorch module."""
    total_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
    submodules = list(module.named_children())

    # Print current module
    prefix = " " * (indent * 2)
    print(f"{prefix}{name or module.__class__.__name__}: {total_params:,} params")

    # Recurse into children
    for child_name, child_module in submodules:
        summarize_parameters(child_module, name=child_name, indent=indent + 1)


print(sum(p.numel() for p in model.parameters() if p.requires_grad))

summarize_parameters(model)

In [None]:
model.input_blocks

In [None]:
print(sum(p.numel() for p in model.parameters() if p.requires_grad))

In [None]:
fig, axs = plt.subplots(1, 3, figsize=(15, 5))
axs[0].hist(
    eval_dataset.dataset.env_returns.storage["episode_reward"],
    bins=100,
)
axs[0].set_title("Episode Reward")
axs[1].hist(
    eval_dataset.dataset.env_returns.storage["expected_reward"],
    bins=100,
)
axs[1].set_title("Expected Reward (Critic)")

classifier_returns = []
with torch.no_grad():
    for x, _, _ in eval_dataloader:
        classifier_returns.append(model(x.to(device)))
classifier_returns = torch.cat(classifier_returns, dim=0)
axs[2].hist(
    classifier_returns.cpu().numpy(),
    bins=100,
)
axs[2].set_title("Classifier Returns")


print(
    "Mean Episode Reward: ",
    train_dataset.dataset.env_returns.storage["episode_reward"].mean(),
)
print(
    "Mean Expected Reward: ",
    train_dataset.dataset.env_returns.storage["expected_reward"].mean(),
)
print(
    "Mean Classifier Returns: ",
    classifier_returns.mean(),
)


In [None]:
generator = Generator(
    batch_size=10,
    generator_model_path=latest_diffusion_checkpoint,
    scenario=cfg.scenario,
    guidance_wt=200 if representation == "image" else 5.0,
    representation=representation,
)
guidance_model = model
guidance_model.eval()

operation = OptimizerDetails()
match representation:
    case "graph":
        operation.lr = 0.01
        operation.num_recurrences = 8
        operation.backward_steps = 16
        operation.projection_constraint = graph_projection_constraint(cfg.scenario)
    case "image":
        operation.num_recurrences = 8
        operation.backward_steps = 0
        operation.projection_constraint = image_projection_constraint(cfg.scenario)


def show_batch(
    environment_batch,
    representation,
    n: int = 8,
):
    layouts = []
    for theta in environment_batch:
        layout = storage_to_layout(theta, cfg.scenario, representation=representation)
        warehouse = Warehouse(layout=layout, render_mode="rgb_array")
        layouts.append(warehouse.render())
        warehouse.close()

    fig, axs = plt.subplots(3, 3, figsize=(12, 12))
    axs = axs.ravel()
    for ax in axs:
        ax.axis("off")
    for i in range(n):
        axs[i].imshow(layouts[i])
    return fig, axs


environment_batch = generator.generate_batch(
    value=guidance_model,
    use_operation=True,
    operation_override=operation,
)


for env in environment_batch:
    layout = storage_to_layout(env, cfg.scenario, representation=representation)
    print(len(layout.reset_shelves()))
fig, axs = show_batch(environment_batch, representation)
fig.suptitle("Guided Generation")
fig.tight_layout()

X_batch = torch.from_numpy(environment_batch).to(device=device, dtype=torch.float32)
match representation:
    case "graph":
        X_batch = (X_batch / (cfg.scenario.size - 1)) * 2 - 1
    case "image":
        X_batch = X_batch * 2 - 1

print(X_batch.shape)
print(guidance_model(X_batch))

In [None]:
# FIGURE_SIZE_CNST = 2.5

# test_layout = [next(iter(train_dataset))]
# test_layout, _ = collate_fn(test_layout)
# pos, color = test_layout

# pos.requires_grad = True
# pos_optim = torch.optim.Adam([pos], lr=0.01)

# constraint = graph_projection_constraint(cfg.scenario)

# n_iterations = 1000
# for iteration in range(n_iterations):
#     pos.requires_grad = True
#     pos_optim.zero_grad()
#     y_pred = model.predict((pos, color))
#     loss = -y_pred.mean()
#     loss.backward()
#     pos_optim.step()

#     if iteration % (n_iterations // 10) == 0:
#         print(f"Iteration {iteration} Loss: {loss.item()}")
#         # pos = constraint(pos.detach())

#         fig, ax = plt.subplots(figsize=(FIGURE_SIZE_CNST, FIGURE_SIZE_CNST))

#         show_pos = (pos.squeeze() + 1) / 2
#         show_pos = show_pos * cfg.scenario.size
#         layout = storage_to_layout(
#             features=show_pos.numpy(force=True),
#             config=cfg.scenario,
#             representation_override="graph",
#         )
#         warehouse = Warehouse(layout=layout, render_mode="rgb_array")
#         print(len(warehouse.shelves))
#         im = warehouse.render()
#         ax.imshow(im)
#         ax.axis("off")
#         plt.show()
#         warehouse.close()

In [None]:
from diffusion_co_design.rware.model.classifier import GNNClassifier

model = GNNClassifier(cfg=cfg.scenario).to(device=device)

# Test
(pos, color), y = next(iter(train_dataloader))

number_parameters = sum([p.numel() for p in model.parameters()])
print(f"Number of parameters: {number_parameters}")
assert model.predict((pos, color)).shape == y.shape
pass