In [1]:
import os
%load_ext autoreload
%autoreload 2
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Import Libraries

In [None]:

import pandas as pd
import torch
import torch_geometric as pyg

from torch_geometric.explain import ExplainerConfig, ModelConfig
from torch_geometric.explain.config import ModelMode 

from model_store import get_gnn, model_names
from data_store import get_nc_dataset, nc_datasets
from explainer_store import get_explainer, explainer_names
from explain_nc import evaluate_nc_explainer, evaluate_nc_explainer_on_data
from utils import setup_models, get_motif_nodes

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
dataset_path = 'datasets/'
model_path = 'models/'
model_name = 'all'
stds = ['00', '01', '02', '03', '04', '05', '06', '07', '08', '09', '10']
std = stds[4]

In [6]:
gcn = pyg.nn.GCN(in_channels=-1, hidden_channels=1, num_layers=1)
grapsage = pyg.nn.GraphSAGE(in_channels=-1, hidden_channels=1, num_layers=1)
gat = pyg.nn.GAT(in_channels=-1, hidden_channels=1, num_layers=1)
gin = pyg.nn.GIN(in_channels=-1, hidden_channels=1, num_layers=1)

# Explain Node Classification Predictions

In [7]:
metric_names = ['accuracy', 'precision', 'recall', 'iou', 'fid+', 'fid-', 'unfaithfulness', 'characterization_score',
                'inference_time']
explainer_config = ExplainerConfig(
    explanation_type='phenomenon',
    node_mask_type='object',
    edge_mask_type='object',
)

In [43]:
def evaluation_df(eval_data, dataset_names, metric_names):
    # Flatten the nested dictionary into a list of rows
    rows = []
    for (explainer, model), dataset_metrics in eval_data.items():
        for dataset in dataset_names:
            row = {
                'explainer': explainer,
                'dataset': dataset
            }

            # Split the model name into 'model' (before the hyphen) and 'std' (after the hyphen)
            model_parts = model.split('-')
            if len(model_parts) == 2:
                row['model'] = model_parts[0]  # e.g., 'gcn'
                row['std'] = model_parts[1]  # e.g., '00' (store as 'std')
            else:
                row['model'] = model  # If no hyphen, keep the entire model as 'model'
                row['std'] = None  # No model ID, set 'std' as None

            # Add metric values to the row
            for metric in metric_names:
                row[metric] = dataset_metrics.get((dataset, metric), None)

            rows.append(row)

    # Create DataFrame from the list of rows
    dataframe = pd.DataFrame(rows)

    # Reorder columns: model and std columns should come before accuracy
    column_order = ['explainer', 'dataset', 'model', 'std'] + metric_names
    dataframe = dataframe[column_order]

    return dataframe

# Load Node Classification Datasets

In [12]:
datasets = get_nc_dataset(dataset_path, 'all', std_str=std, new=False)
for idx, (dataset_name, data) in enumerate(datasets):
    datasets[idx] = (dataset_name, data.to(device))

$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
Loading ba_shapes dataset with 04 standard deviation from datasets/ba_shapes/ba_shapes04.pth
Number of PyTorch Geometric Data object (undirected) edges: 4110
Used feature matrix shape: torch.Size([700, 2])
Average node degree: 5.87
Number of ground truth edges: 480
Node mask shape: torch.Size([700])
Edge mask shape: torch.Size([4110])
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
Loading tree_grid dataset with 04 standard deviation from datasets/tree_grid/tree_grid04.pth
Number of PyTorch Geometric Data object (undirected) edges: 3130
Used feature matrix shape: torch.Size([1231, 2])
Average node degree: 2.54
Number of ground truth edges: 960
Node mask shape: torch.Size([1231])
Edge mask shape: torch.Size([3130])


# Load Node Classification Models

In [27]:
task = 'nc'
dataset_models = []

## BA-Shapes

In [28]:
bas_models = get_gnn(model_path, task, model_name, 'ba_shapes', std=std)
bas_models = setup_models(bas_models, device)
dataset_models.append(bas_models)

## Tree-Grid

In [29]:
trg_models = get_gnn(model_path, task, model_name, 'tree_grid', std=std)
trg_models = setup_models(trg_models, device)
dataset_models.append(trg_models)

