In [9]:
import dgl
import dgl.function as fn
import torch
import dgl.nn as dglnn
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import AMDataset

In [10]:
# Define a Heterograph Conv model

class RGCN(nn.Module):
    def __init__(self, in_feats, hid_feats, out_feats, rel_names):
        super().__init__()

        self.conv1 = dglnn.HeteroGraphConv({
            rel: dglnn.GraphConv(in_feats, hid_feats)
            for rel in rel_names}, aggregate='sum')
#         self.conv2 = dglnn.HeteroGraphConv({
#             rel: dglnn.GraphConv(hid_feats, hid_feats)
#             for rel in rel_names}, aggregate='sum')
#         self.conv3 = dglnn.HeteroGraphConv({
#             rel: dglnn.GraphConv(hid_feats, hid_feats)
#             for rel in rel_names}, aggregate='sum')
#         self.conv4 = dglnn.HeteroGraphConv({
#             rel: dglnn.GraphConv(hid_feats, out_feats)
#             for rel in rel_names}, aggregate='sum')

    def forward(self, graph, feat, eweight=None):
        # inputs are features of nodes
        with graph.local_scope():
            feat = self.conv1(graph, feat)
            feat = {k: F.relu(v) for k, v in feat.items()}
#             feat = self.conv2(graph, feat)
#             feat = {k: F.relu(v) for k, v in feat.items()}
#             feat = self.conv3(graph, feat)
#             feat = {k: F.relu(v) for k, v in feat.items()}
#             feat = self.conv4(graph, feat)
            graph.ndata['h'] = feat
            if eweight is None:
                graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            else:
                graph.edata['w'] = eweight
                graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
            return graph.ndata['h']

In [11]:
def train(model, hetero_graph, node_features, epochs, printInterval):
    opt = torch.optim.Adam(model.parameters())
    train_mask = g.nodes[category].data['train_mask']
    test_mask = g.nodes[category].data['test_mask']
    labels = g.nodes[category].data['labels']

    for epoch in range(epochs):
        model.train()
        # forward propagation by using all nodes and extracting the user embeddings
        logits = model(hetero_graph, node_features)[category]
        pred = logits.argmax(1)
        # compute loss
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])
        # Compute validation accuracy.  Omitted in this example.
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()
        # backward propagation
        opt.zero_grad()
        loss.backward()
        opt.step()
        
        if epoch % printInterval == 0:
            print('In epoch {}, loss: {:.3f}, train acc: {:.3f}, test acc: {:.3f})'.format(
                epoch, loss,train_acc, test_acc))
    print('In epoch {}, loss: {:.3f}, train acc: {:.3f}, test acc: {:.3f})'.format(
                epoch, loss,train_acc, test_acc))


In [12]:
dataset = AMDataset()
g = dataset[0]

Done loading data from cached files.


In [13]:
num_classes = dataset.num_classes
category = dataset.predict_category

In [14]:
features = {}
for ntype in g.ntypes:
    features[ntype] = torch.zeros((g.num_nodes(ntype), 10))

In [15]:
model = RGCN(10, 20, num_classes, g.etypes)
train(model, g, features, epochs=2, printInterval=1)
torch.save(model, 'AM_Trained_Model.pt')

  loss = F.cross_entropy(logits[train_mask], labels[train_mask])
  train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
  test_acc = (pred[test_mask] == labels[test_mask]).float().mean()


In epoch 0, loss: 2.996, train acc: 0.347, test acc: 0.348)
In epoch 1, loss: 2.996, train acc: 0.347, test acc: 0.348)
In epoch 1, loss: 2.996, train acc: 0.347, test acc: 0.348)


In [17]:
from gnnexplainer import HeteroGNNExplainer

explainer = HeteroGNNExplainer(model, num_hops=1, lr=0.01, num_epochs=150)
new_center, sg, feat_mask, edge_mask = explainer.explain_node(category, 32, g, features)

Explain node 32: 100%|████████████████████████| 150/150 [00:13<00:00, 11.11it/s]


In [26]:
feat_mask

