In [1]:
import numpy as np
import pickle as pkl
from Extractor import Extractor
from scipy.sparse import coo_matrix,csr_matrix
#import tensorflow as tf
import torch
from utils import *
import networkx as nx
from matplotlib import pyplot as plt
from model import *
from train import *

from explain import *
import time
from tqdm import tqdm
from sklearn.metrics import roc_auc_score
from torch_geometric.utils import to_networkx

import sys

np.set_printoptions(threshold=sys.maxsize)


In [2]:
class BA_Shape_Dataset(Dataset):
    def __init__(self, root, name, setting=1, hops=3, transform = None, pre_transform = None, subgraph = False, remap=False):
        super(BA_Shape_Dataset, self).__init__(root, transform, pre_transform)
        self.root = root
        self.subgraph = subgraph
        self.remap = remap
        self.name = name
        self.setting =setting
        with open(os.path.join(self.root, name + '.pkl'), 'rb') as fin:
            self.adj, self.features, self.y_train, self.y_val, self.y_test, self.train_mask, self.val_mask, self.test_mask, self.edge_label_matrix  = pkl.load(fin)
        self.hops = hops

        self.all_label = np.logical_or(self.y_train,np.logical_or(self.y_val,self.y_test))
        self.single_label = np.argmax(self.all_label,axis=-1)
        #self.allnodes = [i for i in range(self.single_label.shape[0]) if self.single_label[i] != 0] #[i for i in range(400,700,5)]
        self.csr_adj = csr_matrix(self.adj)
        self.extractor = Extractor(self.csr_adj,self.features,self.edge_label_matrix,self.all_label,self.hops)            
    @property
    def num_features(self):
        return 10
    @property
    def num_classes(self):
        return self.all_label.shape[1]

    @property 
    def allnodes(self):
        if self.setting==1:
            if self.name=='syn3':
                allnodes = [i for i in range(511,871,6)]
            elif self.name=='syn4':
                allnodes = [i for i in range(511,800,1)]
            else:
                allnodes = [i for i in range(400,700,5)] # setting from their original paper
        elif self.setting==2:
            allnodes = [i for i in range(self.single_label.shape[0]) if self.single_label[i] ==1]
        elif self.setting==3:
            if self.name == 'syn2':
                allnodes = [i for i in range(self.single_label.shape[0]) if self.single_label[i] != 0 and self.single_label[i] != 4]
            else:
                allnodes = [i for i in range(self.single_label.shape[0]) if self.single_label[i] != 0]
        return allnodes

    def set_hops(self, hops=3):
        self.hops = hops
        self.extractor = Extractor(self.csr_adj,self.features,self.edge_label_matrix,self.all_label,self.hops)  

    def get_subgraph(self, idx):
        #if self.subgraph:
        idx = self.allnodes[idx] if self.remap else idx
        sub_adj,sub_feature, sub_label,sub_edge_label_matrix = dataset.extractor.subgraph(idx)
        return sub_adj,sub_feature, sub_label,sub_edge_label_matrix
        #else:
        #    return None
    def len(self):
        if self.subgraph:
            if not self.remap:
                return len(self.single_label)
            return len(self.allnodes)
        else:
            return 1
    
    def get(self, idx):
        if self.subgraph:
            if self.remap:
                idx = self.allnodes[idx]
            sub_adj,sub_feature, sub_label,sub_edge_label_matrix = self.extractor.subgraph(idx)
            edge_index = torch.tensor(preprocess_adj(sub_adj)[0].T, dtype = torch.long)
            x = torch.tensor(sub_feature).float()
            y = torch.argmax(torch.tensor(sub_label, dtype = torch.int32), dim=-1)
            data = Data(edge_index = edge_index, x = x, y = y)
        else:

            edge_index = torch.tensor(preprocess_adj(self.adj)[0].T, dtype = torch.long)
            x = torch.tensor(self.features).float()
            y = torch.argmax(torch.tensor(np.logical_or(self.y_train,np.logical_or(self.y_val,self.y_test)), dtype = torch.int32), dim=-1)
            data = Data(edge_index = edge_index, x = x, y = y)
        return data

