In [6]:
import json
import os
from typing import Dict, Set

def extract_types_from_cfg(cfg: Dict) -> set:
    """从AST配置中提取所有可能的类型"""
    types = set()
    
    def traverse(node):
        if isinstance(node, dict):
            if "type" in node:
                types.add(node["type"])
            for value in node.values():
                traverse(value)
        elif isinstance(node, list):
            for item in node:
                traverse(item)
                
    traverse(cfg)
    return types

# 遍历所有json文件并去重
all_types = set()
processed_files = set()

for root in ["./static_ast", "./llm_ast"]:
    for file in os.listdir(root):
        if not file.endswith('.json'):
            continue
            
        # 检查文件是否已处理过(通过文件名去重)
        if file in processed_files:
            continue
            
        processed_files.add(file)
        file_path = os.path.join(root, file)
        
        with open(file_path) as f:
            cfg = json.load(f)
            types = extract_types_from_cfg(cfg)
            all_types.update(types)

print("所有文件中提取的唯一类型:")
for t in sorted(all_types):
    print(f"- {t}")


所有文件中提取的唯一类型:
- aliased_import
- argument_list
- as_pattern
- as_pattern_target
- assert_statement
- assignment
- attribute
- augmented_assignment
- await
- binary_operator
- block
- boolean_operator
- break_statement
- call
- class_definition
- comment
- comparison_operator
- concatenated_string
- conditional_expression
- continue_statement
- decorated_definition
- decorator
- default_parameter
- delete_statement
- dictionary
- dictionary_comprehension
- dictionary_splat
- dictionary_splat_pattern
- dotted_name
- elif_clause
- ellipsis
- else_clause
- escape_interpolation
- escape_sequence
- except_clause
- expression_list
- expression_statement
- false
- finally_clause
- float
- for_in_clause
- for_statement
- format_specifier
- function_definition
- future_import_statement
- generator_expression
- generic_type
- global_statement
- identifier
- if_clause
- if_statement
- import_from_statement
- import_prefix
- import_statement
- integer
- interpolation
- keyword_argument
- keyword_se

In [8]:
"""
Multi-threaded script where we separately:
1) Generate an LLM-based AST,
2) Generate a tree-sitter-based static AST,
3) Compare snippet-level labels,
4) Save both ASTs as JSON.

We split logic into two main functions:
- generate_llm_ast(code): returns LLM AST
- generate_static_ast(code): returns tree-sitter static AST

Then process_single_file() calls both, compares, and saves results.
Finally main() uses ThreadPoolExecutor to process multiple files in parallel.

Dependencies:
- pip install py-tree-sitter
- Ensure you have a valid llm.py with get_llm_answers()
- (Optional) pip install transformers, if using HF tokenizer
"""

import os
import sys
import json
from typing import Any, Dict, List, Tuple
import concurrent.futures
from multiprocessing import cpu_count

from tqdm import tqdm

# LLM interface
from llm import get_llm_answers

# py-tree-sitter
import tree_sitter_python as tspython
from tree_sitter import Language, Parser

###############################################################################
#                          Tree-sitter initialization                         #
###############################################################################
# 假设 tree_sitter_python.language() 可以返回 .so 或 .dylib 路径，根据实际情况修改
PY_LANGUAGE = Language(tspython.language())
parser = Parser(PY_LANGUAGE)

###############################################################################
#                       Tokenization & Utility Functions                      #
###############################################################################
import re

def tokenize_code(code: str) -> List[Tuple[int, int, str]]:
    """
    使用正则分词，直接在原始代码上通过 finditer() 获取匹配位置，
    并保留换行符，让 LLM 能识别多行结构。
    """
    token_pattern = (
        r'[A-Za-z_]\w*|'  # 标识符
        r'[0-9]+|'        # 数字
        r'"[^"]*"|'       # 双引号字符串
        r"'[^']*'|"       # 单引号字符串
        r'\\[ntr]|'       # 转义符 \n \t \r
        r'//.*|'          # 单行注释 (如 C++/Java/JS 风格)
        r'/\*.*?\*/|'     # 多行注释 (如 C 风格) - 如无需要可去掉
        r'\n|\r|\t|'      # 处理原生换行/回车/制表符
        r'\S'             # 其他符号(如 +, -, {, }, 以及任何其它非空白字符)
    )

    tokens_with_offset = []
    for match in re.finditer(token_pattern, code, re.MULTILINE | re.DOTALL):
        tk = match.group(0)
        start_offset, end_offset = match.span()
        tokens_with_offset.append((start_offset, end_offset, tk))

    return tokens_with_offset

