# Code Graph Visualization

This notebook demonstrates how to visualize the AST (Abstract Syntax Tree) graph representation of code used in the vulnerability detection GNN model.

In [5]:
# Import necessary libraries
import sys
import os
import torch
import networkx as nx
import matplotlib.pyplot as plt
import numpy as np
from tree_sitter import Language, Parser
from torch_geometric.utils import from_networkx, to_networkx

# Import functions from our project
from gnn_main import build_graph_from_code, enrich_node_features

ModuleNotFoundError: No module named 'gnn_main'

## 1. Load Tree-sitter Language

First, we need to load the appropriate language for tree-sitter parsing.

In [None]:
# Check if we have the language file
LANGUAGE_PATH = os.path.join(os.path.dirname(os.getcwd()), 'build', 'languages.so')

if not os.path.exists(LANGUAGE_PATH):
    print(f"Language file not found at {LANGUAGE_PATH}")
    # Try to find it in the current directory structure
    for root, dirs, files in os.walk(os.path.dirname(os.getcwd())):
        for file in files:
            if file.endswith('languages.so'):
                LANGUAGE_PATH = os.path.join(root, file)
                print(f"Found language file at: {LANGUAGE_PATH}")
                break
        if os.path.exists(LANGUAGE_PATH):
            break

try:
    # Load the C language for demonstration
    C_LANGUAGE = Language(LANGUAGE_PATH, 'c')
    print("Successfully loaded C language")
except Exception as e:
    print(f"Error loading language: {e}")
    print("Trying to load from default path...")
    try:
        # Try with a relative path that might work in your environment
        C_LANGUAGE = Language('build/languages.so', 'c')
        print("Successfully loaded C language from default path")
    except Exception as e:
        print(f"Error loading language from default path: {e}")
        print("Please ensure tree-sitter language files are properly installed")

## 2. Sample Code

Let's create a simple code example to visualize. We'll use a C function with a potential vulnerability.

In [None]:
# Sample vulnerable C code (buffer overflow)
sample_code = """
void copy_data(char *user_input) {
    char buffer[10];
    strcpy(buffer, user_input);  // Vulnerable: no bounds checking
    printf("Buffer contains: %s\n", buffer);
}
"""

print(sample_code)

## 3. Create Graph from Code

Now let's create a graph representation of the code using tree-sitter and our custom functions.

In [None]:
# Create a manual version of the build_graph_from_code function for better visualization
def build_detailed_graph(source_code, language):
    """Convert source code to graph using AST with detailed node information"""
    parser = Parser()
    parser.set_language(language)
    tree = parser.parse(bytes(source_code, "utf8"))
    root_node = tree.root_node

    graph = nx.Graph()
    node_count = 0
    node_map = {}  # To store original nodes for reference

    def traverse(node, parent=None):
        nonlocal node_count
        # Extract code snippet for this node
        start, end = node.start_byte, node.end_byte
        code_snippet = source_code[start:end]
        if len(code_snippet) > 20:  # Truncate long snippets
            code_snippet = code_snippet[:17] + "..."

        # Add node with detailed information
        graph.add_node(node_count,
                      type=node.type,
                      start_byte=node.start_byte,
                      end_byte=node.end_byte,
                      code=code_snippet)

        # Store original node for reference
        node_map[node_count] = node

        if parent is not None:
            graph.add_edge(parent, node_count, relationship="parent-child")

        current_node = node_count
        node_count += 1

        for child in node.children:
            traverse(child, current_node)

    traverse(root_node)
    return graph, node_map

# Build the graph
try:
    detailed_graph, node_map = build_detailed_graph(sample_code, C_LANGUAGE)
    print(f"Created graph with {len(detailed_graph.nodes)} nodes and {len(detailed_graph.edges)} edges")
except Exception as e:
    print(f"Error building graph: {e}")

## 4. Visualize the Graph

Now let's visualize the graph with meaningful labels.