In [3]:
dataset = BA_Shape_Dataset(root = './dataset', name = 'syn2', setting = 3)
#dataset.allnodes

explainer


In [16]:
gcn_model = GCN_Node(dataset, 3, 20)
dataset.set_hops(3)

explainer


In [17]:
dataset.setting = 3
dataset.subgraph = False
dataset.remap = True
#train(gcn_model, dataset, max_epoch=1000, lr=0.005, train_batch_size = 1024, temp_name = 'community_temp')
node_pred_task_train(gcn_model, dataset, max_epoch=1000, lr=0.005, temp_name='community_temp', train_rate = 0.70, save_model = True)

prepare dataloader
done
Epoch:  0 Avg loss:  2.0779347 ; acc:  0.1072523 ; epoch time:  0.3596928119659424
eval test...
test acc:  0.13064134
saving...


Epoch:  1 Avg loss:  2.067303 ; acc:  0.1072523 ; epoch time:  0.3488128185272217
eval test...
test acc:  0.13064134


Epoch:  2 Avg loss:  2.0568254 ; acc:  0.1072523 ; epoch time:  0.3212876319885254
eval test...
test acc:  0.13064134


Epoch:  3 Avg loss:  2.0463288 ; acc:  0.1072523 ; epoch time:  0.33660125732421875
eval test...
test acc:  0.13064134


Epoch:  4 Avg loss:  2.0358975 ; acc:  0.30439225 ; epoch time:  0.35362863540649414
eval test...
test acc:  0.34916866
saving...


Epoch:  5 Avg loss:  2.0248094 ; acc:  0.44739532 ; epoch time:  0.34763598442077637
eval test...
test acc:  0.47981
saving...


Epoch:  6 Avg loss:  2.0126474 ; acc:  0.43513793 ; epoch time:  0.3192760944366455
eval test...
test acc:  0.46080762


Epoch:  7 Avg loss:  1.9990535 ; acc:  0.42798775 ; epoch time:  0.3345980644226074
eval test...
test ac

In [18]:
gcn_model = torch.load('./checkpoint/community_temp')

In [16]:
dataset.subgraph = True
dataset.remap = True
train(gcn_model, dataset, max_epoch=200, lr=0.005, train_batch_size = 1024)

prepare dataloader
done
Epoch:  0 Avg loss:  2.9537191 ; acc:  0.7151807 ; epoch time:  2.35606050491333
Epoch:  1 Avg loss:  1.6656601 ; acc:  0.7950278 ; epoch time:  2.291776657104492
Epoch:  2 Avg loss:  1.4816986 ; acc:  0.8152323 ; epoch time:  2.270766258239746
Epoch:  3 Avg loss:  1.7021397 ; acc:  0.8165064 ; epoch time:  2.34914493560791
Epoch:  4 Avg loss:  1.85531 ; acc:  0.8166733 ; epoch time:  2.308025598526001
Epoch:  5 Avg loss:  1.8359437 ; acc:  0.8167643 ; epoch time:  2.39426589012146
Epoch:  6 Avg loss:  1.6931294 ; acc:  0.81711316 ; epoch time:  2.3895483016967773
Epoch:  7 Avg loss:  1.5028516 ; acc:  0.81863 ; epoch time:  2.3883438110351562
Epoch:  8 Avg loss:  1.3349968 ; acc:  0.818812 ; epoch time:  2.5111613273620605
Epoch:  9 Avg loss:  1.2339685 ; acc:  0.8165216 ; epoch time:  2.422281503677368
Epoch:  10 Avg loss:  1.2117757 ; acc:  0.8129418 ; epoch time:  2.4183578491210938
Epoch:  11 Avg loss:  1.2392082 ; acc:  0.80783004 ; epoch time:  2.42352867

In [22]:
dataset.setting = 3
dataset.subgraph = False
dataset.remap = False
#gcn_model = torch.load('./checkpoint/gcn_com_sub')
eval(gcn_model, dataset)

Epoch:  0 Avg loss:  0.71721673 ; acc:  0.9121429 ; epoch time:  0.020233631134033203


array(0.9121429, dtype=float32)