{'TYPE': tensor([0.1979, 0.1968, 0.1828, 0.1847, 0.1761, 0.1910, 0.2116, 0.2108, 0.1597,
         0.2061]),
 '_BNode': tensor([0.1997, 0.1969, 0.2070, 0.1955, 0.1915, 0.1837, 0.1921, 0.2019, 0.1834,
         0.2056]),
 'aggregation': tensor([0.1956, 0.1971, 0.1883, 0.2087, 0.1977, 0.1782, 0.1929, 0.2099, 0.2100,
         0.1964]),
 'p': tensor([0.1985, 0.1913, 0.1896, 0.2207, 0.2088, 0.1855, 0.2107, 0.2033, 0.2044,
         0.2050]),
 'physical': tensor([0.1996, 0.2027, 0.1963, 0.1982, 0.2002, 0.2121, 0.2018, 0.1795, 0.1987,
         0.1821]),
 'proxy': tensor([0.1954, 0.1960, 0.1966, 0.1810, 0.2008, 0.1999, 0.2007, 0.2011, 0.2108,
         0.1814]),
 't': tensor([0.1864, 0.1943, 0.2178, 0.1862, 0.2217, 0.1977, 0.1964, 0.1986, 0.1838,
         0.1957])}

In [27]:
edge_mask

{('TYPE', 'rev-dimensionPart', '_BNode'): tensor([]),
 ('TYPE',
  'rev-http://www_w3_org/1999/02/22-rdf-syntax-ns#type',
  '_BNode'): tensor([]),
 ('TYPE',
  'rev-http://www_w3_org/1999/02/22-rdf-syntax-ns#type',
  'p'): tensor([]),
 ('TYPE',
  'rev-http://www_w3_org/1999/02/22-rdf-syntax-ns#type',
  't'): tensor([]),
 ('TYPE', 'rev-http://www_w3_org/2004/02/skos/core#inScheme', 't'): tensor([]),
 ('TYPE', 'rev-http://www_w3_org/2004/02/skos/core#narrower', 't'): tensor([]),
 ('_BNode', 'alternativeNumberInstitution', 't'): tensor([]),
 ('_BNode', 'currentLocation', 't'): tensor([]),
 ('_BNode', 'currentLocationFitness', 't'): tensor([]),
 ('_BNode', 'dimensionNotes', 't'): tensor([]),
 ('_BNode', 'dimensionPart', 'TYPE'): tensor([]),
 ('_BNode', 'dimensionPart', 't'): tensor([]),
 ('_BNode', 'dimensionType', 't'): tensor([]),
 ('_BNode', 'documentationAuthor', 'p'): tensor([]),
 ('_BNode', 'documentationTitle', 't'): tensor([]),
 ('_BNode', 'exhibitionOrganiser', 'p'): tensor([]),
 ('

In [28]:
sg

Graph(num_nodes={'TYPE': 0, '_BNode': 7, 'aggregation': 1, 'p': 0, 'physical': 1, 'proxy': 1, 't': 6},
      num_edges={('TYPE', 'rev-dimensionPart', '_BNode'): 0, ('TYPE', 'rev-http://www_w3_org/1999/02/22-rdf-syntax-ns#type', '_BNode'): 0, ('TYPE', 'rev-http://www_w3_org/1999/02/22-rdf-syntax-ns#type', 'p'): 0, ('TYPE', 'rev-http://www_w3_org/1999/02/22-rdf-syntax-ns#type', 't'): 0, ('TYPE', 'rev-http://www_w3_org/2004/02/skos/core#inScheme', 't'): 0, ('TYPE', 'rev-http://www_w3_org/2004/02/skos/core#narrower', 't'): 0, ('_BNode', 'alternativeNumberInstitution', 't'): 0, ('_BNode', 'currentLocation', 't'): 0, ('_BNode', 'currentLocationFitness', 't'): 0, ('_BNode', 'dimensionNotes', 't'): 0, ('_BNode', 'dimensionPart', 'TYPE'): 0, ('_BNode', 'dimensionPart', 't'): 0, ('_BNode', 'dimensionType', 't'): 0, ('_BNode', 'documentationAuthor', 'p'): 0, ('_BNode', 'documentationTitle', 't'): 0, ('_BNode', 'exhibitionOrganiser', 'p'): 0, ('_BNode', 'exhibitionVenue', 't'): 0, ('_BNode', 'http