In [None]:
# ============= 2) py-tree-sitter 相关 ============= #
from tree_sitter import Language, Parser

import os
import sys
import json
from typing import Any, Dict, List, Tuple
import re
import concurrent.futures
from multiprocessing import cpu_count

from tqdm import tqdm
## need to change the version of tree-sitter 0.19.0 here!!!
def create_parser():
    cangjie_parser = Parser()
    cangjie_language = Language('../cj.so', 'char')
    cangjie_parser.set_language(cangjie_language)
    return cangjie_parser

parser = create_parser()

# ============= 3) Tokenization & Utility Functions ============= #
import re

def tokenize_code(code: str) -> List[Tuple[int, int, str]]:
    """
    使用正则分词，直接在原始代码上通过 finditer() 获取匹配位置，
    并保留换行符，让 LLM 能识别多行结构。
    返回形式: [(start_offset, end_offset, token_text), ...]
    """
    token_pattern = (
        r'[A-Za-z_]\w*|'  # 标识符
        r'[0-9]+|'        # 数字
        r'"[^"]*"|'       # 双引号字符串
        r"'[^']*'|"       # 单引号字符串
        r'\\[ntr]|'       # 转义符 \n \t \r
        r'//.*|'          # 单行注释 (如 C++/Java/JS 风格)
        r'/\*.*?\*/|'     # 多行注释 (如 C 风格)
        r'\n|\r|\t|'      # 换行/回车/制表符
        r'\S'             # 其他符号(如 +, -, {, }, 以及任何其它非空白字符)
    )

    tokens_with_offset = []
    for match in re.finditer(token_pattern, code, re.MULTILINE | re.DOTALL):
        tk = match.group(0)
        start_offset, end_offset = match.span()
        tokens_with_offset.append((start_offset, end_offset, tk))

    return tokens_with_offset

# ============= 6) Tree-sitter-based Static AST (带 start_token/end_token) ============= #

def find_token_index_for_byte(byte_offset: int, tokens: List[Tuple[int, int, str]], is_start=True) -> int:
    """
    在 tokens 列表里找出对应 byte_offset 所处的 token 下标。
    is_start=True 表示在找 start_token，False 则是 end_token。
    简易的线性搜索，若需要效率可改为二分。
    """
    for i, (tk_start, tk_end, _) in enumerate(tokens):
        if is_start:
            # start_token: tk_start <= byte_offset < tk_end
            if tk_start <= byte_offset < tk_end:
                return i
        else:
            # end_token: tk_start < byte_offset <= tk_end
            if tk_start < byte_offset <= tk_end:
                return i
    return -1

class PyTreeSitterStaticHandler:
    """
    生成一棵 AST：{type, label, start_token, end_token, children}，
    其中 start_token/end_token 通过与 tokenize_code 的结果对应。
    """
    def __init__(self, code: str):
        self.parser = parser
        self.code = code
        # 与 LLM 相同的 tokenize 函数
        self.tokens = tokenize_code(code)

    def generate_static_ast(self) -> Dict[str, Any]:
        tree = self.parser.parse(self.code.encode())
        root_node = tree.root_node
        return self.ts_node_to_dict(root_node)

    def ts_node_to_dict(self, node) -> Dict[str, Any]:
        # 跳过未命名节点
        if not node.is_named:
            return None

        node_type = node.type
        node_text = (node.text or b"").decode("utf-8")

        start_byte = node.start_byte
        end_byte = node.end_byte

        start_token_idx = find_token_index_for_byte(start_byte, self.tokens, is_start=True)
        end_token_idx = find_token_index_for_byte(end_byte, self.tokens, is_start=False)

        custom = {
            "type": node_type,
            "label": node_text,  # 这里保留整段源码片段(可自行决定是否截断换行等)
            "start_token": start_token_idx,
            "end_token": end_token_idx,
            "children": []
        }

        for i in range(node.child_count):
            child = node.child(i)
            child_dict = self.ts_node_to_dict(child)
            if child_dict:
                custom["children"].append(child_dict)

        return custom

def generate_tree_sitter_ast(code: str) -> Dict[str, Any]:
    """
    对外暴露的函数: 解析 Python 代码并生成带有 (start_token, end_token) 的静态 AST。
    """
    handler = PyTreeSitterStaticHandler(code)
    return handler.generate_static_ast()

