In [1]:
import torch
import torch_geometric.nn
from torch_geometric.explain import Explainer, GNNExplainer
from torch_geometric.datasets import Planetoid
from torch_geometric.loader import DataLoader

# Load a dataset (e.g., Cora)
dataset = Planetoid(root='/tmp/Cora', name='Cora')
data = dataset[0]

print(f"Dataset: {dataset}")
print(f"Number of graphs: {len(dataset)}")
print(f"Number of features: {dataset.num_features}")
print(f"Number of classes: {dataset.num_classes}")
print(f"Number of nodes: {data.num_nodes}")
print(f"Number of edges: {data.num_edges}")

# Create a GNN model
class GCN(torch.nn.Module):
    def __init__(self, in_channels, out_channels):
        super(GCN, self).__init__()
        self.conv1 = torch_geometric.nn.GCNConv(in_channels, 16)
        self.conv2 = torch_geometric.nn.GCNConv(16, out_channels)

    def forward(self, x, edge_index):
        x = torch.relu(self.conv1(x, edge_index))
        x = self.conv2(x, edge_index)
        return x

# Initialize the model and optimizer
model = GCN(dataset.num_node_features, dataset.num_classes)
optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

# Train the model (reduced epochs for faster execution)
model.train()
num_epochs = 50  # Reduced from 200
for epoch in range(num_epochs):
    optimizer.zero_grad()
    out = model(data.x, data.edge_index)
    loss = torch.nn.CrossEntropyLoss()(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    
    if epoch % 10 == 0:
        print(f'Epoch {epoch:03d}, Loss: {loss:.4f}')

print("Model trained successfully!")

# Evaluate the model quickly
model.eval()
with torch.no_grad():
    pred = model(data.x, data.edge_index).argmax(dim=-1)
    correct = (pred[data.test_mask] == data.y[data.test_mask]).float().mean()
    print(f'Test Accuracy: {correct:.4f}')

# Initialize the explainer using the new API
explainer = Explainer(
    model=model,
    algorithm=GNNExplainer(epochs=50),  # Reduced epochs for faster explanation
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',
    ),
)

# Explain a prediction
node_idx = 42  # Example node for explanation
print(f"\nExplaining prediction for node {node_idx}...")
explanation = explainer(data.x, data.edge_index, index=node_idx)

print(f"Explanation for node {node_idx}:")
print(f"Node importance scores shape: {explanation.node_mask.shape if explanation.node_mask is not None else 'None'}")
print(f"Edge importance scores shape: {explanation.edge_mask.shape if explanation.edge_mask is not None else 'None'}")

# Show some basic info about the explanation
if explanation.node_mask is not None:
    print(f"Top 5 most important features for node {node_idx}:")
    top_features = explanation.node_mask[node_idx].argsort(descending=True)[:5]
    for i, feat_idx in enumerate(top_features):
        print(f"  Feature {feat_idx}: {explanation.node_mask[node_idx][feat_idx]:.4f}")

print("\nExplanation completed successfully!")


Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.x
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.tx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.allx
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.y
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ty
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.ally
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.graph
Downloading https://github.com/kimiyoung/planetoid/raw/master/data/ind.cora.test.index
Processing...
Done!


Dataset: Cora()
Number of graphs: 1
Number of features: 1433
Number of classes: 7
Number of nodes: 2708
Number of edges: 10556
Epoch 000, Loss: 1.9538
Epoch 010, Loss: 0.6884
Epoch 020, Loss: 0.1508
Epoch 030, Loss: 0.0272
Epoch 040, Loss: 0.0087
Model trained successfully!
Test Accuracy: 0.7790

Explaining prediction for node 42...
Explanation for node 42:
Node importance scores shape: torch.Size([2708, 1433])
Edge importance scores shape: torch.Size([10556])
Top 5 most important features for node 42:
  Feature 421: 0.6658
  Feature 725: 0.6657
  Feature 702: 0.6574
  Feature 416: 0.6499
  Feature 653: 0.6472