In [80]:
#torch.save(gcn_model, './checkpoint/')

In [7]:
correct = 0
dataset.setting = 3
dataset.subgraph = True
dataset.remap = True
#load_model = torch.load('./checkpoint/gcn_mix').to('cuda')
load_model = gcn_model
load_model.eval()
in_correct = []
for idx in range(dataset.len()):
    print('index: ', idx)
    data = get_data(dataset, idx).to('cuda')
    pred = load_model.forward(data)
    #print(preds)
    label = data.y.to('cuda')

    _, indices = torch.max(pred, 1)
    #print(pred)
    #print(label)
    #print(torch.sum(label))
    #print(indices)
    #print(torch.argmax(load_model(data)[0]))
    if (indices[0] == label[0]):
        correct += 1
    else:
        in_correct.append(idx)
print(correct)
print(in_correct)

index:  0
index:  1
index:  2
index:  3
index:  4
index:  5
index:  6
index:  7
index:  8
index:  9
index:  10
index:  11
index:  12
index:  13
index:  14
index:  15
index:  16
index:  17
index:  18
index:  19
index:  20
index:  21
index:  22
index:  23
index:  24
index:  25
index:  26
index:  27
index:  28
index:  29
index:  30
index:  31
index:  32
index:  33
index:  34
index:  35
index:  36
index:  37
index:  38
index:  39
index:  40
index:  41
index:  42
index:  43
index:  44
index:  45
index:  46
index:  47
index:  48


KeyboardInterrupt: 

In [7]:
acc_list = []
auc_list = []
clust_list = []
edge_auc_list = []



dataset.subgraph = True
dataset.remap = True
dataset.setting=3
dataset.set_hops(3)
load_model = gcn_model
#load_model = torch.load('./checkpoint/gcn_com_temp').to('cuda')
load_model.eval()

all_node_label = []
all_node_color = []

#for idx in range(dataset.len()):
for idx in range(len(dataset)):
    #if idx in in_correct:
    #    continue
    #idx=15
    print('\nindex: ', idx)
    sub_adj,sub_feature, sub_label,sub_edge_label_matrix = dataset.get_subgraph(idx)
    #truth_node = np.where(sub_label[:,1] == True)[0]
    truth_node = get_node_set(sub_edge_label_matrix)
    #if len(truth_node) > 6:
    #    continue
    print('0 label: ', np.argmax(sub_label[0],axis=-1))
    node_sort, node_color = print_explain(dataset, load_model, idx, class_idx = np.argmax(sub_label[0],axis=-1), visible = False)



    node_label = np.array([0] * sub_label.shape[0])
    node_label[list(truth_node)] = 1
    #pred  = np.array([0] * sub_label.shape[0])
    #pred[node_sort[:6]] = 1
    
    #edge_label = []
    #edge_pred = []
    #for r,c in list(zip(sub_adj.row,sub_adj.col)):
    #    sub_edge_label = sub_edge_label_matrix.todense()
    #    edge_label.append(sub_edge_label[r,c] or sub_edge_label[c,r])
    #    edge_pred.append((node_color[r] + node_color[c])/2)
    #print(edge_label)
    try:
        auc = roc_auc_score(node_label, node_color)
    except:
        print('foo')
        auc = 1.0

    #print(truth_node)
    #print(node_sort)
    acc = len([node for node in node_sort[:5] if node in truth_node])/5
    acc_list.append(acc)
    auc_list.append(auc)



    print('acc: ', acc)
    print('auc: ', auc)
    #if auc == 0.0:
    #    print_explain(dataset, load_model, idx, class_idx = 0, visible = True)
    print('mean acc: ', np.mean(acc_list))
    print('mean auc: ', np.mean(auc_list))
    #print('mean clust: ', np.mean(clust_list))





explainer

index:  0
0 label:  1
epoch time:  1.5644793510437012
acc:  1.0
auc:  0.7905270655270656
mean acc:  1.0
mean auc:  0.7905270655270656

index:  1
0 label:  1
epoch time:  0.586745023727417
acc:  1.0
auc:  0.7011494252873562
mean acc:  1.0
mean auc:  0.745838245407211