def rebuild_label_llm(
    code: str,
    indexed_tokens: List[Tuple[int, int, str]],  # [(start_offset, end_offset, token_text), ...]
    start_token: int,
    end_token: int
) -> str:
    """
    根据 start_token / end_token 在 indexed_tokens 中查找首尾字符偏移，
    然后从源码中截取原始片段，并移除换行符，以保证最终 label 不含 '\n'。
    """
    if (
        start_token < 0 or end_token < 0
        or start_token >= len(indexed_tokens)
        or end_token >= len(indexed_tokens)
        or end_token < start_token
    ):
        return ""

    start_offset = indexed_tokens[start_token][0]  # 第一个token的start
    end_offset = indexed_tokens[end_token][1]      # 最后一个token的end

    if start_offset < 0 or end_offset > len(code):
        return ""

    snippet = code[start_offset:end_offset]
    return snippet

def collect_llm_labels(
    node: Dict[str, Any],
    code: str,
    indexed_tokens: List[Tuple[int, int, str]]
) -> None:
    """
    递归地给 LLM AST 中每个节点设置 node["label"]，
    使用 start_token / end_token 对应的 offset 在 code 中切片，并去除换行符。
    """
    st = node.get("start_token", -1)
    et = node.get("end_token", -1)
    snippet = rebuild_label_llm(code, indexed_tokens, st, et)
    node["label"] = snippet

    for child in node.get("children", []):
        collect_llm_labels(child, code, indexed_tokens)

###############################################################################
#                         Compare: snippet-level text                         #
###############################################################################

def compare_ast_nodes(
    node1: Dict[str, Any],
    node2: Dict[str, Any],
    path: str = ""
):
    """
    Compare node["label"] text, plus child count, recursively.
    """
    if not node1 or not node2:
        return

    label1 = (node1.get("label", "") or "").strip()
    label2 = (node2.get("label", "") or "").strip()
    if label1 != label2:
        print(f"[Diff] label mismatch at {path}:")
        print(f"  1: {repr(label1)}")
        print(f"  2: {repr(label2)}")

    c1 = node1.get("children", [])
    c2 = node2.get("children", [])
    if len(c1) != len(c2):
        print(f"[Diff] Child count mismatch at {path}: {len(c1)} vs {len(c2)}")

    for i in range(min(len(c1), len(c2))):
        compare_ast_nodes(c1[i], c2[i], path + f".children[{i}]")

