In [16]:
import dgl
import torch.nn.functional as F
import torch
import dgl.function as fn
import torch.nn as nn
import dgl.nn as dglnn
from dgl.data import MUTAGDataset
# from dgl.nn import HeteroGNNExplainer
from gnnexplainer import HeteroGNNExplainer
# Load dataset
data = MUTAGDataset()
g = data[0]
predict_ntype = data.predict_category
train_mask = g.nodes[predict_ntype].data['train_mask']
test_mask = g.nodes[predict_ntype].data['test_mask']
labels = g.nodes[predict_ntype].data['labels']
features = {}
for ntype in g.ntypes:
    features[ntype] = torch.zeros((g.num_nodes(ntype), 10))
# Define a model
class Model(nn.Module):
    def __init__(self, in_feats, out_feats, rel_names):
        super(Model, self).__init__()
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, out_feats)
            for rel in rel_names}, aggregate='sum')
    def forward(self, graph, feat, eweight=None):
        with graph.local_scope():
            feat = self.conv1(graph, feat)
            feat = {k: F.relu(v) for k, v in feat.items()}
            graph.ndata['h'] = feat
            if eweight is None:
                graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            else:
                graph.edata['w'] = eweight
                graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
            return graph.ndata['h']
# Train the model
model = Model(len(features[predict_ntype][0]), data.num_classes, g.etypes)
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(10):
    logits = model(g, features)[predict_ntype]
    loss = F.cross_entropy(logits[train_mask], labels[train_mask])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
# Explain the prediction for node 10
explainer = HeteroGNNExplainer(model, num_hops=1)
new_center, sg, feat_mask, edge_mask = explainer.explain_node(predict_ntype, 23, g, features)

Done loading data from cached files.


  loss = F.cross_entropy(logits[train_mask], labels[train_mask])
Explain node 23: 100%|████████████████████████| 100/100 [00:04<00:00, 24.30it/s]


In [17]:
new_center

tensor([0])

In [18]:
sg.num_nodes()

5

In [19]:
feat_mask

{'SCHEMA': tensor([0.2978, 0.2735, 0.2618, 0.2601, 0.2320, 0.2645, 0.2795, 0.2812, 0.2774,
         0.2777]),
 '_Literal': tensor([0.2591, 0.2965, 0.2651, 0.3066, 0.2702, 0.2980, 0.3261, 0.3015, 0.2642,
         0.2909]),
 'bond': tensor([0.2386, 0.2759, 0.2739, 0.2762, 0.2892, 0.2817, 0.2889, 0.2984, 0.2520,
         0.2521]),
 'd': tensor([0.3003, 0.2939, 0.2714, 0.2745, 0.2751, 0.2384, 0.3025, 0.2379, 0.3005,
         0.2823]),
 'hasStructure': tensor([0.2902, 0.2823, 0.2649, 0.2675, 0.2985, 0.2650, 0.2838, 0.2440, 0.3224,
         0.2721])}

In [20]:
edge_mask

{('SCHEMA', '22-rdf-syntax-ns#type', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'owl#disjointWith', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rdf-schema#domain', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rdf-schema#range', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rdf-schema#subClassOf', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rev-22-rdf-syntax-ns#type', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rev-22-rdf-syntax-ns#type', 'bond'): tensor([]),
 ('SCHEMA', 'rev-22-rdf-syntax-ns#type', 'd'): tensor([0.1438]),
 ('SCHEMA', 'rev-owl#disjointWith', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rev-rdf-schema#domain', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rev-rdf-schema#range', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rev-rdf-schema#subClassOf', 'SCHEMA'): tensor([]),
 ('_Literal', 'rev-amesTestPositive', 'd'): tensor([]),
 ('_Literal', 'rev-charge', 'd'): tensor([0.0960]),
 ('_Literal', 'rev-chromaberr', 'd'): tensor([]),
 ('_Literal', 'rev-chromex', 'd'): tensor([]),
 ('_Literal', 'rev-cytogen_ca', 'd'): tensor([]),
 ('_Literal', 'rev-cytogen