In [90]:
import dgl
import dgl.function as fn
import torch
from dgl.nn import GraphConv
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import BAShapeDataset

In [159]:
class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, hidden_feats=20):
        '''
        in_feats: Input features
        out_feats: Output features
        hidden_feats: Hidden layer features
        '''
        super(GCN, self).__init__()
        self.conv1 = GraphConv(in_feats, hidden_feats)
        self.conv2 = GraphConv(hidden_feats, out_feats)
        
    def forward(self, graph, feat, eweight=None):
        with graph.local_scope():
            feat = self.conv1(graph, feat)
            feat = F.relu(feat)
            feat = self.conv2(graph, feat)
            graph.ndata['h'] = feat
            if eweight is None:
                graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
            else:
                graph.edata['w'] = eweight
                graph.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h'))
            return graph.ndata['h']

In [208]:
def train(model, g, epochs=10, printInterval=5, lr=0.001):
    '''
    model: Training Model
    g: Training graph
    epochs: Number of epochs
    printInterval: Interval that data is displayed at
    lr: Learning rate
    '''
    optimizer = torch.optim.Adam(model.parameters(), lr)

    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    test_mask = g.ndata['test_mask']
    
    for epoch in range(epochs):
        model.train()
        logits = model(g, features)
        pred = logits.argmax(1)
        
        loss = F.cross_entropy(logits[train_mask], labels[train_mask])
        
        train_acc = (pred[train_mask] == labels[train_mask]).float().mean()
        test_acc = (pred[test_mask] == labels[test_mask]).float().mean()

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if epoch % printInterval == 0:
            print('In epoch {}, loss: {:.3f}, train acc: {:.3f}, test acc: {:.3f})'.format(
                epoch, loss, train_acc, test_acc))
        
    print('In epoch {}, loss: {:.3f}, train acc: {:.3f}, test acc: {:.3f})'.format(
            epoch, loss, train_acc, test_acc))

In [221]:
def displaySG(model, nodeID):
    # Defines explainer model and trains model to explain node
    features = g.ndata['feat']
    explainer = GNNExplainer(model, num_hops=2, lr=0.001, num_epochs=150)
    new_center, sg, feat_mask, edge_mask = explainer.explain_node(nodeID, g, features)
    
    # Draws sub-graph
    nx_G = sg.to_networkx().to_undirected()
    # Create Node Color Map
    colors = ['azure', 'tan', 'wheat', 'lavender', 'lightskyblue', 'lightsalmon', 'lightgreen', 'thistle', 
              'slateblue', 'roseybrown', 'plum', 'peru', 'pink', 'palegreen', 'olive', 'moccasin', 'mintcream',
              'oldlace', 'linen', 'maroon', 'green', 'yellow', 'blue', 'orange', 'lightyellow', 'lightpink',
              'lavenderblush', 'ivory', 'purple', 'violet', 'lightgray', 'darkgreen', 'darkyellow', 'darkblue',
              'honeydue', 'darkpurple', 'darkorange', 'coral', 'hotpink', 'gainsboro']
    color_map = []
    for i, node in enumerate(nx_G):
        if node == new_center.numpy()[0]:
            color_map.append('gray')
        else:
            color_map.append(colors[sg.ndata['label'].numpy()[i]])
        
    # Draw Graph
    pos = nx.kamada_kawai_layout(nx_G)
    nx.draw_networkx(nx_G, pos, node_color=color_map, with_labels=True)
#     print(colors[sg.ndata['label'].numpy()[new_center.numpy()[0]]])
    
    return feat_mask, edge_mask

In [33]:
data = BAShapeDataset()
g = data[0]
g = dgl.add_self_loop(g)

Done loading data from cached files.


In [173]:
n_nodes = g.num_nodes()

train_mask = torch.zeros((n_nodes), dtype=bool)
test_mask = torch.zeros((n_nodes), dtype=bool)

n_train = int(n_nodes * 0.8)

train_mask[:n_train] = True
test_mask[n_train:] = True

g.ndata['train_mask'] = train_mask
g.ndata['test_mask'] = test_mask

In [213]:
# features = g.ndata['feat']
# model = GCN(features.shape[1], data.num_classes)
# train(model, g, epochs=1000, printInterval=100)
# torch.save(model, 'BAShape_Trained_Model.pt')
model = torch.load('./models/BAShape_Trained_Model.pt')

In [245]:
uniqDict = {}
labels = g.ndata['label'].numpy()

for i in range(len(labels)):
    if not labels[i] in uniqDict:
        uniqDict[labels[i]] = [i]
    else:
        uniqDict[labels[i]].append(i)

In [246]:
uniqDict # Dictionary for node index of each class

{2: [0,
  1,
  2,
  3,
  4,
  5,
  6,
  7,
  8,
  9,
  10,
  11,
  12,
  13,
  14,
  15,
  16,
  17,
  18,
  19,
  20,
  21,
  22,
  23,
  24,
  25,
  26,
  27,
  29,
  31,
  32,
  33,
  34,
  35,
  36,
  37,
  39,
  41,
  43,
  45,
  46,
  47,
  48,
  50,
  52,
  53,
  55,
  57,
  58,
  59,
  60,
  61,
  62,
  64,
  66,
  67,
  68,
  69,
  70,
  71,
  72,
  74,
  75,
  76,
  78,
  81,
  83,
  84,
  85,
  87,
  90,
  93,
  96,
  100,
  102,
  104,
  106,
  108,
  111,
  113,
  115,
  118,
  121,
  123,
  124,
  126,
  129,
  131,
  133,
  136,
  139,
  142,
  144,
  146,
  149,
  152,
  155,
  159,
  162,
  165,
  168,
  171,
  173,
  176,
  181,
  183,
  186,
  190,
  193,
  196,
  199,
  200,
  203,
  205,
  209,
  211,
  213,
  217,
  219,
  221,
  223,
  225,
  228,
  231,
  234,
  236,
  239,
  241,
  243,
  246,
  249,
  253,
  256,
  258,
  261,
  264,
  266,
  271,
  280,
  287,
  295,
  298,
  300,
  304,
  324,
  333,
  343,
  345,
  356,
  367,
  372,
  378,
  387,
  409,
  

In [1]:
displaySG(model, 3) # Each classes sg should look the same if the explainer works properly

NameError: name 'displaySG' is not defined