In [1]:
import os
%load_ext autoreload
%autoreload 2
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Import Libraries

In [2]:
import pandas as pd
import torch
import torch_geometric as pyg

from torch_geometric.explain import ExplainerConfig, ModelConfig
from torch_geometric.explain.config import ModelMode 

from model_store import get_gnn, model_names
from data_store import get_lp_dataset, lp_datasets
from explainer_store import get_explainer, explainer_names
from explain_lp import evaluate_lp_explainer, evaluate_lp_explainer_on_data
from utils import setup_models, get_motif_nodes

In [3]:
device = 'cpu'#torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
dataset_path = 'datasets/'
model_path = 'models/'
model_name = 'all'
stds = ['00', '01', '02', '03', '04', '05', '06', '07', '08', '09', '10']
std = stds[4]

In [5]:
gcn = pyg.nn.GCN(in_channels=-1, hidden_channels=1, num_layers=1)
grapsage = pyg.nn.GraphSAGE(in_channels=-1, hidden_channels=1, num_layers=1)
gat = pyg.nn.GAT(in_channels=-1, hidden_channels=1, num_layers=1)
gin = pyg.nn.GIN(in_channels=-1, hidden_channels=1, num_layers=1)

# Explain Link Predictions

In [6]:
metric_names = ['accuracy', 'precision', 'recall', 'iou', 'fid+', 'fid-', 'unfaithfulness', 'characterization_score',
                'inference_time']
explainer_config = ExplainerConfig(
    explanation_type='phenomenon',
    node_mask_type='object',
    edge_mask_type='object',
)

In [7]:
def evaluation_df(eval_data, dataset_names, metric_names):
    # Flatten the nested dictionary into a list of rows
    rows = []
    for (explainer, model), dataset_metrics in eval_data.items():
        for dataset in dataset_names:
            row = {
                'explainer': explainer,
                'dataset': dataset
            }

            # Split the model name into 'model' (before the hyphen) and 'std' (after the hyphen)
            model_parts = model.split('-')
            if len(model_parts) == 2:
                row['model'] = model_parts[0]  # e.g., 'gcn'
                row['std'] = model_parts[1]  # e.g., '00' (store as 'std')
            else:
                row['model'] = model  # If no hyphen, keep the entire model as 'model'
                row['std'] = None  # No model ID, set 'std' as None

            # Add metric values to the row
            for metric in metric_names:
                row[metric] = dataset_metrics.get((dataset, metric), None)

            rows.append(row)

    # Create DataFrame from the list of rows
    dataframe = pd.DataFrame(rows)

    # Reorder columns: model and std columns should come before accuracy
    column_order = ['explainer', 'dataset', 'model', 'std'] + metric_names
    dataframe = dataframe[column_order]

    return dataframe

# Load Link Prediction Datasets

In [9]:
datasets = get_lp_dataset(dataset_path, 'all', std=std)
for idx, (dataset_name, train_data, val_data, test_data) in enumerate(datasets):
        datasets[idx] = (dataset_name, train_data.to(device), val_data.to(device), test_data.to(device))

