In [17]:
import os
import re
import subprocess
from pathlib import Path
from typing import Optional
from tree_sitter import Language, Parser

def ensure_tree_sitter_python():
    """确保 tree-sitter-python 语言库已经构建"""
    build_path = Path('build')
    so_path = build_path / 'my-languages.so'
    repo_path = Path('tree-sitter-python')

    # 如果.so文件已存在，直接返回路径
    if so_path.exists():
        return str(so_path.absolute())

    # 创建build目录
    build_path.mkdir(exist_ok=True)

    # 克隆仓库（如果不存在）
    if not repo_path.exists():
        subprocess.run([
            'git', 'clone',
            'https://github.com/tree-sitter/tree-sitter-python.git'
        ], check=True)

    # 构建语言库
    Language.build_library(
        str(so_path),
        [str(repo_path)]
    )

    return str(so_path.absolute())

def setup_parser():
    """设置并返回配置好的Parser"""
    so_path = ensure_tree_sitter_python()
    PY_LANGUAGE = Language(so_path, 'python')
    parser = Parser()
    parser.set_language(PY_LANGUAGE)
    return parser

def tokenize_code(code: str, start_index: int = 0):
    """分词函数"""
    code = code.replace('\n', ' \\n ').replace('\t', ' \\t ').replace('\r', ' \\r ')
    token_pattern = r'[A-Za-z_]\w*|[$$\{\}\;\,\=\+\-\*/]|[0-9]+|"[^"]*"|\'[^\']*\'|\\[ntr]|\S'
    raw_tokens = re.findall(token_pattern, code)
    
    token_objects = {}
    for idx, tk in enumerate(raw_tokens):
        if tk == '\\n':
            tk = '\n'
        elif tk == '\\t':
            tk = '\t'
        elif tk == '\\r':
            tk = '\r'
        token_objects[idx + start_index] = tk
    
    return token_objects

def get_token_index(pos: int, code: str, token_map: dict) -> Optional[int]:
    """获取给定位置对应的token索引"""
    current_pos = 0
    for idx, token in token_map.items():
        token_len = len(token)
        if current_pos <= pos < current_pos + token_len:
            return idx
        current_pos += token_len + 1  # +1 for space
    return None

def tree_to_dict(node, code: str, token_map: dict):
    """将 tree-sitter 节点转换为字典结构"""
    # 获取节点的文本值
    value = code[node.start_byte:node.end_byte].strip()
    
    # 获取token范围
    start_token = get_token_index(node.start_byte, code, token_map)
    end_token = get_token_index(node.end_byte - 1, code, token_map)
    
    result = {
        "value": value,
        "children": []
    }
    
    # 递归处理子节点
    cursor = node.walk()
    if cursor.goto_first_child():
        while True:
            child_node = cursor.node
            if child_node.is_named:  # 只处理named节点
                child_dict = tree_to_dict(child_node, code, token_map)
                if child_dict:
                    result["children"].append(child_dict)
            if not cursor.goto_next_sibling():
                break
    
    return result

def parse_code(code: str):
    """解析Python代码并返回AST"""
    # 设置解析器
    parser = setup_parser()
    
    # 解析代码
    tree = parser.parse(bytes(code, "utf8"))
    
    # 生成token映射
    token_map = tokenize_code(code)
    
    # 转换为字典结构
    return tree_to_dict(tree.root_node, code, token_map)

# 测试代码
if __name__ == "__main__":
    import json
    import os

    data_dir = "data"
    AST_dir = "AST"
    if not os.path.exists(AST_dir):
        os.makedirs(AST_dir)
    py_files = [f for f in os.listdir(data_dir) if f.endswith('.py')]
    for file in py_files:
        with open(os.path.join(data_dir, file), 'r', encoding='utf-8') as f:
            code = f.read()
        result = parse_code(code)
        with open(os.path.join(AST_dir, file.replace('.py', '.json')), 'w', encoding='utf-8') as f:
            json.dump(result, f, indent=2)
