In [1]:
import os
%load_ext autoreload
%autoreload 2
#os.environ['CUDA_LAUNCH_BLOCKING'] = '1'

# Import Libraries

In [2]:
import pandas as pd
import torch
import torch_geometric as pyg

from torch_geometric.explain import ExplainerConfig, ModelConfig
from torch_geometric.explain.config import ModelMode 

from model_store import get_gnn, model_names
from data_store import get_gc_dataset, gc_datasets
from explainer_store import get_explainer, explainer_names
from explain_gc import evaluate_gc_explainer, evaluate_gc_explainer_on_data
from utils import setup_models, get_motif_nodes

In [3]:
device = 'cpu'#torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [4]:
dataset_path = 'datasets/'
model_path = 'models/'
model_name = 'all'
stds = ['00', '01', '02', '03', '04', '05', '06', '07', '08', '09', '10']
std = stds[4]

In [5]:
gcn = pyg.nn.GCN(in_channels=-1, hidden_channels=1, num_layers=1)
grapsage = pyg.nn.GraphSAGE(in_channels=-1, hidden_channels=1, num_layers=1)
gat = pyg.nn.GAT(in_channels=-1, hidden_channels=1, num_layers=1)
gin = pyg.nn.GIN(in_channels=-1, hidden_channels=1, num_layers=1)

# Explain Graph Classification Predictions

In [6]:
metric_names = ['accuracy', 'precision', 'recall', 'iou', 'fid+', 'fid-', 'unfaithfulness', 'characterization_score',
                'inference_time']
explainer_config = ExplainerConfig(
    explanation_type='phenomenon',
    node_mask_type='object',
    edge_mask_type='object',
)

In [7]:
def evaluation_df(eval_data, dataset_names, metric_names):
    # Flatten the nested dictionary into a list of rows
    rows = []
    for (explainer, model), dataset_metrics in eval_data.items():
        for dataset in dataset_names:
            row = {
                'explainer': explainer,
                'dataset': dataset
            }

            # Split the model name into 'model' (before the hyphen) and 'std' (after the hyphen)
            model_parts = model.split('-')
            if len(model_parts) == 2:
                row['model'] = model_parts[0]  # e.g., 'gcn'
                row['std'] = model_parts[1]  # e.g., '00' (store as 'std')
            else:
                row['model'] = model  # If no hyphen, keep the entire model as 'model'
                row['std'] = None  # No model ID, set 'std' as None

            # Add metric values to the row
            for metric in metric_names:
                row[metric] = dataset_metrics.get((dataset, metric), None)

            rows.append(row)

    # Create DataFrame from the list of rows
    dataframe = pd.DataFrame(rows)

    # Reorder columns: model and std columns should come before accuracy
    column_order = ['explainer', 'dataset', 'model', 'std'] + metric_names
    dataframe = dataframe[column_order]

    return dataframe

# Load Graph Classification Datasets

In [8]:
datasets = get_gc_dataset(dataset_path, 'all', explain=True, std_str=std)
for i, (dataset_name, data_list, num_classes) in enumerate(datasets):
    for j, data in enumerate(data_list):
        if 'ba' in dataset_name:
            data.edge_mask = torch.logical_and(data.edge_index[0] >= 20, data.edge_index[1] >= 20)
        data_list[i] = data.to(device)
    datasets[i] = (dataset_name, data_list, num_classes)

Loading ba_2motif dataset with 04 standard deviation from datasets/ba_2motif/ba_2motif04.pth
Number of mutagen graphs with NO2 and NH2 1015
$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$$
mutag pickle file information
Adjacency matrix shape: (1015, 418, 418)
Features matrix shape: (1015, 418, 14)
Labels shape: (1015, 2)
Number of graphs: 1015
Number of classes: 2
Number of nodes: 418
Number of edges: 454560


# Load Graph Classification Models

In [9]:
task = 'gc'
dataset_models = []

### BA-2motif

In [10]:
ba2_models = get_gnn(model_path, task, model_name, 'ba_2motif', std=std)
ba2_models = setup_models(ba2_models, device)
dataset_models.append(ba2_models)

### MUTAG

In [11]:
mu_models = get_gnn(model_path, task, model_name, 'mutag')
mu_models = setup_models(mu_models, device)
dataset_models.append(mu_models)

In [34]:
model_config = ModelConfig(
    mode='binary_classification',
    task_level='graph',
    return_type='raw'
)

In [25]:
did = 1
models = dataset_models[did]
dataset_name, data_list, num_classes = datasets[did]
dataset_name, data_list, num_classes

