In [1]:
from pathlib import Path

from torch.utils.data import DataLoader
from torchvision import transforms

from layout_gnn.dataset.dataset import RICOSemanticAnnotationsDataset
from layout_gnn.dataset import transformations
from layout_gnn.model.model import Model
from layout_gnn.utils import pyg_data_collate


DATA_PATH = Path.cwd() / '../data'

In [2]:
dataset = RICOSemanticAnnotationsDataset(root_dir=DATA_PATH)
label_mappings = {k: i for i, k in enumerate(dataset.label_color_map)}
dataset.transform = transform=transforms.Compose([
    transformations.process_data,
    transformations.normalize_bboxes,
    transformations.add_networkx,
    transformations.ConvertLabelsToIndexes(
        node_label_mappings=label_mappings,
        # edge_label_mappings={"parent_of": 0, "child_of": 1},
    ),
    transformations.convert_graph_to_pyg,
])

data_loader = DataLoader(dataset=dataset, batch_size=8, collate_fn=pyg_data_collate)
model = Model(num_labels=len(label_mappings), label_embedding_dim=32, bbox_embedding_layer_dims=32, gnn_hidden_channels=128, gnn_num_layers=3)

In [3]:
for batch in data_loader:
    x = model(batch)
    break
batch, x

(DataBatch(edge_index=[2, 360], bbox=[188, 4], label=[188], edge_label=[360], num_nodes=188, batch=[188], ptr=[9]),
 tensor([[ 0.1158,  0.2690,  0.4183,  ...,  0.1469,  0.1217, -0.0774],
         [ 0.1158,  0.2690,  0.4183,  ...,  0.1469,  0.1217, -0.0774],
         [ 0.1188,  0.2146,  0.3822,  ...,  0.0989,  0.1267, -0.1092],
         ...,
         [-0.0128,  0.0412,  0.3399,  ..., -0.0985,  0.0875, -0.2561],
         [ 0.0624,  0.0467,  0.2904,  ..., -0.1278,  0.1082, -0.2458],
         [ 0.0206,  0.0686,  0.2431,  ..., -0.1598,  0.1279, -0.2055]],
        grad_fn=<AddBackward0>))