$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
ba_shapes_link with 04 standard deviation pickle file information
Train data: Data(x=[700, 2], edge_index=[2, 4110], y=[700], edge_label=[3288], edge_label_index=[2, 3288], motif_edge_label=[3288], node_mask=[700])
Validation data: Data(x=[700, 2], edge_index=[2, 4110], y=[700], edge_label=[410], edge_label_index=[2, 410], motif_edge_label=[410], node_mask=[700])
Test data: Data(x=[700, 2], edge_index=[2, 4110], y=[700], edge_label=[410], edge_label_index=[2, 410], motif_edge_label=[410], node_mask=[700])
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
tree_grid_link with 04 standard deviation pickle file information
Train data: Data(x=[1231, 2], edge_index=[2, 3130], y=[1231], edge_label=[2504], edge_label_index=[2, 2504], motif_edge_label=[2504], node_mask=[1231])
Validation data: Data(x=[1231, 2], edge_index=[2, 3130], y=[1231], 

## Load Link Prediction Models

In [10]:
task = 'lp'
dataset_models = []

### BA-Shapes-Link

In [11]:
basl_models = get_gnn(model_path, task, model_name, 'ba_shapes_link', std=std)
basl_models = setup_models(basl_models, device)
dataset_models.append(basl_models)

### Tree-Grid-Link

In [12]:
trgl_models = get_gnn(model_path, task, model_name, 'tree_grid_link', std=std)
trgl_models = setup_models(trgl_models, device)
dataset_models.append(trgl_models)

In [13]:
model_config = ModelConfig(
    mode='binary_classification',
    task_level='edge',
    return_type='raw'
)

In [14]:
did = 0
data = datasets[did][-1]
models = dataset_models[did]
motif_nodes = (data.y > 0).nonzero().view(-1)
start = motif_nodes[0].item()
end = motif_nodes[-1].item()
step = 5 if dataset_name == 'ba_shapes_link' else 9

In [15]:
model_name, model = models[0]
model

LP_GNN(
  (model): GCN(-1, 20, num_layers=2)
  (criterion): BCEWithLogitsLoss()
  (decoder): InnerProductDecoder()
)

# Explainers

In [16]:
pos_edges = data.edge_label_index[:, (data.motif_edge_label.bool() & data.edge_label.bool())]
#pos_edges = data.edge_label_index[:, data.edge_label.bool()]
deg = pyg.utils.degree(data.edge_index[0], num_nodes=data.num_nodes)
degree_zero_nodes = torch.nonzero(deg == 0, as_tuple=True)[0]
source_nodes = pos_edges[0]
target_nodes = pos_edges[1]

mask = ~torch.isin(source_nodes, degree_zero_nodes) & ~torch.isin(target_nodes, degree_zero_nodes)

num_examples = 200
edge_label_indices = pos_edges[:, mask][:, :num_examples]

idx = 50
edge_label_index = edge_label_indices[:, idx].view(-1, 1)
target = torch.tensor([1])  #data.edge_label[idx].unsqueeze(dim=0).long()
edge_label_index, target

(tensor([[545],
         [549]]),
 tensor([1]))

## Random

In [17]:
explainer_name = explainer_names[0]
explainer = get_explainer(explainer_name, explainer_config, model, model_config)
explanation = explainer(data.x, data.edge_index, target=target, edge_label_index=edge_label_index)
explanation

Explanation(node_mask=[700, 1], edge_mask=[4110], target=[1], x=[700, 2], edge_index=[2, 4110], edge_label_index=[2, 1])

In [18]:
evaluate_lp_explainer_on_data(explainer, data, edge_label_indices, metric_names, start, end, step)

{'accuracy': 1.0,
 'precision': 0.013182675,
 'recall': 0.013182675,
 'iou': 0.0069790627,
 'fid+': 0.06779661016949153,
 'fid-': 0.0847457627118644,
 'unfaithfulness': 0.0,
 'characterization_score': 0.12624196376388078,
 'inference_time': 0.00026919882176286084}

In [19]:
eval_metrics = evaluate_lp_explainer(model_path, explainer_name, explainer_config, datasets, metric_names,
                                     std=std)
eval_metrics_df = evaluation_df(eval_metrics, lp_datasets, metric_names)
eval_metrics_df

-- Evaluating random_explainer explainer on link prediction datasets...
--- Evaluating random_explainer explainer on ba_shapes_link dataset...
----- Evaluating random_explainer explainer on gcn-04 model...
------- Evaluation on gcn-04 model took 0.01 minutes.
----- Evaluating random_explainer explainer on graphsage-04 model...


  denom = (pos_weight / pos_fidelity) + (neg_weight / (1. - neg_fidelity))


------- Evaluation on graphsage-04 model took 0.01 minutes.
----- Evaluating random_explainer explainer on gat-04 model...
------- Evaluation on gat-04 model took 0.02 minutes.
----- Evaluating random_explainer explainer on gin-04 model...
------- Evaluation on gin-04 model took 0.01 minutes.
------ Evaluation on ba_shapes_link dataset took 0.04 minutes.
--- Evaluating random_explainer explainer on tree_grid_link dataset...
----- Evaluating random_explainer explainer on gcn-04 model...
------- Evaluation on gcn-04 model took 0.02 minutes.
----- Evaluating random_explainer explainer on graphsage-04 model...
------- Evaluation on graphsage-04 model took 0.01 minutes.
----- Evaluating random_explainer explainer on gat-04 model...
------- Evaluation on gat-04 model took 0.02 minutes.
----- Evaluating random_explainer explainer on gin-04 model...
------- Evaluation on gin-04 model took 0.01 minutes.
------ Evaluation on tree_grid_link dataset took 0.07 minutes.
--- Evaluation on node classi

Unnamed: 0,explainer,dataset,model,std,accuracy,precision,recall,iou,fid+,fid-,unfaithfulness,characterization_score,inference_time
0,random_explainer,ba_shapes_link,gcn,4,1.0,0.00339,0.00339,0.001883,0.084746,0.067797,0.0,0.155367,0.00051
1,random_explainer,tree_grid_link,gcn,4,1.0,0.002554,0.002554,0.001352,0.011494,0.011494,0.0,0.022724,0.000673
2,random_explainer,ba_shapes_link,graphsage,4,1.0,0.00339,0.00339,0.001883,0.0,0.0,0.0,0.0,0.000388
3,random_explainer,tree_grid_link,graphsage,4,1.0,0.002554,0.002554,0.001352,0.0,0.0,0.0,0.0,0.000579
4,random_explainer,ba_shapes_link,gat,4,1.0,0.013559,0.013559,0.007533,0.084746,0.084746,0.0,0.155128,0.000193
5,random_explainer,tree_grid_link,gat,4,1.0,0.007663,0.007663,0.004057,0.08046,0.045977,0.0,0.148404,0.000283
6,random_explainer,ba_shapes_link,gin,4,1.0,0.010169,0.010169,0.00565,0.067797,0.050847,0.0,0.126554,0.000191
7,random_explainer,tree_grid_link,gin,4,1.0,0.00894,0.00894,0.004733,0.0,0.0,0.0,0.0,0.000329


## GNNExplainer

In [20]:
explainer_name = explainer_names[1]
explainer = get_explainer(explainer_name, explainer_config, model, model_config)
explanation = explainer(data.x, data.edge_index, target=target, edge_label_index=edge_label_index)
explanation

Explanation(node_mask=[700, 1], edge_mask=[4110], target=[1], x=[700, 2], edge_index=[2, 4110], edge_label_index=[2, 1])

In [21]:
evaluate_lp_explainer_on_data(explainer, data, edge_label_indices, metric_names, start, end, step)

  denom = (pos_weight / pos_fidelity) + (neg_weight / (1. - neg_fidelity))


{'accuracy': 1.0,
 'precision': 0.2888889,
 'recall': 0.2888889,
 'iou': 0.1804525,
 'fid+': 0.0,
 'fid-': 0.0,
 'unfaithfulness': 0.0,
 'characterization_score': 0.0,
 'inference_time': 1.8985695838928223}

In [None]:
eval_metrics = evaluate_lp_explainer(model_path, explainer_name, explainer_config, datasets, metric_names,
                                     std=std)
eval_metrics_df = evaluation_df(eval_metrics, lp_datasets, metric_names)
eval_metrics_df

## PGExplainer

In [27]:
explainer_name = explainer_names[2]
explainer = get_explainer(explainer_name, explainer_config, model, model_config, dataset=data, edge_label_indices=edge_label_indices)

PGExplainer took 0.45 minutes to train. Best loss: 0.7040


In [28]:
explanation = explainer(data.x, data.edge_index, target=target, edge_label_index=edge_label_index)
explanation

Explanation(edge_mask=[4110], target=[1], x=[700, 2], edge_index=[2, 4110], edge_label_index=[2, 1])

In [None]:
evaluate_lp_explainer_on_data(explainer, data, edge_label_indices, metric_names, start, end, step)

{'accuracy': 1.0,
 'precision': 0.22222224,
 'recall': 0.22222224,
 'iou': 0.14192308,
 'fid+': 0.2,
 'fid-': 0.2,
 'unfaithfulness': 0.0,
 'characterization_score': 0.32,
 'inference_time': 0.007938146591186523}

In [None]:
eval_metrics = evaluate_lp_explainer(model_path, explainer_name, explainer_config, datasets, metric_names,
                                     std=std)
eval_metrics_df = evaluation_df(eval_metrics, lp_datasets, metric_names)
eval_metrics_df

## CIExplainer

In [30]:
explainer_name = explainer_names[4]
explainer = get_explainer(explainer_name, explainer_config, model, model_config, dataset=data, dataset_name=dataset_name, edge_label_indices=edge_label_indices)
target = model(data.x, data.edge_index, edge_label_index=edge_label_index).sigmoid()
explanation = explainer(data.x, data.edge_index, target=target, edge_label_index=edge_label_index)
explanation

Explanation(node_mask=[700, 1], edge_mask=[4110], target=[1], x=[700, 2], edge_index=[2, 4110], edge_label_index=[2, 1])

In [None]:
evaluate_lp_explainer_on_data(explainer, data, edge_label_indices, metric_names, start, end, step)

  denom = (pos_weight / pos_fidelity) + (neg_weight / (1. - neg_fidelity))


{'accuracy': 1.0,
 'precision': 0.24444444,
 'recall': 0.24444444,
 'iou': 0.15368779,
 'fid+': 0.0,
 'fid-': 0.0,
 'unfaithfulness': 0.0,
 'characterization_score': 0.0,
 'inference_time': 0.024094915390014647}

In [None]:
eval_metrics = evaluate_lp_explainer(model_path, explainer_name, explainer_config, datasets, metric_names,
                                     std=std)
eval_metrics_df = evaluation_df(eval_metrics, lp_datasets, metric_names)
eval_metrics_df