In [3]:
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import dgl.nn as dglnn
import torch.nn.functional as F
from dgl.dataloading import GraphDataLoader
from mutagDataset import MUTAG

In [16]:
class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')
        self.pool = dglnn.AvgPooling()

    def forward(self, graph, inputs, eweight=None):
        with graph.local_scope():
            feat = self.conv1(graph, inputs)
            feat = {k: F.relu(v) for k, v in feat.items()}
            feat = self.conv2(graph, feat)
            graph.ndata['h'] = feat
            if eweight is None:
                graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            else:
                graph.edata['w'] = eweight
                graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
            return self.pool(graph, graph.ndata['h'])

In [4]:
data = MUTAG()

In [5]:
from dgl.dataloading import GraphDataLoader

dataloader = GraphDataLoader(data)