index:  2
0 label:  2
epoch time:  0.10940814018249512
acc:  1.0
auc:  1.0
mean acc:  1.0
mean auc:  0.830558830271474

index:  3
0 label:  2
epoch time:  0.5214858055114746
acc:  1.0
auc:  0.6360225140712945
mean acc:  1.0
mean auc:  0.7819247512214291

index:  4
0 label:  3
epoch time:  0.20128726959228516
acc:  0.8
auc:  0.697089947089947
mean acc:  0.96
mean auc:  0.7649577903951327

index:  5
0 label:  1
epoch time:  0.9581899642944336
acc:  1.0
auc:  1.0
mean acc:  0.9666666666666667
mean auc:  0.8041314919959439

index:  6
0 label:  1
epoch time:  0.16858673095703125
acc:  1.0
auc:  1.0
mean acc:  0.9714285714285714
mean auc:  0.8321127074250948

index:  7
0 label:  2
epoch time:  0.05781245231628418
acc:

KeyboardInterrupt: 

In [12]:
acc_list = []
auc_list = []
clust_list = []

node_num_list = []
dataset.subgraph = True
dataset.remap = True
dataset.setting=3
dataset.set_hops(3)
#load_model = gcn_model
load_model = torch.load('./checkpoint/gcn_com_sub').to('cuda')
load_model.eval()

all_node_label = []
all_node_color = []
#for idx in range(dataset.len()):
for idx in range(800):
    #if idx in in_correct:
    #    continue
    idx=15
    print('\nindex:======================================= ', idx)
    sub_adj,sub_feature, sub_label,sub_edge_label_matrix = dataset.get_subgraph(idx)
    #truth_node = np.where(sub_label[:,1] == True)[0]
    truth_node = get_node_set(sub_edge_label_matrix)
    #if len(truth_node) > 6:
    #    continue
    class_index = np.argmax(sub_label[0],axis=-1)
    print('0 label: ', class_index)
    #node_sort, node_color = print_explain(dataset, load_model, idx, class_idx = class_index, visible = False)

    #_, _, adj, edge_pred = find_edge(load_model, dataset, idx, class_idx = None, node_sort = node_sort, topk = 5, start_num=1)
    #print(_)

    edge_dense = sub_edge_label_matrix.todense()
    edge_list = preprocess_adj(sub_edge_label_matrix)[0]
    edge_label = []
    for r,c in list(zip(sub_adj.row,sub_adj.col)):
        edge_label.append(edge_dense[r,c] or edge_dense[c,r])



    edge_score = Edge_explain(load_model, dataset, idx = idx, edge_list = edge_list)
    edge_colors = []
    for i in range(len(edge_score[0]['rel'])):
        edge_colors.append( edge_score[class_index]['rel'][i] )#- edge_score[0]['rel'][i] )
    print('edge sort: ', np.argsort(edge_colors)[::-1])
    #print(edge_list[np.argsort(edge_colors)[::-1]])
    try:
        auc = roc_auc_score(edge_label,edge_colors)
    except:
        auc = 1.0
    print('auc: ', auc)
    auc_list.append(auc)

    print('auc mean: ', np.mean(auc_list))
    break

explainer

