In [1]:
from pathlib import Path

from pytorch_lightning import Trainer
from torch.utils.data import DataLoader
from torch_geometric.nn import global_mean_pool
from torchvision import transforms

from layout_gnn.dataset.dataset import RICOTripletsDataset, DATA_PATH
from layout_gnn.dataset import transformations
from layout_gnn.model.model import LayoutGraphModel, CNNNeuralRasterizer
from layout_gnn.model.lightning_module import LayoutGraphModelCNNNeuralRasterizer
from layout_gnn.utils import pyg_triplets_data_collate


In [2]:
dataset = RICOTripletsDataset(triplets=DATA_PATH / "pairs_0_10000.json")
label_mappings = {k: i for i, k in enumerate(dataset.label_color_map)}
dataset.transform = transform=transforms.Compose([
    transformations.process_data,
    transformations.normalize_bboxes,
    transformations.add_networkx,
    transformations.RescaleImage(256, 256, allow_missing_image=True),
    transformations.ConvertLabelsToIndexes(
        node_label_mappings=label_mappings,
        # edge_label_mappings={"parent_of": 0, "child_of": 1},
    ),
    transformations.convert_graph_to_pyg,
])

data_loader = DataLoader(dataset=dataset, batch_size=8, collate_fn=pyg_triplets_data_collate)
model = LayoutGraphModelCNNNeuralRasterizer(
    num_labels=len(label_mappings) + 1, 
    cnn_output_dim=3,
    cnn_output_size=256,
)
trainer = Trainer(default_root_dir=DATA_PATH)

  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [3]:
trainer.fit(model, data_loader)

  rank_zero_warn(f"attribute '{k}' removed from hparams because it cannot be pickled")
  rank_zero_warn("You defined a `validation_step` but have no `val_dataloader`. Skipping val loop.")

  | Name                | Type                          | Params
----------------------------------------------------------------------
0 | encoder             | LayoutGraphModel              | 42.3 K
1 | decoder             | CNNNeuralRasterizer           | 1.2 M 
2 | triplet_loss        | TripletMarginWithDistanceLoss | 0     
3 | reconstruction_loss | MSELoss                       | 0     
----------------------------------------------------------------------
1.3 M     Trainable params
0         Non-trainable params
1.3 M     Total params
5.066     Total estimated model params size (MB)
  rank_zero_warn(


Training: 0it [00:00, ?it/s]



RuntimeError: Found dtype Double but expected Float

In [None]:
type(model.hparams)

pytorch_lightning.utilities.parsing.AttributeDict