Explanation completed successfully!


In [None]:
import matplotlib.pyplot as plt
import networkx as nx
import numpy as np
from torch_geometric.utils import to_networkx

# Create a visualization of the explanation
def visualize_explanation(data, explanation, node_idx, subset_size=30):
    """
    Visualize the graph explanation with highlighted important nodes and edges
    """
    # Convert to NetworkX graph for easier visualization
    G = to_networkx(data, to_undirected=True)
    
    # Get the subgraph around the explained node for better visualization
    # Find neighbors within 2 hops of the target node
    neighbors = set([node_idx])
    for _ in range(2):  # 2-hop neighborhood
        new_neighbors = set()
        for node in neighbors:
            if node in G:
                new_neighbors.update(G.neighbors(node))
        neighbors.update(new_neighbors)
    
    # Limit the size for better visualization
    neighbors = list(neighbors)[:subset_size]
    subgraph = G.subgraph(neighbors)
    
    # Create the plot
    plt.figure(figsize=(15, 10))
    
    # Get positions for nodes
    pos = nx.spring_layout(subgraph, k=1, iterations=50)
    
    # Prepare node colors based on importance
    node_colors = []
    node_sizes = []
    
    for node in subgraph.nodes():
        if explanation.node_mask is not None and node < len(explanation.node_mask):
            # Use the sum of feature importance for this node
            importance = explanation.node_mask[node].sum().item()
            # Normalize importance for color mapping
            importance = max(0, min(1, (importance + 1) / 2))  # Normalize to [0, 1]
        else:
            importance = 0.1
        
        if node == node_idx:
            node_colors.append('red')  # Target node in red
            node_sizes.append(800)
        else:
            # Color based on importance (blue to yellow gradient)
            node_colors.append(plt.cm.RdYlBu_r(importance))
            node_sizes.append(300 + importance * 300)
    
    # Prepare edge colors and widths based on importance
    edge_colors = []
    edge_widths = []
    
    if explanation.edge_mask is not None:
        # Create edge index mapping for the subgraph
        edge_mapping = {}
        for i, (u, v) in enumerate(data.edge_index.t().tolist()):
            if u in neighbors and v in neighbors:
                edge_mapping[(u, v)] = i
                edge_mapping[(v, u)] = i  # For undirected graphs
        
        for edge in subgraph.edges():
            u, v = edge
            if (u, v) in edge_mapping:
                edge_idx = edge_mapping[(u, v)]
                if edge_idx < len(explanation.edge_mask):
                    importance = explanation.edge_mask[edge_idx].item()
                    importance = max(0, min(1, (importance + 1) / 2))  # Normalize
                else:
                    importance = 0.1
            else:
                importance = 0.1
            
            edge_colors.append(plt.cm.Reds(importance))
            edge_widths.append(0.5 + importance * 3)
    else:
        edge_colors = ['gray'] * len(subgraph.edges())
        edge_widths = [1] * len(subgraph.edges())
    
    # Draw the graph
    nx.draw_networkx_edges(subgraph, pos, edge_color=edge_colors, width=edge_widths, alpha=0.7)
    nx.draw_networkx_nodes(subgraph, pos, node_color=node_colors, node_size=node_sizes, alpha=0.8)
    
    # Add labels for important nodes
    important_nodes = {node_idx: str(node_idx)}
    nx.draw_networkx_labels(subgraph, pos, important_nodes, font_size=12, font_weight='bold')
    
    plt.title(f'Graph Explanation for Node {node_idx}\n'
              f'Red node = target, Color intensity = importance', fontsize=14)
    plt.axis('off')
    
    # Add colorbar for node importance
    sm = plt.cm.ScalarMappable(cmap=plt.cm.RdYlBu_r, norm=plt.Normalize(vmin=0, vmax=1))
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=plt.gca(), shrink=0.6)
    cbar.set_label('Node Importance', fontsize=12)
    
    plt.tight_layout()
    plt.show()

