In [39]:
import os
import pickle
import networkx as nx
from tqdm import tqdm
import torch
from torch.nn import (
    BatchNorm1d,
    Embedding,
    Linear,
    ModuleList,
    ReLU,
    Sequential,
)
from torch.optim.lr_scheduler import ReduceLROnPlateau

import numpy as np
import math

from torch_geometric.datasets import TUDataset
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx, from_networkx, to_dense_adj
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader
from torch_geometric.nn import GCNConv, global_mean_pool, global_add_pool, global_max_pool

In [31]:
def generate_shortest_path_graph(num_nodes: int, topology: str = "complete") -> Data:
    assert num_nodes > 0
    assert topology in ["complete", "path", "cycle", "regular", "tree", "ER"], "Error: unknown topology"  # Extend this list for other topologies
    
    # Create a networkx graph with the desired topology
    if topology == "complete":
        raw_graph = create_complete_graph(num_nodes)

    if topology == "path":
        raw_graph = create_path_graph(num_nodes)

    if topology == "cycle":
        raw_graph = create_cycle_graph(num_nodes)

    if topology == "regular":
        raw_graph = create_4_regular_grid_graph(num_nodes, num_nodes)

    if topology == "tree":
        raw_graph = create_binary_tree(num_nodes)
        
    if topology == "ER":
        raw_graph = create_er_graph(num_nodes)

    # Randomly select two nodes to be relevant
    relevant_nodes = np.random.choice(raw_graph.nodes(), 2, replace=False)
    
    # Add features to nodes: 1 for relevant nodes, 0 for others
    for node in raw_graph.nodes():
        raw_graph.nodes[node]['x'] = 1 if node in relevant_nodes else 0

    # Convert the NetworkX graph to PyTorch Geometric's Data format
    attributed_graph = from_networkx(raw_graph)
    
    # Calculate the shortest path distance between the two relevant nodes
    shortest_path_length = nx.shortest_path_length(raw_graph, source=relevant_nodes[0], target=relevant_nodes[1])
    
    # Add the distance as the graph label
    attributed_graph.y = torch.tensor([shortest_path_length])
    
    return attributed_graph

In [177]:
# shortest path task on complete graphs

random_integers = np.random.randint(10, 101, size=1000)
graphs = [generate_shortest_path_graph(num_nodes=nodes, topology='complete') for nodes in random_integers]

In [175]:
test_dataset = graphs

test_loader = DataLoader(test_dataset, batch_size=64)

test_mae = test(test_loader, model, device, optimizer)
print(test_mae)

2.15819287109375


In [33]:
# topologies

def create_complete_graph(num_nodes: int) -> nx.graph:
    complete_graph = nx.complete_graph(num_nodes).to_undirected()
    return complete_graph

def create_path_graph(num_nodes: int) -> nx.Graph:
    path_graph = nx.path_graph(num_nodes)
    return path_graph

def create_cycle_graph(num_nodes: int) -> nx.Graph:
    cycle_graph = nx.cycle_graph(num_nodes)
    return cycle_graph

def create_4_regular_grid_graph(rows: int, cols: int) -> nx.Graph:    
    grid_graph = nx.grid_2d_graph(rows, cols, periodic=True)  # Wraps around for 4-regular structure
    grid_graph =  nx.convert_node_labels_to_integers(grid_graph)
    for node in grid_graph.nodes:
        grid_graph.nodes[node].clear()
    return grid_graph

def create_binary_tree(num_nodes: int) -> nx.Graph:
    max_depth = math.ceil(math.log2(num_nodes + 1)) - 1
    tree = nx.balanced_tree(r=2, h=max_depth)    
    return tree

def create_er_graph(num_nodes, probability=0.5):
    G = nx.erdos_renyi_graph(n=num_nodes, p=probability)
    return G

## Train an example model