In [None]:
def visualize_ast_graph(graph, node_map, max_nodes=30):
    """Visualize AST graph with node types and code snippets"""
    # If graph is too large, create a simplified version
    if len(graph.nodes) > max_nodes:
        print(f"Graph is too large ({len(graph.nodes)} nodes). Showing only the first {max_nodes} nodes.")
        # Get a subgraph of the first max_nodes nodes
        nodes = list(graph.nodes)[:max_nodes]
        graph = graph.subgraph(nodes)

    plt.figure(figsize=(20, 16))

    # Create a layout for the graph
    pos = nx.spring_layout(graph, seed=42, k=0.8)

    # Create node labels with type and code snippet
    node_labels = {}
    for node in graph.nodes:
        node_type = graph.nodes[node]['type']
        code = graph.nodes[node].get('code', '')
        if code and len(code) > 0:
            label = f"{node}: {node_type}\n'{code}'"
        else:
            label = f"{node}: {node_type}"
        node_labels[node] = label

    # Draw the graph
    nx.draw_networkx_nodes(graph, pos, node_size=2000, node_color='lightblue', alpha=0.8)
    nx.draw_networkx_edges(graph, pos, width=1.5, alpha=0.7, edge_color='gray')
    nx.draw_networkx_labels(graph, pos, labels=node_labels, font_size=10, font_family='sans-serif')

    plt.title("Abstract Syntax Tree (AST) Graph Representation", fontsize=16)
    plt.axis('off')
    plt.tight_layout()
    plt.show()

    # Return node information for detailed explanation
    return node_labels

try:
    node_labels = visualize_ast_graph(detailed_graph, node_map)
except Exception as e:
    print(f"Error visualizing graph: {e}")

## 5. Explain Node and Edge Meanings

Let's explain what each node type and edge represents in the context of vulnerability detection.

In [None]:
# Create a dictionary explaining common node types in C AST
node_type_explanations = {
    "translation_unit": "The root node representing the entire source file",
    "function_definition": "A function declaration with its implementation",
    "parameter_list": "List of parameters for a function",
    "parameter_declaration": "Declaration of a single parameter",
    "compound_statement": "A block of code enclosed in braces {}",
    "declaration": "Variable or type declaration",
    "init_declarator": "Declaration with initialization",
    "declarator": "The name and type being declared",
    "array_declarator": "Declaration of an array",
    "primitive_type": "Basic types like int, char, etc.",
    "call_expression": "Function call",
    "argument_list": "Arguments passed to a function",
    "string_literal": "String constant in quotes",
    "identifier": "Name of a variable, function, etc.",
    "expression_statement": "Statement containing an expression",
    "binary_expression": "Expression with two operands and an operator",
    "parenthesized_expression": "Expression in parentheses",
    "number_literal": "Numeric constant"
}

# Print explanations for node types in our graph
print("Node Type Explanations in the Context of Vulnerability Detection:\n")
node_types_in_graph = set(nx.get_node_attributes(detailed_graph, 'type').values())

for node_type in sorted(node_types_in_graph):
    explanation = node_type_explanations.get(node_type, "Custom or specialized syntax element")
    print(f"- {node_type}: {explanation}")

    # Add vulnerability relevance for certain node types
    if node_type == "call_expression":
        print("  * VULNERABILITY RELEVANCE: Critical for detecting unsafe function calls like strcpy, gets, etc.")
    elif node_type == "array_declarator":
        print("  * VULNERABILITY RELEVANCE: Important for buffer overflow detection (array size vs usage)")
    elif node_type == "string_literal":
        print("  * VULNERABILITY RELEVANCE: Can indicate hardcoded credentials or format string vulnerabilities")
    elif node_type == "binary_expression":
        print("  * VULNERABILITY RELEVANCE: May reveal integer overflow issues or improper comparisons")

print("\nEdge Meaning:\n")
print("- Edges represent parent-child relationships in the AST")
print("- Parent nodes represent containing or higher-level syntax elements")
print("- Child nodes represent nested or component elements")
print("- The graph structure captures the syntactic structure of the code")
print("- GNN models use this structure to understand code context and detect patterns associated with vulnerabilities")

