In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import matplotlib.colors as colors
from matplotlib.gridspec import GridSpec
import torch.nn.functional as F
import cv2  # Add OpenCV import
import os

def set_deterministic():
    """Set random seeds for reproducibility"""
    torch.manual_seed(42)
    torch.cuda.manual_seed_all(42)
    np.random.seed(42)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

def plot_gcn_importance(model, sample_data, n=5, cols=5, 
                       cbar_label_size=25, cbar_tick_size=25, title_fontsize=30):
    """
    Visualize GCN node importance
    Parameters:
        model: Trained model
        sample_data: Single graph data sample
        n: Number of nodes to visualize
        cols: Number of nodes per row
        cbar_label_size: Colorbar label font size
        cbar_tick_size: Colorbar tick font size
        title_fontsize: Title font size
    """
    
    # === Add debug checks ===
    print("\n=== Internal data check in visualization function ===")
    device = next(model.parameters()).device
    sample_data = sample_data.to(device)
    
    # Check true values of input data
    y_true_check = sample_data.y.cpu().numpy().flatten()
    print(f"Current sample true value range: [{np.min(y_true_check):.4f}, {np.max(y_true_check):.4f}]")
    
    # Check model prediction values
    with torch.no_grad():
        y_pred_check = model(sample_data).cpu().numpy().flatten()
    print(f"Current sample prediction range: [{np.min(y_pred_check):.4f}, {np.max(y_pred_check):.4f}]")
    print(f"Current sample negative predictions count: {np.sum(y_pred_check < 0)}")
    print("=" * 60)
    
    
    # Data preparation
    device = next(model.parameters()).device
    sample_data = sample_data.to(device)
    
    # Forward pass to get predictions
    with torch.no_grad():
        y_pred = model(sample_data).cpu().numpy().flatten()
    y_true = sample_data.y.cpu().numpy().flatten()
    abs_errors = np.abs(y_true - y_pred)
    
    # Calculate error percentiles
    p50, p75 = np.percentile(abs_errors, [50, 75])
    
    # Get microstructure data
    micro_features = sample_data.x[:, :-1].view(-1, 21, 21).cpu().numpy()
    
    # Create visualization
    rows = int(np.ceil(n / cols))
    fig = plt.figure(figsize=(cols*6, rows*6), dpi=100)  # Adjust overall size
    
    # Use GridSpec for more flexible layout
    gs = GridSpec(rows, cols + 1, figure=fig, width_ratios=[1]*cols + [0.1])
    
    # Store heatmap object for colorbar
    heatmap_im = None
    
    for i in range(min(n, len(micro_features))):  # Ensure not exceeding actual node count
        row = i // cols
        col = i % cols
        ax = fig.add_subplot(gs[row, col])
        
        # --- 1. Draw original image with border ---
        bordered_img = np.ones((23, 23))  # Add 1 pixel border
        bordered_img[1:-1, 1:-1] = micro_features[i]
        ax.imshow(bordered_img, cmap='binary')
        
        # --- 2. Draw black border ---
        border = patches.Rectangle((0.5, 0.5), 21, 21,
                                 linewidth=5.0,
                                 edgecolor='black',
                                 facecolor='none',
                                 zorder=10)
        ax.add_patch(border)
        
        # --- 3. Calculate Respond-CAM heatmap ---
        heatmap = compute_respond_cam(model, micro_features[i])
        if heatmap.max() > 1e-5:
        # --- 4. Heatmap overlay ---
            im = ax.imshow(heatmap,
                         extent=[0.5, 21.5, 21.5, 0.5],
                         cmap='jet',
                         alpha=0.8,
                         vmin=0,
                         vmax=1,
                         interpolation='bilinear',
                         zorder=5)
            heatmap_im = im  # Save last heatmap object for colorbar
        
        # --- 5. Title settings ---
        error = abs_errors[i]
        title_color = 'red' if error > p75 else 'orange' if error > p50 else 'green'
        ax.set_title(f"Node {i}\nTrue: {y_true[i]:.3f}\nPred: {y_pred[i]:.3f}", 
                    fontsize=33, color=title_color, pad=10)
        ax.axis('off')
    
    # --- 6. Colorbar settings (same height as images) ---
    if n > 0 and heatmap_im is not None:
        # Create colorbar area with same height as images
        cax = fig.add_subplot(gs[:, -1])  # Use all rows, last column
        
        # Create colorbar
        cbar = fig.colorbar(heatmap_im, cax=cax, orientation='vertical')
        
        # Set colorbar label and tick font sizes
        cbar.set_label('Importance Score', 
                      fontsize=cbar_label_size, 
                      rotation=270,
                      labelpad=30)  # Adjust label position
        
        # Set tick font size
        cbar.ax.tick_params(labelsize=cbar_tick_size)
        
        # Set tick values (optional)
        cbar.set_ticks([0, 0.2, 0.4, 0.6, 0.8, 1.0])
    
    plt.suptitle(f"Node-level Prediction", 
                y=1.02, fontsize=title_fontsize, weight='bold')
    plt.tight_layout()
    plt.show()

