In [None]:
"""
explain_benchmarks.py

In [None]:
This script computes sparsity and fidelity benchmarks for GNN models
trained on molecular datasets like TOX21 using PyTorch Geometric.
It supports GNNExplainer and Integrated Gradients (Captum).

In [None]:
"""

In [None]:
import torch
import torch.nn.functional as F
from torch_geometric.explain import GNNExplainer
from captum.attr import IntegratedGradients
from torch_geometric.data import DataLoader
import matplotlib.pyplot as plt
import pandas as pd
import os

In [None]:
# === Original load_model (commented out) ===
# def load_model(model_path):
#     """Load a trained PyTorch model"""
#     model = torch.load(model_path)
#     model.eval()
#     return model

# === Revised functions for TOX21_GAT ===
def load_model(model_path, device):
    from models.gat import GAT  # Adjust path if needed
    model = GAT(in_channels=75, hidden_channels=128, out_channels=12, num_heads=4, dropout=0.25)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()
    return model

In [None]:
# === Original load_data (commented out) ===
# def load_data(dataset_name, batch_size=32):
#     """Load dataset (placeholder)"""
#     from torch_geometric.datasets import Tox21
#     dataset = Tox21(root='./data', task=dataset_name)
#     return DataLoader(dataset, batch_size=batch_size, shuffle=False)

def load_data(split='test', batch_size=128):
    from dataset import Tox21Dataset
    from torch_geometric.loader import DataLoader
    dataset = Tox21Dataset(split)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return loader