In [None]:
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torch_geometric.nn import global_add_pool
from tqdm import tqdm
from omegaconf import OmegaConf
import matplotlib.pyplot as plt
from diffusion_co_design.pretrain.rware.transform import storage_to_layout
from diffusion_co_design.bin.train_rware import (
    TrainingConfig,
    DesignerRegistry,
    DesignerConfig,
    ScenarioConfig,
)
from diffusion_co_design.utils import (
    omega_to_pydantic,
    get_latest_model,
    cuda,
)
from diffusion_co_design.rware.env import create_env
from diffusion_co_design.rware.model import rware_models
from diffusion_co_design.pretrain.rware.graph import WarehouseGNNBase, E3GNNLayer
from rware.warehouse import Warehouse

from dataset import load_dataset, CollateFn, working_dir

device = cuda
training_dir = "/home/markhaoxiang/.diffusion_co_design/training/2025-04-05/04-00-12"  # Four corners

# Load latest model and config
checkpoint_dir = os.path.join(training_dir, "checkpoints")
latest_policy = get_latest_model(checkpoint_dir, "policy_")
# Get config
hydra_dir = os.path.join(training_dir, ".hydra")
cfg = omega_to_pydantic(
    OmegaConf.load(os.path.join(hydra_dir, "config.yaml")), TrainingConfig
)

cfg.scenario.representation = "image"

_, env_designer = DesignerRegistry.get(
    DesignerConfig(type="random"),
    cfg.scenario,
    working_dir,
    environment_batch_size=32,
    device=device,
)
env = create_env(cfg.scenario, env_designer, render=True, device=device)
policy, _ = rware_models(env, cfg.policy, device=device)
policy.load_state_dict(torch.load(latest_policy))

# Make Dataset
train_dataset, eval_dataset = load_dataset(
    scenario=cfg.scenario,
    policy=policy,
    dataset_size=10_000,
    num_workers=25,
    test_proportion=0.2,
    recompute=True,
    device=device,
)


In [None]:
BATCH_SIZE = 128

collate_fn = CollateFn(cfg.scenario, device)


def make_dataloader(dataset, batch_size=128):
    return DataLoader(
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        collate_fn=collate_fn,
        persistent_workers=True,
    )


train_dataloader = make_dataloader(train_dataset, batch_size=128)
eval_dataloader = make_dataloader(eval_dataset, batch_size=128)

pass

In [None]:
class MLPClassifier(nn.Module):
    def __init__(self, cfg: ScenarioConfig, hidden_dim: int = 512, num_layers: int = 4):
        super().__init__()
        in_dim = (2 + cfg.n_colors) * cfg.n_shelves

        layers = []
        dims = [in_dim] + [hidden_dim] * (num_layers - 1)

        for i in range(len(dims) - 1):
            layers.append(nn.Linear(dims[i], dims[i + 1]))
            layers.append(nn.LayerNorm(dims[i + 1]))
            layers.append(nn.SiLU())

        layers.append(nn.Linear(hidden_dim, 1))
        self.net = nn.Sequential(*layers)

    def forward(self, pos, colors):
        x = torch.cat([pos, colors], dim=-1)
        x = x.view(x.shape[0], -1)
        x = self.net(x)
        return x.squeeze(-1)


model = MLPClassifier(cfg.scenario).to(device)

# Test
(pos, color), y = next(iter(train_dataloader))

number_parameters = sum([p.numel() for p in model.parameters()])
print(f"Number of parameters: {number_parameters}")
assert model(pos, color).shape == y.shape
pass

In [None]:
TRAIN_NUM_EPOCHS = 100
RECOMPUTE = True

model = MLPClassifier(cfg.scenario).to(device)
model_dir = os.path.join(working_dir, "classifier_mlp.pt")


