In [None]:
import verification_library as veri   
from peft import PeftConfig, PeftModel
from transformers import LlamaForCausalLM, AutoTokenizer, BitsAndBytesConfig
import torch
import os  
import importlib
import gc
from dotenv import load_dotenv
from collections import defaultdict
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"

importlib.reload(veri)

counter_greater_than_200 = 0

contract_lines_4X = []
contract_lines_8X = []

contracts_4X = veri.get_files("/home/matteo/FLAMES/verification-results/sb-heists/smartbugs-curated/0.4.x/contracts/dataset")
contracts_8X = veri.get_files("/home/matteo/FLAMES/verification-results/sb-heists/smartbugs-curated/0.8.x/contracts/dataset/arithmetic") 

for contract_path, contract_name in contracts_4X:
    contract, line = veri.find_occurrences(contract_path, "// <yes> <report>")
    
    if contract.count('\n') > 200:
        counter_greater_than_200 += 1 
   
    contract = veri.replace_lines_with_string(contract, line, '')
    contract_lines_4X.append((contract_name, contract, line))

for contract_path, contract_name in contracts_8X:
    contract, line = veri.find_occurrences(contract_path, "// <yes> <report>")
    
    if contract.count('\n') > 200:
        counter_greater_than_200 += 1
  
    contract = veri.replace_lines_with_string(contract, line, '')
    contract_lines_8X.append((contract_name, contract, line))

print(counter_greater_than_200, len(contracts_4X), len(contracts_8X))
veri.print_json_report("reports/contract_no_comment8X.json", contract_lines_8X)
veri.print_txt_report("reports/contract_no_comment8x.txt", contract_lines_8X)

veri.print_json_report("reports/contract_no_comment4X.json", contract_lines_4X)
veri.print_txt_report("reports/contract_no_comment4x.txt", contract_lines_4X)


In [None]:
load_dotenv()
token = os.getenv("HF_TOKEN")

all_contracts = []
mapping = []  

contract_lines = contract_lines_4X  # or _8X

for idx, (contract_name, contract, lines) in enumerate(contract_lines):
    for line in lines:
        prompt_with_fill = veri.replace_lines_with_string(contract, [line], 'require(<FILL_ME>);')
        all_contracts.append(prompt_with_fill)
        mapping.append((contract_name, contract, line)) 


In [None]:

config = PeftConfig.from_pretrained("GGmorello/FLAMES-20k", token=token)

ft_model = LlamaForCausalLM.from_pretrained(
    config.base_model_name_or_path,
    token = token,
    quantization_config=BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.bfloat16,
    ),
    cache_dir=os.environ.get("TMPDIR")
)

ft_model_20 = PeftModel.from_pretrained(ft_model, "GGmorello/FLAMES-20k", token=token)

#llama_tokenizer = AutoTokenizer.from_pretrained("huggyllama/llama-7b", token=token)
llama_tokenizer = AutoTokenizer.from_pretrained(config.base_model_name_or_path, token=token) 
ft_model_20 = ft_model_20.to('cuda')


In [None]:

PROMPT = all_contracts[2]

input_ids = llama_tokenizer(PROMPT, return_tensors="pt")["input_ids"].to('cuda')
generated_ids = ft_model_20.generate(input_ids, max_new_tokens=128)

filling = llama_tokenizer.batch_decode(generated_ids[:, input_ids.shape[1]:], skip_special_tokens = True)[0]
#print(filling)

In [12]:
from tqdm import tqdm

results_20 = []

chunk_size = 8  
for i in tqdm(range(0, len(all_contracts), chunk_size)):
    chunk = all_contracts[i:i + chunk_size]

    for data in chunk:
        tok = llama_tokenizer(data, return_tensors='pt', truncation=True, max_length=2048)
        tok = {k: v.to('cuda') for k, v in tok.items()}

        with torch.no_grad():
            generated_ids = ft_model_20.generate(
                **tok,
                max_new_tokens=256,
                pad_token_id=llama_tokenizer.eos_token_id
            )

        ft_filling = llama_tokenizer.batch_decode(
            generated_ids[:, tok['input_ids'].shape[1]:],
            skip_special_tokens=True
        )[0]

        results_20.append(ft_filling)

        del tok
        del generated_ids
        torch.cuda.empty_cache()
        gc.collect()

100%|██████████| 26/26 [12:10<00:00, 28.11s/it]


In [13]:
results_20 += [''] * (len(all_contracts) - len(results_20))

contracts_with_results = defaultdict(list)

for generated, (contract_name, contract, line) in zip(results_20, mapping):
    replaced_contract = veri.replace_lines_with_string(contract, [line], f'require({generated});')
    contracts_with_results[contract_name].append((replaced_contract,line, f'require({generated});'))
veri.print_json_report("reports/contract_with_results_20.json", contracts_with_results)

####use contarct lines to take the original contract