In [None]:
"""
explain_benchmarks.py

In [None]:
This script computes sparsity and fidelity benchmarks for GNN models
trained on molecular datasets like TOX21 using PyTorch Geometric.
It supports GNNExplainer and Integrated Gradients (Captum).

In [None]:
"""

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.explain import GNNExplainer
from captum.attr import IntegratedGradients
from torch_geometric.data import DataLoader
import matplotlib.pyplot as plt
import pandas as pd
import os

In [None]:
# Define placeholder functions
def load_model(model_path):
    """Load a trained PyTorch model"""
    model = torch.load(model_path)
    model.eval()
    return model

In [None]:
def load_data(dataset_name, batch_size=32):
    """Load dataset (placeholder)"""
    # Replace with actual data loading logic
    from torch_geometric.datasets import Tox21
    dataset = Tox21(root='./data', task=dataset_name)
    return DataLoader(dataset, batch_size=batch_size, shuffle=False)

In [None]:
def compute_gnnexplainer_benchmark(model, loader, device, output_dir):
    explainer = GNNExplainer(model, epochs=200, return_type='logits')
    sparsity_results = []
    fidelity_results = []

In [None]:
for i, batch in enumerate(loader):
        batch = batch.to(device)
        for j in range(batch.num_graphs):
            graph = batch[j]
            pred_full = model(graph.x, graph.edge_index, batch.batch[j].unsqueeze(0))
            node_feat_mask, edge_mask = explainer.explain_graph(graph.x, graph.edge_index)

In [None]:
# Here we would apply masking and recompute predictions
            # Compute sparsity, fidelity metrics

In [None]:
sparsity_results.append({
                'graph_idx': i,
                'sparsity': 0.0,  # Replace with actual
                'auc_drop': 0.0   # Replace with actual
            })

In [None]:
fidelity_results.append({
                'graph_idx': i,
                'fidelity': 0.0   # Replace with actual
            })

In [None]:
pd.DataFrame(sparsity_results).to_csv(os.path.join(output_dir, 'gnnexplainer_sparsity.csv'), index=False)
    pd.DataFrame(fidelity_results).to_csv(os.path.join(output_dir, 'gnnexplainer_fidelity.csv'), index=False)

In [None]:
def compute_integrated_gradients(model, loader, device, output_dir):
    ig = IntegratedGradients(model)
    fidelity_results = []

In [None]:
for i, batch in enumerate(loader):
        batch = batch.to(device)
        input = batch.x.requires_grad_()
        target = batch.y

In [None]:
attributions, delta = ig.attribute(inputs=input,
                                           target=target,
                                           additional_forward_args=(batch.edge_index,),
                                           return_convergence_delta=True)

In [None]:
# Apply attribution mask, recompute predictions
        # Compute fidelity or infidelity metrics
        fidelity_results.append({
            'graph_idx': i,
            'fidelity': 0.0  # Replace with actual
        })

In [None]:
pd.DataFrame(fidelity_results).to_csv(os.path.join(output_dir, 'integrated_gradients_fidelity.csv'), index=False)

In [None]:
def main():
    import argparse
    parser = argparse.ArgumentParser()
    parser.add_argument('--model_path', type=str, required=True)
    parser.add_argument('--dataset', type=str, default='hiv')
    parser.add_argument('--outdir', type=str, default='./benchmarks')
    args = parser.parse_args()

In [None]:
os.makedirs(args.outdir, exist_ok=True)
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = load_model(args.model_path).to(device)
    loader = load_data(args.dataset)

In [None]:
compute_gnnexplainer_benchmark(model, loader, device, args.outdir)
    compute_integrated_gradients(model, loader, device, args.outdir)

In [None]:
if __name__ == '__main__':
    main()