In [1]:
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch_geometric.nn as geom_nn
from torch_geometric.data import Data
from torch_geometric.data import DataLoader
from torch_geometric.data import InMemoryDataset
from torch_geometric.utils import from_networkx

# Generating a cycle dataset

Here we generate a dataset containing pairs of graphs that are not distinguishable by the 1-WL isomorphism test.
Later we will use GNNs to learn to tell them apart.

In [2]:
graphs_nx = []
graphs_is_cycle = []  # keeps track if a graph is a cycle or disjoint
for n in range(6, 16):
    g_cyc = nx.Graph()
    g_cyc.add_nodes_from(range(n))
    g_cyc.add_edges_from([(x, x+1) for x in range(n-1)] + [(n-1, 0)])  # connect nodes to cycle
    
    for split_n in range(3,n-2):
        g_split = nx.Graph()
        g_cyc.add_nodes_from(range(n))
        g_split.add_edges_from([(x, x+1) for x in range(split_n-1)] + [(split_n-1, 0)])  # first cycle of size split_n
        g_split.add_edges_from([(x, x+1) for x in range(split_n, n-1)] + [(n-1, split_n)])  # dsecond cycle of remaing nodes
        graphs_nx.append(g_split)
        graphs_is_cycle.append(False)
        graphs_nx.append(g_cyc)  # add g_cyc every time to maintain balance
        graphs_is_cycle.append(True)
        
        

In [3]:
# Converting the graphs to torch_geometric.data.Data objects
graphs = [from_networkx(g) for g in graphs_nx]
for i_g, g in enumerate(graphs):
    g.x = torch.zeros((g.num_nodes, 50))  # uniform x/features
    g.y = torch.tensor([1 if graphs_is_cycle[i_g] else 0])  # target label indicating whether graph is cycle or disjoint

# Building a GNN

In [13]:
# Custom message passing layer
#class CustomLayer(geom_nn.MessagePassing):
    #def __init__(self, in_channels, out_channels):
        #super().__init__(aggr='add')
        #self.lin = nn.Linear(in_channels, out_channels)
        #self.activation = nn.ReLU()
        
    # activation, linear etc probably goes here
    #def forward(self, x, edge_index):
        #return self.propagate(edge_index, x=x)
    
    # stuff that happens after all the message passing, so just id??
    #def update(self, x):
        #return self.activation(self.linear(x))
        
    # this looks fine
    #def message(self, x_j):
        #return x_j

class Net(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, depth, mp_layer=None):
        super().__init__()
        # the type of message passing layer used throughout the net
        self.mp_layer = geom_nn.GCNConv if mp_layer is None else mp_layer
        self.pool = geom_nn.global_mean_pool
        self.mp_layers = nn.ModuleList()
        
        # standard mlp used after message passing layers
        self.post_mp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid())
        
        # add message passing layers
        self.mp_layers.append(self.mp_layer(input_dim, hidden_dim))
        for i in range(depth-1):
            self.mp_layers.append(self.mp_layer(hidden_dim, hidden_dim))
            
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        for l in self.mp_layers:
            x = l(x, edge_index)
            x = F.relu(x)
        x = self.pool(x, batch)
        x = self.post_mp(x)
        return x
    
    def loss(self, pred, label):
        return torch.nn.functional.binary_cross_entropy(pred, label)
        

In [129]:
def train(dataset, nb_epochs=50, rni=False, lr=0.01, test_data=None):
    train_loader = DataLoader(dataset, batch_size=32, shuffle=True)
    
    if rni:
        model = RNINet(dataset[0].num_node_features, 50, 1, 16)
    else:
        model = Net(dataset[0].num_node_features, 50, 1, 16)
    opt = torch.optim.Adam(model.parameters(), lr=lr)
    
    for epoch in range(nb_epochs):
        total_loss = 0
        model.train()
        for batch in train_loader:
            opt.zero_grad()
            pred = model(batch).flatten()
            label = batch.y.to(torch.float32)
            loss = model.loss(pred, label)
            loss.backward()
            opt.step()
            total_loss += loss.item()
        if not test_data is None:
            if epoch % 10 == 0:
                test_acc = test(model, test_data)
                print("Epoch {}. Loss: {:.4f}. Test accuracy: {:.4f}".format(epoch, total_loss, test_acc))
        
    return model

In [59]:
def test(model, dataset):
    loader = DataLoader(dataset, batch_size=32, shuffle=True)
    model.eval()
    
    correct = 0
    incorrect = 0
    tot = 0
    for data in loader:
        with torch.no_grad():
            pred = model(data).flatten()
            label = data.y
            pred = [0 if p <= 0.5 else 1 for p in pred]  # 0 if prediction <= 0.5, 1 othwerwise
            correct_i = np.equal(pred, label)
            correct += np.array(correct_i).sum().item()
    return correct / len(loader.dataset)

