In [1]:
import dgl
import torch
import torch.nn as nn
import torch.nn.functional as F

Using backend: pytorch


In [2]:
# Generate a synthetic dataset with 10000 graphs, ranging from 10 to 500 nodes
import dgl.data

dataset = dgl.data.GINDataset('PROTEINS', self_loop=True)
print('Node feature dimensionality:', dataset.dim_nfeats)
print('Number of graph categories:', dataset.gclasses)


Node feature dimensionality: 3
Number of graph categories: 2


In [3]:
#Load data with mini-batches

from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

num_samples = len(dataset)
num_train_samples = int(0.8 * num_samples)
num_valid_samples = int(0.1 * num_samples)
num_test_samples = int(0.1 * num_samples)

# Choose a data sampler
# There are many other options, see details below:
# https://pytorch.org/docs/stable/data.html#data-loading-order-and-sampler
train_sampler = SubsetRandomSampler(torch.arange(num_train_samples))
valid_sampler = SubsetRandomSampler(torch.arange(num_train_samples, num_train_samples+num_valid_samples))
test_sampler = SubsetRandomSampler(torch.arange(num_train_samples+num_valid_samples, num_train_samples+num_valid_samples+num_test_samples))

train_loader = GraphDataLoader(
    dataset, sampler=train_sampler, batch_size=16, drop_last=False)
valid_loader = GraphDataLoader(
    dataset, sampler=valid_sampler, batch_size=16, drop_last=False)
test_loader = GraphDataLoader(
    dataset, sampler=test_sampler, batch_size=16, drop_last=False)

# Check datapoint
item = iter(train_loader)
batch = next(item)
print(batch)

[Graph(num_nodes=1006, num_edges=4572,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), tensor([0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0])]


In [4]:
graph_batch, labels = batch
print('Number of nodes for each graph:', graph_batch.batch_num_nodes())
print('Number of edges for each graph:', graph_batch.batch_num_edges())

# Reconstruct original graphs in a mini-batch
ori_graphs = dgl.unbatch(graph_batch)
print('The original graphs in the minibatch:', ori_graphs)

Number of nodes for each graph: tensor([246, 126,  43,  71,  43,   6,  18,  27, 146,  23,  50,   6, 144,  13,
         18,  26])
Number of edges for each graph: tensor([1116,  656,  195,  317,  199,   30,   80,  149,  628,  107,  226,   28,
         582,   63,   80,  116])
The original graphs in the minibatch: [Graph(num_nodes=246, num_edges=1116,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=126, num_edges=656,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=43, num_edges=195,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
      edata_schemes={}), Graph(num_nodes=71, num_edges=317,
      ndata_schemes={'label': Scheme(shape=(), dtype=torch.int64), 'attr': Scheme(shape=(3,), dtype=torch.float32)}
     

In [18]:
# Build models

from dgl.nn import GraphConv, GATConv, GINConv

# Graph Convolutional Networks
# TODO: more options for graph pooling
class GCN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, h_feats)
        self.conv2 = GraphConv(h_feats, num_classes)
        
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')
    
# Graph Attention Networks
class GAT(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GAT, self).__init__()
        self.conv1 = GATConv(in_feats, h_feats, num_heads=3)
        self.conv2 = GATConv(h_feats, num_classes, num_heads=3)
        
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = torch.mean(h, dim=1) # Average value over three attention heads
        h = self.conv2(g, h)
        h = torch.mean(h, dim=1) # Average value over three attention heads
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')
    
# Graph Isomorphism Networks
class GIN(nn.Module):
    def __init__(self, in_feats, h_feats, num_classes):
        super(GIN, self).__init__()
        lin1 = torch.nn.Linear(in_feats, h_feats)
        lin2 = torch.nn.Linear(h_feats, num_classes)
        self.conv1 = GINConv(lin1, 'sum')
        self.conv2 = GINConv(lin2, 'sum')
        
    def forward(self, g, in_feat):
        h = self.conv1(g, in_feat)
        h = F.relu(h)
        h = self.conv2(g, h)
        g.ndata['h'] = h
        return dgl.mean_nodes(g, 'h')

In [13]:
# Train models

model_gcn = GCN(dataset.dim_nfeats, 16, dataset.gclasses)

optimizer = torch.optim.Adam(model_gcn.parameters(), lr=0.01)

for epoch in range(20):
    for graph_batch, labels in train_loader:
        pred = model_gcn(graph_batch, graph_batch.ndata['attr'].float())
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

num_correct = 0
num_valids = 0
for graph_batch, labels in valid_loader:
    pred = model_gcn(graph_batch, graph_batch.ndata['attr'].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_valids += len(labels)        

print('Valid accuracy:', num_correct / num_valids)
    
num_correct = 0
num_tests = 0
for graph_batch, labels in test_loader:
    pred = model_gcn(graph_batch, graph_batch.ndata['attr'].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_tests += len(labels)

print('Test accuracy:', num_correct / num_tests)

Valid accuracy: 0.2882882882882883
Test accuracy: 0.2702702702702703


In [20]:
# Train GAT model

model_gat = GAT(dataset.dim_nfeats, 16, dataset.gclasses)

optimizer = torch.optim.Adam(model_gat.parameters(), lr=0.01)

for epoch in range(20):
    for graph_batch, labels in train_loader:
        pred = model_gat(graph_batch, graph_batch.ndata['attr'].float())
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

num_correct = 0
num_valids = 0
for graph_batch, labels in valid_loader:
    pred = model_gat(graph_batch, graph_batch.ndata['attr'].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_valids += len(labels)        

print('Valid accuracy:', num_correct / num_valids)
    
num_correct = 0
num_tests = 0
for graph_batch, labels in test_loader:
    pred = model_gat(graph_batch, graph_batch.ndata['attr'].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_tests += len(labels)

print('Test accuracy:', num_correct / num_tests)

Valid accuracy: 0.2972972972972973
Test accuracy: 0.27927927927927926


In [21]:
# Train GIN model

model_gin = GIN(dataset.dim_nfeats, 16, dataset.gclasses)

optimizer = torch.optim.Adam(model_gin.parameters(), lr=0.01)

for epoch in range(20):
    for graph_batch, labels in train_loader:
        pred = model_gin(graph_batch, graph_batch.ndata['attr'].float())
        loss = F.cross_entropy(pred, labels)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

num_correct = 0
num_valids = 0
for graph_batch, labels in valid_loader:
    pred = model_gin(graph_batch, graph_batch.ndata['attr'].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_valids += len(labels)        

print('Valid accuracy:', num_correct / num_valids)
    
num_correct = 0
num_tests = 0
for graph_batch, labels in test_loader:
    pred = model_gin(graph_batch, graph_batch.ndata['attr'].float())
    num_correct += (pred.argmax(1) == labels).sum().item()
    num_tests += len(labels)

print('Test accuracy:', num_correct / num_tests)

Valid accuracy: 0.2702702702702703
Test accuracy: 0.24324324324324326
