In [186]:
import torch
import torch_geometric
from torch_geometric.data import Data
from torch_geometric.datasets import TUDataset
from torch_geometric.loader import DataLoader
import numpy as np
from tqdm import tqdm

In [187]:
def graphToData(file):
    #data.x: Node feature matrix with shape [num_nodes, num_node_features]
    #data.edge_index: Graph connectivity in COO format with shape [2, num_edges] and type torch.long
    #data.edge_attr: Edge feature matrix with shape [num_edges, num_edge_features]
    #data.y: Target to train against (may have arbitrary shape), e.g., node-level targets of shape [num_nodes, *] or graph-level targets of shape [1, *]
    #data.pos: Node position matrix with shape [num_nodes, num_dimensions]

    graph = np.load(file)
    label = torch.tensor(graph["label"]) #shape=[1, ]
    node_values = torch.tensor(graph["nodes"], dtype=torch.float) #shape=[num_nodes, 1]
    edge_values = torch.tensor(graph["edges"], dtype=torch.float) #shape=[num_edges, 2]
    adj_list = torch.tensor(graph["adj_list"], dtype=torch.long) #shape=[num_edges, 2] <- needs to be reshaped (see pyg doc)
    positions = torch.tensor(graph["positions"], dtype=torch.float) #shape=[num_nodes, 2]
    
    return Data(x=node_values, edge_index=adj_list.t().contiguous(), edge_attr=edge_values, y=label, pos=positions)

def loadData():

    graphs = list()
    print("load squares")
    for i in tqdm(range(200)):
        file = f"../GenData/squares/square_{i}.npz"
        graphs.append(graphToData(file))
    
    print("load rectangles")
    for i in tqdm(range(200)):
        file = f"../GenData/rectangles/rectangular_{i}.npz"
        graphs.append(graphToData(file))
    
    print("load hexagons")
    for i in tqdm(range(200)):
        file = f"../GenData/hexagons/hexagonal_{i}.npz"
        graphs.append(graphToData(file))
    
    print("load oblique")
    for i in tqdm(range(200)):
        file = f"../GenData/oblique/oblique_{i}.npz"
        graphs.append(graphToData(file))

    return graphs
        

In [188]:
class gineConv(torch.nn.Module):
    def __init__(self, in_channels, num_classes, edge_dim):
        super().__init__()

        self.nn_1 = torch.nn.Sequential(
            torch.nn.Linear(in_features=in_channels, out_features=10, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(10, 10),
            torch.nn.ReLU(),
            torch.nn.Linear(10, 10)
        )

        self.nn_2 = torch.nn.Sequential(
            torch.nn.Linear(in_features=20, out_features=30, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(30, 30),
            torch.nn.ReLU(),
            torch.nn.Linear(30, 30)
        )

        self.gine_conv1 = torch_geometric.nn.GINEConv(nn=self.nn_1, train_eps=True, edge_dim = edge_dim)
        self.gine_conv2 = torch_geometric.nn.GINEConv(nn=self.nn_2, train_eps=True, edge_dim = edge_dim)

        self.classifier = torch.nn.Sequential(
            #torch.nn.Linear(in_features=30, out_features=20, bias=True),
            #torch.nn.ReLU(),
            torch.nn.Linear(10, out_features=num_classes, bias=True)
        )

    
    def forward(self, data):

        #message passing layers:
        x, edge_index, edge_attr, batch = data.x, data.edge_index, data.edge_attr, data.batch

        x = self.gine_conv1(x, edge_index, edge_attr)
        #print("conv1: ", x.shape)
        x = torch.nn.functional.relu(x)

        #x = self.gine_conv2(x, edge_index, edge_attr)
        #x = torch.nn.functional.relu(x)
        #print("conv2: ", x.shape)

        #readout layer (mean over all nodes)
        x = torch_geometric.nn.global_mean_pool(x, batch)

        #classifier
        x = self.classifier(x)

        return torch.nn.functional.log_softmax(x, dim=1) #to use with loss function nll_loss

In [189]:
def train(loader, model, loss_fn, optimizer, device):
    total_num_dataset = len(loader.dataset)
    model.train()
    for batch_nr, batch_dat in enumerate(loader):
        batch_dat = batch_dat.to(device)
        #print(batch_dat)
        #print(batch_dat.num_graphs)
        pred = model(batch_dat)
        #print(pred.shape)
        loss = loss_fn(pred, batch_dat.y)
        
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

        if batch_nr % 20 == 0:
            loss, current = loss.item(), (batch_nr + 1)*len(batch_dat)
            print(f"loss: {loss:>7f} [{current:>5d}/{total_num_dataset:>5d}]")

In [190]:
#load data
data_list = loadData()
train_dataloader = DataLoader(data_list, batch_size=8, shuffle=True)
print(len(train_dataloader.dataset))

load squares


100%|██████████| 200/200 [00:00<00:00, 646.41it/s]


load rectangles


100%|██████████| 200/200 [00:00<00:00, 728.84it/s]


load hexagons


100%|██████████| 200/200 [00:00<00:00, 523.22it/s]


load oblique


100%|██████████| 200/200 [00:00<00:00, 742.85it/s]

800





In [191]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [192]:
model = gineConv(1, 4, 2)
model.to(device)
loss_fn = torch.nn.NLLLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2)

epochs = 20
for t in range(epochs):
    print(f"\nEpoch {t+1}\n----------------------------------------")
    train(train_dataloader, model, loss_fn, optimizer, device)
    #test(test_dataloader, model, loss_fn)
print("Done")


Epoch 1
----------------------------------------
loss: 1.424645 [    8/  800]
loss: 1.385198 [  168/  800]
loss: 1.339840 [  328/  800]
loss: 1.354807 [  488/  800]
loss: 1.042459 [  648/  800]

Epoch 2
----------------------------------------
loss: 0.795919 [    8/  800]
loss: 0.847001 [  168/  800]
loss: 0.992510 [  328/  800]
loss: 0.537189 [  488/  800]
loss: 0.924451 [  648/  800]

Epoch 3
----------------------------------------
loss: 1.077623 [    8/  800]
loss: 0.388849 [  168/  800]
loss: 0.845413 [  328/  800]
loss: 0.892118 [  488/  800]
loss: 0.942486 [  648/  800]

Epoch 4
----------------------------------------
loss: 0.702253 [    8/  800]
loss: 0.966353 [  168/  800]
loss: 0.866604 [  328/  800]
loss: 0.533416 [  488/  800]
loss: 0.696019 [  648/  800]

Epoch 5
----------------------------------------
loss: 0.893044 [    8/  800]
loss: 0.937990 [  168/  800]
loss: 1.096215 [  328/  800]
loss: 0.442699 [  488/  800]
loss: 0.828826 [  648/  800]

Epoch 6
----------------

Epoch 1
----------------------------------------
Epoch 2
----------------------------------------
Epoch 3
----------------------------------------
Epoch 4
----------------------------------------
Epoch 5
----------------------------------------
Done