###############################################################################
#                        1) Generate LLM AST (chunk-based)                    #
###############################################################################
def llm_build_ast_from_tokens(tokens_with_offset: List[Tuple[int, int, str]], top_level=True) -> Dict[str, Any]:
    """
    1) 将 tokens 列表转成可读字符串 token_info (index -> token_string)
    2) 调用 LLM, 返回 JSON AST
       - top_level: 是否处于最外层(只在第一次调用时为True)
         => 只有最外层才需要 'module' 根节点，子层则用 'block' 或其他节点，不重复 'module'.
    """
    indexed_tokens = [(i, t[2]) for i, t in enumerate(tokens_with_offset)]
    token_info = "\n".join(f"{i}: {text}" for (i, text) in indexed_tokens)

    # 该列表包含了所有可用的节点类型：
    allowed_types = [
        "aliased_import", "argument_list", "as_pattern", "as_pattern_target", "assert_statement",
        "assignment", "attribute", "augmented_assignment", "await", "binary_operator", "block",
        "boolean_operator", "break_statement", "call", "class_definition", "comment",
        "comparison_operator", "concatenated_string", "conditional_expression",
        "continue_statement", "decorated_definition", "decorator", "default_parameter",
        "delete_statement", "dictionary", "dictionary_comprehension", "dictionary_splat",
        "dictionary_splat_pattern", "dotted_name", "elif_clause", "ellipsis", "else_clause",
        "escape_interpolation", "escape_sequence", "except_clause", "expression_list",
        "expression_statement", "false", "finally_clause", "float", "for_in_clause", "for_statement",
        "format_specifier", "function_definition", "future_import_statement",
        "generator_expression", "generic_type", "global_statement", "identifier", "if_clause",
        "if_statement", "import_from_statement", "import_prefix", "import_statement", "integer",
        "interpolation", "keyword_argument", "keyword_separator", "lambda", "lambda_parameters",
        "line_continuation", "list", "list_comprehension", "list_splat", "list_splat_pattern",
        "module", "named_expression", "none", "nonlocal_statement", "not_operator", "pair",
        "parameters", "parenthesized_expression", "pass_statement", "pattern_list", "raise_statement",
        "relative_import", "return_statement", "set", "set_comprehension", "slice", "string",
        "string_content", "string_end", "string_start", "subscript", "true", "try_statement",
        "tuple", "tuple_pattern", "type", "type_parameter", "typed_default_parameter", "typed_parameter",
        "unary_operator", "union_type", "while_statement", "with_clause", "with_item", "with_statement",
        "yield"
    ]
    allowed_types_str = ", ".join(allowed_types)

    # 顶层只允许一次 module 的说明
    top_level_instruction = (
        "Exactly one 'module' node can appear at the root. For sub-blocks inside, use 'block' or other suitable types.\n"
        "Avoid repeating the same code block multiple times.\n"
    )

    prompt = (
        f"Below is a list of tokens (index -> token_string) for a code snippet:\n"
        f"{token_info}\n\n"
        "Please create a JSON-based AST with the following requirements:\n\n"
        "1) Each node in the AST must have these fields:\n"
        f"   - 'type': a string from this set: {allowed_types_str}\n"
        "   - 'start_token': index in the token list where this node begins\n"
        "   - 'end_token': index in the token list where this node ends\n"
        "   - 'children': an array of child nodes\n\n"
        "2) All leaf nodes must represent exactly one token (start_token == end_token), i.e. the smallest indivisible lexical elements.\n\n"
        "3) Higher-level constructs (function defs, statements, expressions) should be internal nodes with 'children'.\n\n"
        "4) If a piece of code doesn't map well, choose the closest match (like 'expression_statement' or 'block').\n\n"
        "5) You may omit or merge purely structural tokens (e.g. whitespace) if not semantically relevant,\n"
        "   but keep meaningful tokens (identifiers, literals, operators, punctuation, etc.) as separate leaf nodes.\n\n"
        "6) Return valid JSON only, with no extra text.\n\n"
        "7) Sibling nodes at the same parent level must NOT have overlapping [start_token..end_token].\n"
        "   - Each child's range must strictly follow the previous one. If code repeats or overlaps, unify or skip.\n\n"
        "8) Sort children by ascending token range, ensuring start_token of each child is greater than\n"
        "   the end_token of the previous child.\n"
    )

    if top_level:
        prompt += (
            "\nImportant:\n"
            " * The root node must be of type 'module'.\n"
            " * Do not nest 'module' inside 'module'. If you need to group child statements, use 'block'.\n"
            f"{top_level_instruction}"
        )
    else:
        prompt += (
            "\nImportant:\n"
            " * Do NOT create an additional 'module' node. The top-level is already a 'module'.\n"
            " * If needed, use 'block' or other node types for grouping.\n"
            f"{top_level_instruction}"
        )

    try:
        llm_output = get_llm_answers(
            prompt,
            model_name="gpt-4o",  # 示例model名，可替换
            require_json=True,
            temperature=0
        )
        ast_dict = json.loads(llm_output)
        return ast_dict
    except Exception as e:
        print(f"[Error] LLM build AST: {e}")
        return {
            "type": "ErrorNode",
            "start_token": -1,
            "end_token": -1,
            "children": []
        }


