### GNNExplainer: Model Interpretability

In this section, we use GNNExplainer to understand how our GNN model makes predictions by analyzing:
1. Which node features are most important for predictions
2. How these features reflect the global graph structure
3. The relationship between node properties and model decisions

This helps provide interpretability and insights into the model's decision-making process.

In [ ]:
import matplotlib.pyplot as plt
import seaborn as sns
from torch_geometric.utils import to_edge_index

# Convert SparseTensor to edge_index format
edge_index, edge_attr = to_edge_index(data.adj_t)
data.edge_index = edge_index
data.edge_attr = edge_attr
del edge_attr, edge_index

In [ ]:
from torch_geometric.explain import Explainer, GNNExplainer
import os
os.makedirs('figs_tabs', exist_ok=True)

# Setup the explainer
explainer = Explainer(
    model=best_model,
    algorithm=GNNExplainer(epochs=50),
    explanation_type='model',
    node_mask_type='attributes',
    edge_mask_type='object',
    model_config=dict(
        mode='multiclass_classification',
        task_level='node',
        return_type='log_probs',
    ),
)

# Select node to explain
node_index = 10
explanation = explainer(data.x, data.edge_index, index=node_index)

print(f'Generated explanations in {explanation.available_explanations}')

# Create enhanced visualization
plt.figure(figsize=(10, 6))
sns.set_style("whitegrid")

# Get top-10 feature importances
importances = explanation.node_mask[0]
top_k = 10
top_indices = importances.topk(top_k).indices
top_values = importances[top_indices]

# Create bar plot with improved styling
bars = plt.bar(range(top_k), top_values.cpu().numpy())
plt.xlabel('Feature Index', fontsize=12)
plt.ylabel('Feature Importance Score', fontsize=12)
plt.title('Top-10 Most Important Node Features for Prediction', fontsize=14, pad=20)

# Add value labels on bars
for bar in bars:
    height = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2., height,
             f'{height:.3f}',
             ha='center', va='bottom')

# Add grid for better readability
plt.grid(True, axis='y', linestyle='--', alpha=0.7)
plt.tight_layout()

# Save enhanced plot
path = 'figs_tabs/feature_importance.png'
plt.savefig(path, dpi=300, bbox_inches='tight')
print(f"Enhanced feature importance plot saved to '{path}'")

# Print interpretation of results
print("\nFeature Importance Interpretation:")
print("--------------------------------")
print("1. Top Features: The plot shows the most influential node features that shape the model's predictions.")
print("2. Global Context: Higher importance scores indicate features that capture significant network patterns.")
print("3. Feature Roles: Features with larger bars have stronger influence on the node's classification.")
print("\nPrediction Analysis for Node {}:".format(node_index))
print("--------------------------------")
pred_class = best_model(data.x, data.adj_t)[node_index].argmax().item()
true_class = data.y[node_index].item()
print(f"Predicted Class: {pred_class}")
print(f"True Class: {true_class}")
print(f"Top Feature Index: {top_indices[0].item()} (Importance: {top_values[0]:.3f})")