# Evaluate the fidelity

In [1]:
from Datasets.synthetics import Infection
import torch
import numpy as np

%load_ext autoreload
%autoreload 2

  from .autonotebook import tqdm as notebook_tqdm


## Setup PyG

In [2]:
torch.set_num_threads(4)
device = 'cuda'

from utils import set_seeds

ModuleNotFoundError: No module named 'utils'

# Compute the fidelity

In [3]:
from utils import build_expl, compute_fidelity
from models.models_Infection import GCN_framework, GraphSAGE_framework, GAT_framework, GIN_framework, CHEB_framework


# Define the parameters
DATASET = 'Infection'
MODELS = ['GCN', 'GAT', 'GIN', 'GraphSAGE', 'Cheb']

EXPLS = ['cam', 'grad_cam', 'grad_exp', 'guided_bp', 'ig_node', 'pgmexplainer', 'gnnexpl', 'pgexplainer']
EXPLS = ["subgraphX"]
MODES = ['train']

IGNORE = [('Infection', 'Cheb', 'gnnexpl', 'train'),
          ('Infection', 'Cheb', 'pgexplainer', 'train'),
          ('Infection', 'GIN', 'gnnexpl', 'train')]

GNN_NUM_LAYERS = {'GCN': 2, 'GAT': 2, 'GIN': 2, 'GraphSAGE': 2, 'Cheb': 2}

FRAMEWORKS = {'GCN': GCN_framework, 
              'GAT': GAT_framework, 
              'GIN': GIN_framework,
              'GraphSAGE': GraphSAGE_framework, 
              'Cheb': CHEB_framework
             }


# Load the dataset
set_seeds()
dataset = Infection()


# Define history variables
suff = {0: [], 1: [], 2: []}
model, expl = [], []


# Compute the metrics
for MODE in MODES:
    for MODEL in MODELS:
        print(8 * '* ' + MODEL + 8 * ' *')
        
        # Define and load the trained model
        gcn = FRAMEWORKS[MODEL](dataset, device='cuda')
        path = 'models/' + DATASET + '_' + MODEL
        gcn.load_model(path)

        # Loop over the explainers
        for EXPL in EXPLS:
            # Define the setting and store it
            ID = (DATASET, MODEL, EXPL, MODE)
            model += [MODEL]
            expl += [EXPL]
            
            # Compute the sufficiency
            if ID in IGNORE:
                # Return nan metrics
                suff[0].append(float('nan'))
                suff[1].append(float('nan'))
                suff[2].append(float('nan'))
            else:
                # Load and process the explanations
                graphs = build_expl(DATASET, MODEL, EXPL, GNN_NUM_LAYERS, num_features=2, cut_ego=False)

                # Loop over the class labels
                for label in graphs:
                    if not graphs[label] == None:
                        # If there are valid explanations compute suff and comp    
                        suff[label] += [compute_fidelity(gcn, dataset.data, graphs, num_features=2, y=label)]
                    else:
                        # Otherwise return nan metrics
                        suff[label].append(float('nan'))
                    
            # Print the partial results                    
            print(' '.join([' {:.3f}'.format(suff[label][-1]) if ~np.isnan(suff[label][-1]) else ' -----' for label in suff]) + '\t' + EXPL)        

* * * * * * * * GCN * * * * * * * *
 -----  -----  -----	cam
 0.548  0.305  0.590	grad_cam
 0.098  0.533  0.521	grad_exp
 -----  0.267  0.644	guided_bp
 0.997  0.715  -0.205	ig_node
 -----  0.537  -----	pgmexplainer


  out=out, **kwargs)
  ret = ret.dtype.type(ret / rcount)


 -----  0.912  0.004	gnnexpl
 -----  -----  -----	pgexplainer
* * * * * * * * GAT * * * * * * * *
 -0.414  0.775  0.753	cam
 -0.380  0.774  -----	grad_cam
 -0.413  0.774  0.781	grad_exp
 -----  0.779  -----	guided_bp
 -0.433  0.775  0.795	ig_node
 -----  0.765  -----	pgmexplainer
 0.130  0.949  0.860	gnnexpl
 -----  -----  -----	pgexplainer
* * * * * * * * GIN * * * * * * * *
 0.046  0.526  -----	cam
 -0.498  0.496  0.595	grad_cam
 -0.026  0.197  0.679	grad_exp
 -0.474  0.562  0.656	guided_bp
 0.304  0.042  -----	ig_node
 -0.351  -----  -----	pgmexplainer
 -----  -----  -----	gnnexpl
 -----  -----  -----	pgexplainer
* * * * * * * * GraphSAGE * * * * * * * *
 0.643  -0.179  0.774	cam
 0.641  -0.193  0.775	grad_cam
 0.550  -0.193  0.795	grad_exp
 0.629  -0.189  0.779	guided_bp
 0.643  -0.191  0.775	ig_node
 -----  -----  -----	pgmexplainer
 0.998  -0.000  0.997	gnnexpl
 -----  -----  -----	pgexplainer
* * * * * * * * Cheb * * * * * * * *
 0.492  -0.106  0.626	cam
 0.579  0.057  0.500	gra

In [4]:
import pandas as pd
results = pd.DataFrame({'model': model, 'expl': expl, 'class 0': suff[0], 'class 1': suff[1], 'class 2': suff[2]})
results.fillna(-100).to_csv('./metrics/fidelity_sufficiency_' + '_'.join(MODELS) + '.csv')
results

Unnamed: 0,model,expl,class 0,class 1,class 2
0,GCN,cam,,,
1,GCN,grad_cam,0.547851,0.305446,0.590118
2,GCN,grad_exp,0.098242,0.533261,0.521082
3,GCN,guided_bp,,0.267366,0.64364
4,GCN,ig_node,0.99746,0.714688,-0.204753
5,GCN,pgmexplainer,,0.537414,
6,GCN,gnnexpl,,0.911612,0.003924
7,GCN,pgexplainer,,,
8,GAT,cam,-0.413834,0.774938,0.752584
9,GAT,grad_cam,-0.380134,0.773831,
