In [1]:
from pathlib import Path

from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from torch_geometric.nn import global_mean_pool
from torchvision import transforms

from layout_gnn.dataset.dataset import RICOTripletsDataset, DATA_PATH
from layout_gnn.dataset.transforms.core import process_data, normalize_bboxes
from layout_gnn.dataset.transforms.pyg import convert_graph_to_pyg
from layout_gnn.dataset.transforms.image import RescaleImage
from layout_gnn.dataset.transforms.nx import add_networkx, ConvertLabelsToIndexes
from layout_gnn.nn.model import LayoutGraphModel
from layout_gnn.nn.neural_rasterizer import CNNNeuralRasterizer
from layout_gnn.lightning_module import LayoutGraphModelCNNNeuralRasterizer
from layout_gnn.utils import pyg_triplets_data_collate


In [2]:
dataset = RICOTripletsDataset(triplets=DATA_PATH / "pairs_0_10000.json")
label_mappings = {k: i for i, k in enumerate(dataset.label_color_map)}
dataset.transform = transform=transforms.Compose([
    process_data,
    normalize_bboxes,
    add_networkx,
    RescaleImage(256, 256, allow_missing_image=True),
    ConvertLabelsToIndexes(
        node_label_mappings=label_mappings,
        # edge_label_mappings={"parent_of": 0, "child_of": 1},
    ),
    convert_graph_to_pyg,
])

data_loader = DataLoader(dataset=dataset, batch_size=8, collate_fn=pyg_triplets_data_collate)
model = LayoutGraphModelCNNNeuralRasterizer(
    num_labels=len(label_mappings) + 1, 
    cnn_output_dim=3,
    cnn_output_size=256,
)
trainer = Trainer(default_root_dir=DATA_PATH)
model.hparams

  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


"bbox_embedding_layer_dims":      32
"cnn_hidden_dim":                 8
"edge_label_embedding_dim":       None
"gnn_hidden_channels":            128
"gnn_model_cls":                  <class 'torch_geometric.nn.models.basic_gnn.GCN'>
"gnn_num_layers":                 3
"gnn_out_channels":               None
"label_embedding_dim":            32
"lr":                             0.001
"readout":                        <function LayoutGraphModelCNNNeuralRasterizer.__init__.<locals>.<lambda> at 0x7f61390f0790>
"reconstruction_loss_weight":     1
"triplet_loss_distance_function": None
"triplet_loss_margin":            1
"triplet_loss_swap":              False
"use_edge_attr":                  False

In [3]:
trainer.fit(model, data_loader)

  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
  rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")

  | Name                | Type                          | Params
----------------------------------------------------------------------
0 | encoder             | LayoutGraphModel              | 42.3 K
1 | decoder             | CNNNeuralRasterizer           | 1.2 M 
2 | triplet_loss        | TripletMarginWithDistanceLoss | 0     
3 | reconstruction_loss | MSELoss                       | 0     
----------------------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.066     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]

