In [1]:
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import dgl.nn as dglnn
import torch.nn.functional as F
from dgl.dataloading import GraphDataLoader
from MUTAGDataset import MUTAG

In [9]:
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')
        self.pool = dglnn.AvgPooling()

    def forward(self, graph, inputs, eweight=None):
        with graph.local_scope():
            feat = self.conv1(graph, inputs)
            feat = {k: F.relu(v) for k, v in h.items()}
            feat = self.conv2(graph, feat)
            graph.ndata['h'] = feat
            if eweight is None:
                graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            else:
                graph.edata['w'] = eweight
                graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
            return self.pool(graph, graph.ndata['h'])
        
class HeteroClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes, rel_names):
        super().__init__()

        self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g, eweight=None):
        h = g.ndata['feat']
        h = self.rgcn(g, h)
        with g.local_scope():
            g.ndata['h'] = h
            # Calculate graph representation by average readout.
            hg = 0
            for ntype in g.ntypes:
                hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
            return self.classify(hg)

In [10]:
data = MUTAG()

In [21]:
from dgl.dataloading import GraphDataLoader

dataloader = GraphDataLoader(data)

In [25]:
etypes = ['C', 'N', 'O', 'F', 'I', 'Cl', 'Br', 'aromatic', 'single', 'double', 'triple']
model = HeteroClassifier(10, 20, 2, etypes)

KeyError: "attribute 'double' already exists"

In [None]:
opt = torch.optim.Adam(model.parameters())
for epoch in range(20):
    for graph, labels in dataloader:
        feats = torch.randn((graph.num_nodes(), 10)) * 0.1
        logits = model(graph)
        loss = F.cross_entropy(logits, labels)
        opt.zero_grad()
        loss.backward()
        opt.step()

In [26]:
data[0][0]

Graph(num_nodes={1: 18},
      num_edges={(1, 'C', 1): 15, (1, 'N', 1): 3, (1, 'O', 1): 2, (1, 'aromatic', 1): 32, (1, 'double', 1): 2, (1, 'single', 1): 4},
      metagraph=[(1, 1, 'C'), (1, 1, 'N'), (1, 1, 'O'), (1, 1, 'aromatic'), (1, 1, 'double'), (1, 1, 'single')])