In [17]:
model_config = ModelConfig(
    mode='multiclass_classification',
    task_level='node',
    return_type='raw'
)

In [30]:
did = 0
models = dataset_models[did]
dataset_name, data = datasets[did]
motif_nodes = (data.y > 0).nonzero().view(-1)
start = motif_nodes[0].item()
end = motif_nodes[-1].item()
step = 5 if dataset_name == 'ba_shapes' else 9
motif_nodes_mask = (data.y > 0) & data.test_mask
test_nodes = motif_nodes_mask.nonzero().view(-1)
if 'tree' in dataset_name:
    model_config.mode = ModelMode.binary_classification
dataset_name, data, test_nodes.size(0)

('ba_shapes',
 Data(edge_index=[2, 4110], y=[700], train_mask=[700], val_mask=[700], test_mask=[700], gt_edges=[2, 480], edge_mask=[4110], node_mask=[700], num_classes=4, num_nodes=700, x=[700, 2]),
 42)

In [19]:
get_motif_nodes(start, end, step, test_nodes[0].item())

tensor([300, 301, 302, 303, 304])

In [41]:
model_name, model = models[0]
model_name, model

('gcn',
 NC_GNN(
   (model): GCN(-1, 4, num_layers=3)
   (criterion): CrossEntropyLoss()
 ))

In [32]:
model.eval()
with torch.no_grad():
    pred = model(data.x, data.edge_index)
    pred = torch.softmax(pred, dim=-1).argmax(dim=-1) if model_config.mode == ModelMode.multiclass_classification else (
            pred.sigmoid().view(-1) > 0.5).long()
    acc = pred[test_nodes].eq(data.y[test_nodes]).sum().item() / test_nodes.size(0)
print(acc)

0.8333333333333334


# Explainers

In [33]:
test_node_idx = 0

## Random

In [36]:
explainer_name = explainer_names[0]
explainer = get_explainer(explainer_name, explainer_config, model, model_config)
explanation = explainer(data.x, data.edge_index, target=data.y, index=test_nodes[test_node_idx].item())
explanation

Explanation(node_mask=[700, 1], edge_mask=[4110], target=[700], index=[1], x=[700, 2], edge_index=[2, 4110])

In [40]:
evaluate_nc_explainer_on_data(explainer, data, test_nodes, metric_names, start, end, step)

{'accuracy': 0.5129643621898833,
 'precision': 0.02857143,
 'recall': 0.02857143,
 'iou': 0.016534392,
 'fid+': 0.30952380952380953,
 'fid-': 0.3333333333333333,
 'unfaithfulness': 0.4449986834522514,
 'characterization_score': 0.4227642276422764,
 'inference_time': 0.0005607321148826962}

In [42]:
eval_metrics = evaluate_nc_explainer(model_path, 'random_explainer', explainer_config, datasets, metric_names,
                                     std=std)
eval_metrics_df = evaluation_df(eval_metrics, nc_datasets, metric_names)
eval_metrics_df

-- Evaluating random_explainer explainer on node classification datasets...
--- Evaluating random_explainer explainer on ba_shapes dataset...
----- Evaluating random_explainer explainer on gcn-04 model...
------- Evaluation on gcn-04 model took 0.02 minutes.
----- Evaluating random_explainer explainer on graphsage-04 model...
------- Evaluation on graphsage-04 model took 0.01 minutes.
----- Evaluating random_explainer explainer on gat-04 model...
------- Evaluation on gat-04 model took 0.02 minutes.
----- Evaluating random_explainer explainer on gin-04 model...
------- Evaluation on gin-04 model took 0.01 minutes.
------ Evaluation on ba_shapes dataset took 0.08 minutes.
--- Evaluating random_explainer explainer on tree_grid dataset...
----- Evaluating random_explainer explainer on gcn-04 model...
------- Evaluation on gcn-04 model took 0.03 minutes.
----- Evaluating random_explainer explainer on graphsage-04 model...


  denom = (pos_weight / pos_fidelity) + (neg_weight / (1. - neg_fidelity))