def llm_determine_chunk_boundaries(code: str) -> List[Tuple[int, int]]:
    """
    调用 LLM 来确定代码分块（分段）行号的示例。
    如果 LLM 出错或返回结果不合要求，就回退到整段处理。
    """
    lines = code.splitlines()
    numbered_code = "\n".join(f"{i+1}: {line}" for i, line in enumerate(lines))
    prompt = (
        f"Below is code with line numbers:\n{numbered_code}\n\n"
        "Please split this code into top-level sections by line range. Return JSON like:\n"
        "[[1, 10], [11, 30], ...]"
    )
    try:
        raw = get_llm_answers(
            prompt,
            model_name="deepseek-chat",
            require_json=True,
            temperature=0
        )
        arr = json.loads(raw)
        boundaries = []
        for item in arr:
            if isinstance(item, list) and len(item) == 2:
                boundaries.append((item[0], item[1]))
        if not boundaries:
            return [(1, len(lines))]  # 如果 LLM 没给有效结果，就整段处理
        return boundaries
    except Exception as e:
        print(f"[Error] LLM chunk boundary: {e}")
        return [(1, len(lines))]

def extract_lines(code: str, start_line: int, end_line: int) -> str:
    """
    截取指定行区间的代码片段。
    """
    all_lines = code.splitlines(keepends=True)
    if start_line < 1:
        start_line = 1
    if end_line > len(all_lines):
        end_line = len(all_lines)
    return "".join(all_lines[start_line - 1:end_line])

def chunk_and_build_llm_ast(
    code: str,
    recursion_level=0,
    max_chunk_size=1500,
    top_level=True
) -> Dict[str, Any]:
    """
    先用 LLM 拆分成片段，再对每个片段进行分词(并记录offset)生成AST。
    如果 chunk 依然太长，就递归拆分。
    - top_level: 是否是最外层(只在初始调用时为 True)。
    """
    if recursion_level > 5:
        tokens_with_offset = tokenize_code(code)
        return llm_build_ast_from_tokens(tokens_with_offset, top_level=top_level)

    boundaries = llm_determine_chunk_boundaries(code)
    if len(boundaries) == 1:
        sl, el = boundaries[0]
        snippet = extract_lines(code, sl, el)
        # 如果单段还是太长，就一分为二递归
        if len(snippet) > max_chunk_size:
            half = (sl + el) // 2
            code1 = extract_lines(code, sl, half)
            code2 = extract_lines(code, half + 1, el)
            ast1 = chunk_and_build_llm_ast(code1, recursion_level + 1, max_chunk_size, top_level=top_level)
            ast2 = chunk_and_build_llm_ast(code2, recursion_level + 1, max_chunk_size, top_level=False)
            node_type = "module" if top_level else "block"
            return {
                "type": node_type,
                "start_token": -1,
                "end_token": -1,
                "children": [ast1, ast2]
            }
        else:
            # snippet 足够小 => 分词 => 调 LLM
            tokens_with_offset = tokenize_code(snippet)
            return llm_build_ast_from_tokens(tokens_with_offset, top_level=top_level)

    # 如果 LLM chunk boundary 返回多段
    partial_asts = []
    for i, (sl, el) in enumerate(boundaries):
        snippet = extract_lines(code, sl, el)
        if len(snippet) > max_chunk_size:
            sub_ast = chunk_and_build_llm_ast(snippet, recursion_level + 1, max_chunk_size, top_level=False)
            partial_asts.append(sub_ast)
        else:
            tokens_with_offset = tokenize_code(snippet)
            sub_ast = llm_build_ast_from_tokens(tokens_with_offset, top_level=False)
            partial_asts.append(sub_ast)

    # 将这些子 AST 合并到一个父节点
    node_type = "module" if top_level else "block"
    return {
        "type": node_type,
        "start_token": -1,
        "end_token": -1,
        "children": partial_asts
    }

