In [None]:
from concurrent.futures import ThreadPoolExecutor, as_completed
from datasets import load_from_disk
from typing import Tuple, List
from hashlib import sha256
from openai import OpenAI
from loguru import logger
from pprint import pprint
from tqdm import tqdm
import random
import pickle
import json
import zstd
import time
import os
import re

We delete the output of the block above, as it will output a warning prompt containing identity information.

In [2]:
def decompress_data(b_str):
    return pickle.loads(zstd.decompress(b_str))
def compress_data(obj):
    return zstd.compress(pickle.dumps(obj))

In [3]:
with open("secret.json") as f:
    secret = json.load(f)
LLM_KEY = secret["LLM_KEY"]
LLM_URL = secret["LLM_URL"]
LLM_MODEL = secret["LLM_MODEL"]
TEMPERATURE = 1
MAX_TOKENS = 8192
TIMEOUT = 60
with open("prompts/text_generation.txt") as f:
    SYSTEM_PROMPT = f.read()


In [4]:
client = OpenAI(api_key=LLM_KEY, base_url=LLM_URL)
def complete(user: str):
    cnt = 0
    try:
        completion = client.chat.completions.create(
            messages=[
                {
                    "role": "system",
                    "content": SYSTEM_PROMPT,
                },
                {
                    "role": "user",
                    "content": user,
                },
            ],
            timeout=TIMEOUT,
            model=LLM_MODEL,
            temperature=TEMPERATURE,
            max_tokens=MAX_TOKENS,
        )
        result = completion.choices[0].message.content
        return result
    except Exception as e:
        logger.error(e)
        return None

In [5]:
def source_template(src_list):
    result = ""
    cnt = 0
    for src_idx, src_row in enumerate(src_list):
        if cnt > 768:
            break
        cnt += 1
        result += f"{src_idx} {src_row}\n"
    return result

In [6]:
def get_json(s):
    pattern = r"```[\w\s]*\n(.*?)```"
    match = re.search(pattern, s, re.DOTALL)
    if match:
        code_block = match.group(1).strip()
        return eval(code_block)
    else:
        raise Exception

In [7]:
def generate_descriptions(row):
    src_list = decompress_data(row["src"])
    llm_candidates = row["snippets_from_llm"]
    rule_candidates = row["snippets_from_rule"]
    candidates = list(set([tuple(c) for c in llm_candidates + rule_candidates]))
    src_s = source_template(src_list)
    user_input = "Here's the target source function:" + "\n"
    user_input += source_template(src_list) + "\n"
    user_input += f"candidates: {candidates}\n"
    response = complete(user_input)
    result = get_json(response)
    return result

In [8]:
ds = load_from_disk("data/rule_llm_extract_snippets")

# Illustration of text generation in both `function-level` and `snippet-level`

In [9]:
row = random.choice(ds)
print(source_template(decompress_data(row["src"])))

0 int process_cddb_titles(int sock_fd, char *inbuff, int readbytes)
1 {
2 	int	finished = 0;
3 	char	*p = inbuff;
4 	int	ind = 0;
5 	unsigned char **	target = &global.creator;
6 
7 	do {
8 		while (readbytes > 0) {
9 			/* do we have a complete line in the buffer? */
10 			p = (char *)memchr(inbuff+ind, '\n', readbytes);
11 			if (p == NULL) break;
12 
13 			/* look for the terminator first */
14 			if (!strncmp(".\r\n", inbuff+ind, 3)) {
15 				finished = 1;
16 				break;
17 			}
18 			/* kill carriage return */
19 			if (p > inbuff+ind && *(p-1) == '\r') {
20 				*(p-1) = '\0';
21 			}
22 			/* kill line feed */
23 			*p = '\0';
24 
25 			/* handle escaped characters */
26 
27 			{
28 				char *q = inbuff+ind;
29 				while (*q) {
30 					if (*q++ == '\\' && *q != '\0') {
31 						if (*q == '\\') {
32 							readbytes--;
33 							p--;
34 							memmove(q, q+1, readbytes - (q-inbuff-ind));
35 						} else if (*q == 'n') {
36 							*(q-1) = '\n';
37 							readbytes--;
38 							p--;


In [10]:
result = generate_descriptions(row)
pprint(result)

{'function_description': {'functionality': 'Process and parse CDDB (Compact '
                                           'Disc Data Base) titles from a '
                                           'buffer, handling various metadata '
                                           'fields such as disc titles, track '
                                           'titles, release years, and genres. '
                                           'The function reads data from a '
                                           'socket, processes it line by line, '
                                           'and updates global variables with '
                                           'the parsed information.',
                          'implementation': 'The function iteratively reads '
                                            'data from a buffer, checks for '
                                            'complete lines, and processes '
                                            'each line based on 

# Create new dataset

In [11]:
ds = ds.map(lambda x: {"description": compress_data(generate_descriptions(x))}, num_proc=25)
ds.save_to_disk("data/function_snippets_with_descriptions")

Map (num_proc=25): 100%|██████████| 100/100 [03:02<00:00,  1.82s/ examples]
Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 3380.38 examples/s]
