In [5]:
# LLM
from langchain_ollama import ChatOllama
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

# Vector DB
from langchain_chroma import Chroma
from langchain_ollama import OllamaEmbeddings
from langchain_community.document_loaders import JSONLoader

# Output
import os
import pprint as pp
from jinja2 import Template

In [6]:
root = os.path.abspath(os.path.join(os.getcwd(), os.pardir))
root

'c:\\Users\\ezequ\\Desktop\\ITBA\\2C2024\\DeepLearning'

# Dataset Preparation

In [7]:
import re

class ContractParser:
    def __init__(self, contract_path) -> None:
        with open(contract_path, "r") as file:
            self.solidity_contract = file.read()

    @property
    def str_functions_with_names(self):
        start_idx = 0
        str_functions = []
        for line_num, line in enumerate(self.solidity_contract.splitlines(), start=1):
            if line.startswith("    function"):
                start_idx = line_num
            if line.startswith("    }"):
                lines = self.solidity_contract.splitlines()[start_idx - 1:line_num]
                func_code = '\n'.join(lines)
                if func_code: str_functions.append(func_code)

        function_names = [self._extract_function_name(f) for f in str_functions]

        return str_functions, function_names

    def _extract_function_name(self, func_code):
            pattern = r'\s+(function)\s+(\w+)\s*'
            match = re.match(pattern, func_code)
            return match.group(2) if match else None

In [8]:
cp = ContractParser(contract_path=f"{root}/contracts/ERC20.sol")
functions, function_names = cp.str_functions_with_names
print(function_names)
pp.pprint(functions[0])

['name', 'symbol', 'decimals', 'totalSupply', 'balanceOf', 'transfer', 'allowance', 'approve', 'transferFrom', 'increaseAllowance', 'decreaseAllowance', '_transfer', 'mint', '_burn', '_approve', '_spendAllowance']
('    function name() public view virtual returns (string memory) {\n'
 '        return _name;\n'
 '    }')


In [43]:
directory = f"{root}/data/contracts"
dataset = []
    
for filename in os.listdir(directory):
    filepath = os.path.join(directory, filename)
    contract_name = filename.split('.')[0]
    
    if os.path.isfile(filepath) and filename.endswith(".t.sol"):        
        contract_path = os.path.join(directory, f"{contract_name}.sol")

        test_contract = ContractParser(contract_path=filepath)
        code_contract = ContractParser(contract_path=contract_path)

        # Parse functions and tests
        functions, function_names = code_contract.str_functions_with_names
        tests, test_names = test_contract.str_functions_with_names
        
        ### For each function in the contract code, search for all functions in the test code whose name includes the function name and the word "test".
        ### Append the setup code to the tests' code.

        # Search for the setup function of the test contract
        setup_body = ""
        for i, test_name in enumerate(test_names):
            if test_name == "setUp":
                # remove function signature and closing bracket
                setup_body = tests[i].split('\n')[1:-1]
                setup_body = '\n'.join(setup_body)
                break

        # Search for the tests that correspond to each function in the contract code, append the setup code to the test code, and add them to the dataset
        for i, function_name in enumerate(function_names):
            tests_found = []
            for j, test_name in enumerate(test_names):
                if "test" in test_name and function_name.lower() in test_name.lower():
                    # add the setup code to the test code
                    test_code = tests[j].split('\n')
                    setup_code = setup_body.split('\n')
                    test_with_setup = test_code[:1] + setup_code + test_code[1:]
                    # append the test code to the list of tests found for the current function
                    tests_found.append('\n'.join(test_with_setup))
            
            if len(tests_found) > 0:
                dataset.append({
                    "function": functions[i],
                    "tests": tests_found
                })

print(len(dataset))
pp.pprint(dataset[2])

17
{'function': '    function safeBatchTransferFrom(\n'
             '        address from,\n'
             '        address to,\n'
             '        uint256[] calldata ids,\n'
             '        uint256[] calldata amounts,\n'
             '        bytes calldata data\n'
             '    ) public virtual {\n'
             '        require(ids.length == amounts.length, '
             '"LENGTH_MISMATCH");\n'
             '\n'
             '        require(msg.sender == from || '
             'isApprovedForAll[from][msg.sender], "NOT_AUTHORIZED");\n'
             '\n'
             '        // Storing these outside the loop saves ~15 gas per '
             'iteration.\n'
             '        uint256 id;\n'
             '        uint256 amount;\n'
             '\n'
             '        for (uint256 i = 0; i < ids.length; ) {\n'
             '            id = ids[i];\n'
             '            amount = amounts[i];\n'
             '\n'
             '            balanceOf[from][id]

In [44]:
# Generate a description for each function in the dataset, using a light language model

model = ChatOllama(
    model="llama3.2",
)

PROMPT_TEMPLATE = """
Based on the following function written in the Solidity language, summarize its behavior in plain text, without giving a line-by-line description and without making any reference to the code.

```solidity
{function_code}
```
"""

prompt = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)

