In [1]:
import numpy as np
import pickle as pkl
from Extractor import Extractor
from scipy.sparse import coo_matrix,csr_matrix
#import tensorflow as tf
import torch
from utils import *
import networkx as nx
from matplotlib import pyplot as plt
from model import *
from train import *
import os
from explain import *
import time
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
dataset_name = 'syn3'

class BA_Shape_Dataset(Dataset):
    def __init__(self, root, name, hops=3,transform = None, pre_transform = None, subgraph = False, remap=False):
        super(BA_Shape_Dataset, self).__init__(root, transform, pre_transform)
        self.root = root
        self.subgraph = subgraph
        self.remap = remap
        self.name = name
        self.setting = 1
        with open(os.path.join(self.root, name + '.pkl'), 'rb') as fin:
            self.adj, self.features, self.y_train, self.y_val, self.y_test, self.train_mask, self.val_mask, self.test_mask, self.edge_label_matrix  = pkl.load(fin)
        self.hops = hops

        self.all_label = np.logical_or(self.y_train,np.logical_or(self.y_val,self.y_test))
        self.single_label = np.argmax(self.all_label,axis=-1)
        #self.allnodes = [i for i in range(self.single_label.shape[0]) if self.single_label[i] != 0] #[i for i in range(400,700,5)]
        self.csr_adj = csr_matrix(self.adj)
        self.extractor = Extractor(self.csr_adj,self.features,self.edge_label_matrix,self.all_label,self.hops)            
    @property
    def num_features(self):
        return 10
    @property
    def num_classes(self):
        return self.all_label.shape[1]

    @property 
    def allnodes(self):
        if self.setting==1:
            if self.name=='syn3':
                allnodes = [i for i in range(511,871,6)]
            elif self.name=='syn4':
                allnodes = [i for i in range(511,800,1)]
            else:
                allnodes = [i for i in range(400,700,5)] # setting from their original paper
        elif self.setting==2:
            allnodes = [i for i in range(self.single_label.shape[0]) if self.single_label[i] ==1]
        elif self.setting==3:
            if self.name == 'syn2':
                allnodes = [i for i in range(self.single_label.shape[0]) if self.single_label[i] != 0 and self.single_label[i] != 4]
            else:
                allnodes = [i for i in range(self.single_label.shape[0]) if self.single_label[i] != 0]
        return allnodes

        
    def get_subgraph(self, idx):
        #if self.subgraph:
        idx = self.allnodes[idx] if self.remap else idx
        sub_adj,sub_feature, sub_label,sub_edge_label_matrix = dataset.extractor.subgraph(idx)
        return sub_adj,sub_feature, sub_label,sub_edge_label_matrix
        #else:
        #    return None
    def len(self):
        if self.subgraph:
            if not self.remap:
                return len(self.single_label)
            return len(self.allnodes)
        else:
            return 1
    
    def get(self, idx):
        if self.subgraph:
            if self.remap:
                idx = self.allnodes[idx]
            sub_adj,sub_feature, sub_label,sub_edge_label_matrix = self.extractor.subgraph(idx)
            edge_index = torch.tensor(preprocess_adj(sub_adj)[0].T, dtype = torch.long)
            x = torch.tensor(sub_feature).float()
            y = torch.argmax(torch.tensor(sub_label, dtype = torch.int32), dim=-1)
            data = Data(edge_index = edge_index, x = x, y = y)
        else:

            edge_index = torch.tensor(preprocess_adj(self.adj)[0].T, dtype = torch.long)
            x = torch.tensor(self.features).float()
            y = torch.argmax(torch.tensor(np.logical_or(self.y_train,np.logical_or(self.y_val,self.y_test)), dtype = torch.int32), dim=-1)
            data = Data(edge_index = edge_index, x = x, y = y)
        return data

dataset = BA_Shape_Dataset(root = './dataset', name = 'syn3', hops=3)


load_model = torch.load('./checkpoint/gcn_grid_', map_location=torch.device(device))

acc_list = []
auc_list = []

dataset.subgraph = False
dataset.remap = False
dataset.setting=3

load_model.eval()

data = get_data(dataset, 0)#.to('cuda')
preds = load_model.forward(data)
_, indices = torch.max(preds, 1)

all_node_label = []
all_node_color = []
all_elapsed = []
all_len = []
for idx in dataset.allnodes:
    print('\nindex: ', idx)
    if indices[idx] == 0:
        continue
    sub_adj,sub_feature, sub_label,sub_edge_label_matrix = dataset.get_subgraph(idx)

    truth_node = list(get_node_set(sub_edge_label_matrix))
    
    class_idx = np.argmax(sub_label[0],axis=-1)
    print('0 label: ', class_idx)
    node_range = dataset.extractor.nodes
    print('node range: ', node_range)
    node_sort, node_color, elapsed = print_subgraph_explain(dataset = dataset, model = load_model, idx = 0, class_idx = class_idx, visible = False, figsize = (12,9), node_range = node_range)
    all_elapsed.append(elapsed)
    all_len.append(len(node_color))
    node_label = np.array([0] * sub_label.shape[0])
    for n in truth_node:
        if abs((node_range[n] - node_range[0])) <= 6:
            node_label[n] = 1

    try:
        auc = roc_auc_score(node_label, node_color)
    except:
        continue

    print("truth node: ", truth_node)

    acc = len([node for node in node_sort[:6] if node in truth_node])/6
    acc_list.append(acc)
    auc_list.append(auc)

    print('acc: ', acc)
    print('auc: ', auc)

    print('mean acc: ', np.mean(acc_list))
    print('mean auc: ', np.mean(auc_list))


  return torch._C._cuda_getDeviceCount() > 0


explainer

index:  511
0 label:  1
node range:  [511, 0, 513, 2, 515, 1, 3, 4, 5, 6, 512, 516, 514]
target node:  511
epoch time:  0.11343884468078613
truth node:  [0, 2, 4, 10, 11, 12]
acc:  0.8333333333333334
auc:  0.9285714285714285
mean acc:  0.8333333333333334
mean auc:  0.9285714285714285

index:  512
0 label:  1
node range:  [512, 0, 2, 514, 516, 1, 513, 515, 5, 6, 511]
target node:  512
epoch time:  0.09611892700195312
truth node:  [0, 3, 4, 6, 7, 10]
acc:  0.8333333333333334
auc:  0.9333333333333333
mean acc:  0.8333333333333334
mean auc:  0.930952380952381

index:  513
0 label:  1
node range:  [513, 0, 1, 515, 2, 5, 6, 512, 514, 516, 11, 12, 13, 14, 511]
target node:  513
epoch time:  0.1220548152923584
truth node:  [0, 3, 7, 8, 9, 14]
acc:  0.8333333333333334
auc:  0.9814814814814815
mean acc:  0.8333333333333334
mean auc:  0.9477954144620812

index:  514
0 label:  1
node range:  [514, 512, 0, 2, 516, 5, 6, 513, 515, 511]
target node:  514
epoch time:  0.08436226844787598
tr