In [None]:
from tqdm import tqdm
from transformers import PreTrainedTokenizerFast
from common import mean_pooling
from transformers import MPNetModel
from datasets import load_from_disk
from model import AsmEncoder
from pprint import pprint
from openai import OpenAI
import pickle
import torch
import json
import zstd
import sys
import re
import os

We delete the output of the block above, as it will output a warning prompt containing identity information.

In [2]:
def decompress_data(b_str):
    return pickle.loads(zstd.decompress(b_str))
def compress_data(obj):
    return zstd.compress(pickle.dumps(obj))

# We used our model weights when executing this notebook, but did not make them public.

In [3]:
if not (os.path.exists("models/binquery/asm") and os.path.join("models/binquery/desc")):
    print("This Jupyter notebook is just a demonstration, the model weights will be released when the paper is officially published.")
    sys.exit(0)

In [4]:
asm_model = AsmEncoder.from_pretrained("models/binquery/asm")
desc_model = MPNetModel.from_pretrained("models/binquery/desc")
asm_tokenizer = PreTrainedTokenizerFast.from_pretrained("tokenizers/asm")
desc_tokenizer = PreTrainedTokenizerFast.from_pretrained("tokenizers/desc")

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'CodeRangeTokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.
The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'MPNetTokenizer'. 
The class this function is called from is 'PreTrainedTokenizerFast'.


In [5]:
with open("dataset/vulsearch/task.json") as f:
    task = json.load(f)
target = task["target"]
query = task["cve_description"]
project = task['project']

In [6]:
ds = load_from_disk(f"dataset/vulsearch/{project}")

# Compute embedding vectors for all binary functions in the project.

In [7]:
candidate_list = ["\n".join(decompress_data(row["asm"])) for row in ds]
tokenized_candidate = asm_tokenizer(candidate_list, truncation=True, return_tensors="pt", padding="max_length", max_length=1024)
asm_model = asm_model.cuda()
asm_model.eval()
asm_model.requires_grad_(False)
embedding_list = []
with torch.no_grad():
    for start in tqdm(range(0, len(candidate_list), 64)):
        input_ids = tokenized_candidate["input_ids"][start:start+64]
        attention_mask = tokenized_candidate["attention_mask"][start:start+64]
        token_type_ids = tokenized_candidate["token_type_ids"][start:start+64]
        input_ids = input_ids.cuda()
        attention_mask = attention_mask.cuda()
        token_type_ids = token_type_ids.cuda()
        # output = asm_model(input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        output = asm_model(input_ids, attention_mask=attention_mask)
        embedding_list.append(mean_pooling(output.last_hidden_state.cpu(), attention_mask))
candidate_embeddings = torch.cat(embedding_list, dim=0)

100%|██████████| 39/39 [00:43<00:00,  1.11s/it]


# Preparing the environment for query augmentation.

In [8]:
with open("dataset/secret.json") as f:
    secret = json.load(f)
LLM_KEY = secret["LLM_KEY"]
LLM_URL = secret["LLM_URL"]
LLM_MODEL = secret["LLM_MODEL"]
TEMPERATURE = 1
MAX_TOKENS = 8192
TIMEOUT = 60
with open("dataset/prompts/augmentation.txt") as f:
    SYSTEM_PROMPT = f.read()

In [9]:
client = OpenAI(api_key=LLM_KEY, base_url=LLM_URL)
def complete(user: str):
    cnt = 0
    try:
        completion = client.chat.completions.create(
            messages=[
                {
                    "role": "system",
                    "content": SYSTEM_PROMPT,
                },
                {
                    "role": "user",
                    "content": user,
                },
            ],
            timeout=TIMEOUT,
            model=LLM_MODEL,
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS,
        )
        result = completion.choices[0].message.content
        return result
    except Exception as e:
        print(e)
        return None

In [10]:
def get_json(s):
    pattern = r"```[\w\s]*\n(.*?)```"
    match = re.search(pattern, s, re.DOTALL)
    if match:
        code_block = match.group(1).strip()
        return eval(code_block)
    else:
        raise Exception

In [11]:
def augment_query(query: str):
    user_input = "Original query: " + query
    response = complete(user_input)
    result = get_json(response)['description']
    return result

# Processing Natural Language-based Binary Function Retrieval

In [12]:
pprint(f"Target: {target}")
pprint(f"Original Query: {query}")

'Target: 3'
('Original Query: Multiple use-after-free vulnerabilities in the (1) '
 'htmlPArsePubidLiteral and (2) htmlParseSystemiteral functions in libxml2 '
 'before 2.9.4, as used in Apple iOS before 9.3.2, OS X before 10.11.5, tvOS '
 'before 9.2.1, and watchOS before 2.2.1, allow remote attackers to cause a '
 'denial of service via a crafted XML document.')


In [13]:
augmented_query = augment_query(query)
pprint(f"Augmented Query: {augmented_query}")

('Augmented Query: Implements two functions, (1) htmlParsePubidLiteral and (2) '
 'htmlParseSystemLiteral, which parse public and system identifiers in XML '
 'documents, respectively. These functions are designed to handle specific XML '
 'syntax elements and ensure proper parsing of identifiers within XML '
 'documents. The functions are part of the libxml2 library, which is used in '
 'various Apple operating systems for XML processing.')


In [14]:
tokenized = desc_tokenizer(query, truncation=True, return_tensors="pt", padding="max_length", max_length=1024)
query_embedding = mean_pooling(desc_model(input_ids=tokenized["input_ids"], attention_mask=tokenized["attention_mask"]).last_hidden_state, tokenized["attention_mask"])
scores = torch.matmul(candidate_embeddings, query_embedding.T)
target_score = scores[task['target']]
rank = (scores > target_score).sum().item() + 1
print(f"The result of original query is ranked {rank} out of {scores.size(0)}")

The result of original query is ranked 1131 out of 2477


In [15]:
tokenized = desc_tokenizer(augmented_query, truncation=True, return_tensors="pt", padding="max_length", max_length=1024)
query_embedding = mean_pooling(desc_model(input_ids=tokenized["input_ids"], attention_mask=tokenized["attention_mask"]).last_hidden_state, tokenized["attention_mask"])
scores = torch.matmul(candidate_embeddings, query_embedding.T)
target_score = scores[task['target']]
rank = (scores > target_score).sum().item() + 1
print(f"The result of augmented query is ranked {rank} out of {scores.size(0)}")

The result of augmented query is ranked 30 out of 2477