def generate_llm_ast(code: str) -> Dict[str, Any]:
    """
    Public function to generate LLM-based AST from code,
    plus collecting snippet labels (去除换行符).
    """
    # 使用 chunk 拆分逻辑 (top_level=True 表示最外层)
    llm_ast = chunk_and_build_llm_ast(code, top_level=True)

    # 这里再次对整份 code 分词(带offset)，用来最终补充 label
    all_tokens_with_offset = tokenize_code(code)
    collect_llm_labels(llm_ast, code, all_tokens_with_offset)
    return llm_ast

###############################################################################
#              2) Generate Tree-sitter-based "static" Python AST             #
###############################################################################

class PyTreeSitterStaticHandler:
    """
    We'll produce an AST structure: {type, label, children},
    where label = node.text.decode("utf-8") for the entire node range.
    """
    def __init__(self):
        self.parser = parser

    def generate_static_ast(self, code: str) -> Dict[str, Any]:
        tree = self.parser.parse(code.encode("utf-8"))
        root_node = tree.root_node
        return self.ts_node_to_dict(root_node)

    def ts_node_to_dict(self, node) -> Dict[str, Any]:
        if not node.is_named:
            return None
        node_type = node.type
        node_text = (node.text or b"").decode("utf-8")
        custom = {
            "type": node_type,
            "label": node_text,
            "children": []
        }
        for i in range(node.child_count):
            child = node.child(i)
            sub = self.ts_node_to_dict(child)
            if sub:
                custom["children"].append(sub)
        return custom

def generate_tree_sitter_ast(code: str) -> Dict[str, Any]:
    """
    Public function to parse code with tree-sitter, returning {type, label, children}.
    """
    handler = PyTreeSitterStaticHandler()
    return handler.generate_static_ast(code)

###############################################################################
#                         process_single_file logic                           #
###############################################################################

def process_llm_ast(code: str, file_path: str) -> dict:
    """处理LLM AST生成和保存"""
    llm_out_dir = "llm_ast/gpt-4o"
    os.makedirs(llm_out_dir, exist_ok=True)
    llm_json_path = os.path.join(llm_out_dir, os.path.basename(file_path) + ".json")

    # 若已生成过，就直接加载，不再重复调用
    if os.path.exists(llm_json_path):
        with open(llm_json_path, "r", encoding="utf-8") as f:
            llm_ast = json.load(f)
        print(f"[LLM AST] (cached) => {llm_json_path}")
        return llm_ast

    # 否则生成
    llm_ast = generate_llm_ast(code)
    with open(llm_json_path, "w", encoding="utf-8") as fout:
        json.dump(llm_ast, fout, indent=4, ensure_ascii=False)
    print(f"[LLM AST] => {llm_json_path}")
    return llm_ast

def process_static_ast(code: str, file_path: str) -> dict:
    """处理静态AST生成和保存"""
    ts_ast = generate_tree_sitter_ast(code)
    ts_out_dir = "static_ast"
    os.makedirs(ts_out_dir, exist_ok=True)
    ts_json_path = os.path.join(ts_out_dir, os.path.basename(file_path) + ".json")
    with open(ts_json_path, "w", encoding="utf-8") as fout:
        json.dump(ts_ast, fout, indent=4, ensure_ascii=False)
    print(f"[TS AST] => {ts_json_path}")
    return ts_ast

def process_single_file(file_path: str):
    """
    1) read code
    2) generate LLM AST 
    3) generate static AST (tree-sitter)
    4) compare snippet-level label (可选)
    5) store as JSON
    """
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            code = f.read()
    except Exception as e:
        print(f"[Error] reading {file_path}: {e}")
        return

    # 2) generate LLM AST
    llm_ast = process_llm_ast(code, file_path)

    # 3) generate static AST
    ts_ast = process_static_ast(code, file_path)

    # 4) compare snippet-level label (可选)
    # compare_ast_nodes(llm_ast, ts_ast)

###############################################################################
#                                 main()                                      #
###############################################################################

