In [1]:
# LLM
from langchain_ollama import ChatOllama
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate

# Vector DB
from langchain_chroma import Chroma
from langchain_ollama import OllamaEmbeddings
from langchain_community.document_loaders import JSONLoader

# Vector DB

In [10]:
def metadata_func(record: dict, metadata: dict) -> dict:
    metadata["test"] = record.get("test")
    metadata["function"] = record.get("function")
    return metadata

# Reference: https://python.langchain.com/docs/integrations/document_loaders/json
loader = JSONLoader(
    file_path="../data/tests.json",
    jq_schema=".tests[]",
    content_key="description",
    metadata_func=metadata_func,
)

docs = loader.load()
len(docs)

2

In [11]:

embedding_model = OllamaEmbeddings(model="mxbai-embed-large")

vectorstore = Chroma.from_documents(
    documents=docs,
    collection_name="foundry_tests",
    embedding=embedding_model,
    persist_directory="../chroma_db"
)

In [12]:
docs = vectorstore.similarity_search("test", k=1)
docs

[Document(metadata={'function': 'function safeTransferFrom(\n    address from,\n    address to,\n    uint256 id,\n    uint256 amount,\n    bytes calldata data\n) public virtual {\n    require(msg.sender == from || isApprovedForAll[from][msg.sender], "NOT_AUTHORIZED");\n\n    balanceOf[from][id] -= amount;\n    balanceOf[to][id] += amount;\n\n    emit TransferSingle(msg.sender, from, to, id, amount);\n\n    require(\n        to.code.length == 0\n            ? to != address(0)\n            : ERC1155TokenReceiver(to).onERC1155Received(msg.sender, from, id, amount, data) ==\n                ERC1155TokenReceiver.onERC1155Received.selector,\n        "UNSAFE_RECIPIENT"\n    );\n}', 'seq_num': 2, 'source': 'C:\\Users\\ezequ\\Desktop\\ITBA\\2C2024\\DeepLearning\\data\\tests.json', 'test': 'function testSafeTransferFromToEOA() public {\n    address from = address(0xABCD);\n\n    token.mint(from, 1337, 100, "");\n\n    vm.prank(from);\n    token.setApprovalForAll(address(this), true);\n\n    toke

# Create code

In [121]:
model = ChatOllama(
    model="gemma2:27b",
)

In [184]:
PROMPT_TEMPLATE = """
Based on the test function ([reference function test code]) which tests the function code example ([reference function code example]), generate a corresponding test function for the ([function to be tested]) function within the ([contract code to be tested]) contract. This function is for use within the Foundry framework for writing smart contracts.

---

[reference function test code]: {reference_function_test_code}
[reference function code example]: {reference_function_code_example}
[contract code to be tested]: {contract_code}
[function code to be tested]: {function_code}

---

Your output MUST be a single valid Solidity function, with the setup and assertions necessary to test the function. Do NOT wrap the function in a markdown code block.
REMEMBER, do NOT include a description of the function or any other text, only the code.
REMEMBER, you MUST only generate a single function, not a full test contract.
"""

prompt = ChatPromptTemplate.from_template(PROMPT_TEMPLATE)

chain = (
    prompt
    | model
    | StrOutputParser()
)

with open("./contracts/ERC20.sol", "r") as f:    
    contract_code = f.read()

reference_function_code_example = """
function transfer(address to, uint256 amount) public virtual override returns (bool) {
    address owner = msg.sender;
    _transfer(owner, to, amount);
    return true;
}
"""

reference_function_test_code = """
function test_transfer() public {
    ERC20 token = new ERC20("USD Coin", "USDC");
    address owner = address(this);
    address recipient = address(0x1);
    token.mint(owner, 100);

    token.transfer(recipient, 50);
    assertEq(token.balanceOf(owner), 50);
    assertEq(token.balanceOf(recipient), 50);
}
"""

function_code = """
function mint(address account, uint256 amount) public virtual {
    require(account != address(0), "ERC20: mint to the zero address");

    _beforeTokenTransfer(address(0), account, amount);

    _totalSupply += amount;
    unchecked {
        // Overflow not possible: balance + amount is at most totalSupply + amount, which is checked above.
        _balances[account] += amount;
    }
    emit Transfer(address(0), account, amount);

    _afterTokenTransfer(address(0), account, amount);
}
"""

response = chain.invoke({
    "reference_function_test_code": reference_function_test_code,
    "reference_function_code_example": reference_function_code_example,
    "contract_code": contract_code,
    "function_code": function_code
})

In [155]:
with open("response.txt", "w") as f:
    f.write(response)

# Fix compiler errors

In [189]:
err_model = ChatOllama(
    model="gemma2:27b",
)

In [108]:
ERROR_PROMPT_TEMPLATE = """
I wrote the Solidity test function ([test function code]) to run on the Foundry framework. When this code is compiled with Foundry, I get this error ([compiler error]).
This test is for the function ([function code to be tested]) within the contract ([contract code to be tested]). 

Your task is to understand the test I provided, fix the test code, and correct the error within the test. You MUST modify the test function code while maintaining its functionality, but do NOT add other unrelated code.

---

[test function code]: {test_res}
[compiler error]: {error_info}
[function code to be tested]: {function_code}
[contract code to be tested]: {contract_code}

---

Your output MUST be the single valid Solidity test function you fixed.
REMEMBER, do NOT include a description of the function or any other text, only the code.
REMEMBER, you MUST only generate a single test function, not a full test contract.
"""

error_prompt = ChatPromptTemplate.from_template(ERROR_PROMPT_TEMPLATE)


chain = (
    error_prompt
    | err_model
    | StrOutputParser()
)

error_info = """
Compiler run failed:
Error (6160): Wrong argument count for function call: 0 arguments given but expected 2.
  --> test/ERC20.t.sol:10:23:
   |
10 |         ERC20 token = new ERC20();
   |                       ^^^^^^^^^^^

Error (9574): Type int_const 4886718345 is not implicitly convertible to expected type address.
  --> test/ERC20.t.sol:11:9:
   |
11 |         address account = 0x123456789;
   | 
"""

test_res = """
function testMint() public {
    ERC20 token = new ERC20();
    address account = 0x123456789;
    uint256 amount = 100;

    // Initial balance
    assertEq(token.balanceOf(account), 0);

    // Mint tokens
    token.mint(account, amount);

    // Check balance after minting
    assertEq(token.balanceOf(account), amount);
}
"""

error_response = chain.invoke({
    "error_info": error_info,
    "test_res": test_res,
    "contract_code": contract_code,
    "function_code": function_code
})

In [109]:
with open("err_response.txt", "w") as f:
    f.write(error_response)