In [None]:
import numpy as np

import torch
from torch_geometric.data import Data
from torch_geometric.explain import Explainer, CaptumExplainer
import time

import import_ipynb
from dataset import get_data
from simulation import simul_fromData
from constants import *
from utils import *

from XAI_methods.gnnexplainer import GNNExplainer
from XAI_methods.gnnexplainer_variant import GNNExplainer_variant
from XAI_methods.gnnexplainer_variant_initsaliency import GNNExplainer_variant_initsaliency
from XAI_methods.pgexplainer import pretrain_pyg_pgexplainer
from XAI_methods.pgexplainer_variant import pretrain_pyg_pgexplainer_variant
from XAI_methods.pgexplainer_onlyconc import PGExplainer_onlyconc
from XAI_methods.pgexplainer_onlyconc_initsaliency import PGExplainer_onlyconc_initsaliency
from XAI_methods.pgexplainer_variant_GNN import pretrain_pyg_pgexplainer_variant_GNN


In [None]:
def pyg_expl_pretrain(alg_name, dataset_name, del_edge_num, **alg_kwags):
    if alg_name=='PGExplainer': return pretrain_pyg_pgexplainer(dataset_name, **alg_kwags)
    elif alg_name=='PGExplainer_variant': return pretrain_pyg_pgexplainer_variant(dataset_name, del_edge_num, **alg_kwags)
    elif alg_name=='PGExplainer_variant_GNN': return pretrain_pyg_pgexplainer_variant_GNN(dataset_name, del_edge_num, **alg_kwags)        

In [None]:

def pyg_expl(alg_name, adj_list, seed_idx, prob, del_edge_num, **kwargs):
    time_alg = -time.time()
    model_name = kwargs.get('model_name', DEFAULT_MODEL)
    gpu_num = kwargs.get('gpu_num', 'cpu')
    kwargs['del_edge_num'] = del_edge_num
    gnn_latent_dim = kwargs.get('gnn_latent_dim', [128,128,128,128,128,128])
    if 'gnn_latent_dim' in kwargs:
        del kwargs['gnn_latent_dim']
    
    device = set_gpu(gpu_num)
    model = load_model(model_name, device, gnn_latent_dim=gnn_latent_dim)  

    
    if alg_name=='GNNExplainer':
        explainer = Explainer(
            model=model,
            algorithm=GNNExplainer(**kwargs),
            explanation_type='phenomenon',
            edge_mask_type='object',
            model_config=dict(
                mode='regression',
                task_level='graph',
                return_type='raw',
            )
        )
    elif alg_name=='GNNExplainer_variant':
        explainer = Explainer(
            model=model,
            algorithm=GNNExplainer_variant(**kwargs),
            explanation_type='phenomenon',
            edge_mask_type='object',
            model_config=dict(
                mode='regression',
                task_level='graph',
                return_type='raw'
            )
        )
    elif alg_name=='GNNExplainer_variant_initsaliency':
        explainer = Explainer(
            model=model,
            algorithm=GNNExplainer_variant_initsaliency(epochs=epochs,lr=lr,del_edge_num=del_edge_num),
            explanation_type='phenomenon',
            edge_mask_type='object',
            model_config=dict(
                mode='regression',
                task_level='graph',
                return_type='raw'
            )
        )
    elif alg_name=='PGExplainer_onlyconc':
        explainer = Explainer(
            model=model,
            algorithm=PGExplainer_onlyconc(**kwargs),
            explanation_type='phenomenon',
            edge_mask_type='object',
            model_config=dict(
                mode='regression',
                task_level='graph',
                return_type='raw',
            )
        )
    elif alg_name=='PGExplainer_onlyconc_initsaliency':
        explainer = Explainer(
            model=model,
            algorithm=PGExplainer_onlyconc_initsaliency(epochs=epochs,lr=lr,del_edge_num=del_edge_num),
            explanation_type='phenomenon',
            edge_mask_type='object',
            model_config=dict(
                mode='regression',
                task_level='graph',
                return_type='raw',
            )
        )
    elif alg_name in ['Saliency', 'InputXGradient', 'Deconvolution', 'GuidedBackprop', 'IntegratedGradients']:
        explainer = Explainer(
            model=model,
            algorithm=CaptumExplainer(alg_name),
            explanation_type='phenomenon',
            edge_mask_type='object',
            model_config=dict(
                mode='regression',
                task_level='graph',
                return_type='raw',
            )
        )
    else: explainer = alg_name  

    
    n = len(adj_list)
    is_seed = np.zeros(n, dtype=int)
    is_seed[seed_idx] = 1
    edge_index = [[],[]]
    edge_attr = []
    edge = []
    for u in range(n):
        for v,p in adj_list[u]:
            edge_index[0].append(u)
            edge_index[1].append(v)
            edge_attr.append([p])
            edge.append((u,v,p))
    edge_index = torch.tensor(edge_index)
    edge_attr = torch.tensor(edge_attr)
    seed = torch.from_numpy(np.expand_dims(is_seed,axis=-1)).float()
    prob = torch.from_numpy(np.expand_dims(prob,axis=-1)).float()
    data = Data(x=seed, edge_index=edge_index, edge_attr=edge_attr, y=prob)
    edge = np.array(edge)

    
    
    data = data.to(device)
    explanation = explainer(data.x, data.edge_index, edge_attr=data.edge_attr, target=torch.unsqueeze(torch.sum(data.y,dim=0),0))
    mask = explanation.edge_mask.cpu()

    if 'prod_p' in kwargs and kwargs['prod_p']==True:
        mask = mask * edge_attr.cpu().squeeze()

    _, indices = torch.sort(mask, descending=True)
    edge = edge[indices.numpy()]

    time_alg += time.time()
    return edge[:del_edge_num], time_alg