def main():
    source_dir = "../source_code"
    if not os.path.isdir(source_dir):
        print(f"[Error] Directory {source_dir} does not exist.")
        return

    # 示例：只处理前 10 个 .py 文件
    files = [f for f in os.listdir(source_dir) if f.endswith(".py")][:200]

    # 并行处理
    with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_count()) as executor:
        futures = []
        pbar = tqdm(total=len(files), desc="处理文件")
        for fname in files:
            full_path = os.path.join(source_dir, fname)
            future = executor.submit(process_single_file, full_path)
            future.add_done_callback(lambda _: pbar.update(1))
            futures.append(future)
        concurrent.futures.wait(futures)
        pbar.close()

if __name__ == "__main__":
    main()


处理文件:   0%|          | 0/30 [00:00<?, ?it/s]

处理文件:   3%|▎         | 1/30 [00:42<20:40, 42.78s/it]

[LLM AST] => llm_ast/gpt-4o/9.py.json
[TS AST] => static_ast/9.py.json


处理文件:   7%|▋         | 2/30 [00:43<08:23, 17.98s/it]

[LLM AST] => llm_ast/gpt-4o/14.py.json
[TS AST] => static_ast/14.py.json


处理文件:  10%|█         | 3/30 [00:57<07:19, 16.29s/it]

[LLM AST] => llm_ast/gpt-4o/11.py.json
[TS AST] => static_ast/11.py.json


处理文件:  13%|█▎        | 4/30 [01:12<06:45, 15.61s/it]

[LLM AST] => llm_ast/gpt-4o/208.py.json
[TS AST] => static_ast/208.py.json


处理文件:  17%|█▋        | 5/30 [01:14<04:29, 10.78s/it]

[LLM AST] => llm_ast/gpt-4o/184.py.json
[TS AST] => static_ast/184.py.json


处理文件:  20%|██        | 6/30 [01:17<03:17,  8.22s/it]

[LLM AST] => llm_ast/gpt-4o/167.py.json
[TS AST] => static_ast/167.py.json


处理文件:  23%|██▎       | 7/30 [01:34<04:12, 10.99s/it]

[LLM AST] => llm_ast/gpt-4o/6.py.json
[TS AST] => static_ast/6.py.json


处理文件:  27%|██▋       | 8/30 [01:35<02:53,  7.88s/it]

[LLM AST] => llm_ast/gpt-4o/174.py.json
[TS AST] => static_ast/174.py.json


处理文件:  30%|███       | 9/30 [01:40<02:25,  6.95s/it]

[LLM AST] => llm_ast/gpt-4o/107.py.json
[TS AST] => static_ast/107.py.json


处理文件:  33%|███▎      | 10/30 [01:57<03:21, 10.10s/it]

[LLM AST] => llm_ast/gpt-4o/201.py.json
[TS AST] => static_ast/201.py.json


处理文件:  37%|███▋      | 11/30 [01:59<02:24,  7.59s/it]

[LLM AST] => llm_ast/gpt-4o/148.py.json
[TS AST] => static_ast/148.py.json


处理文件:  40%|████      | 12/30 [02:12<02:43,  9.07s/it]

[LLM AST] => llm_ast/gpt-4o/163.py.json
[TS AST] => static_ast/163.py.json


处理文件:  43%|████▎     | 13/30 [02:20<02:33,  9.02s/it]

[LLM AST] => llm_ast/gpt-4o/13.py.json
[TS AST] => static_ast/13.py.json


处理文件:  47%|████▋     | 14/30 [02:25<02:01,  7.61s/it]

[LLM AST] => llm_ast/gpt-4o/62.py.json
[TS AST] => static_ast/62.py.json


处理文件:  50%|█████     | 15/30 [02:35<02:04,  8.30s/it]

[LLM AST] => llm_ast/gpt-4o/139.py.json
[TS AST] => static_ast/139.py.json


处理文件:  53%|█████▎    | 16/30 [02:37<01:32,  6.58s/it]

