In [29]:
import dgl
import torch.nn.functional as F
import torch
import dgl.function as fn
import torch.nn as nn
import dgl.nn as dglnn
from dgl.data import MUTAGDataset
# from dgl.nn import HeteroGNNExplainer
from gnnexplainer import HeteroGNNExplainer
# Load dataset
data = MUTAGDataset()
g = data[0]
predict_ntype = data.predict_category
train_mask = g.nodes[predict_ntype].data['train_mask']
test_mask = g.nodes[predict_ntype].data['test_mask']
labels = g.nodes[predict_ntype].data['labels']
features = {}
for ntype in g.ntypes:
    features[ntype] = torch.zeros((g.num_nodes(ntype), 10))
# Define a model
class Model(nn.Module):
    def __init__(self, in_feats, out_feats, rel_names):
        super(Model, self).__init__()
        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, out_feats)
            for rel in rel_names}, aggregate='sum')
    def forward(self, graph, feat, eweight=None):
        with graph.local_scope():
            feat = self.conv1(graph, feat)
            feat = {k: F.relu(v) for k, v in feat.items()}
            graph.ndata['h'] = feat
            if eweight is None:
                graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            else:
                graph.edata['w'] = eweight
                graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
            return graph.ndata['h']
# Train the model
model = Model(len(features[predict_ntype][0]), data.num_classes, g.etypes)
optimizer = torch.optim.Adam(model.parameters())
for epoch in range(100):
    logits = model(g, features)[predict_ntype]
    loss = F.cross_entropy(logits[train_mask], labels[train_mask])
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
# Explain the prediction for node 10
explainer = HeteroGNNExplainer(model, num_hops=1)
new_center, sg, feat_mask, edge_mask = explainer.explain_node(predict_ntype, 9528, g, features)

Done loading data from cached files.


  loss = F.cross_entropy(logits[train_mask], labels[train_mask])
  Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
Explain node 9528: 100%|██████████████████████| 100/100 [00:04<00:00, 23.91it/s]


In [17]:
new_center

tensor([23])

In [30]:
sg.num_nodes()

65

In [19]:
feat_mask

{'SCHEMA': tensor([0.2715, 0.2779, 0.2594, 0.2666, 0.3035, 0.2533, 0.2411, 0.2482, 0.2685,
         0.2813]),
 '_Literal': tensor([0.2530, 0.2757, 0.2860, 0.2599, 0.2508, 0.2748, 0.2750, 0.2758, 0.2875,
         0.2710]),
 'bond': tensor([0.2569, 0.2744, 0.3060, 0.2867, 0.2376, 0.2873, 0.2758, 0.2877, 0.3181,
         0.2730]),
 'd': tensor([0.2866, 0.2375, 0.2769, 0.2677, 0.2374, 0.2500, 0.2763, 0.2418, 0.2812,
         0.2618]),
 'hasStructure': tensor([0.2688, 0.2561, 0.2461, 0.2923, 0.2595, 0.2908, 0.2913, 0.2530, 0.2375,
         0.2512])}

In [20]:
edge_mask

{('SCHEMA', '22-rdf-syntax-ns#type', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'owl#disjointWith', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rdf-schema#domain', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rdf-schema#range', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rdf-schema#subClassOf', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rev-22-rdf-syntax-ns#type', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rev-22-rdf-syntax-ns#type', 'bond'): tensor([]),
 ('SCHEMA', 'rev-22-rdf-syntax-ns#type', 'd'): tensor([0.7980]),
 ('SCHEMA', 'rev-owl#disjointWith', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rev-rdf-schema#domain', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rev-rdf-schema#range', 'SCHEMA'): tensor([]),
 ('SCHEMA', 'rev-rdf-schema#subClassOf', 'SCHEMA'): tensor([]),
 ('_Literal', 'rev-amesTestPositive', 'd'): tensor([0.8443]),
 ('_Literal', 'rev-charge', 'd'): tensor([]),
 ('_Literal', 'rev-chromaberr', 'd'): tensor([]),
 ('_Literal', 'rev-chromex', 'd'): tensor([]),
 ('_Literal', 'rev-cytogen_ca', 'd'): tensor([0.8960]),
 ('_Literal', 'rev-c

In [21]:
g.ndata['label']['d'][9528]

tensor(1)

In [27]:
def createMaskMeanDict(mask):
    maskMean = {}
    for key in mask:
        if list(mask[key]):
            maskMean[key] =  torch.mean(mask[key])
    return maskMean

def getMaxKey(dictionary):
    maxKey = None
    maxValue = 0
    for key in dictionary:
        if dictionary[key] > maxValue:
            maxValue = dictionary[key]
            maxKey = key
    
    return maxKey, maxValue

In [28]:
getMaxKey(createMaskMeanDict(edge_mask))

(('_Literal', 'rev-cytogen_ca', 'd'), tensor(0.8960))