In [1]:
import argparse
import os.path as osp
import os
import torch
import torch.nn.functional as F
import torch_geometric
import torch_geometric.data as geom_data
import numpy as np
from torch_geometric.datasets import Entities
from torch_geometric.utils import k_hop_subgraph
from torch_geometric.nn import RGCNConv, FastRGCNConv


In [2]:
cwd = os.getcwd()

dataset = Entities(cwd, "MUTAG")
data = dataset[0]

available dataset info includes:  'name',
 'num_classes',
 'num_edge_features',
 'num_features',
 'num_node_features',
 'num_relations',

In [3]:
print("Data object:", dataset.data)
print("Length:", len(dataset))
print("Dataset: ", dataset)
#print("Average label: %4.2f" % (dataset.data.y.float().mean().item()))
print(dataset.num_relations, dataset.num_classes)


Data object: Data(edge_index=[2, 148454], edge_type=[148454], test_idx=[68], test_y=[68], train_idx=[272], train_y=[272])
Length: 1
Dataset:  MUTAGEntities()
46 2


In [4]:
# Splitting into train and test
torch.manual_seed(42)
dataset.shuffle()
train_dataset = dataset[:150]
test_dataset = dataset[150:]

In [5]:
# torch geometric automatically implements batching multiple graphs into 1 huge block diagonal adjacency matrix 
# + concatenates feature matrices etc.
graph_train_loader = geom_data.DataLoader(train_dataset, batch_size=64, shuffle=True)
graph_val_loader = geom_data.DataLoader(test_dataset, batch_size=64) # Additional loader if you want to change to a larger dataset
graph_test_loader = geom_data.DataLoader(test_dataset, batch_size=64)

In [6]:
# BGS and AM graphs are too big to process them in a full-batch fashion.
# Since our model does only make use of a rather small receptive field, we
# filter the graph to only contain the nodes that are at most 2-hop neighbors
# away from any training/test node.

# k_hop_subgraph
# Computes the k-hop subgraph of edge_index around node node_idx, returns:
# (1) the nodes involved in the subgraph, 
# (2) the filtered edge_index connectivity, 
# (3) the mapping from node indices in node_idx to their new location, and 
# (4) the edge mask indicating which edges were preserved.

node_idx = torch.cat([data.train_idx, data.test_idx], dim=0)
node_idx, edge_index, mapping, edge_mask = k_hop_subgraph(
    node_idx, 2, data.edge_index, relabel_nodes=True)

data.num_nodes = node_idx.size(0)
data.edge_index = edge_index
data.edge_type = data.edge_type[edge_mask]
data.train_idx = mapping[:data.train_idx.size(0)]
data.test_idx = mapping[data.train_idx.size(0):]

np.unique(data.edge_type)


array([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16,
       17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33,
       34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45])

We have 38 graphs stacked together for the test dataset. The batch indices, stored in batch, show that the first 12 nodes belong to the first graph, the next 22 to the second graph, and so on.

These indices are important for performing the final prediction. To perform a prediction over a whole graph, we usually perform a pooling operation over all nodes after running the GNN model. In this case, we will use the average pooling. Hence, we need to know which nodes should be included in which average pool. Using this pooling, we can already create our graph network below. Specifically, we re-use our class GNNModel from before, and simply add an average pool and single linear layer for the graph prediction task.


In [16]:
class Net(torch.nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = RGCNConv(data.num_nodes, 16, dataset.num_relations,
                              num_bases=30)
        self.conv2 = RGCNConv(16, dataset.num_classes, dataset.num_relations,
                              num_bases=30)
        self.node_embeddings = []

    def forward(self, edge_index, edge_type):
        x = F.relu(self.conv1(None, edge_index, edge_type))
        x = self.conv2(x, edge_index, edge_type)   
        x = F.log_softmax(x, dim=1)
        # Here we save node embeddings for all nodes = shape [23606,2]
        self.node_embeddings = x
        return x


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model, data = Net().to(device), data.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01, weight_decay=0.0005)


In [18]:
def train():
    model.train()
    optimizer.zero_grad()
    out = model(data.edge_index, data.edge_type)
    loss = F.nll_loss(out[data.train_idx], data.train_y)
    loss.backward()
    optimizer.step()
    return loss.item()


@torch.no_grad()
def test():
    model.eval()
    pred = model(data.edge_index, data.edge_type).argmax(dim=-1)
    train_acc = pred[data.train_idx].eq(data.train_y).to(torch.float).mean()
    test_acc = pred[data.test_idx].eq(data.test_y).to(torch.float).mean()
    return train_acc.item(), test_acc.item()

# originally 51
for epoch in range(1, 6):
    loss = train()
    train_acc, test_acc = test()
    print(f'Epoch: {epoch:02d}, Loss: {loss:.4f}, Train: {train_acc:.4f} '
          f'Test: {test_acc:.4f}')
    
print(model.node_embeddings)
print(model.node_embeddings.shape)

conv1.weight torch.Size([30, 23606, 16])
conv1.comp torch.Size([46, 30])
conv1.root torch.Size([23606, 16])
conv1.bias torch.Size([16])
conv2.weight torch.Size([30, 16, 2])
conv2.comp torch.Size([46, 30])
conv2.root torch.Size([16, 2])
conv2.bias torch.Size([2])
conv1 weight torch.Size([30, 23606, 16])
conv2 weight torch.Size([30, 16, 2])
Edge index and edge type torch.Size([2, 148082]) torch.Size([148082])
x shape before conv2 torch.Size([23606, 16])
x shape after conv2 torch.Size([23606, 2])
torch.Size([23606, 2])
conv1.weight torch.Size([30, 23606, 16])
conv1.comp torch.Size([46, 30])
conv1.root torch.Size([23606, 16])
conv1.bias torch.Size([16])
conv2.weight torch.Size([30, 16, 2])
conv2.comp torch.Size([46, 30])
conv2.root torch.Size([16, 2])
conv2.bias torch.Size([2])
conv1 weight torch.Size([30, 23606, 16])
conv2 weight torch.Size([30, 16, 2])
Edge index and edge type torch.Size([2, 148082]) torch.Size([148082])
x shape before conv2 torch.Size([23606, 16])
x shape after conv2 to