In [1]:
%load_ext autoreload
%autoreload 2

In [12]:
from pathlib import Path

import pytorch_lightning as pl

import torch
from torch import nn
from torch_geometric.nn import knn_graph
from torch_geometric import transforms
from torch_geometric.data import Data, DataLoader
from torchmetrics.functional import mean_squared_error

from gvp import SyntheticGNN, SyntheticDataModule

In [3]:
data_dir = Path("../data/synthetic")

In [4]:
class ExtendedPPF:
    def __init__(self, norm=True, cat=True):
        self.norm = norm
        self.cat = cat

        self.ppf = transforms.PointPairFeatures(cat=False)
        self.distance = transforms.Distance(norm=norm, cat=False)

    def __call__(self, data):
        existing_features = data.edge_attr

        ppf_features = self.ppf(data).edge_attr
        ppf_features = torch.cos(ppf_features)
        dist_features = self.distance(data).edge_attr


        new_features = torch.cat([dist_features, ppf_features[:, 1:]], dim=-1)
        
        if existing_features is not None and self.cat:
            data.edge_attr = torch.cat([existing_features, new_features], dim=-1)
        else:
            data.edge_attr = new_features

        return data

In [16]:
model = SyntheticGNN(4)

transform = transforms.Compose([
    transforms.KNNGraph(k=10),
    ExtendedPPF()
])

dm = SyntheticDataModule(data_dir, 32, "off_center", transform, num_workers=2)
dm.setup('fit')

In [17]:
for batch in dm.train_dataloader():
    print(batch)
    out = model(batch)
    print(out.shape)
    print(batch.y.shape)
    loss = mean_squared_error(out, batch.y)
    break

Batch(batch=[3200], edge_attr=[32000, 4], edge_index=[2, 32000], norm=[3200, 3], pos=[3200, 3], ptr=[33], x=[3200, 4], y=[32])
torch.Size([3200, 32])
torch.Size([3200, 1])
torch.Size([32])


RuntimeError: Predictions and targets are expected to have the same shape