In [None]:
model_name = 'deepseek-ai/deepseek-coder-1.3b-instruct'
model_alias = 'deepseek-coder-1.3b'

In [None]:
from datasets import load_dataset
from tqdm import tqdm
import csv
from pathlib import Path
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

In [None]:
WORK_DIR = Path(model_alias)
WORK_DIR.mkdir(exist_ok=True)

In [None]:
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    model_name,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    device_map='auto'
)

In [None]:
test_dataset = load_dataset(
    "msc-smart-contract-audition/audits-with-reasons",
    split="test"
)

test_dataset

In [None]:
query_template = \
"""
Below is some solidity code and a description of a vulnerability that the code contains.

Explain how to mitigate or fix the vulnerability.
Codeblocks:
{}

Vulnerability:
{}"""

In [None]:
df_test = test_dataset.to_pandas()
df_test = df_test[df_test['description'].notnull()]
queries = df_test.apply(lambda row: query_template.format(row['code'].replace('\\n', '\n'), row['description'].replace('\\n', '\n')), axis=1)

In [None]:
with open(WORK_DIR/"recommendations.csv", "w", newline="") as f:
    writer = csv.writer(f)
    writer.writerow(["id", "output", "real"])

    for idx, (query, real) in tqdm(enumerate(zip(queries, test_dataset['recommendation'])), total=len(queries)):

        messages = [
            { 'role': 'user', 'content': query }
        ]
        inputs = tokenizer.apply_chat_template(messages, add_generation_prompt=True, return_tensors="pt").to(model.device)
        outputs = model.generate(inputs, max_new_tokens=512, do_sample=True, top_k=50, top_p=0.95, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id)
        recommendation = tokenizer.decode(outputs[0][len(inputs[0]):], skip_special_tokens=True).replace('\n', '\\n')
        writer.writerow([idx, recommendation, real])