In [1]:
import torch
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import networkx as nx
from matplotlib.cm import ScalarMappable
from matplotlib.colors import Normalize
import pandas as pd
import os
from tqdm.auto import tqdm
import matplotlib.ticker as ticker
from torch_geometric.utils import to_networkx
from collections import defaultdict
import matplotlib.patheffects as pe

class GCNNodeVisualization(torch.nn.Module):
    """
    A wrapper for the SarcasmGCNLSTMDetector that captures intermediate GCN layer outputs
    """
    def __init__(self, model):
        super(GCNNodeVisualization, self).__init__()
        self.model = model
        self.gcn_outputs = []
        self.hooks = []
        
        # Register forward hooks for each GCN layer
        self._register_hooks()
    
    def _register_hooks(self):
        def get_hook(layer_idx):
            def hook(module, input, output):
                self.gcn_outputs.append((layer_idx, output))
            return hook
        
        # Register hooks for all GCN layers
        # Assuming the model has gcn1, gcn2, gcn3, gcn4 as in your code
        self.hooks.append(self.model.gcn1.register_forward_hook(get_hook(0)))
        self.hooks.append(self.model.gcn2.register_forward_hook(get_hook(1)))
        self.hooks.append(self.model.gcn3.register_forward_hook(get_hook(2)))
        self.hooks.append(self.model.gcn4.register_forward_hook(get_hook(3)))
    
    def forward(self, input_ids, attention_mask, graph_x, graph_edge_index):
        # Clear previous outputs
        self.gcn_outputs = []
        
        # Forward pass through the model
        output = self.model(input_ids, attention_mask, graph_x, graph_edge_index)
        
        # Sort outputs by layer index
        self.gcn_outputs.sort(key=lambda x: x[0])
        
        return output
    
    def remove_hooks(self):
        """Remove all hooks to avoid memory leaks"""
        for hook in self.hooks:
            hook.remove()
        self.hooks = []


def should_exclude_token(token, format_tokens):
    """Check if a token should be excluded based on format tokens list"""
    token_lower = token.lower()
    return any(ft.lower() in token_lower for ft in format_tokens)


def clean_token_for_display(token):
    """Clean up token for display by removing special markers"""
    # Remove RoBERTa's special token marker
    return token.replace('Ġ', '')