def process_static_ast(code: str, file_path: str) -> dict:
    """处理树莓坐（tree-sitter）AST生成和保存."""
    ts_ast = generate_tree_sitter_ast(code)
    ts_out_dir = "cangjie_ast"
    os.makedirs(ts_out_dir, exist_ok=True)
    ts_json_path = os.path.join(ts_out_dir, os.path.basename(file_path) + ".json")
    with open(ts_json_path, "w", encoding="utf-8") as fout:
        json.dump(ts_ast, fout, indent=4, ensure_ascii=False)
    print(f"[TS AST] => {ts_json_path}")
    return ts_ast

def process_single_file(file_path: str):
    """
    1) 读取代码
    2) 生成静态 AST (tree-sitter)
    3) 保存 JSON
    """
    try:
        with open(file_path, "r", encoding="utf-8") as f:
            code = f.read()
    except Exception as e:
        print(f"[Error] reading {file_path}: {e}")
        return

    # generate static AST
    ts_ast = process_static_ast(code, file_path)

# ============= 8) 主函数：多线程并行处理 ============= #

def main():
    source_dir = "cangjie"
    if not os.path.isdir(source_dir):
        print(f"[Error] Directory {source_dir} does not exist.")
        return

    # 示例：只处理前 10 个 .py 文件，可按需修改
    files = [f for f in os.listdir(source_dir) if f.endswith(".cj")]

    # process_single_file(os.path.join(source_dir, files[0]))

    # 并行处理
    with concurrent.futures.ThreadPoolExecutor(max_workers=cpu_count()) as executor:
        futures = []
        pbar = tqdm(total=len(files), desc="处理文件")
        for fname in files:
            full_path = os.path.join(source_dir, fname)
            future = executor.submit(process_single_file, full_path)
            future.add_done_callback(lambda _: pbar.update(1))
            futures.append(future)
        concurrent.futures.wait(futures)
        pbar.close()

if __name__ == "__main__":
    main()




In [4]:
import json
import os
from typing import Dict, Set

def extract_types_from_cfg(cfg: Dict) -> set:
    """从AST配置中提取所有可能的类型"""
    types = set()
    
    def traverse(node):
        if isinstance(node, dict):
            if "type" in node:
                types.add(node["type"])
            for value in node.values():
                traverse(value)
        elif isinstance(node, list):
            for item in node:
                traverse(item)
                
    traverse(cfg)
    return types

# 遍历所有json文件并去重
all_types = set()
processed_files = set()

for root in ["../../dataset/cangjie_ast"]:
    for file in os.listdir(root):
        if not file.endswith('.json'):
            continue
            
        # 检查文件是否已处理过(通过文件名去重)
        if file in processed_files:
            continue
            
        processed_files.add(file)
        file_path = os.path.join(root, file)
        
        with open(file_path) as f:
            cfg = json.load(f)
            types = extract_types_from_cfg(cfg)
            all_types.update(types)

print("所有文件中提取的唯一类型:")
for t in sorted(all_types):
    print(f"- {t}")


所有文件中提取的唯一类型:
- ABSTRACT
- AS
- BOOLEAN
- BREAK
- CASE
- CATCH
- CHAR
- CLASS
- CONST
- CONTINUE
- DO
- ELSE
- ENUM
- ERROR
- EXTEND
- FALSE
- FINALLY
- FOR
- FOREIGN
- FROM
- FUNC
- IF
- IMPORT
- IN
- INIT
- INOUT
- INTERFACE
- INTNATIVE
- IS
- LET
- MACRO
- MAIN
- MATCH
- MUT
- OPEN
- OPERATOR
- OVERRIDE
- PACKAGE
- PRIVATE
- PROP
- PROTECTED
- PUBLIC
- QUOTE
- REDEF
- RETURN
- SEALED
- SPAWN
- STATIC
- STRUCT
- SUPER
- SYNCHRONIZED
- THIS
- THROW
- TRUE
- TRY
- TYPE
- UINTNATIVE
- UNIT
- UNSAFE
- VAR
- WHERE
- WHILE
- argumentList
- arrayLiteral
- arrowType
- assignmentExpression
- binaryExpression
- block
- body
- booleanLiteral
- breakExpression
- builtinFunction
- callExpression
- caseBody
- charLangTypes
- characterLiteral
- classDefinition
- classType
- collectionLiteral
- comment
- constantPattern
- continueExpression
- dollarIdentifier
- element
- elements
- enumBody
- enumDefinition
- enumPattern
- enumPatternParameters
- escapeSeq
- exceptionTypePattern
- extendBody
- exten