In [1]:
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import dgl.nn as dglnn
import torch.nn.functional as F
from dgl.dataloading import GraphDataLoader
from mutagDataset import MUTAG, MUTAGOneNtype

In [2]:
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs is features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

class HeteroClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes, rel_names):
        super().__init__()
        self.in_dim = in_dim
        self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        h = {ntype:torch.zeros((g.num_nodes(ntype), self.in_dim)) for ntype in g.ntypes}
        
        h = self.rgcn(g, h)

        with g.local_scope():
            g.ndata['h'] = h
            # Calculate graph representation by average readout.
            hg = 0
            for ntype in g.ntypes:
                hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
            return self.classify(hg)


In [27]:
from dgl.dataloading import GraphDataLoader
from torch.utils.data.sampler import SubsetRandomSampler

data = MUTAG()

num_examples = len(data)
num_train = int(num_examples * 0.8)

train_sampler = SubsetRandomSampler(torch.arange(num_train))
test_sampler = SubsetRandomSampler(torch.arange(num_train, num_examples))

train_dataloader = GraphDataLoader(
    data, sampler=train_sampler)
test_dataloader = GraphDataLoader(
    data, sampler=test_sampler)

In [6]:
rel_names = set()

for graph in data:
    for edge in graph[0].etypes:
        rel_names.add(str(edge))

rel_names = list(rel_names)

In [33]:
model = HeteroClassifier(10, 20, 2, rel_names)
opt = torch.optim.Adam(model.parameters())
for epoch in range(100):
    num_correct, num_tests = 0, 0
    for batched_graph, labels in train_dataloader:
        logits = model(batched_graph)
        loss = F.cross_entropy(logits, labels)
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        num_correct += (logits.argmax(1) == labels).sum().item()
        num_tests += len(labels)
    if epoch % 10 == 0:
        print(f'Epochs {epoch}, Train accuracy: {num_correct / num_tests * 100:.2f}%')
    
    
num_correct, num_tests = 0, 0    
for batched_graph, labels in test_dataloader:
    pred = model(batched_graph)
    num_correct += (logits.argmax(1) == labels).sum().item()
    num_tests += len(labels)
        
print(f'Test accuracy: {num_correct / num_tests * 100:.2f}%')

Epochs 0, Train accuracy: 66.00%
Epochs 10, Train accuracy: 66.00%
Epochs 20, Train accuracy: 66.00%
Epochs 30, Train accuracy: 67.33%
Epochs 40, Train accuracy: 69.33%
Epochs 50, Train accuracy: 70.67%
Epochs 60, Train accuracy: 66.00%
Epochs 70, Train accuracy: 68.67%
Epochs 80, Train accuracy: 70.00%
Epochs 90, Train accuracy: 68.67%
Test accuracy: 68.42%
