In [None]:
# Calculate average prediction error for each graph (load one graph at a time)

def evaluate_per_graph(model, dataset, dataset_name="train"):
    """Calculate average true value, average prediction, and error for each graph"""
    model.eval()
    loader = DataLoader(dataset, batch_size=1, shuffle=False)
    graph_metrics = []

    with torch.no_grad():
        for idx, data in enumerate(loader):
            data = data.to(device)
            preds = model(data).cpu().numpy()
            trues = data.y.cpu().numpy()

            mean_pred = float(np.mean(preds))
            mean_true = float(np.mean(trues))
            mean_error = abs(mean_pred - mean_true)

            graph_metrics.append({
                'index': idx,
                'mean_pred': mean_pred,
                'mean_true': mean_true,
                'mean_error': mean_error
            })

    print(f"\n=== {dataset_name} set sample mean error statistics completed, total {len(graph_metrics)} graphs ===")
    return graph_metrics


# Use original dataset (not DataLoader)
train_graph_metrics = evaluate_per_graph(model, train_dataset, "Training")
test_graph_metrics = evaluate_per_graph(model, test_dataset, "Test")

# Sort by error
train_sorted = sorted(train_graph_metrics, key=lambda x: x['mean_error'])
test_sorted = sorted(test_graph_metrics, key=lambda x: x['mean_error'])

# Take top 10 best samples and top 10 worst samples
best_train = train_sorted[:10]
worst_train = train_sorted[-10:]
best_test = test_sorted[:10]
worst_test = test_sorted[-10:]

# =============== Print results ===============
def print_graph_stats(title, items):
    print(f"\n===== {title} =====")
    print(f"{'Graph Index':>10} {'True Mean':>12} {'Pred Mean':>12} {'Abs Error':>12}")
    for item in items:
        print(f"{item['index']:>10} {item['mean_true']:.6f} {item['mean_pred']:.6f} {item['mean_error']:.6f}")

# Training set results
print_graph_stats("Top 10 Best Predicted Graph Structures in Training Set (Smallest Error)", best_train)
print_graph_stats("Top 10 Worst Predicted Graph Structures in Training Set (Largest Error)", worst_train)

# Test set results
print_graph_stats("Top 10 Best Predicted Graph Structures in Test Set (Smallest Error)", best_test)
print_graph_stats("Top 10 Worst Predicted Graph Structures in Test Set (Largest Error)", worst_test)

In [None]:
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import networkx as nx
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
from sklearn.metrics import r2_score, mean_squared_error

