In [18]:
import dgl
import torch.nn.functional as F
import torch.nn as nn
import dgl.nn as dglnn
from dgl.data import MUTAGDataset
# from dgl.nn import HeteroGNNExplainer
from gnnexplainer import HeteroGNNExplainer
# Load dataset
data = MUTAGDataset()
g = data[0]
predict_ntype = data.predict_category
train_mask = g.nodes[predict_ntype].data['train_mask']
test_mask = g.nodes[predict_ntype].data['test_mask']
labels = g.nodes[predict_ntype].data['labels']
features = {}
for ntype in g.ntypes:
    features[ntype] = torch.zeros((g.num_nodes(ntype), 10))
# Define a model
class Model(nn.Module):
    def __init__(self, in_feats, out_feats, rel_names):
        super(Model, self).__init__()
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, out_feats)
            for rel in rel_names}, aggregate='sum')
    def forward(self, graph, feat, eweight=None):
        with graph.local_scope():
            feat = self.conv1(graph, feat)
            feat = {k: F.relu(v) for k, v in feat.items()}
            graph.ndata['h'] = feat
            if eweight is None:
                graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            else:
                graph.edata['w'] = eweight
                graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
            return graph.ndata['h']
# Train the model
model = Model(len(features[predict_ntype][0]), data.num_classes, g.etypes)
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(10):
    logits = model(g, features)[predict_ntype]
    loss = F.cross_entropy(logits[train_mask], labels[train_mask])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
# Explain the prediction for node 10
explainer = HeteroGNNExplainer(model, num_hops=8)
new_center, sg, feat_mask, edge_mask = explainer.explain_node(predict_ntype, 10, g, features)

Done loading data from cached files.


  loss = F.cross_entropy(logits[train_mask], labels[train_mask])
Explain node 10: 100%|████████████████████████| 100/100 [00:04<00:00, 24.12it/s]


In [19]:
new_center

tensor([0])

In [20]:
sg.num_edges()

10

In [21]:
feat_mask

{'SCHEMA': tensor([0.2595, 0.2569, 0.2505, 0.2719, 0.2575, 0.2746, 0.2882, 0.2612, 0.2582,
         0.2774]),
 '_Literal': tensor([0.2729, 0.2802, 0.2796, 0.2739, 0.2619, 0.2903, 0.2771, 0.2515, 0.2879,
         0.2734]),
 'bond': tensor([0.3162, 0.2492, 0.2246, 0.2584, 0.2525, 0.2435, 0.2575, 0.2880, 0.3060,
         0.2872]),
 'd': tensor([0.2628, 0.2904, 0.2877, 0.2416, 0.2869, 0.2561, 0.2542, 0.2945, 0.2757,
         0.2786]),
 'hasStructure': tensor([0.2896, 0.2622, 0.2715, 0.3015, 0.2753, 0.2617, 0.2518, 0.2856, 0.2686,
         0.2764])}