In [1]:
from torch_geometric.data import Data
from torch_geometric.explain import Explainer, GNNExplainer

data = Data(...)  # A homogeneous graph data object.

explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=200),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',  # Model returns log probabilities.
    ),
)

# Generate explanation for the node at index `10`:
explanation = explainer(data.x, data.edge_index, index=10)
print(explanation.edge_mask)
print(explanation.node_mask)

In [2]:
data

Data(x=Ellipsis)

In [None]:
# work with this version



import torch
import torch.nn.functional as F
from torch_geometric.explain import Explainer, GNNExplainer
from torch_geometric.explain.config import ModelConfig
import torch_geometric as pyg

# Assuming the LitGNN model and the data loader are already provided and correctly set up

# Your existing GNN model (assumed to be using DGL, but adapted for PyTorch Geometric Explainer)

# Define the function that sets up and runs GNNExplainer
def run_gnn_explainer(model, data, node_idx, explain_node=True):
    """
    Runs GNNExplainer to explain a node's feature contribution to the model's prediction.
    
    :param model: The GNN model (LitGNN in this case)
    :param data_loader: The DataLoader containing graphs (DGL graph objects)
    :param node_idx: The index of the node to explain
    :param explain_node: If True, explain node classification, otherwise edge classification
    """
    # Assuming you have a batch of graphs from your data_loader
    for batch in data.test_dataloader():
        graph = batch  # Your DGL graph or the data for this batch
        # Ensure you can get node features and labels
        node_features = graph.ndata['_CODEBERT']  # Example: node features stored under this key
        labels = graph.ndata['_VULN']  # Example: node labels (vulnerable or not)
        # Define the model configuration required for the explainer
        model_config = ModelConfig(
            mode='multiclass_classification',  # The model's task type: 'multiclass_classification', 'regression'
            task_level='node',  # The task type: 'node' or 'edge' classification
            return_type='log_probs'  # The type of output your model returns: 'log_probs', 'prob', or 'raw'
        )
        # Wrap your model into PyTorch Geometric Explainer
        explainer = Explainer(
            model=model,
            algorithm=GNNExplainer(),
            explanation_type='model',
            node_mask_type='object',  # Node feature mask
            edge_mask_type='object',  # Edge mask (optional, if you also want to explain edges)
            model_config=model_config  # Pass the model configuration here
        )
        # Convert DGL graph to PyTorch Geometric's format
        edge_index = torch.stack(graph.edges())  # Extract edge indices from DGL graph

        # Explain for a specific node in the graph
         # Apply the explainer to the node of interest
        explanation = explainer(
            x=node_features,  # The node features
            edge_index=edge_index,  # The edge index from the DGL graph
            index=0  # The index of the node to explain
        )
        
        # Now you can visualize the feature contributions using explanation object
        print(f"Explanation for node {node_idx}:\n", explanation)

        # If you want to visualize the explanation (optional)
        explainer.visualize_subgraph(node_idx, graph.edge_index, explanation.edge_mask)

        # Breaking after the first batch; for demonstration purposes
        break

# Call the function after training your model, passing in the model and data loader
node_to_explain = 0  # The node index to explain
run_gnn_explainer(model, data, node_to_explain)


In [None]:


# Grad-CAM
import torch
import torch.nn.functional as F
import dgl

class GNNGradCAM:
    def __init__(self, model):
        self.model = model
        self.gradients = None
        self.activations = None
        
        # Hook for getting gradients
        def save_grad(grad):
            self.gradients = grad
        
        # Hook for getting activations
        def save_activation(module, input, output):
            self.activations = output
            # Ensure that the activations require gradients
            self.activations.retain_grad()

        # Register hook to capture the activations and gradients at the target layer
        self.model.gat2.register_forward_hook(save_activation)
        self.model.gat2.register_full_backward_hook(lambda module, grad_in, grad_out: grad_out[0].register_hook(save_grad))

    def forward(self, g):
        """Forward pass through the model"""
        return self.model(g)

    def backward(self, class_idx, logits):
        """Compute gradients with respect to the target class"""
        self.model.zero_grad()  # Clear previous gradients
        one_hot_output = torch.zeros_like(logits)
        one_hot_output[:, class_idx] = 1  # Target the specific class for which you want the Grad-CAM
        logits.backward(gradient=one_hot_output, retain_graph=True)

    def generate_cam(self):
        """Generate the CAM by computing weighted sum of the gradients and activations"""
        weights = torch.mean(self.gradients, dim=0)  # Compute average weights
        cam = torch.zeros(self.activations.shape[0], self.activations.shape[1])
        
        for i, w in enumerate(weights):
            cam += w * self.activations[:, i]
        
        # Apply ReLU
        cam = F.relu(cam)
        
        # Normalize the CAM for visualization
        cam = cam - cam.min()
        cam = cam / cam.max()
        
        return cam

# Example usage
model = model  # Initialize the pre-configured model with your setup
gradcam = GNNGradCAM(model)

# Example batch (replace with actual batch from your data loader)
for batch in data.train_dataloader():
    g = batch  # DGL graph batch
    break

# Forward pass
logits, _ = gradcam.forward(g)

# Pick a target class to compute Grad-CAM for (e.g., class 1)
class_idx = 1
gradcam.backward(class_idx, logits)

# Generate CAM
cam = gradcam.generate_cam()

# Visualize or print CAM (e.g., feature importance for each node)
print(cam)
