In [6]:
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx

# Generating a cycle dataset

Here we generate a dataset containing pairs of graphs that are not distinguishable by the 1-WL isomorphism test.
Later we will use GNNs to learn to tell them apart.

In [9]:
graphs_nx = []
graphs_is_cycle = []  # keeps track if a graph is a cycle or disjoint
for n in range(6, 16):
    g_cyc = nx.Graph()
    g_cyc.add_nodes_from(range(n))
    g_cyc.add_edges_from([(x, x+1) for x in range(n-1)] + [(n-1, 0)])  # connect nodes to cycle
    
    for split_n in range(3,n-2):
        g_split = nx.Graph()
        g_cyc.add_nodes_from(range(n))
        g_split.add_edges_from([(x, x+1) for x in range(split_n-1)] + [(split_n-1, 0)])  # first cycle of size split_n
        g_split.add_edges_from([(x, x+1) for x in range(split_n, n-1)] + [(n-1, split_n)])  # dsecond cycle of remaing nodes
        graphs_nx.append(g_split)
        graphs_is_cycle.append(False)
        graphs_nx.append(g_cyc)  # add g_cyc every time to maintain balance
        graphs_is_cycle.append(True)
        
        

In [10]:
# Converting the graphs to torch_geometric.data.Data objects
graphs = [from_networkx(g) for g in graphs_nx]
for i_g, g in enumerate(graphs):
    g.x = torch.zeros((g.num_nodes, 50))  # uniform x/features
    g.y = torch.tensor([1 if graphs_is_cycle[i_g] else 0])  # target label indicating whether graph is cycle or disjoint