('mutag',
 [Data(x=[28, 14], edge_index=[2, 58], y=[1], edge_mask=[58], num_classes=2, num_nodes=28),
  Data(x=[26, 14], edge_index=[2, 56], y=[1], edge_mask=[56], num_classes=2, num_nodes=26),
  Data(x=[33, 14], edge_index=[2, 70], y=[1], edge_mask=[70], num_classes=2, num_nodes=33),
  Data(x=[26, 14], edge_index=[2, 56], y=[1], edge_mask=[56], num_classes=2, num_nodes=26),
  Data(x=[20, 14], edge_index=[2, 42], y=[1], edge_mask=[42], num_classes=2, num_nodes=20),
  Data(x=[29, 14], edge_index=[2, 60], y=[1], edge_mask=[60], num_classes=2, num_nodes=29),
  Data(x=[35, 14], edge_index=[2, 76], y=[1], edge_mask=[76], num_classes=2, num_nodes=35),
  Data(x=[18, 14], edge_index=[2, 36], y=[1], edge_mask=[36], num_classes=2, num_nodes=18),
  Data(x=[27, 14], edge_index=[2, 56], y=[1], edge_mask=[56], num_classes=2, num_nodes=27),
  Data(x=[30, 14], edge_index=[2, 64], y=[1], edge_mask=[64], num_classes=2, num_nodes=30),
  Data(x=[17, 14], edge_index=[2, 34], y=[1], edge_mask=[34], num_clas

In [26]:
train_size = int(0.8 * len(data_list))
val_size = int(0.1 * len(data_list))
test_data_list = data_list[train_size + val_size:]
len(test_data_list)

102

In [27]:
model_name, model = models[0]
model_name, model

('gcn',
 GC_GNN(
   (model): GCN(-1, 20, num_layers=3)
   (criterion): BCEWithLogitsLoss()
   (lin): Linear(in_features=20, out_features=1, bias=True)
 ))

In [28]:
# evaluate model on test_data_list
acc = 0
for data in test_data_list:
    pred = (model(data.x, data.edge_index).sigmoid() > 0.5).long()
    acc += pred.eq(data.y).sum().item()
acc / len(test_data_list)

0.7450980392156863

# Explainers

In [56]:
test_graph_idx = 24
data = test_data_list[test_graph_idx]
data

Data(x=[70, 14], edge_index=[2, 146], y=[1], edge_mask=[146], num_classes=2, num_nodes=70)

## Random

In [57]:
explainer_name = explainer_names[0]
explainer = get_explainer(explainer_name, explainer_config, model, model_config)
explanation = explainer(data.x, data.edge_index, target=data.y)
explanation

Explanation(node_mask=[70, 1], edge_mask=[146], target=[1], x=[70, 14], edge_index=[2, 146])

In [48]:
evaluate_gc_explainer_on_data(explainer, test_data_list, metric_names)

{'accuracy': 0.477629878357345,
 'precision': 0.14106755,
 'recall': 0.14106755,
 'iou': 0.086317845,
 'fid+': 0.37254901960784315,
 'fid-': 0.2647058823529412,
 'unfaithfulness': 0.0,
 'characterization_score': 0.4945340968245706,
 'inference_time': 0.00023971118179022096}

In [None]:
eval_metrics = evaluate_gc_explainer(model_path, explainer_name, explainer_config, datasets, metric_names, std=std)
eval_metrics_df = evaluation_df(eval_metrics, gc_datasets, metric_names)
eval_metrics_df

## GNNExplainer

In [61]:
explainer_name = explainer_names[1]
explainer = get_explainer(explainer_name, explainer_config, model, model_config)
explanation = explainer(data.x, data.edge_index, target=data.y.unsqueeze(0))
explanation

Explanation(node_mask=[70, 1], edge_mask=[146], target=[1, 1], x=[70, 14], edge_index=[2, 146])

In [None]:
evaluate_gc_explainer_on_data(explainer, test_data_list, metric_names)

  denom = (pos_weight / pos_fidelity) + (neg_weight / (1. - neg_fidelity))


{'accuracy': 1.0,
 'precision': 1.0,
 'recall': 1.0,
 'iou': 1.0,
 'fid+': 0.0,
 'fid-': 0.0,
 'unfaithfulness': 0.0,
 'characterization_score': 0.0,
 'inference_time': 1.0242867469787598}

In [None]:
eval_metrics = evaluate_gc_explainer(model_path, explainer_name, explainer_config, datasets, metric_names, std=std)
eval_metrics_df = evaluation_df(eval_metrics, gc_datasets, metric_names)
eval_metrics_df

## PGExplainer

In [None]:
explainer_name = explainer_names[2]
explainer = get_explainer(explainer_name, explainer_config, model, model_config, dataset=test_data_list)

PGExplainer took 1.07 minutes to train. Best loss: 0.4581


In [64]:
explanation = explainer(data.x, data.edge_index, target=data.y)
explanation

Explanation(edge_mask=[146], target=[1], x=[70, 14], edge_index=[2, 146])

In [None]:
evaluate_gc_explainer_on_data(explainer, test_data_list, metric_names)

  denom = (pos_weight / pos_fidelity) + (neg_weight / (1. - neg_fidelity))


{'accuracy': 0.6535714268684387,
 'precision': 0.0,
 'recall': 0.0,
 'iou': 0.0,
 'fid+': 0.0,
 'fid-': 0.5,
 'unfaithfulness': 0.0,
 'characterization_score': 0.0,
 'inference_time': 0.003965973854064941}

In [None]:
eval_metrics = evaluate_gc_explainer(model_path, explainer_name, explainer_config, datasets, metric_names, std=std)
eval_metrics_df = evaluation_df(eval_metrics, gc_datasets, metric_names)
eval_metrics_df

## SubgraphX

In [66]:
explainer_name = explainer_names[3]
explainer = get_explainer(explainer_name, explainer_config, model, model_config)
explanation = explainer(data.x, data.edge_index, target=data.y)
explanation

The nodes in graph is Data(x=[70, 14], edge_index=[2, 146])
At the 0 rollout, 311states that have been explored.
At the 1 rollout, 311states that have been explored.
At the 2 rollout, 311states that have been explored.
At the 3 rollout, 311states that have been explored.
At the 4 rollout, 311states that have been explored.
At the 5 rollout, 311states that have been explored.
At the 6 rollout, 311states that have been explored.
At the 7 rollout, 311states that have been explored.
At the 8 rollout, 311states that have been explored.
At the 9 rollout, 311states that have been explored.


Explanation(
  x=[70, 14],
  edge_index=[2, 146],
  node_mask=[70, 1],
  edge_mask=[146],
  results=[311],
  related_pred={
    masked=1.0,
    maskout=1.0,
    origin=1.0,
    sparsity=0.9285714285714286,
  },
  masked_node_list=[5],
  explained_edge_list=[2, 8],
  target=[1]
)

In [None]:
evaluate_gc_explainer_on_data(explainer, test_data_list, metric_names)

The nodes in graph is Data(x=[20, 14], edge_index=[2, 42])
At the 0 rollout, 89states that have been explored.
At the 1 rollout, 89states that have been explored.
At the 2 rollout, 89states that have been explored.
At the 3 rollout, 89states that have been explored.
At the 4 rollout, 89states that have been explored.
At the 5 rollout, 89states that have been explored.
At the 6 rollout, 89states that have been explored.
At the 7 rollout, 89states that have been explored.
At the 8 rollout, 89states that have been explored.
At the 9 rollout, 89states that have been explored.
The nodes in graph is Data(x=[35, 14], edge_index=[2, 78])
At the 0 rollout, 235states that have been explored.
At the 1 rollout, 235states that have been explored.
At the 2 rollout, 235states that have been explored.
At the 3 rollout, 235states that have been explored.
At the 4 rollout, 235states that have been explored.
At the 5 rollout, 235states that have been explored.
At the 6 rollout, 235states that have been e

{'accuracy': 0.6857143044471741,
 'precision': 0.0,
 'recall': 0.0,
 'iou': 0.0,
 'fid+': 0.0,
 'fid-': 0.0,
 'unfaithfulness': 0.0,
 'characterization_score': 0.0,
 'inference_time': 1.3401110172271729}

In [None]:
eval_metrics = evaluate_gc_explainer(model_path, explainer_name, explainer_config, datasets, metric_names, std=std)
eval_metrics_df = evaluation_df(eval_metrics, gc_datasets, metric_names)
eval_metrics_df

## CIExplainer

In [71]:
explainer_name = explainer_names[4]
explainer = get_explainer(explainer_name, explainer_config, model, model_config, dataset=test_data_list)
target = torch.sigmoid(model(data.x, data.edge_index)).view(-1)
explanation = explainer(data.x, data.edge_index, target=target)
explanation

Explanation(node_mask=[70, 1], edge_mask=[146], target=[1], x=[70, 14], edge_index=[2, 146])

In [None]:
evaluate_gc_explainer_on_data(explainer, test_data_list, metric_names)

  denom = (pos_weight / pos_fidelity) + (neg_weight / (1. - neg_fidelity))


{'accuracy': 0.8852207899093628,
 'precision': 0.0,
 'recall': 0.0,
 'iou': 0.0,
 'fid+': 0.0,
 'fid-': 0.4,
 'unfaithfulness': 0.0,
 'characterization_score': 0.0,
 'inference_time': 0.002267122268676758}

In [None]:
eval_metrics = evaluate_gc_explainer(model_path, explainer_name, explainer_config, datasets, metric_names, std=std)
eval_metrics_df = evaluation_df(eval_metrics, gc_datasets, metric_names)
eval_metrics_df