In [None]:
def calculate_node_importance_by_removal(model, data, layer_name="Layer1", criterion=nn.MSELoss()):
    """
    Calculate real node importance by node removal (layer-wise calculation)
    """
    device = next(model.parameters()).device
    data = data.to(device)
    model.eval()
    
    with torch.no_grad():
        # Get original node features
        x_all = data.x
        num_nodes = data.num_nodes
        
        # Run CNN and fusion layers (fixed part)
        x_cnn = x_all[:, :-1].view(num_nodes, 1, 21, 21)
        x_mfrac = x_all[:, -1].unsqueeze(1)
        
        x_cnn = model.cnn(x_cnn)
        x_cnn = x_cnn.view(num_nodes, -1)
        x, gate_weights = model.fusion(x_cnn, x_mfrac)
        
        # Key fix: Layer-wise calculation
        if layer_name == "Layer1":
            # Layer 1: Use fused features
            layer_output = F.relu(model.conv1(x, data.edge_index))
            original_loss = criterion(layer_output.mean(dim=1, keepdim=True), data.y).item()
        else:
            # Layer 2: Use Layer 1 output as input
            layer1_output = F.relu(model.conv1(x, data.edge_index))
            layer_output = F.relu(model.conv2(layer1_output, data.edge_index))
            original_loss = criterion(layer_output.mean(dim=1, keepdim=True), data.y).item()
        
        node_importance = []
        removal_effects = []
        
        print(f"Analyzing {num_nodes} node importance for {layer_name}...")
        
        for node_idx in range(num_nodes):
            # Create node removal mask (set this node's features to zero)
            modified_x = x.clone()
            modified_x[node_idx] = 0  # Key operation: remove node features
            
            # Also remove all edges connected to this node
            edge_mask = (data.edge_index[0] != node_idx) & (data.edge_index[1] != node_idx)
            modified_edge_index = data.edge_index[:, edge_mask]
            
            if layer_name == "Layer1":
                # Layer 1: Use modified features and edges
                modified_output = F.relu(model.conv1(modified_x, modified_edge_index))
                modified_loss = criterion(modified_output.mean(dim=1, keepdim=True), data.y).item()
            else:
                # Layer 2: Use Layer 1 output
                layer1_output_modified = F.relu(model.conv1(modified_x, modified_edge_index))
                modified_output = F.relu(model.conv2(layer1_output_modified, modified_edge_index))
                modified_loss = criterion(modified_output.mean(dim=1, keepdim=True), data.y).item()
            
            # Calculate importance metrics
            importance = abs(modified_loss - original_loss)
            effect = modified_loss - original_loss  # Positive effect indicates important node
            
            node_importance.append(importance)
            removal_effects.append(effect)
            
            if (node_idx + 1) % 5 == 0 or (node_idx + 1) == num_nodes:
                print(f"Analyzed {node_idx + 1}/{num_nodes} nodes")
        
        # Normalize importance scores to [0,1] range
        importance_scores = np.array(node_importance)
        if importance_scores.max() > importance_scores.min():
            importance_scores = (importance_scores - importance_scores.min()) / (importance_scores.max() - importance_scores.min())
        else:
            importance_scores = np.ones_like(importance_scores) * 0.5
        
        return {
            'importance_scores': importance_scores,
            'removal_effects': np.array(removal_effects),
            'original_loss': original_loss,
            'node_features': data.x.cpu().numpy(),
            'edge_index': data.edge_index.cpu().numpy(),
            'layer_name': layer_name
        }

def visualize_node_importance(model, data, importance_results, title="Node Importance by Removal",
                            inner_label_size=10, outer_label_size=8, label_offset=0.15,
                            min_node_size=800, max_node_size=2800):
    """
    Visualize node importance (enhanced version)
    Parameters:
        inner_label_size: Size of circle inner number label
        outer_label_size: Size of circle outer importance label
        label_offset: Offset for outer label
        min_node_size: Minimum node size (default 800)
        max_node_size: Maximum node size (default 2800)
    """
    device = next(model.parameters()).device
    data = data.to(device)
    
    importance_scores = importance_results['importance_scores']
    layer_name = importance_results['layer_name']
    
    # Create network graph
    G = to_networkx(data, to_undirected=True)
    pos = nx.shell_layout(G)
    
    fig, ax = plt.subplots(figsize=(12, 10))
    
    # Calculate node sizes (linear interpolation based on importance scores)
    node_sizes = min_node_size + (max_node_size - min_node_size) * importance_scores
    
    # Draw nodes (color represents importance)
    nodes = nx.draw_networkx_nodes(
        G, pos,
        node_size=node_sizes,
        node_color=importance_scores,
        cmap='cividis',
        vmin=0, vmax=1,
        alpha=0.9,
        ax=ax
    )
    
    # Add colorbar