def calculate_real_edge_importance_by_removal(model, data, layer_name="Layer1", criterion=nn.MSELoss()):
    """Calculate edge importance by removing edges and measuring loss change"""
    device = next(model.parameters()).device
    data = data.to(device)
    model.eval()
    
    with torch.no_grad():
        # Get original node features
        x_all = data.x
        num_nodes = data.num_nodes
        
        # Run CNN and fusion layers (fixed part)
        x_cnn = x_all[:, :-1].view(num_nodes, 1, 21, 21)
        x_mfrac = x_all[:, -1].unsqueeze(1)
        
        x_cnn = model.cnn(x_cnn)
        x_cnn = x_cnn.view(num_nodes, -1)
        x, gate_weights = model.fusion(x_cnn, x_mfrac)
        
        edge_index = data.edge_index
        num_edges = edge_index.shape[1]
        
        # Key fix: Layer-wise calculation
        if layer_name == "Layer1":
            # Layer 1: Use fused features
            layer_output = F.relu(model.conv1(x, edge_index))
            # Use node-level prediction loss
            original_loss = criterion(layer_output.mean(dim=1, keepdim=True), data.y).item()
        else:
            # Layer 2: Use Layer 1 output as input
            layer1_output = F.relu(model.conv1(x, edge_index))
            layer_output = F.relu(model.conv2(layer1_output, edge_index))
            original_loss = criterion(layer_output.mean(dim=1, keepdim=True), data.y).item()
        
        edge_importance_scores = []
        edge_removal_effects = []
        
        print(f"Analyzing {num_edges} edge importance for {layer_name}...")
        
        for edge_idx in range(num_edges):
            # Create graph structure with single edge removed
            mask = torch.ones(num_edges, dtype=torch.bool)
            mask[edge_idx] = False
            
            modified_edge_index = edge_index[:, mask]
            
            if layer_name == "Layer1":
                # Layer 1: Use same input x, but modified edge index
                modified_output = F.relu(model.conv1(x, modified_edge_index))
                modified_loss = criterion(modified_output.mean(dim=1, keepdim=True), data.y).item()
            else:
                # Layer 2: Use Layer 1 output, but modified edge index
                layer1_output_modified = F.relu(model.conv1(x, modified_edge_index))
                modified_output = F.relu(model.conv2(layer1_output_modified, modified_edge_index))
                modified_loss = criterion(modified_output.mean(dim=1, keepdim=True), data.y).item()
            
            # Calculate importance: larger loss change means more important edge
            importance = abs(modified_loss - original_loss)
            effect = modified_loss - original_loss
            
            edge_importance_scores.append(importance)
            edge_removal_effects.append(effect)
            
            if (edge_idx + 1) % 10 == 0 or (edge_idx + 1) == num_edges:
                print(f"Analyzed {edge_idx + 1}/{num_edges} edges")
        
        # Normalize importance scores to [0,1] range
        importance_scores = np.array(edge_importance_scores)
        if importance_scores.max() > importance_scores.min():
            importance_scores = (importance_scores - importance_scores.min()) / (importance_scores.max() - importance_scores.min())
        else:
            importance_scores = np.ones_like(importance_scores) * 0.5
        
        return {
            'importance_scores': importance_scores,
            'removal_effects': np.array(edge_removal_effects),
            'original_loss': original_loss,
            'edge_index': edge_index.cpu().numpy(),
            'layer_name': layer_name
        }

def visualize_real_gcn_attention(model, sample_data, importance_results, title_suffix="",
                               node_size=1600, label_font_size=14, edge_width_multiplier=4,
                               edge_label_size=10, show_edge_labels=True, 
                               edge_label_color='darkred', label_pos_offset=0.3):
    """
    Visualize real GCN edge importance (enhanced version)
    Parameters:
        edge_label_color: Edge label font color (default 'darkred')
        label_pos_offset: Label position offset (default 0.3), prevents label overlapping with edges
    """
    device = next(model.parameters()).device
    sample_data = sample_data.to(device)
    
    importance_scores = importance_results['importance_scores']
    layer_name = importance_results['layer_name']
    
    # Create network graph
    G = to_networkx(sample_data, to_undirected=True)
    pos = nx.shell_layout(G)
    
    fig, ax = plt.subplots(figsize=(15, 10))
    
    # Draw nodes (using real Mfraction values)
    node_features = sample_data.x.detach().cpu().numpy()
    node_colors = node_features[:, -1]  # Mfraction values
    
    # Fixed colormap range
    vmin, vmax = 0, 0.26
    nodes = nx.draw_networkx_nodes(G, pos, node_color=node_colors,
                                 cmap='viridis', vmin=vmin, vmax=vmax, 
                                 node_size=node_size, ax=ax)
    
    # Add Mfraction colorbar
    cbar_mfrac = plt.colorbar(nodes, ax=ax, label='Global Martensite Fraction Value')
    cbar_mfrac.set_ticks([0, 0.052, 0.104, 0.156, 0.208, 0.26])
    
    # Draw edges (using real edge importance)
    edge_widths = importance_scores * edge_width_multiplier + 1
    edge_colors = plt.cm.Blues(importance_scores)
    edges = nx.draw_networkx_edges(G, pos, width=edge_widths,
                                  edge_color=edge_colors, ax=ax)
    
    # Add edge labels (with anti-overlap handling)
    if show_edge_labels:
        edge_labels = {}
        label_pos = {}
        
        # Ensure edge indices correspond to importance scores
        edge_list = list(G.edges())
        for i, (u, v) in enumerate(edge_list):
            if i >= len(importance_scores):
                break  # Prevent index out of bounds
                
            edge_labels[(u, v)] = f"{importance_scores[i]:.2f}"
            
            # Calculate label position (offset along edge)
            if u in pos and v in pos:  # Ensure nodes are in position dictionary
                x1, y1 = pos[u]
                x2, y2 = pos[v]
                dx = x2 - x1
                dy = y2 - y1
                length = np.sqrt(dx**2 + dy**2)
                
                if length > 0:
                    dx /= length
                    dy /= length
                    
                # Apply offset
                offset_x = label_pos_offset * dy  # Vertical offset
                offset_y = -label_pos_offset * dx # Vertical offset
                
                label_pos[(u, v)] = ((x1 + x2)/2 + offset_x, 
                                    (y1 + y2)/2 + offset_y)
        
        # Draw edge labels
        if edge_labels and label_pos:
            nx.draw_networkx_edge_labels(
                G, 
                pos=pos,  # Use original layout as base
                edge_labels=edge_labels,
                label_pos=0.5,  # Default middle position
                font_color=edge_label_color,
                font_size=edge_label_size,
                bbox=dict(facecolor='white', edgecolor='none', alpha=0.7),
                ax=ax,
                rotate=False  # Don't rotate labels
            )
    
    # Add attention weight colorbar
    sm = plt.cm.ScalarMappable(cmap=plt.cm.Blues, norm=plt.Normalize(vmin=0, vmax=1))
    sm.set_array([])
    cbar_attn = plt.colorbar(sm, ax=ax, label='Edge Importance')

    # Set tick label font size
    cbar_attn.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
    cbar_attn.ax.tick_params(labelsize=35)  # Set tick label font size

    # Set colorbar title font size
    cbar_attn.set_label('Edge Importance', size=35, weight='normal')  # Set title font size and weight
    
    # Add node labels
    nx.draw_networkx_labels(G, pos, ax=ax, font_size=label_font_size, font_weight='bold')
    
    # Set title
    title = f"GCN {layer_name} Edge Importance" + title_suffix
    ax.set_title(title, fontsize=16, pad=20)
    ax.axis('off')
    plt.tight_layout()
    plt.show()
    
    return importance_scores