if RECOMPUTE or not os.path.exists(model_dir):
    optim = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.05)
    criterion = torch.nn.MSELoss()

    train_losses = []
    eval_losses = []
    with tqdm(range(TRAIN_NUM_EPOCHS)) as pbar:
        for epoch in range(TRAIN_NUM_EPOCHS):
            running_train_loss = 0
            model.train()
            for (pos, color), y in train_dataloader:
                optim.zero_grad()

                y_pred = model(pos, color)
                loss = criterion(y_pred, y)
                loss.backward()
                optim.step()

                running_train_loss += loss.item()
            running_train_loss = running_train_loss / len(train_dataloader)

            # Evaluate
            model.eval()
            running_eval_loss = 0
            with torch.no_grad():
                for (pos, color), y in train_dataloader:
                    y_pred = model(pos, color).squeeze()
                    loss = criterion(y_pred, y)

                    running_eval_loss += loss.item()
            running_eval_loss = running_eval_loss / len(eval_dataloader)

            train_losses.append(running_train_loss)
            eval_losses.append(running_eval_loss)
            pbar.set_description(
                f" Train Loss {running_train_loss} Eval Loss {running_eval_loss}"
            )
            pbar.update()

    torch.save(model.state_dict(), model_dir)
else:
    model.load_state_dict(torch.load(model_dir))

In [None]:
try:
    assert train_losses is not None
    fig, ax = plt.subplots()
    ax.plot(train_losses, label="train")
    ax.plot(eval_losses, label="eval")
    ax.set_ylabel("Loss")
    ax.set_xlabel("Epoch")
    ax.set_title("Training Losses (MLP)")
    fig.legend()
except:
    pass

In [None]:
torch.cuda.empty_cache()

FIGURE_SIZE_CNST = 2.5

env_returns_sorted_index = torch.argsort(eval_dataset[:][1], descending=True)
# env_returns_sorted_index = torch.argsort(
#     env_returns.storage["episode_reward"], descending=True
# )
best_5 = env_returns_sorted_index[:5]
worst_5 = env_returns_sorted_index[-5:]

fig, axs = plt.subplots(2, 5)
fig.set_size_inches(5 * FIGURE_SIZE_CNST, 2 * FIGURE_SIZE_CNST)

for i, idx in enumerate(best_5):
    ax = axs[0, i]
    layout = storage_to_layout(
        features=eval_dataset[:][0][idx].numpy(force=True),
        config=cfg.scenario,
    )
    print(eval_dataset[:][1][idx])
    warehouse = Warehouse(layout=layout, render_mode="rgb_array")
    im = warehouse.render()
    ax.imshow(im)
    warehouse.close()
    ax.axis("off")

print("===")

for i, idx in enumerate(worst_5):
    ax = axs[1, i]
    layout = storage_to_layout(
        # env_returns.storage["env"][idx],
        features=eval_dataset[:][0][idx].numpy(force=True),
        config=cfg.scenario,
    )
    print(eval_dataset[:][1][idx])
    warehouse = Warehouse(layout=layout, render_mode="rgb_array")
    im = warehouse.render()
    ax.imshow(im)
    warehouse.close()
    ax.axis("off")

In [None]:
test_layout = [next(iter(train_dataset))]
test_layout, _ = collate_fn(test_layout)
pos, color = test_layout

pos.requires_grad = True
pos_optim = torch.optim.Adam([pos], lr=0.01)