[LLM AST] => llm_ast/gpt-4o/71.py.json
[TS AST] => static_ast/71.py.json


处理文件:  57%|█████▋    | 17/30 [02:56<02:14, 10.34s/it]

[LLM AST] => llm_ast/gpt-4o/202.py.json
[TS AST] => static_ast/202.py.json
[Error] LLM chunk boundary: Expecting ',' delimiter: line 923 column 9 (char 17312)


处理文件:  60%|██████    | 18/30 [03:34<03:41, 18.47s/it]

[LLM AST] => llm_ast/gpt-4o/180.py.json
[TS AST] => static_ast/180.py.json


处理文件:  63%|██████▎   | 19/30 [03:36<02:29, 13.59s/it]

[LLM AST] => llm_ast/gpt-4o/59.py.json
[TS AST] => static_ast/59.py.json


处理文件:  67%|██████▋   | 20/30 [04:28<04:09, 24.99s/it]

[LLM AST] => llm_ast/gpt-4o/138.py.json
[TS AST] => static_ast/138.py.json


处理文件:  70%|███████   | 21/30 [05:15<04:44, 31.60s/it]

[LLM AST] => llm_ast/gpt-4o/129.py.json
[TS AST] => static_ast/129.py.json


处理文件:  73%|███████▎  | 22/30 [05:15<02:58, 22.34s/it]

[LLM AST] => llm_ast/gpt-4o/176.py.json
[TS AST] => static_ast/176.py.json


处理文件:  77%|███████▋  | 23/30 [06:41<04:49, 41.34s/it]

[LLM AST] => llm_ast/gpt-4o/195.py.json
[TS AST] => static_ast/195.py.json


处理文件:  80%|████████  | 24/30 [08:02<05:18, 53.12s/it]

[LLM AST] => llm_ast/gpt-4o/54.py.json
[TS AST] => static_ast/54.py.json


处理文件:  83%|████████▎ | 25/30 [08:53<04:23, 52.70s/it]

[LLM AST] => llm_ast/gpt-4o/60.py.json
[TS AST] => static_ast/60.py.json


处理文件:  87%|████████▋ | 26/30 [10:02<03:49, 57.46s/it]

[LLM AST] => llm_ast/gpt-4o/120.py.json
[TS AST] => static_ast/120.py.json


处理文件:  90%|█████████ | 27/30 [11:17<03:08, 62.88s/it]

[LLM AST] => llm_ast/gpt-4o/98.py.json
[TS AST] => static_ast/98.py.json


处理文件:  93%|█████████▎| 28/30 [16:00<04:17, 128.93s/it]

[LLM AST] => llm_ast/gpt-4o/151.py.json
[TS AST] => static_ast/151.py.json


处理文件:  97%|█████████▋| 29/30 [17:47<02:02, 122.08s/it]

[LLM AST] => llm_ast/gpt-4o/12.py.json
[TS AST] => static_ast/12.py.json


### Visualization

In [4]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
Usage:
  python visualize_ast.py ast.json ast.pdf