def analyze_layer_comparison(layer1_results, layer2_results):
    """
    Analyze differences between two GCN layers
    """
    layer1_scores = layer1_results['importance_scores']
    layer2_scores = layer2_results['importance_scores']
    
    correlation = np.corrcoef(layer1_scores, layer2_scores)[0,1]
    differences = np.abs(layer1_scores - layer2_scores)
    
    print("\n" + "="*60)
    print("GCN Layer Edge Importance Comparison Analysis")
    print("="*60)
    
    print(f"Layer 1 average importance: {layer1_scores.mean():.4f}")
    print(f"Layer 2 average importance: {layer2_scores.mean():.4f}")
    print(f"Correlation between layers: {correlation:.4f}")
    print(f"Average difference: {differences.mean():.4f}")
    print(f"Maximum difference: {differences.max():.4f}")
    
    return correlation

def comprehensive_gcn_analysis(model, sample_data, node_size=1600, label_font_size=14, 
                             edge_width_multiplier=4, edge_label_size=10, show_edge_labels=True,
                             edge_label_color='darkred', label_pos_offset=0.3):
    """
    Comprehensive GCN edge importance analysis
    Parameters:
        edge_label_color: Edge label font color (default 'darkred')
        label_pos_offset: Label position offset (default 0.3)
    """
    print("="*60)
    print("Starting real GCN edge importance analysis based on edge removal")
    print("="*60)
    
    # Analyze Layer 1
    print("\n=== GCN Layer 1 Analysis ===")
    layer1_results = calculate_real_edge_importance_by_removal(model, sample_data, "Layer1")
    layer1_scores = visualize_real_gcn_attention(
        model, sample_data, layer1_results, 
        " - Real Importance",
        node_size=node_size,
        label_font_size=label_font_size,
        edge_width_multiplier=edge_width_multiplier,
        edge_label_size=edge_label_size,
        show_edge_labels=show_edge_labels,
        edge_label_color=edge_label_color,
        label_pos_offset=label_pos_offset
    )
    
    # Analyze Layer 2
    print("\n=== GCN Layer 2 Analysis ===")
    layer2_results = calculate_real_edge_importance_by_removal(model, sample_data, "Layer2")
    layer2_scores = visualize_real_gcn_attention(
        model, sample_data, layer2_results, 
        " - Real Importance",
        node_size=node_size,
        label_font_size=label_font_size,
        edge_width_multiplier=edge_width_multiplier,
        edge_label_size=edge_label_size,
        show_edge_labels=show_edge_labels,
        edge_label_color=edge_label_color,
        label_pos_offset=label_pos_offset
    )
    
    # Comparative analysis
    correlation = analyze_layer_comparison(layer1_results, layer2_results)
    
    # Detailed statistics
    print("\n=== Detailed Statistics ===")
    print(f"Layer 1 - Original loss: {layer1_results['original_loss']:.6f}")
    print(f"Layer 2 - Original loss: {layer2_results['original_loss']:.6f}")
    
    # Find the most important edges
    edge_index = layer1_results['edge_index']
    top_k = 3
    
    for layer_name, results in [("Layer 1", layer1_results), ("Layer 2", layer2_results)]:
        scores = results['importance_scores']
        top_indices = np.argsort(scores)[-top_k:][::-1]
        
        print(f"\n{layer_name} most important {top_k} edges:")
        for i, idx in enumerate(top_indices):
            source, target = edge_index[0, idx], edge_index[1, idx]
            importance = scores[idx]
            effect = results['removal_effects'][idx]
            print(f"  {i+1}. Edge {source}-{target}: Importance={importance:.4f}, Effect={effect:.6f}")
    
    return layer1_results, layer2_results