In [127]:
def cross_validate(data, k=5, **kwargs):
    # We perform k-fold cross validation
    chunks = []
    chunk_size = int(len(data)/k)
    # split the data
    for i in range(k):
        if i*chunk_size+chunk_size <= len(data):
            chunks.append(data[i*chunk_size:(i+1)*chunk_size])
        else:
            chunks.append(data[i*chunk_size:])
    # perform training and testing
    accuracies = []
    for i, chunk in enumerate(chunks):
        test_set = chunk
        train_set = [x for l in (chunks[0:i] + chunks[i+1:]) for x in l]
        model = train(train_set, **kwargs)
        acc = test(model, test_set)
        accuracies.append(acc)
        print('Accuracy at test-chunk {}: {}'.format(i, acc))
    return np.array(accuracies).sum() / len(accuracies)

In [128]:
# we exprect a test accuracy of 0.5, since these types of graphs can't be distinguished by the net
avg_acc = cross_validate(graphs)
print(avg_acc)

Accuracy at test-chunk 0: 0.5
Accuracy at test-chunk 1: 0.5
Accuracy at test-chunk 2: 0.5
Accuracy at test-chunk 3: 0.5
Accuracy at test-chunk 4: 0.5
0.5


In [117]:
# Initializes 50% of node fearures randomly, to increase expressive power
class RNINet(nn.Module):
    def __init__(self, input_dim, hidden_dim, output_dim, depth, mp_layer=None):
        super().__init__()
        # the type of message passing layer used throughout the net
        self.mp_layer = geom_nn.GCNConv if mp_layer is None else mp_layer
        self.pool = geom_nn.global_mean_pool
        self.mp_layers = nn.ModuleList()
        
        # standard mlp used after message passing layers
        self.post_mp = nn.Sequential(
            nn.Linear(hidden_dim, hidden_dim),
            nn.ReLU(),
            nn.Dropout(0.25),
            nn.Linear(hidden_dim, output_dim),
            nn.Sigmoid())
        
        # add message passing layers
        self.mp_layers.append(self.mp_layer(input_dim, hidden_dim))
        for i in range(depth-1):
            self.mp_layers.append(self.mp_layer(hidden_dim, hidden_dim))
            
    def rni(self, x):
        init_len = len(x[0])
        random_len = int(len(x[0])/2)
        random_samples = torch.normal(mean=0, std=1, size=(len(x), random_len))
        x_half = x.transpose(0,1)[init_len-random_len:].transpose(0,1)
        x_rni = torch.cat((x_half,random_samples), 1)
        assert len(x_rni[0]) == init_len
        return x_rni
            
    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        # RNI
        x = self.rni(x)
        for l in self.mp_layers:
            x = l(x, edge_index)
            x = F.relu(x)
        x = self.pool(x, batch)
        x = self.post_mp(x)
        return x
    
    def loss(self, pred, label):
        return torch.nn.functional.binary_cross_entropy(pred, label)

In [126]:
# train with RNI. We expect the accuracy to go up, since expressive power of the model was increased
model = train(graphs[:int(0.8*len(graphs))], nb_epochs=300, rni=True, lr=0.0001, test_data=graphs[int(0.8*len(graphs)):])

Epoch 0. Loss: 2.0795. Test accuracy: 0.5000
Epoch 10. Loss: 2.0727. Test accuracy: 0.5000
Epoch 20. Loss: 2.0832. Test accuracy: 0.5000
Epoch 30. Loss: 2.0812. Test accuracy: 0.5000
Epoch 40. Loss: 2.0779. Test accuracy: 0.5000
Epoch 50. Loss: 2.0757. Test accuracy: 0.5000
Epoch 60. Loss: 2.0749. Test accuracy: 0.5000
Epoch 70. Loss: 2.0815. Test accuracy: 0.5000
Epoch 80. Loss: 2.0573. Test accuracy: 0.5000
Epoch 90. Loss: 2.0637. Test accuracy: 0.5000
Epoch 100. Loss: 2.0158. Test accuracy: 0.6818
Epoch 110. Loss: 2.0014. Test accuracy: 0.8182
Epoch 120. Loss: 1.9460. Test accuracy: 0.5909
Epoch 130. Loss: 1.9411. Test accuracy: 0.5909
Epoch 140. Loss: 1.9185. Test accuracy: 0.5909
Epoch 150. Loss: 1.7162. Test accuracy: 0.4091
Epoch 160. Loss: 1.7870. Test accuracy: 0.7273
Epoch 170. Loss: 1.8048. Test accuracy: 0.5455
Epoch 180. Loss: 1.6966. Test accuracy: 0.6818
Epoch 190. Loss: 1.7480. Test accuracy: 0.6818
Epoch 200. Loss: 1.7930. Test accuracy: 0.7727
Epoch 210. Loss: 1.6587.

In [130]:
avg_acc = cross_validate(graphs, rni=True, nb_epochs=300, lr=0.0001)
print(avg_acc)

Accuracy at test-chunk 0: 0.6818181818181818
Accuracy at test-chunk 1: 0.9090909090909091
Accuracy at test-chunk 2: 0.6363636363636364
Accuracy at test-chunk 3: 0.6818181818181818
Accuracy at test-chunk 4: 0.6818181818181818
0.718181818181818