This script will read the JSON AST (with fields: type, label, children, etc.)
and generate a PDF visualization using Graphviz.
"""

import os
import sys
import json
import uuid
import argparse
from graphviz import Digraph

def traverse_ast(dot: Digraph, node: dict, parent_id: str = None):
    """
    DFS遍历AST，用Graphviz构建有向图。
    
    :param dot:    Graphviz Digraph 对象
    :param node:   当前AST节点(dict: {type, label, children})
    :param parent_id: 父节点的ID(若无则为 None)
    """
    if not node:
        return

    # 生成唯一ID，以保证不同节点不会重复
    current_id = str(uuid.uuid4())

    # 准备节点的可视化标签，可自行调整。
    # 这里展示 type 以及 label（可能略长）
    node_type = node.get("type", "")
    node_label = node.get("label", "")
    
    # 为了避免节点过宽,限制label长度
    if node_label and len(node_label) > 50:
        node_label_clean = node_label[:47] + "..."
    else:
        node_label_clean = node_label.replace('\n', ' ')

    # 将 type / label 放在同一个节点中，使用分隔符 \n 或 |
    label_text = f"{node_type}\n{node_label_clean}"

    # 设置节点样式以限制宽度
    dot.node(current_id, label_text, shape='box', width='2', fixedsize='true')

    # 如果存在父节点，就连一条边
    if parent_id is not None:
        dot.edge(parent_id, current_id)

    # 递归处理子节点
    children = node.get("children", [])
    for child in children:
        traverse_ast(dot, child, current_id)

def visualize_ast(input_json: str, output_pdf: str):
    """
    读取 JSON 文件并渲染为 PDF
    """
    # 读取 JSON AST
    try:
        with open(input_json, "r", encoding="utf-8") as f:
            ast_data = json.load(f)
    except Exception as e:
        print(f"[Error] Failed to read {input_json}: {e}")
        sys.exit(1)

    # 初始化 Graphviz Digraph
    dot = Digraph(comment='AST', format='pdf')
    
    # 设置图形属性以控制布局
    dot.attr(rankdir='TB')      # 从上到下布局
    dot.attr(ranksep='0.6')     # 减小层级间距
    dot.attr(nodesep='0.4')     # 减小同层节点间距
    dot.attr(size='8,11.7')     # 设置为A4纸大小
    dot.attr(ratio='compress')   # 允许压缩以适应页面
    dot.attr(dpi='150')         # 降低DPI以减小图形尺寸

    # DFS 构建图
    traverse_ast(dot, ast_data, parent_id=None)

    # 渲染输出 PDF
    # 这里指定 cleanup=True，会把中间生成的 .gv 等文件清理掉，只保留最终 PDF
    out_path = dot.render(output_pdf, view=False, cleanup=True)
    print(f"[OK] PDF generated => {out_path}")

def main():
    llm_ast_dir = "./llm_ast/deepseek"
    static_ast_dir = "./static_ast"

    visualize_static_dir = "./visualize/static"
    visualize_llm_dir = "./visualize/llm/deepseek"
    os.makedirs(visualize_static_dir, exist_ok=True)
    os.makedirs(visualize_llm_dir, exist_ok=True)

    for file in os.listdir(llm_ast_dir):
        if not file.endswith('.json'):
            continue
        json_file = os.path.join(llm_ast_dir, file)
        output_pdf = os.path.join(visualize_llm_dir, file.replace('.json', ''))
        visualize_ast(json_file, output_pdf)

        json_file = os.path.join(static_ast_dir, file)
        output_pdf = os.path.join(visualize_static_dir, file.replace('.json', ''))
        visualize_ast(json_file, output_pdf)

if __name__ == "__main__":
    main()




[OK] PDF generated => visualize/llm/deepseek/195.py.pdf




[OK] PDF generated => visualize/static/195.py.pdf
[OK] PDF generated => visualize/llm/deepseek/11.py.pdf




[OK] PDF generated => visualize/static/11.py.pdf
[OK] PDF generated => visualize/llm/deepseek/208.py.pdf




[OK] PDF generated => visualize/static/208.py.pdf
[OK] PDF generated => visualize/llm/deepseek/13.py.pdf




[OK] PDF generated => visualize/static/13.py.pdf




[OK] PDF generated => visualize/llm/deepseek/60.py.pdf




[OK] PDF generated => visualize/static/60.py.pdf
[OK] PDF generated => visualize/llm/deepseek/202.py.pdf




[OK] PDF generated => visualize/static/202.py.pdf




[OK] PDF generated => visualize/llm/deepseek/98.py.pdf




[OK] PDF generated => visualize/static/98.py.pdf




[OK] PDF generated => visualize/llm/deepseek/180.py.pdf




[OK] PDF generated => visualize/static/180.py.pdf




[OK] PDF generated => visualize/llm/deepseek/176.py.pdf




[OK] PDF generated => visualize/static/176.py.pdf




[OK] PDF generated => visualize/llm/deepseek/151.py.pdf




[OK] PDF generated => visualize/static/151.py.pdf
