In [None]:
from tree_sitter import Language, Parser, Node
from datasets import load_from_disk
import tree_sitter_cpp
import pickle, zstd

We delete the output of the block above, as it will output a warning prompt containing identity information.

In [2]:
def decompress_data(b_str):
    return pickle.loads(zstd.decompress(b_str))
def compress_data(obj):
    return zstd.compress(pickle.dumps(obj))

In [3]:
cpp_language = Language(tree_sitter_cpp.language())
def traverse_node(node: Node, mode: str = "post_order"):
    if mode not in ["post_order", "depth_first"]:
        raise ValueError("mode must be either post_order or depth_first")
    cursor = node.walk()
    while True:
        if mode == "depth_first":
            yield cursor.node
        if cursor.goto_first_child():
            continue
        if mode == "post_order":
            yield cursor.node
        if cursor.goto_next_sibling():
            continue
        while True:
            if not cursor.goto_parent():
                return
            if mode == "post_order":
                yield cursor.node
            if cursor.goto_next_sibling():
                break
def extract_snippets(code):
    parser = Parser(language=cpp_language)
    tree = parser.parse(code.encode())
    skip = []
    results = []
    # ["for_statement", "while_statement", "do_statement", "if_statement", "compound_statement"]:
    for node in traverse_node(tree.root_node):
        if node in skip:
            continue
        if node.type == "function_definition":
            body = node.child_by_field_name("body")
            if body and body.type == "compound_statement":
                skip.append(body)
        if node.type in [
            "for_statement",
            "while_statement",
            "do_statement",
            "if_statement",
            "compound_statement",
        ]:
            results.append((node.start_point.row, node.end_point.row + 1))
    return list(set(results))

In [4]:
def do_extract_snippets(row):
    src_list = decompress_data(row["src"])
    src_s = "\n".join(src_list)
    snippets = extract_snippets(src_s)
    return {"snippets_from_rule": snippets}

In [5]:
ds = load_from_disk("data/llm_extract_snippets")
ds = ds.map(do_extract_snippets, num_proc=20)

Map (num_proc=20): 100%|██████████| 100/100 [00:00<00:00, 391.11 examples/s]


In [6]:
ds.save_to_disk("data/rule_llm_extract_snippets")

Saving the dataset (1/1 shards): 100%|██████████| 100/100 [00:00<00:00, 2703.82 examples/s]