n_iterations = 500
for iteration in range(n_iterations):
    pos_optim.zero_grad()
    y_pred = model(pos, color)
    loss = -y_pred.mean()
    loss.backward()
    pos_optim.step()

    if iteration % (n_iterations // 10) == 0:
        print(f"Iteration {iteration} Loss: {loss.item()}")

        fig, ax = plt.subplots(figsize=(FIGURE_SIZE_CNST, FIGURE_SIZE_CNST))

        show_pos = (pos.squeeze() + 1) / 2
        show_pos = show_pos * cfg.scenario.size
        layout = storage_to_layout(
            features=show_pos.numpy(force=True),
            config=cfg.scenario,
            representation_override="graph",
        )
        warehouse = Warehouse(layout=layout, render_mode="rgb_array")
        print(len(warehouse.shelves))
        im = warehouse.render()
        ax.imshow(im)
        ax.axis("off")
        plt.show()
        warehouse.close()

In [None]:
class GraphClassifier(WarehouseGNNBase):
    def __init__(
        self,
        scenario: ScenarioConfig,
        node_embedding_dim: int = 64,
        edge_embedding_dim: int = 32,
        num_layers: int = 5,
        use_radius_graph: bool = True,
        radius: float = 0.5,
    ):
        super().__init__(
            scenario=scenario,
            use_radius_graph=use_radius_graph,
            radius=radius,
            include_color_features=True,
        )

        self.embedding_dim = node_embedding_dim
        self.num_nodes = scenario.n_goals + scenario.n_shelves
        self.num_layers = num_layers

        self.h_in = nn.Linear(self.feature_dim, node_embedding_dim)

        self.convs = nn.ModuleList()
        for i in range(num_layers):
            self.convs.append(
                E3GNNLayer(
                    node_embedding_dim=node_embedding_dim,
                    edge_embedding_dim=edge_embedding_dim,
                    graph_embedding_dim=0,  # no timestep embeddings
                    update_node_features=i < num_layers - 1,
                    use_attention=True,
                )
            )

        self.readout = global_add_pool
        self.out_mlp = nn.Sequential(
            nn.Linear(node_embedding_dim, node_embedding_dim),
            nn.SiLU(),
            nn.Linear(node_embedding_dim, 1),
        )

    def forward(self, pos: torch.Tensor, color: torch.Tensor) -> torch.Tensor:
        graph, _ = self.make_graph_from_data(pos, color=color)
        h = self.h_in(graph.h)  # [N, d]
        pos = graph.pos  # [N, 2]
        batch = graph.batch  # [N]

        for i, gnn in enumerate(self.convs):
            h, pos = gnn(h, graph.edge_index, pos, None, batch)

        # Readout across entire graph (goals + shelves)
        graph_repr = self.readout(h, batch)
        return self.out_mlp(graph_repr).squeeze(-1)


model = GraphClassifier(cfg.scenario).to(device)

# Test
(pos, color), y = next(iter(train_dataloader))

number_parameters = sum([p.numel() for p in model.parameters()])
print(f"Number of parameters: {number_parameters}")
assert model(pos, color).shape == y.shape
pass

In [None]:
TRAIN_NUM_EPOCHS = 50
RECOMPUTE = True

model = GraphClassifier(cfg.scenario).to(device)
model_dir = os.path.join(working_dir, "classifier_gnn.pt")


if RECOMPUTE or not os.path.exists(model_dir):
    optim = torch.optim.Adam(model.parameters(), lr=3e-4, weight_decay=0.05)
    criterion = torch.nn.MSELoss()

    train_losses = []
    eval_losses = []
    with tqdm(range(TRAIN_NUM_EPOCHS)) as pbar:
        for epoch in range(TRAIN_NUM_EPOCHS):
            running_train_loss = 0
            model.train()
            for (pos, colors), y in train_dataloader:
                optim.zero_grad()

                y_pred = model(pos, colors)
                loss = criterion(y_pred, y)
                loss.backward()
                optim.step()

                running_train_loss += loss.item()
            running_train_loss = running_train_loss / len(train_dataloader)

            # Evaluate
            model.eval()
            running_eval_loss = 0
            with torch.no_grad():
                for (pos, colors), y in train_dataloader:
                    y_pred = model(pos, colors).squeeze()
                    loss = criterion(y_pred, y)

                    running_eval_loss += loss.item()
            running_eval_loss = running_eval_loss / len(eval_dataloader)

            train_losses.append(running_train_loss)
            eval_losses.append(running_eval_loss)
            pbar.set_description(
                f" Train Loss {running_train_loss} Eval Loss {running_eval_loss}"
            )
            pbar.update()

    torch.save(model.state_dict(), model_dir)
else:
    model.load_state_dict(torch.load(model_dir))