# Phase 4: Node & Edge Attribution

**Question:** Which nodes, edges, and features does the GNN rely on? Do alignment edges consistently get higher importance?

Methods:
1. Gradient saliency
2. Integrated Gradients (50 steps)
3. GNNExplainer (3x averaged)
4. Attention weight analysis (EdgeAwareGAT)
5. Aggregate by edge type

In [None]:
import sys
sys.path.insert(0, '../src')

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import torch

from calamr_interp.utils.data_loading import load_and_split
from calamr_interp.utils.model_loading import load_model_checkpoint, find_checkpoints
from calamr_interp.utils.visualization import setup_style, COLORS, EDGE_TYPE_COLORS
from calamr_interp.phase4_attribution import (
    GradientSaliency,
    IntegratedGradientsAttribution,
    GNNExplainerWrapper,
    AttentionAnalyzer,
    EdgeTypeImportanceAggregator,
)

setup_style()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Device: {device}')

## 1. Load Model & Data

In [None]:
# Find available checkpoints
checkpoints = find_checkpoints()
print("Available checkpoints:")
for name, info in checkpoints.items():
    print(f"  {name}: {info['checkpoint_path']}")

# Load data
train_data, val_data, test_data = load_and_split()
test_list = list(test_data)
print(f"\nTest set: {len(test_list)} graphs")

In [None]:
# Load a trained model (update path as needed)
# model = load_model_checkpoint('path/to/best_model.pt', 'EdgeAwareGAT', device=device)
# For now, create a fresh model for demonstration
from calamr_interp.utils.model_loading import create_model
model = create_model('EdgeAwareGAT')
model = model.to(device)
model.eval()
print(f'Model: {type(model).__name__}')

## 2. Gradient Saliency

In [None]:
# Compute gradient saliency for test set
saliency = GradientSaliency(model, device)

# Run on a subset for speed
n_samples = min(20, len(test_list))
edge_importance_df = EdgeTypeImportanceAggregator.batch_aggregate(
    test_list[:n_samples],
    saliency.attribute,
    importance_key='edge_saliency',
)
print(f'Computed saliency for {n_samples} graphs')
edge_importance_df

In [None]:
# Boxplot: edge importance by type
melted = edge_importance_df.melt(
    id_vars=['label'],
    value_vars=['internal/role', 'alignment'],
    var_name='Edge Type',
    value_name='Importance'
)

fig, ax = plt.subplots(figsize=(8, 5))
sns.boxplot(data=melted, x='Edge Type', y='Importance', hue='Edge Type',
            palette={'internal/role': EDGE_TYPE_COLORS['internal'],
                     'alignment': EDGE_TYPE_COLORS['alignment']}, ax=ax)
ax.set_title('Edge Importance by Type (Gradient Saliency)')
plt.tight_layout()
plt.show()

## 3. Integrated Gradients

In [None]:
# Integrated Gradients (slower, more robust)
ig = IntegratedGradientsAttribution(model, n_steps=50, device=device)

# Run on a few examples
n_ig = min(5, len(test_list))
ig_edge_df = EdgeTypeImportanceAggregator.batch_aggregate(
    test_list[:n_ig],
    ig.attribute,
    importance_key='edge_ig',
)
print(f'IG computed for {n_ig} graphs')
ig_edge_df

In [None]:
# Node feature saliency by category
all_categories = []
for data in test_list[:n_samples]:
    result = saliency.attribute(data)
    cats = EdgeTypeImportanceAggregator.node_feature_saliency_by_category(result['node_saliency'])
    cats['label'] = data.y.item()
    all_categories.append(cats)

cat_df = pd.DataFrame(all_categories)

fig, ax = plt.subplots(figsize=(8, 5))
cat_melted = cat_df.melt(id_vars=['label'], var_name='Feature Category', value_name='Saliency')
sns.barplot(data=cat_melted, x='Feature Category', y='Saliency', ax=ax, color=COLORS['primary'])
ax.set_title('Node Feature Saliency by Category')
plt.tight_layout()
plt.show()

## 4. Attention Weight Analysis

In [None]:
# Attention analysis (EdgeAwareGAT only)
if hasattr(model, 'get_attention_weights'):
    attn_analyzer = AttentionAnalyzer(model, device)
    
    # Aggregate attention by edge type across test set
    attn_results = []
    for data in test_list[:n_samples]:
        try:
            by_type = attn_analyzer.attention_by_edge_type(data)
            attn_results.append(by_type)
        except Exception as e:
            print(f'Error: {e}')
    
    if attn_results:
        # Flatten and plot
        records = []
        for result in attn_results:
            for layer, types in result.items():
                for etype, value in types.items():
                    records.append({'layer': layer, 'edge_type': etype, 'attention': value})
        
        attn_df = pd.DataFrame(records)
        fig, ax = plt.subplots(figsize=(10, 5))
        sns.boxplot(data=attn_df, x='layer', y='attention', hue='edge_type', ax=ax)
        ax.set_title('Attention Distribution by Layer and Edge Type')
        plt.tight_layout()
        plt.show()
else:
    print('Model does not support attention weight extraction')

In [None]:
print("\n=== Summary ===")
print("If alignment edges consistently receive higher importance across methods,")
print("this validates the hypothesis that cross-component alignment is the key signal.")
print("\nIf SBERT features have high saliency, the model leverages semantic content.")
print("If node_type/component_type dominate, structural position matters more.")