# DGL Blackbox

In [4]:
import math
import pickle

import dgl
from dgl.nn.pytorch import GraphConv
import numpy as np
import torch
import torch.nn as nn
from torch.nn.parameter import Parameter
import torch.nn.functional as F

## Graph Convolution Layer

In [None]:
class GraphConvolution(nn.Module):
    """
    Simple GCN layer, similar to https://arxiv.org/abs/1609.02907
    """

    def __init__(self, in_features, out_features, bias=True):
        super(GraphConvolution, self).__init__()
        self.in_features = in_features
        self.out_features = out_features
        self.weight = Parameter(torch.FloatTensor(in_features, out_features))
        if bias:
            self.bias = Parameter(torch.FloatTensor(out_features))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def forward(self, input, adj):
        support = torch.mm(input, self.weight)
        output = torch.spmm(adj, support)
        if self.bias is not None:
            return output + self.bias
        else:
            return output

    def __repr__(self):
        return self.__class__.__name__ + ' (' \
               + str(self.in_features) + ' -> ' \
               + str(self.out_features) + ')'

## GCN

In [None]:
class GCN(torch.nn.Module):
    def __init__(self, nfeat, nhid, nout, nclass, dropout, device, if_exp=False):
        super(GCN, self).__init__()
        self.gc1 = GraphConvolution(nfeat, nhid)
        self.gc2 = GraphConvolution(nhid, nhid)
        self.gc3 = GraphConvolution(nhid, nout)
        self.lin = nn.Linear(nhid + nhid + nout, nclass)
        self.dropout = dropout
        self.if_exp = if_exp
        self.device = device

    def forward(self, g, in_feat, e_weight, target_node):
        # "g" is a DGL graph object.
        # map target node index
        x = torch.cat((
            torch.tensor([0]).to(self.device),
            torch.cumsum(g.batch_num_nodes(), dim=0),
        ), dim=0)[:-1]
        target_node = target_node + x
        h = self.conv1(g, in_feat, e_weight)
        h = torch.nn.functional.relu(h)
        h = self.conv2(g, h, e_weight)
        h = torch.nn.functional.relu(h)
        h = self.conv3(g, h, e_weight)
        if self.if_exp:  # if in the explanation mod, should add softmax layer
            h = torch.nn.functional.softmax(h)
        g.ndata['h'] = h
        return g.ndata['h'][target_node]

## Load trained weights

In [6]:
DATASET = 'syn1'

# Import the dataset used by CFGNNExplainer. 
with open(f"../../data/{DATASET}.pickle", "rb") as file:
	data = pickle.load(file)

adj = torch.Tensor(data["adj"]).squeeze() # Does not include self loops.
features = torch.Tensor(data["feat"]).squeeze()
labels = torch.tensor(data["labels"]).squeeze()
idx_train = torch.tensor(data["train_idx"])

In [7]:
args = {
    'nfeat': features.shape[1],
	'nhid': 20,
	'nout': 20,
	'nclass': len(labels.unique()),
	'dropout': 0.0,
}

In [8]:
# Set up original model, get predictions
model = GCNSynthetic(**args)
model.load_state_dict(
    torch.load(f"../../models/gcn_3layer_{DATASET}.pt")
)

<All keys matched successfully>

In [9]:
print(model)

GCNSynthetic(
  (gc1): GraphConvolution (10 -> 20)
  (gc2): GraphConvolution (20 -> 20)
  (gc3): GraphConvolution (20 -> 20)
  (lin): Linear(in_features=60, out_features=4, bias=True)
)


## Predict

In [10]:
def get_degree_matrix(adj):
	return torch.diag(sum(adj))

In [11]:
def normalize_adj(adj):
	# Normalize adjacancy matrix according to reparam trick in GCN paper
	A_tilde = adj + torch.eye(adj.shape[0])
	D_tilde = get_degree_matrix(A_tilde)
	# Raise to power -1/2, set all infs to 0s
	D_tilde_exp = D_tilde ** (-1 / 2)
	D_tilde_exp[torch.isinf(D_tilde_exp)] = 0

	# Create norm_adj = (D + I)^(-1/2) * (A + I) * (D + I) ^(-1/2)
	norm_adj = torch.mm(torch.mm(D_tilde_exp, A_tilde), D_tilde_exp)
	return norm_adj

In [12]:
norm_adj = normalize_adj(adj)
output = model(features, norm_adj)
y_pred_orig = torch.argmax(output, dim=1)

In [21]:
np.unique(y_pred_orig, return_counts=True)

(array([0, 1, 2, 3]), array([300, 166, 144,  90]))