#     sm = plt.cm.ScalarMappable(cmap=plt.cm.cividis, norm=plt.Normalize(vmin=0, vmax=1))
#     sm.set_array([])
#     cbar = plt.colorbar(sm, ax=ax, label='Node Importance')
#     cbar.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
    
    
    sm = plt.cm.ScalarMappable(cmap=plt.cm.cividis, norm=plt.Normalize(vmin=0, vmax=1))
    sm.set_array([])
    cbar = plt.colorbar(sm, ax=ax, label='Node Importance')
    cbar.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])

    # Set tick label font size
    cbar.ax.tick_params(labelsize=35)  # Number label font size

    # Set colorbar title font size
    cbar.set_label('Node Importance', size=35, weight='normal')  # Title font size
    
    
    # Draw edges (gray translucent)
    nx.draw_networkx_edges(
        G, pos,
        width=1.5,
        edge_color='gray',
        alpha=0.5,
        ax=ax
    )
    
    # Add node labels (inner and outer layers)
    for node_idx, (x, y) in pos.items():
        # Inner label (node number)
        ax.text(x, y, str(node_idx), 
               fontsize=inner_label_size, 
               fontweight='bold',
               ha='center', va='center',
               color='black' if importance_scores[node_idx] > 0.5 else 'white')
        
        # Outer label (importance score)
        if importance_scores[node_idx] >= 0:
            # Calculate label position (offset along radius)
            radius = np.sqrt(node_sizes[node_idx]) / 100
            angle = np.arctan2(y, x) if x != 0 else np.pi/2
            offset_x = label_offset * np.cos(angle)
            offset_y = label_offset * np.sin(angle)
            
            ax.text(x + offset_x, y + offset_y, 
                   f"{importance_scores[node_idx]:.2f}",
                   fontsize=outer_label_size,
                   fontweight='normal',
                   ha='center', va='center',
                   )#bbox=dict(facecolor='white', edgecolor='none', alpha=0.7)
    
    ax.set_title(f"{layer_name} - {title}", fontsize=16, pad=20)
    ax.axis('off')
    plt.tight_layout()
    plt.show()

def comprehensive_node_importance_analysis(model, sample_data,
                                         inner_label_size=32,
                                         outer_label_size=28,
                                         label_offset=0.225,
                                         min_node_size=2000,
                                         max_node_size=5000):
    """
    Comprehensive node importance analysis (complete version)
    """
    print("="*60)
    print("Real node importance analysis based on node removal (layer-wise)")
    print("="*60)
    
    # Analyze Layer 1
    print("\n=== GCN Layer 1 Analysis ===")
    layer1_results = calculate_node_importance_by_removal(model, sample_data, "Layer1")
    visualize_node_importance(
        model, sample_data, layer1_results,
        inner_label_size=inner_label_size,
        outer_label_size=outer_label_size,
        label_offset=label_offset,
        min_node_size=min_node_size,
        max_node_size=max_node_size
    )
    
    # Analyze Layer 2
    print("\n=== GCN Layer 2 Analysis ===")
    layer2_results = calculate_node_importance_by_removal(model, sample_data, "Layer2")
    visualize_node_importance(
        model, sample_data, layer2_results,
        inner_label_size=inner_label_size,
        outer_label_size=outer_label_size,
        label_offset=label_offset,
        min_node_size=min_node_size,
        max_node_size=max_node_size
    )
    
    # Compare results between two layers
    print("\n=== Comparison of Node Importance Between Layers ===")
    correlation = np.corrcoef(layer1_results['importance_scores'], 
                             layer2_results['importance_scores'])[0,1]
    differences = np.abs(layer1_results['importance_scores'] - layer2_results['importance_scores'])
    
    print(f"Layer 1 average importance: {layer1_results['importance_scores'].mean():.4f}")
    print(f"Layer 2 average importance: {layer2_results['importance_scores'].mean():.4f}")
    print(f"Correlation between layers: {correlation:.4f}")
    print(f"Average difference: {differences.mean():.4f}")
    print(f"Maximum difference: {differences.max():.4f}")
    
    # Find the most important nodes
    top_k = 3
    for layer_name, results in [("Layer 1", layer1_results), ("Layer 2", layer2_results)]:
        scores = results['importance_scores']
        top_indices = np.argsort(scores)[-top_k:][::-1]
        
        print(f"\n{layer_name} most important {top_k} nodes:")
        for i, idx in enumerate(top_indices):
            importance = scores[idx]
            effect = results['removal_effects'][idx]
            print(f"  {i+1}. Node {idx}: Importance={importance:.4f}, Effect={effect:.6f}")
    
    return layer1_results, layer2_results

if __name__ == "__main__":
    # Set random seed for reproducibility
    torch.manual_seed(42)
    np.random.seed(42)
    
    # Select sample for analysis
    sample_idx = 1  # Can be modified to any sample index
    sample_data = train_dataset[sample_idx]
    
    # Print basic sample information
    print(f"\nAnalyzing sample {sample_idx} (total {len(train_dataset)} samples)")
    print(f"Number of nodes: {sample_data.num_nodes}")
    print(f"Number of edges: {sample_data.edge_index.shape[1]}")
    
    # Execute node importance analysis
    layer1_results, layer2_results = comprehensive_node_importance_analysis(
        model, sample_data,
        inner_label_size=32,
        outer_label_size=28,
        label_offset=0.2,
        min_node_size=1200,
        max_node_size=3500
    )
    
    # Show microstructure images and DamageStrain values at the end
    print("\n=== Displaying microstructure images and DamageStrain values ===")
    visualize_microstructures_with_damage(
        sample_data,
        cmap='gray_r',
        figsize=(20, 4),
        border_color='k',
        border_width=6
    )
    
    # Print analysis completion message
    print("\n=== Analysis Completed ===")
    print(f"Node importance analysis and microstructure visualization for sample {sample_idx} completed")