In [1]:
import ast

class ComplexityFeatureExtractor(ast.NodeVisitor):
    def __init__(self):
        # ---------- Loop structure ----------
        self.num_loops = 0
        self.num_for = 0
        self.num_while = 0
        self.current_loop_depth = 0
        self.max_loop_depth = 0
        self.has_nested_loops = 0

        # ---------- Growth behavior ----------
        self.loop_bound_type = "unknown"  # constant | linear | log | mixed | unknown
        self.has_log_update = 0
        self.uses_sort = 0

        # ---------- Recursion ----------
        self.function_name = None
        self.recursion_flag = 0
        self.num_recursive_calls = 0

        # ---------- Control flow ----------
        self.has_break = 0
        self.has_continue = 0
        self.has_early_return = 0
        self.num_return = 0

        # ---------- Hidden loops ----------
        self.uses_comprehension = 0
        self.uses_generator = 0

        # ---------- Calls ----------
        self.num_function_calls = 0

        # ---------- Data structures ----------
        self.uses_list = 0
        self.uses_dict = 0
        self.uses_set = 0
        self.uses_tuple = 0

    # ---------- Utilities ----------
    def _enter_loop(self):
        self.current_loop_depth += 1
        self.max_loop_depth = max(self.max_loop_depth, self.current_loop_depth)
        if self.current_loop_depth > 1:
            self.has_nested_loops = 1

    def _exit_loop(self):
        self.current_loop_depth -= 1

    def _set_loop_bound(self, kind):
        if self.loop_bound_type == "unknown":
            self.loop_bound_type = kind
        elif self.loop_bound_type != kind:
            self.loop_bound_type = "mixed"

    # ---------- Functions ----------
    def visit_FunctionDef(self, node):
        self.function_name = node.name
        self.generic_visit(node)
        self.function_name = None

    # ---------- Loops ----------
    def visit_For(self, node):
        self.num_loops += 1
        self.num_for += 1
        self._enter_loop()

        # range(...) â†’ linear
        if isinstance(node.iter, ast.Call) and isinstance(node.iter.func, ast.Name):
            if node.iter.func.id == "range":
                self._set_loop_bound("linear")

        self.generic_visit(node)
        self._exit_loop()

    def visit_While(self, node):
        self.num_loops += 1
        self.num_while += 1
        self._enter_loop()
        self.generic_visit(node)
        self._exit_loop()

    # ---------- Calls / recursion / sort ----------
    def visit_Call(self, node):
        self.num_function_calls += 1

        # recursion
        if isinstance(node.func, ast.Name):
            if node.func.id == self.function_name:
                self.recursion_flag = 1
                self.num_recursive_calls += 1

            if node.func.id == "sorted":
                self.uses_sort = 1
                self._set_loop_bound("linear")

        # list.sort()
        if isinstance(node.func, ast.Attribute):
            if node.func.attr == "sort":
                self.uses_sort = 1
                self._set_loop_bound("linear")

        self.generic_visit(node)

    # ---------- Logarithmic update ----------
    def visit_AugAssign(self, node):
        if isinstance(node.op, (ast.FloorDiv, ast.RShift)):
            self.has_log_update = 1
            self._set_loop_bound("log")
        self.generic_visit(node)

    # ---------- Control flow ----------
    def visit_Break(self, node):
        self.has_break = 1

    def visit_Continue(self, node):
        self.has_continue = 1

    def visit_Return(self, node):
        self.num_return += 1
        if self.current_loop_depth > 0:
            self.has_early_return = 1
        self.generic_visit(node)

    # ---------- Hidden loops ----------
    def visit_ListComp(self, node):
        self.uses_comprehension = 1
        self._set_loop_bound("linear")
        self.generic_visit(node)

    def visit_SetComp(self, node):
        self.uses_comprehension = 1
        self._set_loop_bound("linear")
        self.generic_visit(node)

    def visit_DictComp(self, node):
        self.uses_comprehension = 1
        self._set_loop_bound("linear")
        self.generic_visit(node)

    def visit_GeneratorExp(self, node):
        self.uses_generator = 1
        self._set_loop_bound("linear")
        self.generic_visit(node)

    # ---------- Data structures ----------
    def visit_List(self, node):
        self.uses_list = 1
        self.generic_visit(node)

    def visit_Dict(self, node):
        self.uses_dict = 1
        self.generic_visit(node)

    def visit_Set(self, node):
        self.uses_set = 1
        self.generic_visit(node)

    def visit_Tuple(self, node):
        self.uses_tuple = 1
        self.generic_visit(node)