chain = (
    prompt
    | model
    | StrOutputParser()
)

for sample in dataset:
    sample["description"] = chain.invoke({
        "function_code": sample["function"],
    })

dataset[0]

{'function': '    function setApprovalForAll(address operator, bool approved) public virtual {\n        isApprovedForAll[msg.sender][operator] = approved;\n\n        emit ApprovalForAll(msg.sender, operator, approved);\n    }',
 'tests': ['    function testDirectSetApprovalForAll(address by, address operator, bool approved) public {\n        token = new MockERC1155();\n        _expectApprovalForAllEvent(by, operator, approved);\n        vm.prank(by);\n        token.directSetApprovalForAll(operator, approved);\n    }'],
 'description': 'This function allows a user (the "sender") to specify whether another user\'s approval for certain operators should be considered. When this function is called, it updates the approval status and notifies all other users of the change.'}

In [45]:
import json

# Prepare the data and dump it into a JSON file

dataset = {"data": dataset}

with open(f"{root}/data/dataset.json", "w") as json_file:
    json.dump(dataset, json_file, indent=4)

# Vector DB

In [50]:
def metadata_func(record: dict, metadata: dict) -> dict:
    metadata["tests"] = json.dumps(record.get("tests"))
    metadata["function"] = record.get("function")
    return metadata

# Reference: https://python.langchain.com/docs/integrations/document_loaders/json
loader = JSONLoader(
    file_path=f"{root}/data/dataset.json",
    jq_schema=".data[]",
    content_key="description",
    metadata_func=metadata_func,
)

docs = loader.load()
len(docs)

17

In [51]:

embedding_model = OllamaEmbeddings(model="mxbai-embed-large")

vectorstore = Chroma.from_documents(
    documents=docs,
    collection_name="foundry_tests",
    embedding=embedding_model,
    persist_directory=f"{root}/chroma_db"
)

In [74]:
docs = vectorstore.similarity_search("test", k=1)

print(docs[0].metadata["function"]) # function code of the first retrieved document
print(json.loads(docs[0].metadata["tests"])[0]) # parse the first test code from the metadata

    function balanceOfBatch(address[] calldata owners, uint256[] calldata ids)
        public
        view
        virtual
        returns (uint256[] memory balances)
    {
        require(owners.length == ids.length, "LENGTH_MISMATCH");

        balances = new uint256[](owners.length);

        // Unchecked because the only math done is incrementing
        // the array index counter which cannot possibly overflow.
        unchecked {
            for (uint256 i = 0; i < owners.length; ++i) {
                balances[i] = balanceOf[owners[i]][ids[i]];
            }
        }
    }
    function testBalanceOfBatchWithArrayMismatchReverts(uint256) public {
        token = new MockERC1155();
        address[] memory tos = new address[](_random() % 4);
        uint256[] memory ids = new uint256[](_random() % 4);
        if (tos.length == ids.length) return;

        vm.expectRevert(ERC1155.ArrayLengthsMismatch.selector);
        token.balanceOfBatch(tos, ids);
    }


# Create code

In [16]:
model = ChatOllama(
    model="llama3.2",
    temperature=0,
    top_p=1,
)

PROMPT_TEMPLATE = """
Based on the test function ([reference function test code]) which tests the function code example ([reference function code example]), generate a corresponding test function for the ([function to be tested]) function within the ([contract code to be tested]) contract. This function is for use within the Foundry framework for writing smart contracts.

---

[reference function test code]: {reference_function_test_code}
[reference function code example]: {reference_function_code_example}
[contract code to be tested]: {contract_code}
[function code to be tested]: {function_code}

---

Your output MUST be a single valid Solidity function, with the setup and assertions necessary to test the function. Do NOT wrap the function in a markdown code block.
REMEMBER, do NOT include a description of the function or any other text, only the code.
REMEMBER, you MUST only generate a single function, not a full test contract.
"""

prompt = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)

chain = (
    prompt
    | model
    | StrOutputParser()
)

with open(f"{root}/contracts/ERC20.sol", "r") as f:    
    contract_code = f.read()

reference_function_code_example = """
function transfer(address to, uint256 amount) public virtual override returns (bool) {
    address owner = msg.sender;
    _transfer(owner, to, amount);
    return true;
}
"""

reference_function_test_code = """
function test_transfer() public {
    ERC20 token = new ERC20("USD Coin", "USDC");
    address owner = address(this);
    address recipient = address(0x1);
    token.mint(owner, 100);

    token.transfer(recipient, 50);
    assertEq(token.balanceOf(owner), 50);
    assertEq(token.balanceOf(recipient), 50);
}
"""

function_code = """
function mint(address account, uint256 amount) public virtual {
    require(account != address(0), "ERC20: mint to the zero address");

    _beforeTokenTransfer(address(0), account, amount);

    _totalSupply += amount;
    unchecked {
        // Overflow not possible: balance + amount is at most totalSupply + amount, which is checked above.
        _balances[account] += amount;
    }
    emit Transfer(address(0), account, amount);

    _afterTokenTransfer(address(0), account, amount);
}
"""

