In [None]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from datasets import load_from_disk
from typing import Tuple, List
from hashlib import sha256
from openai import OpenAI
from loguru import logger
from pprint import pprint
from tqdm import tqdm
import random
import pickle
import json
import zstd
import time
import os
import re

We delete the output of the block above, as it will output a warning prompt containing identity information.

In [2]:
def decompress_data(b_str):
    return pickle.loads(zstd.decompress(b_str))
def compress_data(obj):
    return zstd.compress(pickle.dumps(obj))

In [3]:
with open("secret.json") as f:
    secret = json.load(f)
LLM_KEY = secret["LLM_KEY"]
LLM_URL = secret["LLM_URL"]
LLM_MODEL = secret["LLM_MODEL"]
TEMPERATURE = 1
MAX_TOKENS = 8192
TIMEOUT = 60
with open("prompts/snippet_extract.txt") as f:
    SYSTEM_PROMPT = f.read()


In [4]:
client = OpenAI(api_key=LLM_KEY, base_url=LLM_URL)
def complete(user: str):
    cnt = 0
    try:
        completion = client.chat.completions.create(
            messages=[
                {
                    "role": "system",
                    "content": SYSTEM_PROMPT,
                },
                {
                    "role": "user",
                    "content": user,
                },
            ],
            timeout=TIMEOUT,
            model=LLM_MODEL,
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS,
        )
        result = completion.choices[0].message.content
        return result
    except Exception as e:
        logger.error(e)
        return None

In [5]:
def source_template(src_list):
    result = ""
    cnt = 0
    for src_idx, src_row in enumerate(src_list):
        if cnt > 768:
            break
        cnt += 1
        result += f"{src_idx} {src_row}\n"
    return result

In [6]:
def get_ranges(target):
    ranges = []
    current_range = target["range"]
    ranges.append(current_range)
    if len(target["sub_snippets"]) > 0:
        for sub_snippet in target["sub_snippets"]:
            ranges += get_ranges(sub_snippet)
    return ranges

In [7]:
def get_json(s):
    pattern = r"```[\w\s]*\n(.*?)```"
    match = re.search(pattern, s, re.DOTALL)
    if match:
        code_block = match.group(1).strip()
        return eval(code_block)
    else:
        raise Exception

In [8]:
def get_snippets(src):
    def _get_snippets(src):
        src_list = decompress_data(src)
        user_input = f"Here's the source code of the target program:\n{source_template(src_list)}\nPlease extract the snippets that are relevant to the target program."
        llm_response = complete(user_input)
        structured_result = get_json(llm_response)
        ranges = get_ranges(structured_result)
        return ranges
    for _ in range(3):
        try:
            return _get_snippets(src)
        except Exception as e:
            logger.error(e)

In [9]:
ds = load_from_disk("data/line_matched")

# Illustration of our snippet extraction prompt

In [10]:
row = random.choice(ds)
src = decompress_data(row["src"])
user_input = f"Here's the source code of the target program:\n{source_template(src)}\nPlease extract the snippets that are relevant to the target program."
llm_response = complete(user_input)
structured_result = get_json(llm_response)
pprint(structured_result)

{'description': 'Function to set the CPU frequency for a given CPU',
 'range': [0, 37],
 'sub_snippets': [{'description': 'Retrieve the current CPU frequency policy',
                   'range': [1, 2],
                   'sub_snippets': []},
                  {'description': 'Initialize variables and set the userspace '
                                  'governor',
                   'range': [3, 7],
                   'sub_snippets': []},
                  {'description': 'Check if the policy retrieval was '
                                  'successful',
                   'range': [9, 10],
                   'sub_snippets': []},
                  {'description': 'Ensure the CPU is using the userspace '
                                  'governor',
                   'range': [12, 20],
                   'sub_snippets': [{'description': 'Compare the current '
                                                    'governor with the '
                                                    

# Create new dataset

In [11]:
ds = ds.map(lambda x: {"snippets_from_llm": get_snippets(x['src'])}, num_proc=20)
ds.save_to_disk("data/llm_extract_snippets")

Map (num_proc=20): 100%|██████████| 100/100 [04:04<00:00,  2.45s/ examples]
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 2948.67 examples/s]