In [178]:
class GCN(torch.nn.Module):
    def __init__(self, channels, num_layers):
        super().__init__()
        
        self.node_emb = Linear(1, channels)
        self.pe_norm = BatchNorm1d(20)
        self.edge_emb = Linear(3, channels)
        
        self.convs = ModuleList()
        for _ in range(num_layers):
            conv = GCNConv(channels, channels, normalize=True)
            self.convs.append(conv)       
            
        self.mlp = Sequential(
            Linear(channels, channels),
            ReLU(),
            Linear(channels, channels // 2),
            ReLU(),
            Linear(channels // 2, channels // 4),
            ReLU(),
            Linear(channels // 4, 1),
        )
        

    def forward(self, x, edge_index, edge_attr, batch):
        # dropout = Dropout(0.5)
        x = x.float()
        # print(f"x shape before reshape: {x.shape}")
        x = x.view(-1, 1)  # Reshape to (batch_size, 1)
        # print(f"x shape after reshape: {x.shape}")

        # print(f"x shape before node_emb: {x.shape}")
        # print(x)
        # x = self.node_emb(x.squeeze(-1))
        # print(f"x shape after node_emb: {x.shape}")
        x = self.node_emb(x)
        
        for conv in self.convs:
            x = conv(x, edge_index)
            # x = dropout(x)
        x = global_max_pool(x, batch)
        return self.mlp(x)

In [179]:
def train(train_loader, model, device, optimizer):
    model.train()

    total_loss = 0
    for data in train_loader:
        data = data.to(device)
        optimizer.zero_grad()
        # model.redraw_projection.redraw_projections()
        # out = model(data.x, data.pe, data.edge_index, data.edge_attr, data.batch)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        loss = (out.squeeze() - data.y).abs().mean()
        loss.backward()
        total_loss += loss.item() * data.num_graphs
        optimizer.step()
    return total_loss / len(train_loader.dataset)


@torch.no_grad()
def test(loader, model, device, optimizer):
    model.eval()

    total_error = 0
    for data in loader:
        data = data.to(device)
        # out = model(data.x, data.pe, data.edge_index, data.edge_attr, data.batch)
        out = model(data.x, data.edge_index, data.edge_attr, data.batch)
        total_error += (out.squeeze() - data.y).abs().sum().item()
    return total_error / len(loader.dataset)

In [186]:
device = 'cpu' # torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = GCN(channels=4, num_layers=12).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001, weight_decay=1e-5)
scheduler = ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=10,
                              min_lr=0.000001)

In [187]:
train_dataset = graphs[:500]
val_dataset = graphs[500:750]
test_dataset = graphs[750:]

train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=64)
test_loader = DataLoader(test_dataset, batch_size=64)

In [151]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.00001, weight_decay=1e-5)

In [None]:
for epoch in range(1, 101):
    loss = train(train_loader, model, device, optimizer)
    val_mae = test(val_loader, model, device, optimizer)
    test_mae = test(test_loader, model, device, optimizer)
    scheduler.step(val_mae)
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Val: {val_mae:.4f}, '
          f'Test: {test_mae:.4f}')

Epoch: 01, Loss: 0.5455, Val: 0.5446, Test: 0.5446
Epoch: 02, Loss: 0.5439, Val: 0.5430, Test: 0.5430
Epoch: 03, Loss: 0.5423, Val: 0.5414, Test: 0.5414
Epoch: 04, Loss: 0.5407, Val: 0.5398, Test: 0.5398
Epoch: 05, Loss: 0.5391, Val: 0.5382, Test: 0.5382
Epoch: 06, Loss: 0.5375, Val: 0.5366, Test: 0.5366
Epoch: 07, Loss: 0.5359, Val: 0.5350, Test: 0.5350
Epoch: 08, Loss: 0.5343, Val: 0.5334, Test: 0.5334
Epoch: 09, Loss: 0.5327, Val: 0.5318, Test: 0.5318
Epoch: 10, Loss: 0.5311, Val: 0.5302, Test: 0.5302
Epoch: 11, Loss: 0.5295, Val: 0.5286, Test: 0.5286
Epoch: 12, Loss: 0.5279, Val: 0.5270, Test: 0.5270
Epoch: 13, Loss: 0.5263, Val: 0.5254, Test: 0.5254
Epoch: 14, Loss: 0.5247, Val: 0.5238, Test: 0.5238
Epoch: 15, Loss: 0.5231, Val: 0.5222, Test: 0.5222
Epoch: 16, Loss: 0.5215, Val: 0.5206, Test: 0.5206
Epoch: 17, Loss: 0.5199, Val: 0.5190, Test: 0.5190
Epoch: 18, Loss: 0.5183, Val: 0.5174, Test: 0.5174
Epoch: 19, Loss: 0.5167, Val: 0.5158, Test: 0.5158
Epoch: 20, Loss: 0.5151, Val: 0