In [91]:
import re
from transformers import AutoTokenizer, AutoModel
import torch
import numpy as np
import os
import json
import subprocess

class Parser():

    def __init__(self, path):
        
        with open(path, 'r') as f:
            self.contract = f.readlines()
        self.path=path
        self.functions = []
        self.semantic_vectors = []
        self.output_dir = 'parsed_contracts'
        self.contract_name = path[:-4]
        self.slither_output=""
        self.echidna_output=""

    def parse_contract_to_functions(self):

        curr_fun = ''
        reading_function = False
        bracket_balance = None
        first_line=0
        end_line=0
        for index,line in enumerate(self.contract):
            if 'function' in line:
                first_line=index
                curr_fun += line
                bracket_balance = 1
                reading_function = True
            elif reading_function:
                left_bracket = line.count('{')
                right_bracket = line.count('}')

                bracket_balance += left_bracket
                bracket_balance -= right_bracket

                curr_fun += line
                if bracket_balance == 0:
                    end_line=index
                    self.functions.append([curr_fun,first_line,end_line,''])
                    curr_fun = ''
                    bracket_balance = None
                    reading_function = False

    def get_semantic_vectors(self):
        tokenizer = AutoTokenizer.from_pretrained("microsoft/codebert-base")
        model = AutoModel.from_pretrained("microsoft/codebert-base")

        for input_text in self.functions:
            self.get_vector_from_input(tokenizer, model, input_text[0])

        self.semantic_vectors_whitening()
        # self.save_functions_and_vectors()

    def get_vector_from_input(self, tokenizer, model, input_text):
        input_ids = tokenizer.encode(input_text, return_tensors="pt", 
                                     max_length=512, truncation=True)
        with torch.no_grad():
            outputs = model(input_ids)
            semantic_vector = outputs.last_hidden_state.mean(dim=1) 
            self.semantic_vectors.append(semantic_vector.squeeze().numpy())
        
    def semantic_vectors_whitening(self):
        vectors = np.asarray(self.semantic_vectors)
        covariance_matrix = np.cov(vectors, rowvar=False)
        eigenvalues, eigenvectors = np.linalg.eigh(covariance_matrix)
        whitening_matrix = np.dot(np.dot(eigenvectors, np.diag(1.0 / np.sqrt(eigenvalues + 1e-5))), eigenvectors.T)

        whitened_vectors = np.dot(vectors, whitening_matrix)
        mean = np.mean(whitened_vectors, axis=0)
        std = np.std(whitened_vectors, axis=0)

        normalized_whitened_vectors = (whitened_vectors - mean) / std
        self.semantic_vectors = normalized_whitened_vectors

    def get_slither_tests(self):
        result = subprocess.run(['slither', self.path], 
                stdout=subprocess.PIPE,
                stderr=subprocess.PIPE)
        output=result.stderr.decode('cp1252')
        self.slither_output=output
    
    def get_echidna_tests(self):
        result = subprocess.run(['echidna',self.path], 
                        stdout=subprocess.PIPE,
                        stderr=subprocess.PIPE)
        output=result.stderr.decode()
        self.echidna_output=output
    def parse_slither_to_functions(self):
        self.get_slither_tests()
        self.save_tests()
        found_lines = ''
        inside_block=False
        mean=0


        with open(os.path.join(self.output_dir, self.contract_name, 'test_outputs', 'slither.txt'), 'r+')  as file:
            for line in file:
                function_match = re.match(r'.*\((' + re.escape(self.contract_name+".sol") + r'#.*?)\)', line)

                if function_match and not inside_block:
                    inside_block=True
                    file_reference = function_match.group(1)
                    
                    
                    if '-' in file_reference.split('#')[1]:
                        numbers = re.findall(r'\d+', file_reference.split('#')[1])
                        num1 = int(numbers[0])
                        num2 = int(numbers[1])
                        mean = (num1 + num2) / 2
                        
                        
                    else:
                        mean = int(re.findall(r'\d+', file_reference.split('#')[1])[0])     
                        
                    found_lines+=line
                    
                elif inside_block and  line.startswith("   "):
                    found_lines+=line      
                elif inside_block and not  line.startswith("   "):
                    inside_block=False 
                    for function in self.functions:
                        start=function[1]
                        end=function[2]
                        if  start <= mean <=end:
                            function[3]+=found_lines
                            print(function)
                    found_lines=''
                    mean=0
    def save_functions_and_vectors(self):
        if not os.path.exists(os.path.join(self.output_dir, self.contract_name, 'functions')):
            os.makedirs(os.path.join(self.output_dir, self.contract_name, 'functions'))

        if not os.path.exists(os.path.join(self.output_dir, self.contract_name, 'semantic_vectors')):
            os.makedirs(os.path.join(self.output_dir, self.contract_name, 'semantic_vectors'))
        if not os.path.exists(os.path.join(self.output_dir, self.contract_name, 'tests')):
            os.makedirs(os.path.join(self.output_dir, self.contract_name, 'tests'))
        
        for i, (fun, vec) in enumerate(zip(self.functions, self.semantic_vectors)):
            
            with open(os.path.join(self.output_dir, self.contract_name, 'functions', f'{i}.txt'), 'w+')  as f:
                f.write(fun[0])
            if fun[3]!='':
                with open(os.path.join(self.output_dir, self.contract_name, 'tests', f'{i}.txt'), 'w+')  as f:
                    f.write(fun[3])

            with open(os.path.join(self.output_dir, self.contract_name, 'semantic_vectors', f'{i}.txt'), 'w+') as f:
                json.dump(vec.tolist(), f)
 
                
    def save_tests(self):  
        if not os.path.exists(os.path.join(self.output_dir, self.contract_name, 'test_outputs')):
            os.makedirs(os.path.join(self.output_dir, self.contract_name, 'test_outputs')) 
        with open(os.path.join(self.output_dir, self.contract_name, 'test_outputs', 'slither.txt'), 'w+')  as f:
            f.write(self.slither_output)
        with open(os.path.join(self.output_dir, self.contract_name, 'test_outputs', 'echidna.txt'), 'w+')  as f:
            f.write(self.echidna_output)
        