def visualize_microstructures_with_damage(sample_data, cmap='gray', figsize=(15, 3), 
                                        border_color='k', border_width=3):
    """
    Visualize microstructures and print DamageStrain values
    """
    import matplotlib.patches as patches
    
    # Get microstructure data and DamageStrain values
    micro_features = sample_data.x[:, :-1].view(-1, 21, 21).detach().cpu().numpy()
    mfrac_values = sample_data.x[:, -1].detach().cpu().numpy()
    damage_strains = sample_data.y.detach().cpu().numpy().flatten()
    
    # Print DamageStrain values
    print("\n=== Node DamageStrain Values ===")
    for i, damage in enumerate(damage_strains):
        print(f"Node {i}: DamageStrain = {damage:.5f}")
    
    # Create figure
    fig, axes = plt.subplots(1, 5, figsize=figsize)
    
    for i, ax in enumerate(axes):
        # Draw image
        ax.imshow(micro_features[i], cmap=cmap)
        
        # Create rectangle border
        rect = patches.Rectangle((-0.5, -0.5), micro_features[i].shape[1], 
                                micro_features[i].shape[0], 
                                linewidth=border_width, edgecolor=border_color, 
                                facecolor='none', alpha=0.8)
        ax.add_patch(rect)
        
        # Add DamageStrain information in title
        ax.set_title(f'Node {i}\nGMF: {mfrac_values[i]:.3f}\nDamage: {damage_strains[i]:.3f}', 
                    fontsize=26)
        ax.axis('off')
    
    plt.suptitle("Node Microstructures with DamageStrain Values", fontsize=16, y=1.05)
    plt.tight_layout()
    plt.show()

# Modified main function
if __name__ == "__main__":
    sample_data = train_dataset[1]  # Use sample index 1
    
    # Perform comprehensive analysis
    layer1_results, layer2_results = comprehensive_gcn_analysis(
        model, sample_data,
        node_size=4000,
        label_font_size=32,
        edge_width_multiplier=5,
        edge_label_size=28,
        show_edge_labels=True,
        edge_label_color='k',
        label_pos_offset=2
    )
    
    # Verify result differences
    diff_ratio = np.mean(np.abs(layer1_results['importance_scores'] - layer2_results['importance_scores']))
    print(f"\nVerification result: Average difference ratio between layer importance = {diff_ratio:.4f}")
    
    if diff_ratio < 0.1:
        print("Warning: Small difference in layer importance, may need to check model structure")
    else:
        print("Normal: Significant difference in layer importance")
    
    # Show microstructure images with DamageStrain values at the end
    print("\n=== Displaying microstructure images with DamageStrain values ===")
    visualize_microstructures_with_damage(
        sample_data,
        cmap='gray_r',
        figsize=(20, 4),
        border_color='k',
        border_width=6
    )