In [1]:
from pathlib import Path

from torch.utils.data import DataLoader
from torchvision import transforms

from layout_gnn.dataset.dataset import RICOSemanticAnnotationsDataset
from layout_gnn.dataset import transformations
from layout_gnn.model.model import LayoutGraphModel
from layout_gnn.utils import pyg_data_collate


DATA_PATH = Path.cwd() / '../data'

In [2]:
dataset = RICOSemanticAnnotationsDataset(root_dir=DATA_PATH)
label_mappings = {k: i for i, k in enumerate(dataset.label_color_map)}
dataset.transform = transform=transforms.Compose([
    transformations.process_data,
    transformations.normalize_bboxes,
    transformations.add_networkx,
    transformations.ConvertLabelsToIndexes(
        node_label_mappings=label_mappings,
        # edge_label_mappings={"parent_of": 0, "child_of": 1},
    ),
    transformations.convert_graph_to_pyg,
])

data_loader = DataLoader(dataset=dataset, batch_size=8, collate_fn=pyg_data_collate)
model = LayoutGraphModel(num_labels=len(label_mappings) + 1, label_embedding_dim=32, bbox_embedding_layer_dims=32, gnn_hidden_channels=128, gnn_num_layers=3)

In [3]:
for batch in data_loader:
    x = model(batch)
    break
batch, x

(DataBatch(edge_index=[2, 360], bbox=[188, 4], label=[188], edge_label=[8], num_nodes=188, batch=[188], ptr=[9]),
 tensor([[-1.6283e-01,  1.0870e-01, -2.3987e-01,  ...,  1.0042e-01,
           1.8368e-01,  5.5398e-01],
         [-1.6283e-01,  1.0870e-01, -2.3987e-01,  ...,  1.0042e-01,
           1.8368e-01,  5.5398e-01],
         [-7.0449e-02,  8.2786e-02, -1.3573e-01,  ...,  1.8040e-01,
           1.0888e-01,  6.4485e-01],
         ...,
         [ 3.5886e-02,  6.5445e-02, -2.8341e-02,  ...,  2.3916e-01,
           2.0711e-01,  2.4402e-01],
         [-3.1221e-02,  3.1263e-02,  1.1671e-02,  ...,  2.0170e-01,
           2.0555e-01,  2.7374e-01],
         [ 5.3307e-04,  5.5070e-02, -2.9330e-02,  ...,  2.8223e-01,
           2.6488e-01,  2.9109e-01]], grad_fn=<AddBackward0>))