In [1]:
%%time
import src
from pathlib import Path
from tqdm.notebook import trange
from src import Petri_Cheb_GNN, Petri_GCN

batch_size = 32
train_dataset = src.get_reachability_dataset(Path('Data/RandData_DS2_train_data.processed'), batch_size=batch_size)
test_dataset = src.get_reachability_dataset(Path('Data/RandData_DS2_test_data.processed'), batch_size=batch_size)

CPU times: user 1.79 s, sys: 1.01 s, total: 2.81 s
Wall time: 2.05 s


In [2]:
import torch
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, MLP
from torch_geometric.utils import scatter

class GCN(torch.nn.Module):
    def __init__(self, num_node_features, num_hidden_features):
        super().__init__()
        self.conv1 = GCNConv(num_node_features, num_hidden_features)
        self.conv2 = GCNConv(num_hidden_features, num_hidden_features)
        self.readout = MLP([num_hidden_features, num_hidden_features*2, num_hidden_features*3, 1])

    def forward(self, data):
        x, edge_index = data.x, data.edge_index

        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = F.dropout(x, training=self.training)
        x = self.conv2(x, edge_index)
        x = self.readout(x)
        return scatter(x, data.batch, dim=0, reduce='mean').view(-1)

In [3]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = Petri_Cheb_GNN(train_dataset.num_features, 16, 3)
model = torch.compile(model, dynamic=True)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

In [4]:
model.train()
for epoch in trange(100):
    for graph in train_dataset:
        optimizer.zero_grad()
        out = torch.flatten(model(graph))
        loss = F.l1_loss(out, graph.y)
        loss.backward()
        optimizer.step()

  0%|          | 0/100 [00:00<?, ?it/s]

In [5]:
model.eval()
pred = []
actual = torch.tensor([graph.y for graph in test_dataset.dataset])
for graph in test_dataset:
    pred.extend(model(graph).tolist())

pred = torch.tensor(pred)
pred = torch.flatten(pred)

print(f'MAE: {F.l1_loss(pred, actual)}')
print(f'MRE: {torch.nn.MSELoss()(pred, actual)}')

MAE: 0.6327988505363464
MRE: 0.6862345933914185