In [2]:
def extract_complexity_features(code: str) -> dict:
    tree = ast.parse(code)
    extractor = ComplexityFeatureExtractor()
    extractor.visit(tree)

    # Default linear if loops exist but no bound detected
    if extractor.loop_bound_type == "unknown" and extractor.num_loops > 0:
        extractor.loop_bound_type = "linear"

    return {
        "representation": "ast_static_complete",

        # Loop structure
        "num_loops": extractor.num_loops,
        "num_for": extractor.num_for,
        "num_while": extractor.num_while,
        "max_loop_depth": extractor.max_loop_depth,
        "has_nested_loops": extractor.has_nested_loops,

        # Growth behavior
        "loop_bound_type": extractor.loop_bound_type,
        "has_log_update": extractor.has_log_update,
        "uses_sort": extractor.uses_sort,

        # Recursion
        "recursion_flag": extractor.recursion_flag,
        "num_recursive_calls": extractor.num_recursive_calls,

        # Control flow
        "has_break": extractor.has_break,
        "has_continue": extractor.has_continue,
        "has_early_return": extractor.has_early_return,
        "num_return": extractor.num_return,

        # Hidden loops
        "uses_comprehension": extractor.uses_comprehension,
        "uses_generator": extractor.uses_generator,

        # Calls
        "num_function_calls": extractor.num_function_calls,

        # Data structures
        "uses_list": extractor.uses_list,
        "uses_dict": extractor.uses_dict,
        "uses_set": extractor.uses_set,
        "uses_tuple": extractor.uses_tuple,
    }


In [3]:
# -------- Sample code to analyze --------
sample_code = """
def example(arr):
    total = 0
    for i in range(len(arr)):
        for j in range(len(arr)):
            total += arr[i] * arr[j]
    return total
"""

# -------- Extract features --------
features = extract_complexity_features(sample_code)

# -------- Print results clearly --------
print("=== Sample Code ===")
print(sample_code)

print("\n=== Extracted Features ===")
for key, value in features.items():
    print(f"{key}: {value}")


=== Sample Code ===

def example(arr):
    total = 0
    for i in range(len(arr)):
        for j in range(len(arr)):
            total += arr[i] * arr[j]
    return total


=== Extracted Features ===
representation: ast_static_complete
num_loops: 2
num_for: 2
num_while: 0
max_loop_depth: 2
has_nested_loops: 1
loop_bound_type: linear
has_log_update: 0
uses_sort: 0
recursion_flag: 0
num_recursive_calls: 0
has_break: 0
has_continue: 0
has_early_return: 0
num_return: 1
uses_comprehension: 0
uses_generator: 0
num_function_calls: 4
uses_list: 0
uses_dict: 0
uses_set: 0
uses_tuple: 0


In [4]:
import json
from tqdm import tqdm

INPUT_PATH = "data.jsonl"
OUTPUT_PATH = "data_features.jsonl"

processed = 0
skipped = 0

with open(INPUT_PATH, "r", encoding="utf-8") as fin, \
     open(OUTPUT_PATH, "w", encoding="utf-8") as fout:

    for line in tqdm(fin, desc="Extracting features"):
        try:
            item = json.loads(line)
            code = item["code"]

            features = extract_complexity_features(code)

            # ðŸ”¹ flatten features into top-level keys
            item.update(features)

            fout.write(json.dumps(item, ensure_ascii=False) + "\n")
            processed += 1

        except Exception:
            skipped += 1
            continue

print(f"Done âœ…")
print(f"Processed: {processed}")
print(f"Skipped:   {skipped}")
print(f"Saved to:  {OUTPUT_PATH}")


Extracting features: 40030it [00:05, 7843.71it/s]

Done âœ…
Processed: 40030
Skipped:   0
Saved to:  data_features.jsonl