class RespondCAM:
    def __init__(self, model):
        """Initialize Respond-CAM for model interpretation"""
        self.model = model
        self.features = None
        self.gradients = None
        self.model.eval()
        self.register_hooks()
    
    def register_hooks(self):
        """Register forward and backward hooks"""
        def forward_hook(module, input, output):
            self.features = output.detach()
        
        def backward_hook(module, grad_input, grad_output):
            self.gradients = grad_output[0].detach()
        
        # Register hooks on the last convolutional layer of CNN
        if hasattr(self.model, 'cnn'):
            for layer in reversed(list(self.model.cnn.children())):
                if isinstance(layer, torch.nn.Conv2d):
                    layer.register_forward_hook(forward_hook)
                    layer.register_backward_hook(backward_hook)
                    break

def compute_respond_cam(model, micro_structure):
    """Calculate Respond-CAM heatmap for microstructure"""
    respond_cam = RespondCAM(model)
    input_tensor = torch.tensor(micro_structure, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
    input_tensor.requires_grad = True
    
    # Forward pass
    output = model.cnn(input_tensor)
    target = output.mean()
    
    # Backward pass
    model.zero_grad()
    target.backward()
    
    # Calculate heatmap
    if respond_cam.features is not None and respond_cam.gradients is not None:
        weights = torch.mean(respond_cam.gradients, dim=(2, 3), keepdim=True)
        cam = torch.sum(weights * respond_cam.features, dim=1)
        cam = F.relu(cam)
        cam = cam - cam.min()
        cam = cam / cam.max() if cam.max() > 0 else cam
        
        # Upsample to original size
        cam = cam.squeeze().detach().numpy()
        cam = cv2.resize(cam, (21, 21), interpolation=cv2.INTER_LINEAR)
        return cam
    return np.zeros((21, 21))

# Usage example
if __name__ == "__main__":
    set_deterministic() 
    
    # Check if data lengths match
    total_samples = len(train_dataset)  + len(test_dataset)
    assert len(file_info) == total_samples, "file_info length does not match dataset"
    
    # Select the 2nd sample from test set (index 1)
    sample_idx = 8
    sample_data = train_dataset[sample_idx]
    
    # Calculate corresponding position in file_info
#     file_info_idx = len(train_dataset)  + sample_idx  # Use this line for test set samples
    file_info_idx = sample_idx  # Use this line for training set samples
    micro_files = file_info[file_info_idx]['micro']
    micro_filenames = [os.path.basename(f) for f in micro_files]
    print(f"MicroInfo files: {micro_filenames}")
    
    # === Key modification: Ensure model architecture matches weights ===
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model = Optimized_CNN_GCN(
        filters=32,  # Must match weight file
        kernel_size=5,
        dense_units=128,
        dropout_rate=0.2,
        gcn_hidden_dim=32,
        learning_rate=0.0015
    ).to(device)

    # Load weights
    model.load_state_dict(torch.load('best_model.pth'))
    model.eval()
    print("Loaded trained best model")
    
    # Visualization
    plot_gcn_importance(
        model, 
        sample_data, 
        n=5, 
        cols=5,
        cbar_label_size=25,    
        cbar_tick_size=25,   
        title_fontsize=30
    )