In [2]:
import numpy as np
%load_ext autoreload
%autoreload 2
from load_data import load_data
import torch
from modules import GNN
from train_model import train_model
from subgraph_relevance import subgraph_original, subgraph_mp_transcription, subgraph_mp_forward_hook, get_H_transform
from utils import create_ground_truth, get_feat_order_local_best_guess, get_auac_aupc, get_stats
import time
from tqdm import tqdm
import matplotlib.pyplot as plt
import sys
import pandas as pd
from io import StringIO
import pickle as pkl

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


### Efficiency

As a function of model layers

In [3]:
graphs, pos_idx, neg_idx = load_data('BA-2motif')

model_dirs = ['gin-2-ba2motif.torch',
            'gin-3-ba2motif.torch',
            'gin-4-ba2motif.torch',
            'gin-5-ba2motif.torch',
            'gin-6-ba2motif.torch',
            'gin-7-ba2motif.torch']

g = graphs[44]
S = [0,1,2,3]
alpha = 0.
verbose = False
num_samples = 50
sample_idx = np.random.choice(len(graphs),num_samples,replace=False)

model_times = []

nn = torch.load('models/'+model_dirs[1])



num_graphs: 1000


In [4]:
def softmin(b): return -0.5*torch.log(1.0+torch.exp(-2*b))


## Runtime experiments

### L dependency

In [None]:
graphs, pos_idx, neg_idx = load_data('BA-2motif')

model_dirs = ['gin-2-ba2motif.torch',
            'gin-3-ba2motif.torch',
            'gin-4-ba2motif.torch',
            'gin-5-ba2motif.torch',
            'gin-6-ba2motif.torch',
            'gin-7-ba2motif.torch']

# g = graphs[44]
S = [0,1,2,3]
alpha = 0.
verbose = False
num_samples = 50
sample_idx = np.random.choice(len(graphs),num_samples,replace=False)

model_times = []

