In [None]:
from openai import OpenAI
from config import OPENAI_API_KEY
client = OpenAI(
    api_key=OPENAI_API_KEY
)

In [None]:
model_alias = "gpt-4o-mini"

In [None]:
from pathlib import Path
WORK_DIR = Path(f'runs/{model_alias}')
WORK_DIR.mkdir(exist_ok=True)

In [None]:
from datasets import load_dataset

test_df = load_dataset(
    "msc-smart-contract-auditing/audits-with-reasons",
    split="test"
).to_pandas()

test_df.loc[:,'code'] = test_df['code'].str.replace(r'\\n', '\n', regex=True)

In [None]:
from pydantic import BaseModel

class Response(BaseModel):
    message: str
    total_tokens: int
    obj: object

In [None]:
def prompt(messages) -> Response:
    response = client.chat.completions.create(
        model="gpt-4o-mini",
        messages=messages,
        temperature=0,
        max_tokens=256,
        top_p=1,
        frequency_penalty=0,
        presence_penalty=0,
        response_format={
            "type": "text"
        }
    )

    return Response(**{
        'message': response.choices[0].message.content,
        'total_tokens': response.usage.total_tokens,
        'obj': response
    })

## Descriptions

In [None]:
SYSTEM_PROMPT_DESC = \
"""
Below are one or more Solidity codeblocks. The codeblocks might contain vulnerable code.
If there is a vulnerability please provide a description of the vulnearblity in terms of the code that is responsible for it.
Describe how an attacker would be able to take advantage of the vulnerability so the explanation is even more clear.

Output only the description of the vulnerability and the attacking vector. No additional information is needed.

If there is no vulnerability output "There is no vulnearbility".
"""

In [None]:
def describe(code) -> Response:
    messages = [
        {
            "role": "system",
            "content": [{
                "type": "text",
                "text": SYSTEM_PROMPT_DESC
            }]
        },
        {
            "role": "user",
            "content": [{
                "text": f"Codeblocks:\n{code}",
                "type": "text"
            }]
        },
    ]
    return prompt(messages)

In [None]:
data = test_df[['code', 'description']]

In [None]:
from tqdm import tqdm
import csv
total_tokens = 0
with open(WORK_DIR / 'descriptions.csv', 'w') as f:
    w = csv.writer(f)
    w.writerow(['id', 'output', 'real'])
    with tqdm(data.iterrows(), total=len(data), desc="Processing", unit="row") as progress_bar:
        for idx, row in progress_bar:
            r = describe(row['code'])
            w.writerow([idx, r.message.replace('\n', '\\n'), row['description']])
            total_tokens += r.total_tokens
            progress_bar.set_postfix({'total_tokens': total_tokens})

## Recommendations

In [None]:
SYSTEM_PROMPT_REC = \
"""
Below is some solidity code and a description of a vulnerability that the code contains.

Explain how to mitigate or fix the vulnerability.
"""

In [None]:
def recommend(code, description) -> Response:
    messages = [
        {
            "role": "system",
            "content": [{
                "type": "text",
                "text": SYSTEM_PROMPT_REC
            }]
        },
        {
            "role": "user",
            "content": [{
                "text": f"Codeblocks:\n{code}\nVulnerability:{description}",
                "type": "text"
            }]
        },
    ]
    return prompt(messages)

In [None]:
test_df = test_df[test_df['description'].notnull()]
data = test_df[['code', 'description', 'recommendation']]
data.loc[:, 'description'] = data['description'].str.replace(r'\\n', '\n', regex=True)

In [None]:
from tqdm import tqdm
import csv
total_tokens = 0
with open(WORK_DIR / 'recommendations.csv', 'w') as f:
    w = csv.writer(f)
    w.writerow(['id', 'output', 'real'])
    with tqdm(data.iterrows(), total=len(data), desc="Processing", unit="row") as progress_bar:
        for idx, row in progress_bar:
            r = recommend(row['code'], row['description'])
            w.writerow([idx, r.message.replace('\n', '\\n'), row['recommendation']])
            total_tokens += r.total_tokens
            progress_bar.set_postfix({'total_tokens': total_tokens})