In [17]:
contract_name = "ERC20"
generated_tests = []

response = chain.invoke({
    "reference_function_test_code": reference_function_test_code,
    "reference_function_code_example": reference_function_code_example,
    "contract_code": contract_code,
    "function_code": function_code
})

generated_tests.append(response)

pp.pprint(response)

('function testMint() public {\n'
 '    ERC20 token = new ERC20("Test Token");\n'
 '    address account = token.minter();\n'
 '    uint256 amount = 100;\n'
 '    token.mint(account, amount);\n'
 '    require(token.balanceOf(account) == amount, "ERC20: mint failed");\n'
 '}')


In [18]:
template = """// SPDX-License-Identifier: MIT
pragma solidity ^0.8.0;

import {Test} from "forge-std/Test.sol";
import "src/{{ contract_name }}.sol";

contract {{ contract_name }}Test is Test {

    {% for test in tests %}
    {{ test }}
    {% endfor %}

}
"""

t = Template(template)

test_file_content = t.render(
    contract_name=contract_name,
    tests=generated_tests,
)

pp.pprint(test_file_content)

('// SPDX-License-Identifier: MIT\n'
 'pragma solidity ^0.8.0;\n'
 '\n'
 'import {Test} from "forge-std/Test.sol";\n'
 'import "src/ERC20.sol";\n'
 '\n'
 'contract ERC20Test is Test {\n'
 '\n'
 '    \n'
 '    function testMint() public {\n'
 '    ERC20 token = new ERC20("Test Token");\n'
 '    address account = token.minter();\n'
 '    uint256 amount = 100;\n'
 '    token.mint(account, amount);\n'
 '    require(token.balanceOf(account) == amount, "ERC20: mint failed");\n'
 '}\n'
 '    \n'
 '\n'
 '}')


In [19]:
with open(f"{root}/foundry/test/{contract_name}.t.sol", "w") as file:
    file.write(test_file_content)

# Run Foundry Test

In [20]:
import subprocess

In [21]:
foundry_dir = os.path.join(root, "foundry")

command = ["forge", "test", "--match-contract", f"{contract_name}Test"]

result = subprocess.run(command, cwd=foundry_dir, capture_output=True, text=True)

print(result.stdout)

Compiler run failed:
Error (6160): Wrong argument count for function call: 1 arguments given but expected 2.
  --> test/ERC20.t.sol:11:19:
   |
11 |     ERC20 token = new ERC20("Test Token");
   |                   ^^^^^^^^^^^^^^^^^^^^^^^

Error (9582): Member "minter" not found or not visible after argument-dependent lookup in contract ERC20.
  --> test/ERC20.t.sol:12:23:
   |
12 |     address account = token.minter();
   |                       ^^^^^^^^^^^^




# Fix compiler errors

In [25]:
err_model = ChatOllama(model="llama3.1:8b", temperature=0, top_p=1)

In [26]:
ERROR_PROMPT_TEMPLATE = """
I wrote the Solidity test function ([test function code]) to run on the Foundry framework. When this code is compiled with Foundry, I get this error ([compiler error]).
This test is for the function ([function code to be tested]) within the contract ([contract code to be tested]). 

Your task is to understand the test I provided, fix the test code, and correct the error within the test. You MUST modify the test function code while maintaining its functionality, but do NOT add other unrelated code.

---

[test function code]: {test_res}
[compiler error]: {error_info}
[function code to be tested]: {function_code}
[contract code to be tested]: {contract_code}

---

Your output MUST be a single valid Solidity function, with the setup and assertions necessary to test the function. Do NOT wrap the function in a markdown code block.
REMEMBER, do NOT include a description of the function or any other text, only the code.
REMEMBER, you MUST only generate a single test function, not a full test contract.
"""

error_prompt = ChatPromptTemplate.from_template(ERROR_PROMPT_TEMPLATE)


chain = (
    error_prompt
    | err_model
    | StrOutputParser()
)

error_info = """
Compiler run failed:
Error (6160): Wrong argument count for function call: 0 arguments given but expected 2.
  --> test/ERC20.t.sol:10:23:
   |
10 |         ERC20 token = new ERC20();
   |                       ^^^^^^^^^^^

Error (9574): Type int_const 4886718345 is not implicitly convertible to expected type address.
  --> test/ERC20.t.sol:11:9:
   |
11 |         address account = 0x123456789;
   | 
"""

test_res = """
function testMint() public {
    ERC20 token = new ERC20();
    address account = 0x123456789;
    uint256 amount = 100;

    // Initial balance
    assertEq(token.balanceOf(account), 0);

    // Mint tokens
    token.mint(account, amount);

    // Check balance after minting
    assertEq(token.balanceOf(account), amount);
}
"""

error_response = chain.invoke({
    "error_info": error_info,
    "test_res": test_res,
    "contract_code": contract_code,
    "function_code": function_code
})

In [27]:
with open("err_response.txt", "w") as f:
    f.write(error_response)