In [92]:
parser = Parser('example2.sol')
parser.parse_contract_to_functions()
parser.get_semantic_vectors()
parser.parse_slither_to_functions()
parser.save_functions_and_vectors()

['    function sendETHToFee(uint256 amount) private {\n        _taxWallet.transfer(amount);\n    }\n', 303, 305, 'META.sendETHToFee(uint256) (example2.sol#304-306) sends eth to arbitrary user\n']
['    function sendETHToFee(uint256 amount) private {\n        _taxWallet.transfer(amount);\n    }\n', 303, 305, 'META.sendETHToFee(uint256) (example2.sol#304-306) sends eth to arbitrary user\n\t- _taxWallet.transfer(amount) (example2.sol#305)\n']
['    function _transfer(address from, address to, uint256 amount) private {\n        require(from != address(0), "ERC20: transfer from the zero address");\n        require(to != address(0), "ERC20: transfer to the zero address");\n        require(amount > 0, "Transfer amount must be greater than zero");\n        uint256 taxAmount=0;\n        if (from != owner() && to != owner()) {\n            require(!bots[from] && !bots[to]);\n\n            if(_buyCount==0){\n                taxAmount = amount.mul((_buyCount>_reduceBuyTaxAt)?_finalBuyTax:_initialB

In [9]:
parser.save_functions_and_vectors()

In [22]:
line = "META.sendETHToFee(uint256) (example2.sol#304) sends eth to arbitrary user"

# Adjusted regex pattern to capture the file reference part
function_match = re.match(r'.*\((.*)\)', line)

if function_match:
    file_reference = function_match.group(1)
    print("File reference:", file_reference)
    
    # Check if there's a hyphen after #
    if '-' in file_reference.split('#')[1]:
        # Split the string at the hyphen and calculate the mean
        numbers = re.findall(r'\d+', file_reference.split('#')[1])
        num1 = int(numbers[0])
        num2 = int(numbers[1])
        mean = (num1 + num2) / 2
        print("Mean:", mean)
    else:
        # Take everything after #
        numbers = re.findall(r'\d+', file_reference.split('#')[1])
        print("Number:", int(numbers[0]))

File reference: example2.sol#304
Number: 304