## 6. Node Features Used in GNN Model

Let's examine the node features that are actually used in the GNN model.

In [None]:
# Convert to PyTorch Geometric format
pyg_graph = from_networkx(detailed_graph)

# Apply the same feature enrichment as in the model
enriched_graph = enrich_node_features(pyg_graph)

# Create more informative node features as in the CodeDefectDataset.get method
node_features = []
for i, node_type in enumerate(pyg_graph.type):
    # Create a more informative feature vector
    node_type_hash = hash(node_type) % 1000

    # Enhanced feature vector with multiple dimensions
    feature = [
        node_type_hash / 1000.0,  # Normalized hash value
        float(len(node_type)) / 50.0,  # Normalized length of type name
        float(pyg_graph.start_byte[i]) / 10000.0,  # Normalized position
        float(pyg_graph.end_byte[i] - pyg_graph.start_byte[i]) / 1000.0,  # Normalized size
        1.0 if "expr" in node_type else 0.0,  # Is expression
        1.0 if "decl" in node_type else 0.0,  # Is declaration
        1.0 if "stmt" in node_type else 0.0,  # Is statement
    ]

    node_features.append(feature)

# Convert to tensor
node_features_tensor = torch.tensor(node_features, dtype=torch.float)

# Print feature explanation
print("Node Feature Vector Explanation:\n")
print("Each node in the graph has a feature vector with the following dimensions:")
print("1. Normalized hash of node type: Unique identifier for the syntax element type")
print("2. Normalized length of type name: Longer names might indicate more specialized elements")
print("3. Normalized position in code: Captures the location in the source code")
print("4. Normalized size (bytes): Indicates the size/complexity of the syntax element")
print("5. Is expression flag: 1.0 if the node represents an expression, 0.0 otherwise")
print("6. Is declaration flag: 1.0 if the node represents a declaration, 0.0 otherwise")
print("7. Is statement flag: 1.0 if the node represents a statement, 0.0 otherwise")
print("\nAdditionally, the model uses one-hot encoding of node types from the enrichment function")

# Show example feature vectors for a few nodes
print("\nExample Feature Vectors:")
for i in range(min(5, len(node_features))):
    node_type = pyg_graph.type[i]
    print(f"Node {i} ({node_type}): {node_features[i]}")

## 7. Vulnerability Detection Process

Let's explain how the GNN uses this graph representation to detect vulnerabilities.

In [None]:
print("GNN Vulnerability Detection Process:\n")
print("1. Code Parsing: Source code is parsed into an Abstract Syntax Tree (AST) using tree-sitter")
print("2. Graph Construction: The AST is converted to a graph where:")
print("   - Nodes represent syntax elements (function definitions, declarations, expressions, etc.)")
print("   - Edges represent the hierarchical structure of the code")
print("   - Node features capture the type and context of each syntax element")
print("\n3. Feature Engineering:")
print("   - Each node gets a feature vector encoding its type and properties")
print("   - Position, size, and semantic role (expression/declaration/statement) are captured")
print("   - One-hot encoding of node types adds categorical information")
print("\n4. Graph Neural Network Processing:")
print("   - GNN layers aggregate information from neighboring nodes")
print("   - This captures patterns across the code structure")
print("   - Multiple GNN layers allow information to flow across the entire graph")
print("\n5. Global Pooling:")
print("   - Node features are aggregated to create a single graph-level representation")
print("   - This represents the entire code snippet's vulnerability profile")
print("\n6. Classification:")
print("   - Fully connected layers process the graph representation")
print("   - Final output is a vulnerability score (0-1)")
print("   - Scores above a threshold (typically 0.5) indicate potential vulnerabilities")
print("\nKey Vulnerability Patterns Detected:")
print("- Unsafe function calls (strcpy, gets, etc.)")
print("- Buffer size mismatches")
print("- Missing bounds checks")
print("- Improper input validation")
print("- Integer overflow/underflow conditions")
print("- Format string vulnerabilities")
print("- Null pointer dereferences")