for model_dir in tqdm(model_dirs):
    ts = []
    nn = torch.load('models/'+model_dir)
    
    t_temp = 0
    for j, i in tqdm(enumerate(sample_idx)):
        g = graphs[i]
        timea = time.time()
        subgraph_original(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
        timeb = time.time()
        t_temp += timeb - timea
        
    ts.append(t_temp / num_samples)

    t_temp = 0
    for i in tqdm(sample_idx):
        g = graphs[i]
        timea = time.time()
        subgraph_mp_transcription(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
        timeb = time.time()
        t_temp += timeb - timea
    ts.append(t_temp / num_samples)

    t_temp = 0
    for i in tqdm(sample_idx):
        g = graphs[i]
        timea = time.time()
        subgraph_mp_forward_hook(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
        timeb = time.time()
        t_temp += timeb - timea
    ts.append(t_temp / num_samples)

    model_times.append(ts)

num_graphs: 1000


  0%|          | 0/6 [00:00<?, ?it/s]
0it [00:00, ?it/s][A
1it [00:00,  6.65it/s][A
7it [00:00, 31.82it/s][A
13it [00:00, 41.16it/s][A
19it [00:00, 45.55it/s][A
25it [00:00, 48.61it/s][A
31it [00:00, 50.01it/s][A
37it [00:00, 49.01it/s][A
42it [00:00, 48.38it/s][A
50it [00:01, 45.36it/s][A

  0%|          | 0/50 [00:00<?, ?it/s][A
100%|██████████| 50/50 [00:00<00:00, 276.22it/s][A

100%|██████████| 50/50 [00:00<00:00, 748.77it/s]
 17%|█▋        | 1/6 [00:01<00:06,  1.36s/it]
0it [00:00, ?it/s][A
1it [00:00,  6.79it/s][A
2it [00:00,  7.41it/s][A
3it [00:00,  8.12it/s][A
4it [00:00,  8.17it/s][A
5it [00:00,  8.27it/s][A
6it [00:00,  8.29it/s][A
7it [00:00,  8.35it/s][A
8it [00:00,  8.38it/s][A
9it [00:01,  8.07it/s][A
10it [00:01,  8.31it/s][A
11it [00:01,  8.29it/s][A
12it [00:01,  8.32it/s][A
13it [00:01,  8.36it/s][A
14it [00:01,  8.36it/s][A
15it [00:01,  8.36it/s][A
16it [00:01,  8.39it/s][A
17it [00:02,  8.39it/s][A
18it [00:02,  8.09it/s][A
19it [00:

In [None]:
# plotting
num_layers = np.arange(1,len(model_times)+1)
model_times = np.array(model_times)
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios': [3, 1]}, figsize=(3.5,4))
fig.subplots_adjust(hspace=0.05)  # adjust space between axes

plt.rc('legend', fontsize=14.5) 
ax2.spines['top'].set_visible(False)

for item in ([ax1.title, ax1.xaxis.label, ax1.yaxis.label] +
             ax1.get_xticklabels() + ax1.get_yticklabels()):
    item.set_fontsize(15)
for item in ([ax2.title, ax2.xaxis.label, ax2.yaxis.label] +
             ax2.get_xticklabels() + ax2.get_yticklabels()):
    item.set_fontsize(15)

ax1.set_ylabel("Time (s)")
ax2.set_xlabel(r'$L$')
plt.xticks(num_layers, [str(i) if i % 2 == 1 else '' for i in range(2,len(model_times)+2)])

ax1.plot(num_layers, model_times[:,0], 'c--')
line2, = ax1.plot(num_layers, [0]*len(num_layers), 'r-')
ax1.legend(['GNN-LRP naive', 'sGNN-LRP'])
line2.remove()
ax2.plot(num_layers, model_times[:,0], 'c--')
ax2.plot(num_layers, model_times[:,2], 'r-')
ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top

ax1.set_ylim(0.005)  # outliers only
ax2.set_ylim(-0,0.013)

d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
              linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)

plt.savefig('imgs/time_consumption_L.eps', dpi=600, format='eps', bbox_inches='tight')
# plt.show()

### |S| dependency

In [None]:
graphs, pos_idx, neg_idx = load_data('BA-2motif')

model_dir = 'gin-3-ba2motif.torch'

alpha = 0.
verbose = False
num_samples = 50
sample_idx = np.random.choice(len(graphs),num_samples,replace=False)
nn = torch.load('models/'+model_dir)

model_times = []

for size_S in tqdm(range(25)):
    S = list(range(size_S))

    ts = []
    
    t_temp = 0
    for j, i in tqdm(enumerate(sample_idx)):
        g = graphs[i]
        timea = time.time()
        subgraph_original(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
        timeb = time.time()
        t_temp += timeb - timea
        
    ts.append(t_temp / num_samples)

    t_temp = 0
    for i in tqdm(sample_idx):
        g = graphs[i]
        timea = time.time()
        subgraph_mp_transcription(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
        timeb = time.time()
        t_temp += timeb - timea
    ts.append(t_temp / num_samples)

    t_temp = 0
    for i in tqdm(sample_idx):
        g = graphs[i]
        timea = time.time()
        subgraph_mp_forward_hook(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
        timeb = time.time()
        t_temp += timeb - timea
    ts.append(t_temp / num_samples)

    model_times.append(ts)

In [None]:
# plotting
num_layers = np.arange(1,len(model_times)+1)
model_times = np.array(model_times)
fig, (ax1, ax2) = plt.subplots(2, 1, sharex=True, gridspec_kw={'height_ratios': [3, 1]}, figsize=(3.5,4))
fig.subplots_adjust(hspace=0.05)  # adjust space between axes
ax1.yaxis.tick_right()
ax1.yaxis.set_label_position("right")
ax2.yaxis.tick_right()

plt.rc('legend', fontsize=12)
ax2.spines['top'].set_visible(False)

for item in ([ax1.title, ax1.xaxis.label, ax1.yaxis.label] +
             ax1.get_xticklabels() + ax1.get_yticklabels()):
    item.set_fontsize(15)
for item in ([ax2.title, ax2.xaxis.label, ax2.yaxis.label] +
             ax2.get_xticklabels() + ax2.get_yticklabels()):
    item.set_fontsize(15)

ax2.set_xlabel(r'$|\mathcal{S}|$')
plt.xticks(num_layers, [str(i) if i % 3 == 1 else '' for i in range(1,len(model_times)+1)])

ax1.plot(num_layers, model_times[:,0], 'c--')
line2, = ax1.plot(num_layers, [0]*len(num_layers), 'r-')
line2.remove()
ax2.plot(num_layers, model_times[:,0], 'c--')
ax2.plot(num_layers, model_times[:,2], 'r-')

ax1.spines.bottom.set_visible(False)
ax2.spines.top.set_visible(False)
ax1.xaxis.tick_top()
ax1.tick_params(labeltop=False)  # don't put tick labels at the top

ax1.set_ylim(0.05)  # outliers only
ax2.set_ylim(-0,0.01)

d = .5  # proportion of vertical to horizontal extent of the slanted line
kwargs = dict(marker=[(-1, -d), (1, d)], markersize=12,
              linestyle="none", color='k', mec='k', mew=1, clip_on=False)
ax1.plot([0, 1], [0, 0], transform=ax1.transAxes, **kwargs)
ax2.plot([0, 1], [1, 1], transform=ax2.transAxes, **kwargs)

plt.savefig('imgs/time_consumption_S.eps', dpi=600, format='eps', bbox_inches='tight')

Compare time consumptions of the three methods

In [None]:
alpha = 0.
S = np.arange(5)
verbose = True
dataset_model_dirs = [['BA-2motif','gin-3-ba2motif.torch'],
                      ['BA-2motif','gin-5-ba2motif.torch'],
                      ['BA-2motif','gin-7-ba2motif.torch'],
                      ['MUTAG', 'gin-3-mutag.torch'],
                      ['Mutagenicity', 'gin-3-mutagenicity.torch'],
                      ['REDDIT-BINARY', 'gin-5-reddit.torch'],
                      ['Graph-SST2', 'gcn-3-sst2graph.torch']]

efficiency_result_originals = []
efficiency_result_mp_transcs = []
efficiency_result_forward_hooks = []

for dataset, model_dir in dataset_model_dirs:
    print(dataset, model_dir)
    graphs, pos_idx, neg_idx = load_data(dataset)

    nn = torch.load('models/'+model_dir)
    s = StringIO()
    old_stdout = sys.stdout
    sys.stdout = s
    lists = []

    for _ in tqdm(range(50)):
        i = np.random.randint(len(graphs))
        g = graphs[i]

        while g.nbnodes < 5:
            i = np.random.randint(len(graphs))
            g = graphs[i]

        for _ in range(3):
            subgraph_original(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
            subgraph_mp_transcription(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
            subgraph_mp_forward_hook(nn, g, S, alpha=alpha, gamma=None, verbose=verbose)
    lists = s.getvalue().splitlines()

    sys.stdout = old_stdout

    lists_ = [[float(data.split(': ')[-1].split(',')[0]) for data in l.split('\t')[1:]] for l in lists]

    efficiency_result_original = pd.DataFrame(lists_[::3],columns=['nbnodes','layers','overhead','subrel'])
    efficiency_result_mp_transc = pd.DataFrame(lists_[1::3],columns=['nbnodes','layers','overhead','subrel'])
    efficiency_result_forward_hook = pd.DataFrame(lists_[2::3],columns=['nbnodes','layers','forward1','backward1'])

    efficiency_result_originals.append(efficiency_result_original.mean(axis=0))
    efficiency_result_mp_transcs.append(efficiency_result_mp_transc.mean(axis=0))
    efficiency_result_forward_hooks.append(efficiency_result_forward_hook.mean(axis=0))

### Test different alphas

ACC AUC AUPC AUFC

In [None]:
verbose = False
test_samples = 200

# BA-2motif
dataset = 'BA-2motif'
graphs, pos_idx, neg_idx = load_data('BA-2motif')
model_dir = "models/gin-3-ba2motif.torch"; num_layer= 3
nn = torch.load(model_dir)

# # MUTAG
# dataset = 'MUTAG'
# graphs, pos_idx, neg_idx = load_data('MUTAG')
# model_dir = "models/gin-3-mutag.torch"; num_layer= 3
# nn = torch.load(model_dir)

# # Graph-SST2
# dataset = 'Graph-SST2'
# graphs, pos_idx, neg_idx = load_data('Graph-SST2')
# model_dir = "models/gcn-3-sst2graph.torch"; num_layer= 3
# nn = torch.load(model_dir)

message = '{}\nModel depth: {}, model: {}, nb of samples: {}\n'.format(model_dir,num_layer, 'gin', test_samples)
print(message)

messages = []
test_sample_idx = []
for alpha in np.arange(0.0,1.01,0.05):
    if dataset == 'BA-2motif':
        stats = {'acc': [], 'auc': [], 'auac': [], 'aupc': [], 'acs': [], 'pcs': [], 'label': []}
    else:
        stats = {'auac': [], 'aupc': [], 'acs': [], 'pcs': [], 'label': []}
    cnt_pos = test_samples / 2
    cnt_neg = test_samples / 2
    start = time.time()
    random_sample = True if test_sample_idx == [] else False
    i = 0
    while cnt_pos > 0 or cnt_neg > 0:
        if random_sample:
            idx = np.random.randint(len(graphs))
            g = graphs[idx]
            if g.nbnodes < 3: continue
            if g.label == 0:
                if cnt_pos == 0: continue
                else: cnt_pos -= 1
            else:
                if cnt_neg == 0: continue
                else: cnt_neg -= 1
            test_sample_idx.append(idx)
        else:
            if i >= len(test_sample_idx): break
            g = graphs[test_sample_idx[i]]
            i += 1

        gr_tr, all_feats = create_ground_truth(g)

        # mode = 'prun'
        mode = 'extr'

        H, transforms = get_H_transform(g.get_adj(),nn,gammas=None)
        fo = get_feat_order_local_best_guess(nn, g, alpha, H, transforms, mode='extr')
        
        if mode == 'extr':
            acc, auc = get_stats(gr_tr, fo, all_feats)
            auac, acs = get_auac_aupc(nn, g, fo, task=mode, use_softmax=True)
            aupc, pcs = [], []
        else:
            acc, auc = [], []
            aupc, pcs = get_auac_aupc(nn, g, fo, task=mode, use_softmax=False)
            auac, acs = [], []

        if dataset == 'BA-2motif':
            stats['acc'].append(acc)
            stats['auc'].append(auc)
        stats['auac'].append(auac)
        stats['aupc'].append(aupc)
        stats['acs'].append(acs)
        stats['pcs'].append(pcs)
        stats['label'].append(g.label)

    message = 'alpha {}, took {} s\n'.format(alpha, round(time.time() - start, 4))
    for key, lst in  stats.items():
        if key in ['acs', 'pcs', 'label']: continue
        elif key in ['acc', 'auc']:
            mstat = round(sum(lst)/len(lst), 4)
            message += '\t {} : {}'.format(key, mstat)
        elif key in ['auac', 'aupc']:
            lst = np.array(lst)
            pos = list(np.argwhere(np.array(stats['label']) == 0).flatten())
            mstat = round(np.sum(lst[pos])/len(pos), 4)
            message += '\t {}_pos : {}'.format(key, mstat)
            neg = list(np.argwhere(np.array(stats['label']) == 1).flatten())
            mstat = round(np.sum(lst[neg])/len(neg), 4)
            message += '\t {}_neg : {}'.format(key, mstat)
            mstat = round(np.sum(lst)/len(lst), 4)
            message += '\t {} : {}'.format(key, mstat)
    message += '\n'

    print(message)


In [None]:
# plot acc auc

with open('evaluation_results/result_local_best_guess.txt','r') as f:
    s = f.readlines()[2:]
stats = {}
for i in range(len(s)//2):
    alpha = float(s[i * 2].split(',')[0].split(' ')[-1])
    for sss in s[i * 2 + 1].split('\n')[0].split('\t')[1:]:
        sss = sss.split(':')
        if sss[0].strip() not in stats:
            stats[sss[0].strip()] = []
        stats[sss[0].strip()].append(float(sss[1]))
fig, ax = plt.subplots(figsize=(5,1.5))
ax.spines['right'].set_visible(False)
ax.spines['top'].set_visible(False)

for item in ([ax.title, ax.xaxis.label, ax.yaxis.label] +
             ax.get_xticklabels() + ax.get_yticklabels()):
    item.set_fontsize(15)
plt.rc('legend', fontsize=15)  

plt.plot(np.arange(0,1.01,0.05),stats['acc'], 'r-')
plt.plot(np.arange(0,1.01,0.05),stats['auc'], 'b--')
plt.legend(["Accuracy", "AUROC"])
plt.xlim(0,1)
plt.ylim(top=1)
plt.xlabel(r'$\alpha$')
# plt.xticks(np.arange(0,1.01,0.05))
plt.savefig('imgs/ba2motif_acc_auc.eps', dpi=600, format='eps', bbox_inches='tight')

In [None]:
# plot aupc aufc

# data_dir = 'evaluation_results/mutag_prun_result.txt'; dataset = 'mutag'
# data_dir = 'evaluation_results/mutag_acti_result.txt'; dataset = 'mutag'
# data_dir = 'evaluation_results/graphsst2_prun_result.txt'; dataset = 'graphsst2'
data_dir = 'evaluation_results/graphsst2_acti_result.txt'; dataset = 'graphsst2'

with open(data_dir,'r') as f:
    s = f.readlines()[2:]
stats = {}
for i in range(len(s)//2):
    alpha = float(s[i * 2].split(',')[0].split(' ')[-1])
    for sss in s[i * 2 + 1].split('\n')[0].split('\t')[1:]:
        sss = sss.split(':')
        if sss[0].strip() not in stats:
            stats[sss[0].strip()] = []
        stats[sss[0].strip()].append(float(sss[1]))

fig = plt.figure(figsize=(3,3))
ax1 = fig.add_subplot(111)

for item in ([ax1.title, ax1.xaxis.label, ax1.yaxis.label] +
             ax1.get_xticklabels() + ax1.get_yticklabels()):
    item.set_fontsize(15)

if 'aupc_pos' in stats.keys():
    key = 'aupc_pos'
    ax1.annotate(xy=(np.argmin(stats[key])*0.05,np.min(stats[key])), text='', 
                 xytext=(np.argmin(stats[key])*0.05,np.min(stats[key])-0.15), arrowprops={'arrowstyle':'->','color':'red'})
else:
    key = 'auac_pos'
    ax1.annotate(xy=(np.argmax(stats[key])*0.05,np.max(stats[key])), text='', 
                 xytext=(np.argmax(stats[key])*0.05,np.max(stats[key])-0.25), arrowprops={'arrowstyle':'->','color':'red'})
ax1.plot(np.arange(0,1.01,0.05),stats[key], 'r-')

ax1.set_ylabel('positive')
ax1.yaxis.label.set_color('red')
ax1.tick_params(axis='y', colors='red')
ax1.set_xlabel(r'$\alpha$')

ax2 = ax1.twinx()
for item in ([ax2.title, ax2.xaxis.label, ax2.yaxis.label] +
             ax2.get_xticklabels() + ax2.get_yticklabels()):
    item.set_fontsize(15)
if 'aupc_neg' in stats.keys():
    key = 'aupc_neg'
    ax2.annotate(xy=(np.argmin(stats[key])*0.05,np.min(stats[key])), text='', 
                 xytext=(np.argmin(stats[key])*0.05,np.min(stats[key])-0.15), arrowprops={'arrowstyle':'->','color':'blue'})
else:
    key = 'auac_neg'
    ax2.annotate(xy=(np.argmax(stats[key])*0.05,np.max(stats[key])), text='', 
                 xytext=(np.argmax(stats[key])*0.05,np.max(stats[key])-0.25), 
                 arrowprops={'arrowstyle':'->','color':'blue'})

ax2.set_ylabel('negative')
ax2.yaxis.label.set_color('blue')
ax2.tick_params(axis='y', colors='blue')

ax2.plot(np.arange(0,1.01,0.05),stats[key], 'b--')
ax2.set_xlim(0,1)
ax2.set_xlabel(r'$\alpha$')

if 'aupc_neg' in stats.keys():
    plt.savefig('imgs/'+dataset+'_aupc.eps', dpi=600, format='eps',bbox_inches='tight')
else:
    plt.savefig('imgs/'+dataset+'_auac.eps', dpi=600, format='eps',bbox_inches='tight')
