In [1]:
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as geom_nn
from torch_geometric.data import Data
from torch_geometric.data import DataLoader
from torch_geometric.data import InMemoryDataset
from torch_geometric.utils import from_networkx

# Generating a cycle dataset

Here we generate a dataset containing pairs of graphs that are not distinguishable by the 1-WL isomorphism test.
Later we will use GNNs to learn to tell them apart.

In [2]:
graphs_nx = []
graphs_is_cycle = []  # keeps track if a graph is a cycle or disjoint
for n in range(6, 16):
    g_cyc = nx.Graph()
    g_cyc.add_nodes_from(range(n))
    g_cyc.add_edges_from([(x, x+1) for x in range(n-1)] + [(n-1, 0)])  # connect nodes to cycle
    
    for split_n in range(3,n-2):
        g_split = nx.Graph()
        g_cyc.add_nodes_from(range(n))
        g_split.add_edges_from([(x, x+1) for x in range(split_n-1)] + [(split_n-1, 0)])  # first cycle of size split_n
        g_split.add_edges_from([(x, x+1) for x in range(split_n, n-1)] + [(n-1, split_n)])  # dsecond cycle of remaing nodes
        graphs_nx.append(g_split)
        graphs_is_cycle.append(False)
        graphs_nx.append(g_cyc)  # add g_cyc every time to maintain balance
        graphs_is_cycle.append(True)
        
        

In [3]:
# Converting the graphs to torch_geometric.data.Data objects
graphs = [from_networkx(g) for g in graphs_nx]
for i_g, g in enumerate(graphs):
    g.x = torch.zeros((g.num_nodes, 50))  # uniform x/features
    g.y = torch.tensor([1 if graphs_is_cycle[i_g] else 0])  # target label indicating whether graph is cycle or disjoint

# Building a GNN

In [13]:
# Custom message passing layer
#class CustomLayer(geom_nn.MessagePassing):
    #def __init__(self, in_channels, out_channels):
        #super().__init__(aggr='add')
        #self.lin = nn.Linear(in_channels, out_channels)
        #self.activation = nn.ReLU()
        
    # activation, linear etc probably goes here
    #def forward(self, x, edge_index):
        #return self.propagate(edge_index, x=x)
    
    # stuff that happens after all the message passing, so just id??
    #def update(self, x):
        #return self.activation(self.linear(x))
        
    # this looks fine
    #def message(self, x_j):
        #return x_j

class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, depth, mp_layer=None):
        super().__init__()
        # the type of message passing layer used throughout the net
        self.mp_layer = geom_nn.GCNConv if mp_layer is None else mp_layer
        self.pool = geom_nn.global_mean_pool
        self.mp_layers = nn.ModuleList()
        
        # standard mlp used after message passing layers
        self.post_mp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid())
        
        # add message passing layers
        self.mp_layers.append(self.mp_layer(input_dim, hidden_dim))
        for i in range(depth-1):
            self.mp_layers.append(self.mp_layer(hidden_dim, hidden_dim))
            
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for l in self.mp_layers:
            x = l(x, edge_index)
            x = F.relu(x)
        x = self.pool(x, batch)
        x = self.post_mp(x)
        return x
    
    def loss(self, pred, label):
        return torch.nn.functional.binary_cross_entropy(pred, label)
        

In [47]:
def train(dataset, nb_epochs=50):
    train_loader = DataLoader(dataset[:int(0.8*len(graphs))], batch_size=32, shuffle=False)
    test_loader = DataLoader(dataset[int(0.2*len(graphs)):], batch_size=32, shuffle=True)
    
    model = Net(dataset[0].num_node_features, 50, 1, 16)
    opt = torch.optim.Adam(model.parameters(), lr=0.01)
    
    for epoch in range(nb_epochs):
        total_loss = 0
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            pred = model(batch).flatten()
            label = batch.y.to(torch.float32)
            loss = model.loss(pred, label)
            loss.backward()
            opt.step()
            total_loss += loss.item()
        if epoch % 10 == 0:
            test_acc = test(model, test_loader)
            print("Epoch {}. Loss: {:.4f}. Test accuracy: {:.4f}".format(epoch, total_loss, test_acc))
        
    return model

In [54]:
def test(model, loader):
    model.eval()
    
    correct = 0
    incorrect = 0
    tot = 0
    for data in loader:
        with torch.no_grad():
            pred = model(data).flatten()
            label = data.y
            pred = [0 if p <= 0.5 else 1 for p in pred]  # 0 if prediction <= 0.5, 1 othwerwise
            correct_i = np.equal(pred, label)
            correct += np.array(correct_i).sum().item()
    return correct / len(loader.dataset)

In [53]:
# we exprect a test accuracy of 0.5, since these types of graphs can't be distinguished by the net
train(graphs, 200)

Epoch 0. Loss: 2.0793. Test accuracy: 0.5000
Epoch 10. Loss: 2.0774. Test accuracy: 0.5000
Epoch 20. Loss: 2.0768. Test accuracy: 0.5000
Epoch 30. Loss: 2.0840. Test accuracy: 0.5000
Epoch 40. Loss: 2.0705. Test accuracy: 0.5000
Epoch 50. Loss: 2.0759. Test accuracy: 0.5000
Epoch 60. Loss: 2.0783. Test accuracy: 0.5000
Epoch 70. Loss: 2.0833. Test accuracy: 0.5000
Epoch 80. Loss: 2.0687. Test accuracy: 0.5000
Epoch 90. Loss: 2.0671. Test accuracy: 0.5000
Epoch 100. Loss: 2.0713. Test accuracy: 0.5000
Epoch 110. Loss: 2.0845. Test accuracy: 0.5000
Epoch 120. Loss: 2.0832. Test accuracy: 0.5000
Epoch 130. Loss: 2.0796. Test accuracy: 0.5000
Epoch 140. Loss: 2.0877. Test accuracy: 0.5000
Epoch 150. Loss: 2.0820. Test accuracy: 0.5000
Epoch 160. Loss: 2.0794. Test accuracy: 0.5000
Epoch 170. Loss: 2.0784. Test accuracy: 0.5000
Epoch 180. Loss: 2.0786. Test accuracy: 0.5000
Epoch 190. Loss: 2.0784. Test accuracy: 0.5000


Net(
  (mp_layers): ModuleList(
    (0): GCNConv(50, 50)
    (1): GCNConv(50, 50)
    (2): GCNConv(50, 50)
    (3): GCNConv(50, 50)
    (4): GCNConv(50, 50)
    (5): GCNConv(50, 50)
    (6): GCNConv(50, 50)
    (7): GCNConv(50, 50)
    (8): GCNConv(50, 50)
    (9): GCNConv(50, 50)
    (10): GCNConv(50, 50)
    (11): GCNConv(50, 50)
    (12): GCNConv(50, 50)
    (13): GCNConv(50, 50)
    (14): GCNConv(50, 50)
    (15): GCNConv(50, 50)
  )
  (post_mp): Sequential(
    (0): Linear(in_features=50, out_features=50, bias=True)
    (1): ReLU()
    (2): Dropout(p=0.25, inplace=False)
    (3): Linear(in_features=50, out_features=1, bias=True)
    (4): Sigmoid()
  )
)