In [2]:
import os
import pprint
from solidity_parser import parser
from tkinter import Tk, filedialog
import json
import torch
from torch_geometric.data import Data
import torch.nn.functional as F
from torch_geometric.nn import GCNConv, global_mean_pool
import re

In [3]:
# Define the GNN model
class GNN(torch.nn.Module):
    def __init__(self, num_node_features, hidden_dim, num_classes):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(num_node_features, hidden_dim)
        self.conv2 = GCNConv(hidden_dim, hidden_dim)
        self.fc = torch.nn.Linear(hidden_dim, num_classes)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        x = F.relu(x)
        x = global_mean_pool(x, batch)
        x = self.fc(x)
        return F.log_softmax(x, dim=1)

In [61]:

# Function to parse Solidity file and return its AST
def parse_solidity_file(file_path):
    try:
        source_unit = parser.parse_file(file_path, loc=True)
        return source_unit
    except Exception as e:
        print(f"Error parsing {file_path}: {e}")
        return None
    
# Function to convert JSON files
def convert_json_files(data):
    data = data.replace("'", '"').replace("False", "false").replace("True", "true").replace("None", "null")
    return data  
               

# Function to combine string literals
def combine_string_literals(text):
    pattern = re.compile(r'"type":\s*"stringLiteral",\s*"value":\s*"([^"]*)"\s*(("([^"]*)"\s*)*)')
    def combine_values(match):
        combined_value = match.group(1)
        additional_values = re.findall(r'"([^"]*)"', match.group(2))
        combined_value += ' ' + ' '.join(additional_values)
        combined_value = combined_value.replace('\n', ' ').strip()
        return f'"type": "stringLiteral", "value": "{combined_value}"'
    cleaned_text = pattern.sub(combine_values, text)
    return cleaned_text

# Function to process and save the cleaned JSON file
def process_file(data):
    data = combine_string_literals(data)
    return data

# Functions to process the AST into a graph
def extract_nodes_edges(ast):
    nodes, edges = [], []
    def traverse(node, parent_index=None):
        node_index = len(nodes)
        nodes.append(node)
        if parent_index is not None:
            edges.append((parent_index, node_index))
        for key, value in node.items():
            if isinstance(value, dict):
                traverse(value, node_index)
            elif isinstance(value, list):
                for item in value:
                    if isinstance(item, dict):
                        traverse(item, node_index)
    traverse(ast)
    return nodes, edges

def create_node_features(nodes):
    features = []
    for node in nodes:
        node_type = node.get('type', 'Unknown')
        feature_vector = one_hot_encode_node_type(node_type)
        features.append(feature_vector)
    return torch.tensor(features, dtype=torch.float)

def one_hot_encode_node_type(node_type):
    types = ['PragmaDirective', 'ContractDefinition', 'FunctionDefinition', 'VariableDeclaration', 'BinaryOperation', 'Unknown']
    vector = [0] * len(types)
    if node_type in types:
        vector[types.index(node_type)] = 1
    return vector

def create_edge_index(edges):
    return torch.tensor(edges, dtype=torch.long).t().contiguous()

# Function to process AST into graph data
def process_ast(ast):
    nodes, edges = extract_nodes_edges(ast)
    node_features = create_node_features(nodes)
    edge_index = create_edge_index(edges)
    graph_data = Data(x=node_features, edge_index=edge_index)
    return graph_data

In [1]:
import os

print(os.getcwd())

c:\Users\Cody\Desktop\FYP Project\Code\AI Model\CFG Model Training


In [60]:
# file_path = ('./AST/Testing data/non-vulnerable/non-vulnerable1.sol')
file_path = ('./AST/Testing data/reentrancy/vulnerable3.sol')

if file_path:
    # Step 1: Parse the Solidity file and save the AST as JSON
    ast = parse_solidity_file(file_path)

    ast = str(ast)
    
    if ast:

        # Step 2: Convert the JSON file
        ast = convert_json_files(ast)

        # Step 3: Process the JSON file to combine string literals
        ast = process_file(ast)
        
        ast = json.loads(ast)
        
        # Step 4: Load the modified JSON and process it into graph data
        new_graph = process_ast(ast)

        # Step 5: Define the model and load the pre-trained weights
        num_node_features = new_graph.x.size(1)
        hidden_dim = 64
        num_classes = 2

        loaded_model = GNN(num_node_features, hidden_dim, num_classes)
        loaded_model.load_state_dict(torch.load("gnn_model.pth"))
        loaded_model.eval()

        # Step 6: Perform prediction
        with torch.no_grad():
            out = loaded_model(new_graph)
            pred = out.argmax(dim=1).item()  # Get the predicted class

        # Step 7: Output the prediction result
        vulnerability_status = "Vulnerable" if pred == 1 else "Not Vulnerable"
        print(f"The file is predicted to be: {vulnerability_status}")
    else:
        print("Failed to parse the Solidity file.")
else:
    print("No file selected.")

{'type': 'SourceUnit', 'children': [{'type': 'PragmaDirective', 'name': 'solidity', 'value': '^0.4.18', 'loc': {'start': {'line': 1, 'column': 0}, 'end': {'line': 1, 'column': 23}}}, {'type': 'ContractDefinition', 'name': 'Ownable', 'baseContracts': [], 'subNodes': [{'type': 'StateVariableDeclaration', 'variables': [{'type': 'VariableDeclaration', 'typeName': {'type': 'ElementaryTypeName', 'name': 'address', 'loc': {'start': {'line': 4, 'column': 4}, 'end': {'line': 4, 'column': 4}}}, 'name': 'owner', 'expression': None, 'visibility': 'public', 'isStateVar': True, 'isDeclaredConst': False, 'isIndexed': False, 'loc': {'start': {'line': 4, 'column': 4}, 'end': {'line': 4, 'column': 24}}}], 'initialValue': None, 'loc': {'start': {'line': 4, 'column': 4}, 'end': {'line': 4, 'column': 24}}}, {'type': 'EventDefinition', 'name': 'OwnershipTransferred', 'parameters': {'type': 'ParameterList', 'parameters': [{'type': 'VariableDeclaration', 'typeName': {'type': 'ElementaryTypeName', 'name': 'add