In [1]:
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import dgl.nn as dglnn
import torch.nn.functional as F
from dgl.data import MiniGCDataset

In [87]:
import dgl.nn.pytorch as dglnn
import torch.nn as nn

class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
        self.conv1 = dglnn.GraphConv(in_dim, hidden_dim)
        self.conv2 = dglnn.GraphConv(hidden_dim, hidden_dim)
        self.conv3 = dglnn.GraphConv(hidden_dim, hidden_dim)
        self.conv4 = dglnn.GraphConv(hidden_dim, hidden_dim)
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g, h):
        # Apply graph convolution and activation.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
        h = F.relu(self.conv3(g, h))
        h = F.relu(self.conv4(g, h))
        with g.local_scope():
            g.ndata['h'] = h
            # Calculate graph representation by average readout.
            hg = dgl.mean_nodes(g, 'h')
            return self.classify(hg)

In [24]:
data = MiniGCDataset(100, 16, 32, seed=0)

In [88]:
from dgl.dataloading import GraphDataLoader
dataloader = GraphDataLoader(
    data,
    batch_size=20,
    shuffle=True)

In [89]:
model = Classifier(10, 20, data.num_classes)
opt = torch.optim.Adam(model.parameters(), 0.001)
num_correct = 0
num_tests = 0
i = 0
for epoch in range(8000):
    for batched_graph, labels in dataloader:
        feats = torch.randn((batched_graph.num_nodes(), 10)) * 0.1
        logits = model(batched_graph, feats)
        loss = F.cross_entropy(logits, labels)
        opt.zero_grad()
        loss.backward()
        opt.step()
    pred = model(batched_graph, feats).argmax(1)
    num_correct += (pred == labels).sum().item()
    num_tests += len(labels)
    i += 1
    if i % 500 == 0:
        print('Epoch: {} Test accuracy: {:.3f}'.format(i, num_correct / num_tests))
        

Epoch: 500 Test accuracy: 0.524
Epoch: 1000 Test accuracy: 0.631
Epoch: 1500 Test accuracy: 0.685
Epoch: 2000 Test accuracy: 0.719
Epoch: 2500 Test accuracy: 0.746
Epoch: 3000 Test accuracy: 0.764
Epoch: 3500 Test accuracy: 0.778
Epoch: 4000 Test accuracy: 0.789
Epoch: 4500 Test accuracy: 0.799
Epoch: 5000 Test accuracy: 0.806
Epoch: 5500 Test accuracy: 0.812
Epoch: 6000 Test accuracy: 0.817
Epoch: 6500 Test accuracy: 0.822
Epoch: 7000 Test accuracy: 0.826
Epoch: 7500 Test accuracy: 0.829
Epoch: 8000 Test accuracy: 0.833


In [94]:
from dgl.nn import GNNExplainer

explainer = GNNExplainer(model, num_hops=4)
feats = torch.randn((batched_graph.num_nodes(), 10)) * 0.1
feat_mask, edge_mask = explainer.explain_graph(data[3], feats)

TypeError: forward() got an unexpected keyword argument 'graph'