------- Evaluation on graphsage-04 model took 0.02 minutes.
----- Evaluating random_explainer explainer on gat-04 model...
------- Evaluation on gat-04 model took 0.03 minutes.
----- Evaluating random_explainer explainer on gin-04 model...
------- Evaluation on gin-04 model took 0.02 minutes.
------ Evaluation on tree_grid dataset took 0.11 minutes.
--- Evaluation on node classification took 0.18 minutes.


Unnamed: 0,explainer,dataset,model,std,accuracy,precision,recall,iou,fid+,fid-,unfaithfulness,characterization_score,inference_time
0,random_explainer,ba_shapes,gcn,4,0.528513,0.009524,0.009524,0.005291,0.428571,0.47619,0.3466,0.471429,0.001017
1,random_explainer,tree_grid,gcn,4,0.505284,0.004566,0.004566,0.002417,0.027397,0.09589,0.0,0.053183,0.000556
2,random_explainer,ba_shapes,graphsage,4,0.48985,0.0,0.0,0.0,0.142857,0.238095,0.299182,0.240602,0.000491
3,random_explainer,tree_grid,graphsage,4,0.487908,0.003044,0.003044,0.001612,0.0,0.09589,0.0,0.0,0.000289
4,random_explainer,ba_shapes,gat,4,0.512814,0.014286,0.014286,0.007937,0.380952,0.380952,0.248987,0.471655,0.000467
5,random_explainer,tree_grid,gat,4,0.508708,0.009132,0.009132,0.004835,0.082192,0.328767,0.0,0.146451,0.000272
6,random_explainer,ba_shapes,gin,4,0.498207,0.0,0.0,0.0,0.309524,0.595238,0.541563,0.350794,0.000545
7,random_explainer,tree_grid,gin,4,0.488495,0.00761,0.00761,0.004029,0.0,0.0,0.0,0.0,0.000579


## GNNExplainer

In [44]:
explainer_name = explainer_names[1]
explainer = get_explainer(explainer_name, explainer_config, model, model_config)
explanation = explainer(data.x, data.edge_index, target=data.y, index=test_nodes[test_node_idx].item())
explanation

Explanation(node_mask=[700, 1], edge_mask=[4110], target=[700], index=[1], x=[700, 2], edge_index=[2, 4110])

In [46]:
evaluate_nc_explainer_on_data(explainer, data, test_nodes, metric_names, start, end, step)

{'accuracy': 0.7070877254009247,
 'precision': 0.76,
 'recall': 0.76,
 'iou': 0.65000004,
 'fid+': 0.4,
 'fid-': 0.2,
 'unfaithfulness': 0.1726197600364685,
 'characterization_score': 0.5333333333333333,
 'inference_time': 3.3128846168518065}

In [None]:
eval_metrics = evaluate_nc_explainer(model_path, explainer_name, explainer_config, datasets, metric_names,
                                     std=std)
eval_metrics_df = evaluation_df(eval_metrics, nc_datasets, metric_names)
eval_metrics_df

## PGExplainer

In [47]:
explainer_name = explainer_names[2]
explainer = get_explainer(explainer_name, explainer_config, model, model_config, dataset=data)

PGExplainer took 2.12 minutes to train. Best loss: 1.4407


In [48]:
explanation = explainer(data.x, data.edge_index, target=data.y, index=test_nodes[test_node_idx].item())
explanation

Explanation(edge_mask=[4110], target=[700], index=[1], x=[700, 2], edge_index=[2, 4110])

In [49]:
evaluate_nc_explainer_on_data(explainer, data, test_nodes, metric_names, start, end, step)

{'accuracy': 0.40384192587364287,
 'precision': 0.13809524,
 'recall': 0.13809524,
 'iou': 0.11139456,
 'fid+': 0.5952380952380952,
 'fid-': 0.3333333333333333,
 'unfaithfulness': 0.6410738031956411,
 'characterization_score': 0.628930817610063,
 'inference_time': 0.00857302120753697}

In [50]:
explainer_names

['random_explainer', 'gnnexplainer', 'pgexplainer', 'subgraphx', 'ciexplainer']

In [None]:
eval_metrics = evaluate_nc_explainer(model_path, explainer_name, explainer_config, datasets, metric_names,
                                     std=std)
eval_metrics_df = evaluation_df(eval_metrics, nc_datasets, metric_names)
eval_metrics_df