def visualize_gcn_node_influence(model, tokenizer, comment, context='', model_path=None, 
                                device='cuda', save_path='gcn_node_visualization',
                                format_tokens=None, layout_type='spring'):
    """
    Visualize the influence of nodes in the GCN component of the sarcasm detection model
    with improved label rendering and formatting
    
    Args:
        model: The sarcasm detection model (SarcasmGCNLSTMDetector instance or None if model_path provided)
        tokenizer: RoBERTa tokenizer
        comment: The comment text to analyze
        context: The context for the comment (optional)
        model_path: Path to load model from (optional)
        device: Device to run model on ('cuda' or 'cpu')
        save_path: Base path to save visualization files
        format_tokens: List of tokens to exclude from analysis (default: ["comment", "context", ":", " "])
        layout_type: Graph layout algorithm to use ('spring', 'kamada', 'spectral', 'circular')
    
    Returns:
        Dictionary with GCN node influence data
    """
    # Create directory for outputs
    os.makedirs(os.path.dirname(save_path) if os.path.dirname(save_path) else '.', exist_ok=True)
    
    # Default format tokens to exclude if not provided
    if format_tokens is None:
        format_tokens = ["comment", "context", ":", " "]
    
    # Load model if path provided
    if model_path and model is None:
        model = SarcasmGCNLSTMDetector().to(device)
        model.load_state_dict(torch.load(model_path, map_location=device))
    
    # Wrap the model with our visualization wrapper
    wrapped_model = GCNNodeVisualization(model).to(device)
    wrapped_model.eval()
    
    # Format input
    if isinstance(context, list):
        context = " ".join([str(c) for c in context if c])
    
    if context.strip():
        combined_text = f"Context: {context} Comment: {comment}"
    else:
        combined_text = f"Comment: {comment}"
    
    # Create dataset for graph construction
    dummy_df = pd.DataFrame({'comment': [comment], 'context': [context], 'label': [0]})
    dataset = SarcasmGraphDataset(dummy_df, tokenizer)
    
    # Get the graph data
    item = dataset[0]
    graph_data = item['graph_data']
    tokens = item['tokens']
    
    # Create a mask for tokens to include/exclude
    token_mask = [not should_exclude_token(token, format_tokens) for token in tokens]
    
    # Prepare batch
    batch = collate_batch([item])
    input_ids = batch['input_ids'].to(device)
    attention_mask = batch['attention_mask'].to(device)
    graph_x = batch['graph_x'].to(device)
    graph_edge_index = batch['graph_edge_index'].to(device)
    
    # Get prediction and GCN outputs
    with torch.no_grad():
        try:
            logits = wrapped_model(
                input_ids=input_ids,
                attention_mask=attention_mask,
                graph_x=graph_x,
                graph_edge_index=graph_edge_index
            )
            
            prediction_prob = torch.sigmoid(logits).item()
            prediction = "Sarcastic" if prediction_prob > 0.5 else "Not Sarcastic"
            confidence = prediction_prob if prediction_prob > 0.5 else 1 - prediction_prob
        except Exception as e:
            print(f"Warning: Error during model prediction: {str(e)}")
            prediction = "Unknown"
            confidence = 0.0
    
    print(f"Model prediction: {prediction} (Confidence: {confidence:.4f})")
    
    # Get the GCN outputs
    gcn_outputs = wrapped_model.gcn_outputs
    
    # Remove hooks to prevent memory leaks
    wrapped_model.remove_hooks()
    
    # Get the original graph structure using networkx
    G = to_networkx(graph_data, to_undirected=True)
    
    # If no nodes, create a simple placeholder graph
    if G.number_of_nodes() == 0:
        print("Warning: Graph has no nodes. Creating a simple placeholder graph.")
        G = nx.Graph()
        G.add_node(0, word="[No Valid Graph]")
    
    # Add token text to graph nodes
    for i, token in enumerate(tokens):
        if i < len(G.nodes):
            G.nodes[i]['word'] = token
            G.nodes[i]['include'] = token_mask[i]
    
    # Calculate node importance at each layer
    node_importance_by_layer = []
    
    for layer_idx, layer_output in gcn_outputs:
        # Calculate node importance based on the norm of the feature vectors
        node_features = layer_output.cpu().numpy()
        node_importance = np.linalg.norm(node_features, axis=1)
        node_importance_by_layer.append(node_importance)
    
    # Check if we have any node importance values
    if not node_importance_by_layer:
        print("Warning: No GCN outputs captured. Visualization will be limited.")
        # Create a simple visualization to indicate the issue
        plt.figure(figsize=(10, 6))
        plt.text(0.5, 0.5, "No GCN outputs captured", horizontalalignment='center',
                verticalalignment='center', transform=plt.gca().transAxes, fontsize=14)
        plt.axis('off')
        plt.savefig(f"{save_path}_no_outputs.png", dpi=300, bbox_inches='tight')
        plt.close()
        
        # Return limited data
        return {
            'prediction': prediction,
            'confidence': confidence,
            'tokens': tokens,
            'graph': G,
            'error': 'No GCN outputs captured'
        }
    
    # 1. Visualize node importance evolution across layers
    plt.figure(figsize=(14, 8))
    
    # Get filtered tokens and their indices
    filtered_token_indices = [i for i, mask in enumerate(token_mask) if mask and i < len(tokens)]
    filtered_tokens = [tokens[i] for i in filtered_token_indices]
    
    # Number of nodes to display (limit for readability)
    num_nodes_to_display = min(10, len(filtered_token_indices))
    
    # Get indices of the most important nodes in the final layer (only considering filtered tokens)
    if filtered_token_indices and len(node_importance_by_layer[-1]) > 0:
        # Get importance values for filtered tokens
        filtered_importance = [node_importance_by_layer[-1][i] if i < len(node_importance_by_layer[-1]) else 0 
                             for i in filtered_token_indices]
        
        # Get top nodes among filtered tokens
        top_filtered_indices = np.argsort(filtered_importance)[-num_nodes_to_display:]
        top_nodes = [filtered_token_indices[i] for i in top_filtered_indices]
        
        # Track node importance across layers
        for node_idx in top_nodes:
            if node_idx < len(tokens):
                display_token = clean_token_for_display(tokens[node_idx])
                importance_values = [importance[node_idx] if node_idx < len(importance) else 0 
                                  for importance in node_importance_by_layer]
                plt.plot(range(1, len(gcn_outputs) + 1), importance_values, marker='o', 
                       label=f"{display_token} (Node {node_idx})")
    
    plt.title('Node Importance Evolution Across GCN Layers')
    plt.xlabel('GCN Layer')
    plt.ylabel('Node Importance (L2 Norm)')
    plt.xticks(range(1, len(gcn_outputs) + 1))
    plt.grid(True, linestyle='--', alpha=0.7)
    plt.legend(loc='best')
    
    plt.tight_layout()
    plt.savefig(f"{save_path}_evolution.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 2. Create graph visualizations for each layer
    for layer_idx, node_importance in enumerate(node_importance_by_layer):
        fig, ax = plt.subplots(figsize=(12, 12))
        
        # Create a copy of the graph
        G_layer = G.copy()
        
        # Add node importance to the graph
        for i, importance in enumerate(node_importance):
            if i < len(G_layer.nodes):
                G_layer.nodes[i]['importance'] = float(importance)
                
        # Create a filtered graph with only the included nodes
        G_filtered = nx.Graph()
        
        # Add only the included nodes to the filtered graph
        node_mapping = {}  # Maps original node IDs to new ones
        new_id = 0
        
        for node in G_layer.nodes:
            if G_layer.nodes[node].get('include', True):
                # Copy node attributes
                attrs = {k: v for k, v in G_layer.nodes[node].items()}
                G_filtered.add_node(new_id, **attrs)
                node_mapping[node] = new_id
                new_id += 1
        
        # Add edges between included nodes
        for u, v in G_layer.edges:
            if G_layer.nodes[u].get('include', True) and G_layer.nodes[v].get('include', True):
                G_filtered.add_edge(node_mapping[u], node_mapping[v])
        
        # If we have nodes in the filtered graph
        if G_filtered.number_of_nodes() > 0:
            # Get importance values for included nodes
            included_importance = [G_filtered.nodes[n].get('importance', 0) for n in G_filtered.nodes]
            max_importance = max(included_importance) if included_importance else 1.0
            
            # Set node sizes proportional to importance
            node_sizes = [1000 * (G_filtered.nodes[n].get('importance', 0) / max_importance if max_importance > 0 else 0) 
                         for n in G_filtered.nodes]
            node_colors = [G_filtered.nodes[n].get('importance', 0) for n in G_filtered.nodes]
            
            # Choose layout algorithm
            if layout_type == 'kamada':
                pos = nx.kamada_kawai_layout(G_filtered)
            elif layout_type == 'spectral':
                pos = nx.spectral_layout(G_filtered)
            elif layout_type == 'circular':
                pos = nx.circular_layout(G_filtered)
            else:  # default spring layout
                pos = nx.spring_layout(G_filtered, k=0.9, iterations=100, seed=42)
            
            # Draw graph
            nx.draw_networkx_edges(G_filtered, pos, alpha=0.3, width=1.0, ax=ax)
            
            nodes = nx.draw_networkx_nodes(G_filtered, pos, node_size=node_sizes, 
                                          node_color=node_colors, cmap=plt.cm.viridis, 
                                          alpha=0.8, ax=ax)
            
            # Add colorbar
            sm = ScalarMappable(cmap=plt.cm.viridis, norm=Normalize(vmin=0, vmax=max_importance))
            sm.set_array([])
            cbar = fig.colorbar(sm, ax=ax)
            cbar.set_label('Node Importance')
            
            # Create a mapping from original tokens to display tokens
            reverse_mapping = {v: k for k, v in node_mapping.items()}
            token_display = {}
            
            for new_id, orig_id in reverse_mapping.items():
                if orig_id < len(tokens):
                    token_display[new_id] = clean_token_for_display(tokens[orig_id])
            
            # Improved label placement
            label_positions = {}
            
            # First pass - get initial positions
            for node in G_filtered.nodes:
                x, y = pos[node]
                
                # Calculate node radius based on size
                node_size = node_sizes[list(G_filtered.nodes).index(node)]
                node_radius = np.sqrt(node_size / np.pi)
                
                # Set base offset based on node size
                base_offset = max(0.1, node_radius / 2000)
                
                # Calculate initial position based on quadrant
                if x >= 0 and y >= 0:  # Top right quadrant
                    offset_x, offset_y = base_offset, base_offset
                elif x < 0 and y >= 0:  # Top left quadrant
                    offset_x, offset_y = -base_offset, base_offset
                elif x >= 0 and y < 0:  # Bottom right quadrant
                    offset_x, offset_y = base_offset, -base_offset
                else:  # Bottom left quadrant
                    offset_x, offset_y = -base_offset, -base_offset
                
                # Initial label position
                label_positions[node] = (x + offset_x, y + offset_y)
            
            # Second pass - apply force-directed label positioning
            for _ in range(10):  # Number of iterations for force-directed adjustment
                for node in label_positions:
                    x, y = label_positions[node]
                    fx, fy = 0, 0  # Force components
                    
                    # Repulsive forces from other labels
                    for other_node, (other_x, other_y) in label_positions.items():
                        if other_node != node:
                            dx = x - other_x
                            dy = y - other_y
                            dist = max(0.01, np.sqrt(dx*dx + dy*dy))  # Avoid division by zero
                            
                            # Repulsive force inversely proportional to distance
                            if dist < 0.2:  # Only apply repulsion for close labels
                                force = 0.001 / (dist * dist)
                                fx += force * dx / dist
                                fy += force * dy / dist
                    
                    # Attractive force to original node position
                    node_x, node_y = pos[node]
                    dx = node_x - x
                    dy = node_y - y
                    dist = max(0.01, np.sqrt(dx*dx + dy*dy))
                    
                    # Attractive force proportional to distance but with a maximum
                    if dist > 0.3:  # Only pull back if too far
                        force = 0.1 * dist
                        fx += force * dx / dist
                        fy += force * dy / dist
                    
                    # Update position
                    label_positions[node] = (x + 0.1 * fx, y + 0.1 * fy)
            
            # Add labels at final positions
            for node in label_positions:
                if node in token_display:
                    display_token = token_display[node]
                    label_x, label_y = label_positions[node]
                    
                    # Add label with enhanced visibility
                    text = ax.text(label_x, label_y, display_token, 
                                  fontsize=11, ha='center', va='center', weight='bold',
                                  bbox=dict(boxstyle="round", fc="white", ec="black", alpha=0.9, pad=0.3))
                    text.set_path_effects([pe.withStroke(linewidth=2, foreground='white')])
        else:
            # No nodes in filtered graph
            ax.text(0.5, 0.5, "No nodes remain after filtering", 
                   horizontalalignment='center', verticalalignment='center', 
                   transform=ax.transAxes, fontsize=12)
        
        ax.set_title(f'Node Importance in GCN Layer {layer_idx + 1}', pad=20)
        ax.axis('off')
        
        plt.tight_layout()
        plt.savefig(f"{save_path}_layer{layer_idx + 1}.png", dpi=300, bbox_inches='tight')
        plt.close()
    
    # 3. Create a final visualization with all important connections
    fig, ax = plt.subplots(figsize=(15, 15))
    
    # Get the final layer importance
    final_importance = node_importance_by_layer[-1]
    
    # Create a copy of the graph
    G_final = G.copy()
    
    # Add node importance to the graph
    for i, importance in enumerate(final_importance):
        if i < len(G_final.nodes):
            G_final.nodes[i]['importance'] = float(importance)
    
    # Create a filtered graph with only the included nodes
    G_filtered = nx.Graph()
    
    # Add only the included nodes to the filtered graph
    node_mapping = {}  # Maps original node IDs to new ones
    new_id = 0
    
    for node in G_final.nodes:
        if G_final.nodes[node].get('include', True):
            # Copy node attributes
            attrs = {k: v for k, v in G_final.nodes[node].items()}
            G_filtered.add_node(new_id, **attrs)
            node_mapping[node] = new_id
            new_id += 1
    
    # Add edges between included nodes
    for u, v in G_final.edges:
        if G_final.nodes[u].get('include', True) and G_final.nodes[v].get('include', True):
            G_filtered.add_edge(node_mapping[u], node_mapping[v])
    
    # Choose layout algorithm for filtered graph
    if layout_type == 'kamada':
        pos = nx.kamada_kawai_layout(G_filtered)
    elif layout_type == 'spectral':
        pos = nx.spectral_layout(G_filtered)
    elif layout_type == 'circular':
        pos = nx.circular_layout(G_filtered)
    else:  # default spring layout
        pos = nx.spring_layout(G_filtered, k=1.2, iterations=200, seed=42)
    
    # If we have nodes in the filtered graph
    if G_filtered.number_of_nodes() > 0:
        # Get importance values for included nodes
        included_importance = [G_filtered.nodes[n].get('importance', 0) for n in G_filtered.nodes]
        max_importance = max(included_importance) if included_importance else 1.0
        
        # Calculate edge weights based on node importance
        edge_weights = []
        for u, v in G_filtered.edges:
            weight = (G_filtered.nodes[u].get('importance', 0) + G_filtered.nodes[v].get('importance', 0)) / 2.0
            G_filtered[u][v]['weight'] = weight
            edge_weights.append(weight)
        
        # Normalize edge weights for visualization
        if edge_weights:
            max_edge_weight = max(edge_weights)
            edge_widths = [3.0 * G_filtered[u][v].get('weight', 0) / max_edge_weight if max_edge_weight > 0 else 1.0 
                          for u, v in G_filtered.edges]
        else:
            edge_widths = [1.0]
        
        # Set node sizes based on importance
        node_sizes = [2000 * (G_filtered.nodes[n].get('importance', 0) / max_importance) if max_importance > 0 else 500 
                     for n in G_filtered.nodes]
        node_colors = [G_filtered.nodes[n].get('importance', 0) for n in G_filtered.nodes]
        
        # Draw graph - first edges
        nx.draw_networkx_edges(G_filtered, pos, alpha=0.5, width=edge_widths, ax=ax)
        
        # Then nodes
        nodes = nx.draw_networkx_nodes(G_filtered, pos, node_size=node_sizes, 
                                      node_color=node_colors, cmap=plt.cm.plasma, 
                                      alpha=0.9, ax=ax)
        
        # Add colorbar
        sm = ScalarMappable(cmap=plt.cm.plasma, norm=Normalize(vmin=0, vmax=max_importance))
        sm.set_array([])
        cbar = fig.colorbar(sm, ax=ax)
        cbar.set_label('Node Importance (Final Layer)')
        
        # Create a mapping from original tokens to display tokens
        reverse_mapping = {v: k for k, v in node_mapping.items()}
        token_display = {}
        
        for new_id, orig_id in reverse_mapping.items():
            if orig_id < len(tokens):
                token_display[new_id] = clean_token_for_display(tokens[orig_id])
        
        # Calculate node radii for node-aware label placement
        node_radii = {}
        for i, n in enumerate(G_filtered.nodes):
            node_radii[n] = np.sqrt(node_sizes[i] / np.pi) / 100
        
        # For each node, place label at an optimal position
        label_positions = {}
        
        for node in G_filtered.nodes:
            if node in token_display:
                x, y = pos[node]
                
                # Use node size to determine offset distance
                importance = G_filtered.nodes[node].get('importance', 0)
                relative_size = importance / max_importance if max_importance > 0 else 0.5
                
                # Find best angle for label
                best_angle = 0
                min_overlap = float('inf')
                
                for angle in np.linspace(0, 2*np.pi, 16, endpoint=False):
                    # Test position at this angle
                    offset = 0.15 + relative_size * 0.05  # Scale offset with node size
                    test_x = x + offset * np.cos(angle)
                    test_y = y + offset * np.sin(angle)
                    
                    # Check for overlaps with other nodes and labels
                    overlap_score = 0
                    
                    # Check distance to other nodes
                    for other_node in G_filtered.nodes:
                        if other_node != node:
                            other_x, other_y = pos[other_node]
                            dist = np.sqrt((test_x - other_x)**2 + (test_y - other_y)**2)
                            
                            # Add penalty for being close to other nodes
                            if dist < 0.2:
                                overlap_score += (0.2 - dist) * 5
                    
                    # Check distance to existing labels
                    for other_node, (other_x, other_y) in label_positions.items():
                        dist = np.sqrt((test_x - other_x)**2 + (test_y - other_y)**2)
                        
                        # Add penalty for being close to other labels
                        if dist < 0.2:
                            overlap_score += (0.2 - dist) * 10
                    
                    if overlap_score < min_overlap:
                        min_overlap = overlap_score
                        best_angle = angle
                
                # Place label at best angle with offset proportional to node size
                offset = 0.15 + relative_size * 0.05
                label_x = x + offset * np.cos(best_angle)
                label_y = y + offset * np.sin(best_angle)
                
                # Save position
                label_positions[node] = (label_x, label_y)
        
        # Add node labels and connecting lines
        for node, (label_x, label_y) in label_positions.items():
            if node in token_display:
                display_token = token_display[node]
                node_x, node_y = pos[node]
                
                # Draw thin connecting line
                ax.plot([node_x, label_x], [node_y, label_y], 'k-', alpha=0.3, linewidth=0.5)
                
                # Add text label
                text = ax.text(label_x, label_y, display_token, 
                             fontsize=12, ha='center', va='center', weight='bold',
                             bbox=dict(boxstyle="round", fc="white", ec="black", alpha=0.9, pad=0.4))
                text.set_path_effects([pe.withStroke(linewidth=2, foreground='white')])
    else:
        # No nodes in filtered graph
        ax.text(0.5, 0.5, "No nodes remain after filtering", 
               horizontalalignment='center', verticalalignment='center', 
               transform=ax.transAxes, fontsize=14)
    
    ax.set_title('Final GCN Layer Node Importance and Connections', fontsize=16, pad=20)
    ax.axis('off')
    
    plt.tight_layout()
    plt.savefig(f"{save_path}_final.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    # 4. Create a detailed analysis table - only for included tokens
    node_analysis = []
    
    for i, token in enumerate(tokens):
        if i < len(final_importance) and token_mask[i]:  # Only include filtered tokens
            analysis = {
                'node_idx': i,
                'token': token,
                'display_token': clean_token_for_display(token),
                'final_importance': float(final_importance[i]),
                'layer_evolution': [float(layer[i]) if i < len(layer) else 0.0 for layer in node_importance_by_layer],
                'growth_rate': float(final_importance[i] / node_importance_by_layer[0][i]) if i < len(node_importance_by_layer[0]) and node_importance_by_layer[0][i] > 0 else 0
            }
            node_analysis.append(analysis)
    
    # Sort by final importance
    node_analysis = sorted(node_analysis, key=lambda x: x['final_importance'], reverse=True)
    
    # Save analysis to CSV
    analysis_df = pd.DataFrame(node_analysis)
    analysis_df.to_csv(f"{save_path}_analysis.csv", index=False)
    
    # 5. Calculate aggregated metrics for semantic analysis - only include filtered tokens
    aggregated_importance = defaultdict(float)
    token_count = defaultdict(int)
    
    # Aggregate importance by token
    for analysis in node_analysis:
        display_token = analysis['display_token']
        importance = analysis['final_importance']
        aggregated_importance[display_token] += importance
        token_count[display_token] += 1
    
    # Calculate average importance per token
    avg_importance = {token: importance / token_count[token] for token, importance in aggregated_importance.items()}
    
    # Sort tokens by average importance
    sorted_tokens = sorted(avg_importance.items(), key=lambda x: x[1], reverse=True)
    
    # Create token importance visualization
    if sorted_tokens:
        plt.figure(figsize=(12, 8))
        
        # Number of tokens to show
        n_tokens = min(15, len(sorted_tokens))
        
        tokens_to_plot = [t[0] for t in sorted_tokens[:n_tokens]]
        importances_to_plot = [t[1] for t in sorted_tokens[:n_tokens]]
        
        plt.barh(range(n_tokens), importances_to_plot, color='coral')
        plt.yticks(range(n_tokens), tokens_to_plot)
        plt.title('Most Important Tokens in Sarcasm Detection (Final GCN Layer)')
        plt.xlabel('Average Node Importance')
        plt.grid(axis='x', linestyle='--', alpha=0.7)
        
        plt.tight_layout()
        plt.savefig(f"{save_path}_token_importance.png", dpi=300, bbox_inches='tight')
        plt.close()
    else:
        print("Warning: No token importance data available for visualization")
    
    # 6. Create a summary report
    try:
        # Generate a markdown report
        report = f"""
# GCN Node Influence Analysis for Sarcasm Detection

## Overview
- **Text**: "{comment}"
- **Context**: "{context if context else 'None'}"
- **Prediction**: {prediction} (Confidence: {confidence:.4f})
- **Number of Nodes**: {len(node_analysis)} (after filtering)
- **Tokens Analyzed**: {len(node_analysis)}

## Key Findings

### Most Important Nodes
The following tokens have the highest importance in the final GCN layer:
        
| Node | Token | Importance | Growth Rate |
|------|-------|------------|-------------|
"""
        
        for i, analysis in enumerate(node_analysis[:5]):  # Top 5 nodes
            report += f"| {analysis['node_idx']} | {analysis['display_token']} | {analysis['final_importance']:.4f} | {analysis['growth_rate']:.2f}x |\n"
        
        if node_analysis:
            report += """
### Node Importance Evolution
The following shows how node importance evolves across GCN layers:

| Node | Token | Layer 1 | Layer 2 | Layer 3 | Layer 4 | Growth Pattern |
|------|-------|---------|---------|---------|---------|----------------|
"""
            
            for i, analysis in enumerate(node_analysis[:5]):  # Top 5 nodes
                layer_values = analysis['layer_evolution']
                
                if len(layer_values) >= 3:
                    pattern = "Increasing" if layer_values[-1] > layer_values[0] else "Decreasing"
                    if layer_values[1] > layer_values[0] and layer_values[2] < layer_values[1]:
                        pattern = "Peak at Layer 2"
                    elif layer_values[1] < layer_values[0] and layer_values[2] > layer_values[1]:
                        pattern = "Dip at Layer 2"
                    
                    # Access layer values safely
                    layer1 = layer_values[0] if 0 < len(layer_values) else 0.0
                    layer2 = layer_values[1] if 1 < len(layer_values) else 0.0
                    layer3 = layer_values[2] if 2 < len(layer_values) else 0.0
                    layer4 = layer_values[3] if 3 < len(layer_values) else 0.0
                    
                    report += f"| {analysis['node_idx']} | {analysis['display_token']} | {layer1:.4f} | {layer2:.4f} | {layer3:.4f} | {layer4:.4f} | {pattern} |\n"
        else:
            report += "\nNo significant node importance data available for analysis."
        
        # Save report
        with open(f"{save_path}_report.md", 'w') as f:
            f.write(report)
        
        print(f"Full GCN node analysis report saved to {save_path}_report.md")
    
    except Exception as e:
        print(f"Error generating report: {str(e)}")
    
    # Return analysis data
    return {
        'prediction': prediction,
        'confidence': confidence,
        'tokens': [clean_token_for_display(tokens[i]) for i in range(len(tokens)) if token_mask[i]],  # Only return filtered tokens
        'node_analysis': node_analysis,
        'node_importance_by_layer': node_importance_by_layer,
        'graph': G_filtered  # Return the filtered graph
    }

In [2]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from transformers import (
    RobertaTokenizer,
    RobertaModel,
    # AdamW,
    get_linear_schedule_with_warmup,
)
from torch.optim import AdamW
from sklearn.metrics import (
    classification_report,
    accuracy_score,
    f1_score,
    confusion_matrix,
)
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import ast
import os
import gc
import networkx as nx
from torch_geometric.nn import GCNConv
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
import numpy as np
import re
import nltk
from nltk.tokenize import word_tokenize
from sklearn.utils import resample
import spacy
from senticnet.senticnet import SenticNet
import gensim.downloader as gensim_downloader

print("Loading GloVe embeddings...")
try:
    glove_embeddings = gensim_downloader.load("glove-wiki-gigaword-300")
    EMBEDDING_DIM = 300
    print(f"Loaded GloVe embeddings with dimension: {EMBEDDING_DIM}")
except Exception as e:
    print(f"Error loading GloVe embeddings: {str(e)}")
    print("Using random embeddings instead")
    glove_embeddings = None
    EMBEDDING_DIM = 300

# Load spaCy model
try:
    nlp = spacy.load("en_core_web_sm")
    print("Loaded spaCy model successfully")
except:
    print("Downloading spaCy model...")
    import subprocess

    subprocess.run(["python", "-m", "spacy", "download", "en_core_web_sm"])
    nlp = spacy.load("en_core_web_sm")

# Initialize SenticNet
try:
    sn = SenticNet()
    print("Loaded SenticNet successfully")
except Exception as e:
    print(f"Error loading SenticNet: {str(e)}")
    sn = None

class SarcasmGraphDataset(Dataset):
    def __init__(self, df, tokenizer, max_length=128):
        self.tokenizer = tokenizer
        self.comments = df["comment"].values
        self.contexts = df["context"].values
        self.labels = df["label"].values
        self.max_length = max_length
        self.window_size = 2  # Window size for graph construction

    def __len__(self):
        return len(self.comments)

    def get_embedding(self, word):
        """Get the GloVe embedding for a word"""
        word = word.lower()
        if glove_embeddings and word in glove_embeddings:
            return torch.tensor(glove_embeddings[word], dtype=torch.float)
        else:
            # Use random embedding if word not found
            return torch.randn(EMBEDDING_DIM, dtype=torch.float)

    def get_sentiment_features(self, word):
        """Extract sentiment features using SenticNet"""
        try:
            if sn is not None:
                concept_info = sn.concept(word)
                # Extract polarity value (float between -1 and 1)
                polarity = float(concept_info["polarity_value"])
                # Create a 5-dimensional feature: [polarity, is_positive, is_negative, is_neutral, intensity]
                is_positive = 1.0 if polarity > 0.1 else 0.0
                is_negative = 1.0 if polarity < -0.1 else 0.0
                is_neutral = 1.0 if abs(polarity) <= 0.1 else 0.0
                intensity = abs(polarity)
                return torch.tensor(
                    [polarity, is_positive, is_negative, is_neutral, intensity],
                    dtype=torch.float,
                )
            else:
                return torch.zeros(5, dtype=torch.float)
        except:
            # Word not found in SenticNet
            return torch.zeros(5, dtype=torch.float)

    def create_graph_from_text(self, text):
        """Create a graph representation of text for GCN with enhanced features"""
        # Parse text with spaCy for dependency parsing
        doc = nlp(text.lower())

        # Create a graph where nodes are tokens
        G = nx.Graph()

        # Store tokens for later embedding lookup
        tokens = [token.text for token in doc]

        # Add nodes with positions
        for i, token in enumerate(doc):
            G.add_node(i, word=token.text, pos=token.pos_)

        # Add edges based on window and dependencies
        # 1. Window-based edges
        for i in range(len(tokens)):
            for j in range(i + 1, min(i + self.window_size + 1, len(tokens))):
                G.add_edge(i, j, edge_type=0)  # Type 0: window edge

        # 2. Dependency-based edges
        for token in doc:
            if token.i < len(tokens) and token.head.i < len(tokens):
                G.add_edge(
                    token.i, token.head.i, edge_type=1
                )  # Type 1: dependency edge

        # Convert to PyTorch Geometric Data object
        if len(G.nodes) > 0:
            data = from_networkx(G)

            # Create feature matrix for nodes [GloVe (25d) + Sentiment (5d)]
            feature_dim = EMBEDDING_DIM + 5
            features = torch.zeros((len(G.nodes), feature_dim), dtype=torch.float)

            for i, token_text in enumerate(tokens):
                if i < len(features):
                    # GloVe embedding
                    glove_feature = self.get_embedding(token_text)
                    # Sentiment features
                    sentiment_feature = self.get_sentiment_features(token_text)
                    # Concatenate features
                    if (
                        len(glove_feature) == EMBEDDING_DIM
                        and len(sentiment_feature) == 5
                    ):
                        features[i] = torch.cat([glove_feature, sentiment_feature])

            data.x = features
            return data, tokens
        else:
            # Return empty graph if there are no nodes
            empty_data = Data(
                x=torch.zeros((1, feature_dim), dtype=torch.float),
                edge_index=torch.zeros((2, 0), dtype=torch.long),
            )
            return empty_data, []

    def __getitem__(self, idx):
        comment = str(self.comments[idx])

        # Parse context if it's a string
        if isinstance(self.contexts[idx], str):
            try:
                context_list = ast.literal_eval(self.contexts[idx])
            except:
                context_list = [self.contexts[idx]]
        else:
            context_list = self.contexts[idx]

        # Join all context elements
        context = " ".join([str(c) for c in context_list])

        # Combine context and comment
        combined_text = f"Context: {context} Comment: {comment}"

        # Create graph data with enhanced features
        graph_data, tokens = self.create_graph_from_text(combined_text)

        # Encode with truncation and padding for transformer
        encoding = self.tokenizer(
            combined_text,
            truncation=True,
            padding="max_length",
            max_length=self.max_length,
            return_tensors="pt",
        )

        return {
            "input_ids": encoding["input_ids"].flatten(),
            "attention_mask": encoding["attention_mask"].flatten(),
            "graph_data": graph_data,
            "tokens": tokens,
            "label": torch.tensor(self.labels[idx], dtype=torch.float),
        }

class GCNLayer(nn.Module):
    def __init__(self, in_features, out_features):
        super(GCNLayer, self).__init__()
        self.gc = GCNConv(in_features, out_features)
        self.bn = nn.BatchNorm1d(out_features)
        self.dropout = nn.Dropout(0.2)

    def forward(self, x, edge_index):
        x = self.gc(x, edge_index)
        if x.size(0) > 1:  # BatchNorm needs more than 1 element
            x = self.bn(x)
        x = F.relu(x)
        return self.dropout(x)


class LSTM(nn.Module):
    def __init__(self, input_dim, hidden_dim, num_layers=1, bidirectional=True):
        super(LSTM, self).__init__()
        self.hidden_dim = hidden_dim
        self.num_layers = num_layers
        self.bidirectional = bidirectional
        self.lstm = nn.LSTM(
            input_dim,
            hidden_dim,
            num_layers,
            batch_first=True,
            bidirectional=bidirectional,
        )

    def forward(self, x):
        # x shape: (batch, seq_len, input_dim)
        lstm_out, _ = self.lstm(x)
        # Get the output from the last non-padded element
        last_output = lstm_out[:, -1, :]
        return last_output


class SarcasmGCNLSTMDetector(nn.Module):
    def __init__(
        self, pretrained_model="roberta-base", gcn_hidden_dim=64, dropout_rate=0.3
    ):
        super(SarcasmGCNLSTMDetector, self).__init__()
        self.roberta = RobertaModel.from_pretrained(pretrained_model)
        self.hidden_dim = self.roberta.config.hidden_size

        # Feature dimensions
        feature_dim = EMBEDDING_DIM + 5  # GloVe + Sentiment

        # 4-layer GCN as per the paper
        self.gcn1 = GCNLayer(feature_dim, gcn_hidden_dim)
        self.gcn2 = GCNLayer(gcn_hidden_dim, gcn_hidden_dim * 2)
        self.gcn3 = GCNLayer(gcn_hidden_dim * 2, gcn_hidden_dim * 2)
        self.gcn4 = GCNLayer(gcn_hidden_dim * 2, gcn_hidden_dim)

        # LSTM for sequential processing
        self.lstm = LSTM(gcn_hidden_dim, gcn_hidden_dim // 2, bidirectional=True)

        # Attention mechanism for combining RoBERTa and GCN-LSTM outputs
        self.attention = nn.Linear(self.hidden_dim + gcn_hidden_dim, 1)

        # Final classification layers
        self.dropout = nn.Dropout(dropout_rate)
        self.fc1 = nn.Linear(self.hidden_dim + gcn_hidden_dim, 256)
        self.fc2 = nn.Linear(256, 1)

    def forward(self, input_ids, attention_mask, graph_x, graph_edge_index):
        # Process text with RoBERTa
        outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask)
        roberta_embedding = outputs.pooler_output  # [CLS] token embedding

        # Process graph with multi-layer GCN
        x1 = self.gcn1(graph_x, graph_edge_index)
        x2 = self.gcn2(x1, graph_edge_index)
        x3 = self.gcn3(x2, graph_edge_index)
        x4 = self.gcn4(x3, graph_edge_index)

        # Prepare for LSTM - reshape if there's a batch
        batch_size = roberta_embedding.shape[0]
        if batch_size > 1:
            # For simplicity, we'll just take the mean of the node embeddings for batched graphs
            gcn_embedding = torch.mean(x4, dim=0).unsqueeze(0)
            gcn_embedding = gcn_embedding.expand(batch_size, -1)
        else:
            # Use LSTM for sequential processing (for single example)
            # Reshape for LSTM: [num_nodes, features] -> [1, num_nodes, features]
            lstm_input = x4.unsqueeze(0)
            gcn_embedding = self.lstm(lstm_input)

        # Concatenate RoBERTa and GCN-LSTM embeddings
        combined = torch.cat((roberta_embedding, gcn_embedding), dim=1)

        # Apply attention
        attention_weights = torch.sigmoid(self.attention(combined))
        weighted_embedding = combined * attention_weights

        # Final classification
        x = self.dropout(weighted_embedding)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)

        # Output logits (not sigmoid)
        return self.fc2(x)

Loading GloVe embeddings...
Loaded GloVe embeddings with dimension: 300
Loaded spaCy model successfully
Loaded SenticNet successfully


In [3]:
def collate_batch(batch):
    """Custom collate function for handling graph data"""
    # Extract elements from batch
    input_ids = torch.stack([item["input_ids"] for item in batch])
    attention_mask = torch.stack([item["attention_mask"] for item in batch])
    labels = torch.stack([item["label"] for item in batch])
    tokens_list = [item["tokens"] for item in batch]

    # For graph data, we create a simple representation with batch size of 1
    # In a production system, you would use proper batching from PyG
    graph_xs = [item["graph_data"].x for item in batch]
    graph_edge_indices = [item["graph_data"].edge_index for item in batch]

    # Use the first graph for simplicity (or you could merge graphs with proper shifts)
    feature_dim = EMBEDDING_DIM + 5  # GloVe + Sentiment
    if len(graph_xs) > 0 and graph_xs[0] is not None and graph_xs[0].numel() > 0:
        graph_x = graph_xs[0]
        graph_edge_index = graph_edge_indices[0]
    else:
        # Fallback for empty graphs
        graph_x = torch.zeros((1, feature_dim), dtype=torch.float)
        graph_edge_index = torch.zeros((2, 0), dtype=torch.long)

    return {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "graph_x": graph_x,
        "graph_edge_index": graph_edge_index,
        "tokens": tokens_list,
        "label": labels,
    }


In [4]:
MODEL_PATH = "../sarcasm_gcn_lstm_detector_best.pt"
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_model_and_tokenizer():
    """Load the sarcasm detection model and tokenizer"""
    print(f"Loading model from {MODEL_PATH}...")
    
    # Initialize tokenizer
    tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
    
    # Initialize model
    model = SarcasmGCNLSTMDetector().to(DEVICE)
    
    # Load trained weights
    if os.path.exists(MODEL_PATH):
        model.load_state_dict(torch.load(MODEL_PATH, map_location=DEVICE))
        print("Model loaded successfully!")
    else:
        print(f"Warning: Model file not found at {MODEL_PATH}")
    
    model.eval()  # Set model to evaluation mode
    
    return model, tokenizer

model, tokenizer = load_model_and_tokenizer()
# Example usage
comment = "Congratulations on stating the obvious. I am sure glaciers will start moving any minute now"

visualize_gcn_node_influence(model,tokenizer,comment=comment,save_path='sarcasm_gcn_analysis')


Loading model from ../sarcasm_gcn_lstm_detector_best.pt...


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['pooler.dense.bias', 'pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Model loaded successfully!
Model prediction: Sarcastic (Confidence: 0.9583)
Full GCN node analysis report saved to sarcasm_gcn_analysis_report.md


{'prediction': 'Sarcastic',
 'confidence': 0.9583242535591125,
 'tokens': ['congratulations',
  'on',
  'stating',
  'the',
  'obvious',
  '.',
  'i',
  'am',
  'sure',
  'glaciers',
  'will',
  'start',
  'moving',
  'any',
  'minute',
  'now'],
 'node_analysis': [{'node_idx': 5,
   'token': 'congratulations',
   'display_token': 'congratulations',
   'final_importance': 10.397899627685547,
   'layer_evolution': [4.880075454711914,
    10.821146965026855,
    12.076385498046875,
    10.397899627685547],
   'growth_rate': 2.1306841373443604},
  {'node_idx': 6,
   'token': 'on',
   'display_token': 'on',
   'final_importance': 7.988095760345459,
   'layer_evolution': [4.9547834396362305,
    8.539734840393066,
    9.949209213256836,
    7.988095760345459],
   'growth_rate': 1.6121987104415894},
  {'node_idx': 13,
   'token': 'sure',
   'display_token': 'sure',
   'final_importance': 7.758020401000977,
   'layer_evolution': [5.312359809875488,
    8.459394454956055,
    11.00170135498046