In [2]:
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import dgl.nn as dglnn
import torch.nn.functional as F
from dgl.dataloading import GraphDataLoader
from mutagDataset import MUTAG, MUTAGOneNtype

In [47]:
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, inputs):
        # inputs is features of nodes
        h = self.conv1(graph, inputs)
        h = {k: F.relu(v) for k, v in h.items()}
        h = self.conv2(graph, h)
        return h

class HeteroClassifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes, rel_names):
        super().__init__()
        
        features = {}
        for canonical_edge in g.canonical_etypes:
          features[str(canonical_edge)] = torch.randn((g[canonical_edge].num_nodes(), 10))
        rel_names = [str(canonical_etype) for canonical_etype in g.canonical_etypes]    
        self.rgcn = RGCN(in_dim, hidden_dim, hidden_dim, rel_names)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
        h = g.ndata['feat']
        
        h = self.rgcn(g, h)
        with g.local_scope():
            print(g.ndata['h'])
            # Calculate graph representation by average readout.
            hg = 0
            for ntype in g.ntypes:
                hg = hg + dgl.mean_nodes(g, 'h', ntype=ntype)
            return self.classify(hg)


In [48]:
data = MUTAG()

In [49]:
from dgl.dataloading import GraphDataLoader

dataloader = GraphDataLoader(data)

In [50]:
for g, l in dataloader:
    break

In [51]:
rel_names = set()

for graph in data:
    for edge in graph[0].canonical_etypes:
        rel_names.add(str(edge))


In [52]:
model = HeteroClassifier(10, 20, 2, rel_names)

In [53]:
model(g)

{}


KeyError: 'h'

In [43]:
# opt = torch.optim.Adam(model.parameters())
# for epoch in range(20):
#     for batched_graph, labels in dataloader:
#         logits = model(batched_graph)
#         loss = F.cross_entropy(logits, labels)
#         opt.zero_grad()
#         loss.backward()
#         opt.step()

In [17]:
data = MUTAGOneNtype()

In [18]:
from dgl.dataloading import GraphDataLoader

dataloader = GraphDataLoader(data)

In [21]:
for g, l in dataloader:
    break
    
g

Graph(num_nodes={1: 18},
      num_edges={(1, 'C', 1): 66, (1, 'N', 1): 6, (1, 'O', 1): 4, (1, 'aromatic', 1): 32, (1, 'double', 1): 2, (1, 'single', 1): 4},
      metagraph=[(1, 1, 'C'), (1, 1, 'N'), (1, 1, 'O'), (1, 1, 'aromatic'), (1, 1, 'double'), (1, 1, 'single')])

In [34]:
rel_names = ['C', 'N', 'O', 'F', 'I', 'Cl', 'Br', 'aromatic', 'single', 'double', 'triple']
features = {}

for etype in g.etypes:
  features[etype] = torch.randn((g[etype].num_nodes(), 10))

{"('Br', 'single', 'C')",
 "('C', 'aromatic', 'C')",
 "('C', 'aromatic', 'N')",
 "('C', 'aromatic', 'O')",
 "('C', 'double', 'C')",
 "('C', 'double', 'N')",
 "('C', 'double', 'O')",
 "('C', 'single', 'Br')",
 "('C', 'single', 'C')",
 "('C', 'single', 'Cl')",
 "('C', 'single', 'F')",
 "('C', 'single', 'I')",
 "('C', 'single', 'N')",
 "('C', 'single', 'O')",
 "('C', 'triple', 'N')",
 "('Cl', 'single', 'C')",
 "('F', 'single', 'C')",
 "('I', 'single', 'C')",
 "('N', 'aromatic', 'C')",
 "('N', 'double', 'C')",
 "('N', 'double', 'N')",
 "('N', 'double', 'O')",
 "('N', 'single', 'C')",
 "('N', 'single', 'N')",
 "('N', 'single', 'O')",
 "('N', 'triple', 'C')",
 "('O', 'aromatic', 'C')",
 "('O', 'double', 'C')",
 "('O', 'double', 'N')",
 "('O', 'single', 'C')",
 "('O', 'single', 'N')"}