In [None]:
import os
import numpy as np
import gzip
import pickle

import torch

import import_ipynb
from constants import *
from simulation import simul
from utils import *

from algorithms.centrality import outdegree, betweenness, pagerank
from algorithms.BPM import BPM, greedy_BPM
from algorithms.KED import KED
from algorithms.MDS import MDS
from algorithms.greedy import greedy_orig, greedy_GNN
from algorithms.random import random_edge
from algorithms.pyg_expl import pyg_expl, pyg_expl_pretrain

In [None]:
def save_mask(save_dir, file_num, mask, seed_size, infl_orig, infl_masked, time_alg):
    with open(save_dir+f'/mask_{file_num}.txt','w') as f:
        f.write('seed infl_orig infl_masked reldec%\n')
        f.write(f'{seed_size} {infl_orig} {infl_masked} {(infl_orig-infl_masked)/(infl_orig-seed_size)}\n')
        f.write('u v p time%\n')
        for i in range(len(mask)):
            u,v,p = mask[i]
            t = time_alg[i]
            f.write(f'{int(u)} {int(v)} {p} {t}\n')


def save_record(save_dir, record, time_pretrain):
    
    avgtime = np.mean(record[:,3])
    avgreldec = np.mean((record[:,1]-record[:,2])/(record[:,1]-record[:,0]))
    with open(save_dir+'/summary.txt','w') as f:
        f.write(f'avgreldec%: {avgreldec}, avgtime: {avgtime}\n')
        f.write(f'pretrain time: {time_pretrain}\n')
        f.write('seed infl_orig infl_masked reldec% time\n')
        for seed_size, infl_orig, infl_masked, time_alg in record: f.write(f'{seed_size} {infl_orig} {infl_masked} {(infl_orig-infl_masked)/(infl_orig-seed_size)} {time_alg}\n')
            

def pipeline(alg_name, dataset_name, del_edge_num, save=True, save_tag=None, data_num=None, adv_log_save=False, **alg_kwags):
    
    if save:
        dataset_tag = {'wiki_indeg-50.pkl.gz':'wiki', 'Extended_test_LP-50.pkl.gz':'E', 'Celebrity_test_LP-50.pkl.gz':'C', 'WannaCry_test_LP-50.pkl.gz':'W'}[dataset_name]
        save_dir = RESULT_DIR+dataset_tag+'/'+str(del_edge_num)+'/'+alg_name
        if save_tag : save_dir += save_tag
        if not os.path.exists(save_dir): os.makedirs(save_dir)

        if adv_log_save:
            adv_log_dir = save_dir+'/log/'
            if not os.path.exists(adv_log_dir): os.makedirs(adv_log_dir)
    torch.cuda.empty_cache()

    
    basic_algs = {'BPM':BPM, 'KED':KED, 'MDS':MDS, 'greedy':greedy_orig, 'greedy_BPM':greedy_BPM, 'greedy_GNN':greedy_GNN, 'random':random_edge, 'outdegree':outdegree, 'betweenness':betweenness, 'pagerank':pagerank}
    pyg_algs = ['GNNExplainer', 'Saliency', 'InputXGradient', 'Deconvolution', 'GuidedBackprop', 'IntegratedGradients','GNNExplainer_variant', 'PGExplainer', 'PGExplainer_variant', 'PGExplainer_onlyconc', 'GNNExplainer_variant_initsaliency', 'PGExplainer_onlyconc_initsaliency']
    pyg_pretrain_algs = ['PGExplainer', 'PGExplainer_variant']
    
    
    time_pretrain = 0
    if alg_name in basic_algs:
        is_expl = False
        algorithm = basic_algs[alg_name]
    elif alg_name in pyg_algs:
        is_expl = True
        algorithm = pyg_expl
        if alg_name in pyg_pretrain_algs: alg_name, time_pretrain = pyg_expl_pretrain(alg_name, dataset_name, del_edge_num, **alg_kwags)
    else: raise Exception("not supported algorithm name")
    
    
    graph_name = dataset_name.split('-')[0]
    n,m,adj_list = txt2adj(graph_name)
    with gzip.open(DATASET_DIR+dataset_name, 'rb') as f: rawdata = pickle.load(f)
    if data_num==None: data_num = len(rawdata)
    if del_edge_num=='all': del_edge_num = m

    record = []
    print('seed\tinfl_orig\tinfl_mask\treldec%\t\ttime')
    for file_num, (is_seed, probs) in enumerate(rawdata[:data_num]):  
        seed_idx = np.where(is_seed==1)[0]

        adv_log_filename = adv_log_dir+f'{file_num}.txt' if adv_log_save else None

        
        if is_expl: mask, time_alg = algorithm(alg_name, adj_list, seed_idx, probs, del_edge_num, log_filename=adv_log_filename, **alg_kwags)
        else: mask, time_alg = algorithm(adj_list, seed_idx, del_edge_num, **alg_kwags)

        
        if not isinstance(time_alg,list) : time_alg = [time_alg]*del_edge_num
        elif len(time_alg)==1: time_alg = time_alg*del_edge_num

        
        adj_mat_masked = list2mat(adj_list)
        for u,v,p in mask: adj_mat_masked[int(u)][int(v)]=0
        adj_list_masked = mat2list(adj_mat_masked)

        
        infl_orig = sum(probs)
        infl_masked = sum(simul(adj_list_masked,seed_idx))
        print(f'{len(seed_idx)}\t{infl_orig :.2f}\t\t{infl_masked :.2f}\t\t{(infl_orig-infl_masked)/(infl_orig-len(seed_idx)) :.4f}\t\t{time_alg[-1]:.4f}')
        record.append([len(seed_idx),infl_orig,infl_masked,time_alg[-1]])

        
        if save: save_mask(save_dir, file_num, mask, len(seed_idx), infl_orig, infl_masked, time_alg)

    record = np.array(record)
    avgtime = np.mean(record[:,3])
    avgreldec = np.mean((record[:,1]-record[:,2])/(record[:,1]-record[:,0]))
    print(f'avgreldec%: {avgreldec:.4f}, avgtime: {avgtime:.4f}')
    
    
    if save: save_record(save_dir, record, time_pretrain)
    return avgreldec, record, time_pretrain

In [None]:
# example
pipeline("greedy_GNN", 'wiki_indeg-50.pkl.gz', del_edge_num=4, save_tag='_6GCN', model_name='wiki_1000_6GCN.pt', gnn_latent_dim=[128,128,128,128,128,128])