# Visualize the explanation
print(f"Visualizing explanation for node {node_idx}...")
visualize_explanation(data, explanation, node_idx)

In [None]:
# Additional visualization: Feature importance for the explained node
def plot_feature_importance(explanation, node_idx, top_k=10):
    """
    Plot the top-k most important features for the explained node
    """
    if explanation.node_mask is None:
        print("No node feature importance available")
        return
    
    # Get feature importance for the target node
    node_importance = explanation.node_mask[node_idx]
    
    # Get top-k features
    top_indices = node_importance.argsort(descending=True)[:top_k]
    top_values = node_importance[top_indices]
    
    plt.figure(figsize=(12, 6))
    
    # Create bar plot
    bars = plt.bar(range(len(top_values)), top_values.detach().numpy(), 
                   color='skyblue', alpha=0.7, edgecolor='navy')
    
    # Highlight the most important feature
    if len(bars) > 0:
        bars[0].set_color('orange')
    
    plt.xlabel('Feature Index', fontsize=12)
    plt.ylabel('Importance Score', fontsize=12)
    plt.title(f'Top {top_k} Most Important Features for Node {node_idx}', fontsize=14)
    plt.xticks(range(len(top_values)), [f'F{idx.item()}' for idx in top_indices])
    plt.grid(axis='y', alpha=0.3)
    
    # Add value labels on bars
    for i, (bar, value) in enumerate(zip(bars, top_values)):
        plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001,
                f'{value:.3f}', ha='center', va='bottom', fontsize=10)
    
    plt.tight_layout()
    plt.show()

# Plot feature importance
plot_feature_importance(explanation, node_idx, top_k=min(10, data.x.shape[1]))

# Print detailed explanation statistics
print(f"\n=== Detailed Explanation Analysis ===")
print(f"Target node: {node_idx}")
print(f"True label: {data.y[node_idx].item()}")

# Get model prediction for this node
model.eval()
with torch.no_grad():
    logits = model(data.x, data.edge_index)
    pred_label = logits[node_idx].argmax().item()
    confidence = torch.softmax(logits[node_idx], dim=0).max().item()

print(f"Predicted label: {pred_label}")
print(f"Prediction confidence: {confidence:.4f}")
print(f"Prediction correct: {pred_label == data.y[node_idx].item()}")

if explanation.node_mask is not None:
    print(f"\nNode feature importance:")
    print(f"  Shape: {explanation.node_mask.shape}")
    print(f"  Mean importance: {explanation.node_mask[node_idx].mean():.4f}")
    print(f"  Max importance: {explanation.node_mask[node_idx].max():.4f}")
    print(f"  Min importance: {explanation.node_mask[node_idx].min():.4f}")

if explanation.edge_mask is not None:
    print(f"\nEdge importance:")
    print(f"  Shape: {explanation.edge_mask.shape}")
    print(f"  Mean importance: {explanation.edge_mask.mean():.4f}")
    print(f"  Max importance: {explanation.edge_mask.max():.4f}")
    print(f"  Min importance: {explanation.edge_mask.min():.4f}")

# Show the most important neighboring nodes
if explanation.node_mask is not None:
    print(f"\nMost important neighboring nodes (within 1-hop):")
    edge_index = data.edge_index
    neighbors = edge_index[1][edge_index[0] == node_idx].unique()
    
    if len(neighbors) > 0:
        neighbor_importance = []
        for neighbor in neighbors:
            if neighbor < len(explanation.node_mask):
                importance = explanation.node_mask[neighbor].sum().item()
                neighbor_importance.append((neighbor.item(), importance))
        
        # Sort by importance
        neighbor_importance.sort(key=lambda x: x[1], reverse=True)
        
        for i, (neighbor, importance) in enumerate(neighbor_importance[:5]):
            print(f"  Node {neighbor}: importance = {importance:.4f}, label = {data.y[neighbor].item()}")
    else:
        print("  No direct neighbors found")