## SubgraphX

In [52]:
explainer_name = explainer_names[3]
explainer = get_explainer(explainer_name, explainer_config, model, model_config)
explanation = explainer(data.x, data.edge_index, target=data.y, index=test_nodes[test_node_idx].item())
explanation

The nodes in graph is Data(x=[6, 2], edge_index=[2, 14])
At the 0 rollout, 6states that have been explored.
At the 1 rollout, 6states that have been explored.
At the 2 rollout, 6states that have been explored.
At the 3 rollout, 6states that have been explored.
At the 4 rollout, 6states that have been explored.
At the 5 rollout, 6states that have been explored.
At the 6 rollout, 6states that have been explored.
At the 7 rollout, 6states that have been explored.
At the 8 rollout, 6states that have been explored.
At the 9 rollout, 6states that have been explored.


Explanation(
  x=[700, 2],
  edge_index=[2, 4110],
  node_mask=[700, 1],
  edge_mask=[4110],
  results=[6],
  subselt=[6],
  related_pred={
    masked=0.9828965067863464,
    maskout=0.02969662845134735,
    origin=0.8863950371742249,
    sparsity=0.16666666666666663,
  },
  masked_node_list=[5],
  explained_edge_list=[2, 8],
  target=[700],
  index=[1]
)

In [53]:
evaluate_nc_explainer_on_data(explainer, data, test_nodes, metric_names, start, end, step)

The nodes in graph is Data(x=[6, 2], edge_index=[2, 14])
At the 0 rollout, 6states that have been explored.
At the 1 rollout, 6states that have been explored.
At the 2 rollout, 6states that have been explored.
At the 3 rollout, 6states that have been explored.
At the 4 rollout, 6states that have been explored.
At the 5 rollout, 6states that have been explored.
At the 6 rollout, 6states that have been explored.
At the 7 rollout, 6states that have been explored.
At the 8 rollout, 6states that have been explored.
At the 9 rollout, 6states that have been explored.


{'accuracy': 0.6666666865348816,
 'precision': 0.8,
 'recall': 0.8,
 'iou': 0.6666667,
 'fid+': 1.0,
 'fid-': 0.0,
 'unfaithfulness': 0.0445246696472168,
 'characterization_score': 1.0,
 'inference_time': 0.0999295711517334}

In [None]:
eval_metrics = evaluate_nc_explainer(model_path, explainer_name, explainer_config, datasets, metric_names,
                                     std=std)
eval_metrics_df = evaluation_df(eval_metrics, nc_datasets, metric_names)
eval_metrics_df

## CIExplainer

In [56]:
explainer_name = explainer_names[4]
explainer = get_explainer(explainer_name, explainer_config, model, model_config, dataset=data, dataset_name=dataset_name)
test_node_idx = 0
index = test_nodes[test_node_idx].item()
target = model(data.x, data.edge_index).softmax(dim=-1).max(dim=-1)[0]
explanation = explainer(data.x, data.edge_index, target=target, index=index)
explanation

Explanation(node_mask=[700, 1], edge_mask=[4110], target=[700], index=[1], x=[700, 2], edge_index=[2, 4110])

In [57]:
y_p = torch.zeros_like(data.y, device=data.y.device, dtype=torch.double)
for idx in test_nodes:
    y_p[idx] = torch.softmax(model(data.x, data.edge_index)[idx], dim=-1).max()
data.y_p = y_p
evaluate_nc_explainer_on_data(explainer, data, test_nodes, metric_names, start, end, step, use_prob=True, threshold=0.0)

  denom = (pos_weight / pos_fidelity) + (neg_weight / (1. - neg_fidelity))


{'accuracy': 1.0,
 'precision': 0.8428572,
 'recall': 0.8428572,
 'iou': 0.7764551,
 'fid+': 0.0,
 'fid-': 0.0,
 'unfaithfulness': 0.24025894346691312,
 'characterization_score': 0.0,
 'inference_time': 0.1358734198978969}

In [None]:
eval_metrics = evaluate_nc_explainer(model_path, explainer_name, explainer_config, datasets, metric_names,
                                     std=std)
eval_metrics_df = evaluation_df(eval_metrics, nc_datasets, metric_names)
eval_metrics_df