In [1]:
import argparse
from cmath import log
import os
import pickle
import sys
import numpy as np
import torch
from torch.nn import BCELoss
from torch.optim import Adam
from tqdm.notebook import tqdm
from data_loader.dataset import DataSet
from modules.model import DevignModel, GGNNSum

In [2]:
import warnings
warnings.filterwarnings("ignore")

In [3]:
# change to proceed dataset dir as the input dataset for interpretation
input_dir = 'reveal_model_data/msr_data/ros_4x/'
processed_data_path = os.path.join(input_dir, 'processed.bin')
dataset = pickle.load(open(processed_data_path, 'rb'))

In [4]:
dataset.batch_size = 1

In [5]:
dataset.initialize_test_batch()

27726

In [6]:
model = GGNNSum(input_dim=dataset.feature_size, output_dim=200,num_steps=6, max_edge_types=dataset.max_edge_type)

In [7]:
from torch import nn
from data_loader.batch_graph import GGNNBatchGraph
import copy

In [8]:
# load the trained-model to interpretation
model.load_state_dict(torch.load('msr_result/ggnn_model/msr_4x/0/Model_ep_49.bin'))

<All keys matched successfully>

In [9]:
model.to('cuda:0')

GGNNSum(
  (ggnn): GatedGraphConv(
    (linears): ModuleList(
      (0): Linear(in_features=200, out_features=200, bias=True)
      (1): Linear(in_features=200, out_features=200, bias=True)
      (2): Linear(in_features=200, out_features=200, bias=True)
      (3): Linear(in_features=200, out_features=200, bias=True)
    )
    (gru): GRUCell(200, 200)
  )
  (classifier): Linear(in_features=200, out_features=1, bias=True)
  (sigmoid): Sigmoid()
)

In [10]:
import numpy

In [11]:
# use when explain model with Sampling_R
class GGNNSum_single(nn.Module):
    def __init__(self, GGNNSum):
        super(GGNNSum_single, self).__init__()
        self.net = GGNNSum

    def forward(self, graph, feat, eweight=None):
        batch_graph = GGNNBatchGraph()
        batch_graph.add_subgraph(copy.deepcopy(graph))
        outputs = self.net(batch_graph,device='cuda:0')
        return torch.tensor([[1-outputs, outputs]])


In [12]:
# use when explain model with Sampling_L
class GGNNSum_latent(nn.Module):
    def __init__(self, GGNNSum,skMLP):
        super(GGNNSum_latent, self).__init__()
        self.net = GGNNSum
        self.clf = skMLP
        
    def forward(self,graph,feat,eweight=None):
        device = 'cuda:0'
        batch_graph = GGNNBatchGraph()
        batch_graph.add_subgraph(copy.deepcopy(graph))
        graph, features, edge_types = batch_graph.get_network_inputs(cuda=True,device=device)
        graph = graph.to(device)
        features = features.to(device)
        edge_types = edge_types.to(device)
        outputs = self.net.ggnn(graph, features, edge_types)
        h_i, _ = batch_graph.de_batchify_graphs(outputs)
        digit = h_i.sum(dim=1).cpu().detach().numpy()
        clf_output = self.clf.predict_proba(digit)
        del graph,edge_types,features
        return torch.tensor(clf_output)

In [13]:
# use when explain model with Sampling_L, load in the classifier you trained with sampling_L
# clf = pickle.load(open('msr_result/backbone_ggnn/smote/sk_model.pkl', 'rb'))

In [14]:
# switch between sampling_L and R
exp_model = GGNNSum_single(model)
# exp_model = GGNNSum_latent(model,clf)

In [15]:
from dgl.nn.pytorch.explain import GNNExplainer

In [16]:
gnnexplainer = GNNExplainer(exp_model,num_hops=1,log =False)

In [17]:
TP_explaination_dict = {}
total_test_item = dataset.initialize_test_batch()
for index in tqdm(range(total_test_item)):
    target = dataset.test_examples[index].target
    if target == 1:
        graph = dataset.test_examples[index].graph
        if graph.num_edges() > 10 and graph.num_nodes() > 10:
            features = graph.ndata['features']
            pred = exp_model(graph,features)
#             print(pred)
#             break
            if pred[0][1] > 0.5:
#                 print(index,'tp')
                _ ,edge_mask = gnnexplainer.explain_graph(graph=graph,feat=features)
                top_10 = np.argpartition(edge_mask.numpy(), -10)[-10:]
                node_list = []
                for x in top_10:
                    node_1,node_2 = graph.find_edges(x)
                    node_list.append(node_1.numpy()[0])
                    node_list.append(node_2.numpy()[0])
                TP_explaination_dict[index] = node_list

  0%|          | 0/27726 [00:00<?, ?it/s]

In [18]:
len(TP_explaination_dict)

411

In [19]:
total_test_item

27726

In [20]:
len(TP_explaination_dict)

411

In [21]:
# save the explaination results for further analysis
import pickle
with open('gnnexplainer_result/msr_4x_split_0_hop_1.pkl', 'wb') as fp:
    pickle.dump(TP_explaination_dict, fp)