0 label:  1
edge sort:  [2722 2724 2719 2721    0 2725 2313 2716    2 2718 2726 2720 2727 2754
 3635 2717 2752 2748 2983 3879 1803  890 1298 2045 3474 3006 1754  666
 2394 2340 3010 3570 2823 3213 2300 1773 2547 1648 1398  745 2656 2142
 3969 3516 2001 2395 2393 2171 2038 2414 2481  285 2081 1949 3220 3266
   77 1604 2216 2461 1213 2435 3812 4297 1659 2651  940 2127 1960  938
 3231 2998 2278 1463 2338 1458  592 1984 1400  877  751 2176 1477 2102
  742 2234 1174 2346  758 2609 2434 2545 1651  544 1459 1739 1276 1219
  946 1536 1740 1986 2999 3275 2203 1868 2240 1427 1016 2204  533  746
  664 1571  733 2269 4036 1306 1447 2996 2408 1601  869  543 2979 3785
 2220  916 3596 3981 2985 3941 4123 3728  681 2222 3476 4137 1562 1377
 2415 2287 2025 2044 1323 1535  298  549 1520  736 4388 3212 3216 4008
 3230 2769 4050 3194  651 2049 1010  536  959 2039 3584 3254  744 2270
 2175 2710 2768 2986 1658 1966 4382 4104 1695  747 3946 2471 2007 4133
  662 1474  684 2446 4300 3901 3200 4142 2

In [21]:
acc_list = []
auc_list = []

dataset.subgraph = False
dataset.remap = False
dataset.setting=1
load_model = gcn_model
#load_model = torch.load('./checkpoint/community_temp')
#load_model = torch.load('./checkpoint/gcn_mix').to('cuda')
load_model.eval()

all_node_label = []
all_node_color = []
for idx in dataset.allnodes:
    #idx = 313
    print('\nindex: ', idx)
    sub_adj,sub_feature, sub_label,sub_edge_label_matrix = dataset.get_subgraph(idx)
    #truth_node = np.where(sub_label[:,1] == True)[0]
    truth_node = list(get_node_set(sub_edge_label_matrix))
    
    class_idx = np.argmax(sub_label[0],axis=-1)
    print('0 label: ', class_idx)
    node_range = dataset.extractor.nodes
    print('node range: ', node_range)
    node_sort, node_color = print_subgraph_explain(dataset = dataset, model = load_model, idx = 0, class_idx = class_idx, visible = False, figsize = (12,9), node_range = node_range)
    print(node_sort)

    node_label = np.array([0] * sub_label.shape[0])
    #node_label[list(truth_node)] = 1
    # find truth node, far node is not real truth
    for n in truth_node:
        if abs((node_range[n] - node_range[0])) <= 8:
            node_label[n] = 1
            print(n)

    try:
        auc = roc_auc_score(node_label, node_color)
    except:
        print('foo')
        auc = 1.0

    print("truth node: ", truth_node)
    #print(node_sort)
    acc = len([node for node in node_sort[:5] if node in truth_node])/5
    acc_list.append(acc)
    auc_list.append(auc)
    #all_node_label.extend(node_label)
    #all_node_color.extend(node_color)
    print('acc: ', acc)
    print('auc: ', auc)
    #if acc == 0.0:
    #    print(node_sort)
    #    print_explain(dataset, load_model, idx, class_idx = np.argmax(sub_label[0],axis=-1), visible = True)
    print('mean acc: ', np.mean(acc_list))
    print('mean auc: ', np.mean(auc_list))
    #break


index:  400
0 label:  1
node range:  [400, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 21, 22, 23, 24, 25, 26, 27, 29, 30, 31, 32, 34, 35, 37, 38, 42, 43, 44, 45, 46, 47, 49, 50, 52, 54, 55, 56, 57, 58, 60, 61, 65, 68, 69, 71, 72, 73, 75, 76, 78, 81, 84, 85, 87, 88, 89, 91, 92, 93, 94, 95, 96, 97, 98, 100, 101, 105, 109, 111, 112, 113, 114, 115, 116, 118, 122, 123, 128, 129, 130, 135, 136, 137, 140, 141, 143, 144, 145, 148, 150, 151, 153, 154, 155, 157, 159, 161, 163, 164, 167, 170, 173, 174, 177, 178, 181, 183, 185, 189, 190, 191, 194, 197, 199, 200, 201, 202, 207, 216, 217, 223, 224, 229, 231, 233, 235, 240, 241, 243, 244, 245, 246, 253, 262, 266, 269, 271, 272, 279, 281, 284, 285, 293, 300, 325, 330, 401, 402, 403, 404, 415, 425, 435, 455, 505, 569, 700, 701, 702, 703, 704, 705, 706, 707, 708, 709, 710, 711, 712, 713, 714, 715, 717, 718, 719, 720, 722, 724, 725, 726, 729, 730, 731, 732, 733, 735, 737, 742, 743, 744, 745, 746, 749, 751, 753, 754, 755, 756, 

KeyboardInterrupt: 