In [57]:
import dgl
import dgl.function as fn
import torch
import torch.nn as nn
import dgl.nn as dglnn
import torch.nn.functional as F
from dgl.data import MUTAGDataset

In [69]:
# Define a Heterograph Conv model

class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv2 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv3 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
        self.conv4 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(hid_feats, out_feats)
            for rel in rel_names}, aggregate='sum')

    def forward(self, graph, feat, eweight=None):
        # inputs are features of nodes
        with graph.local_scope():
            feat = self.conv1(graph, feat)
            feat = {k: F.relu(v) for k, v in feat.items()}
            feat = self.conv2(graph, feat)
            feat = {k: F.relu(v) for k, v in feat.items()}
            feat = self.conv3(graph, feat)
            feat = {k: F.relu(v) for k, v in feat.items()}
            feat = self.conv4(graph, feat)
            graph.ndata['h'] = feat
            if eweight is None:
                graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            else:
                graph.edata['w'] = eweight
                graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
            return graph.ndata['h']

In [70]:
def train(model, g,node_features, category, epochs, printInterval):
    opt = torch.optim.Adam(model.parameters())
    train_mask = g.nodes[category].data['train_mask']
    test_mask = g.nodes[category].data['test_mask']
    labels = g.nodes[category].data['labels']

    for epoch in range(epochs):
        model.train()
        # forward propagation by using all nodes and extracting the user embeddings
        logits = model(g, node_features)[category]
        pred = logits.argmax(1)
        # compute loss
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])
        # Calculate training and test accuracy
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()
        # Backward propagation
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        if epoch % printInterval == 0:
            print('In epoch {}, loss: {:.3f}, train acc: {:.3f}, test acc: {:.3f})'.format(
                epoch, loss,train_acc, test_acc))
    print('In epoch {}, loss: {:.3f}, train acc: {:.3f}, test acc: {:.3f})'.format(
                epoch, loss,train_acc, test_acc))


In [71]:
data = MUTAGDataset()
g = data[0]

Done loading data from cached files.


In [72]:
category = data.predict_category
num_classes = data.num_classes

In [73]:
features = {}
for ntype in g.ntypes:
    features[ntype] = torch.zeros((g.num_nodes(ntype), 10))

In [74]:
# model = RGCN(10, 20, num_classes, g.etypes)
# train(model, g, features, category, 300, 50)
# torch.save(model, '../models/Mutag_Trained_Model.pt')

In [75]:
model = torch.load("../models/Mutag_Trained_Model.pt")

In [76]:
from gnnexplainer import HeteroGNNExplainer

predict_ntype = data.predict_category
explainer = HeteroGNNExplainer(model, num_hops=4)
new_center, sg, feat_mask, edge_mask = explainer.explain_node(predict_ntype, 10, g, features)

Explain node 10: 100%|████████████████████████| 100/100 [00:28<00:00,  3.47it/s]
