In [23]:
from collections import defaultdict
from typing import Dict, List, Set

class CallGraphExtractor:
    def __init__(self):
        self.call_graph: Dict[str, Set[str]] = defaultdict(set)
        self.context_stack: List[str] = []
        self.context_types = {
            "functionDefinition", "classDefinition", "macroDefinition",
            "structDefinition", "enumDefinition", "mainDefinition"
        }
        
        # 新增构造表达式类型识别
        self.instantiation_types = {"newExpression", "constructorCall"}

    def build_call_graph(self, ast_node: dict) -> Dict[str, List[str]]:
        """构建调用图并返回排序后的结果"""
        self._traverse(ast_node)
        return {ctx: sorted(calls) for ctx, calls in self.call_graph.items()}

    def _traverse(self, node: dict):
        """AST遍历入口方法"""
        # 上下文节点处理
        if self._is_context_node(node):
            self._enter_context(node)
        
        # 关键处理逻辑
        self._process_node(node)
        
        # 递归处理子节点
        for child in node.get("children", []):
            self._traverse(child)
        
        # 上下文回溯
        if self._is_context_node(node):
            self._exit_context()

    def _is_context_node(self, node: dict) -> bool:
        """判断是否为上下文容器节点"""
        return node.get("type") in self.context_types

    def _enter_context(self, node: dict):
        """进入新上下文"""
        ctx_name = self._extract_ctx_name(node)
        self.context_stack.append(ctx_name)

    def _exit_context(self):
        """退出当前上下文"""
        if self.context_stack:
            self.context_stack.pop()

    def _extract_ctx_name(self, node: dict) -> str:
        """提取上下文名称的健壮方法"""
        # 策略1：直接子节点查找
        for child in node.get("children", []):
            if child.get("type") == "identifier":
                return child.get("label", "<anonymous>")
        
        # 策略2：深度优先搜索
        def dfs(n: dict) -> str:
            if n.get("type") == "identifier":
                return n.get("label", "<anonymous>")
            for c in n.get("children", []):
                if name := dfs(c):
                    return name
            return ""
        
        return dfs(node) or "<anonymous>"

    def _process_node(self, node: dict):
        """处理当前节点的核心逻辑"""
        node_type = node.get("type")
        
        # 处理普通方法调用
        if node_type == "callExpression":
            self._handle_method_call(node)
        
        # 处理类实例化调用（新增逻辑）
        elif node_type in self.instantiation_types:
            self._handle_instantiation(node)

    def _handle_method_call(self, node: dict):
        """处理方法调用表达式"""
        if not self.context_stack:
            return
        
        if target := self._parse_call_target(node):
            self._record_call(target)

    def _handle_instantiation(self, node: dict):
        """处理类实例化表达式"""
        if not self.context_stack:
            return
        
        # 提取类名并标记为构造调用
        if class_name := self._parse_class_name(node):
            self._record_call(f"{class_name}::constructor")

    def _parse_call_target(self, node: dict) -> str:
        """解析调用目标链（支持a.b().c格式）"""
        parts = []
        if children := node.get("children"):
            self._resolve_expression(children[0], parts)
        return ".".join(parts) if parts else ""

    def _parse_class_name(self, node: dict) -> str:
        """解析类实例化中的类名"""
        # 类型1：new ClassName(...)
        if node.get("type") == "newExpression":
            if children := node.get("children"):
                return self._parse_simple_class(children[0])
        
        # 类型2：ClassName(...)
        elif node.get("type") == "constructorCall":
            return self._parse_call_target(node)
        
        return ""

    def _parse_simple_class(self, node: dict) -> str:
        """解析简单类名"""
        if node.get("type") == "identifier":
            return node.get("label", "")
        # 处理复杂表达式如new mymodule.ClassName(...)
        elif node.get("type") == "fieldExpression":
            parts = []
            self._resolve_expression(node, parts)
            return ".".join(parts)
        return ""

    def _resolve_expression(self, node: dict, parts: list):
        """递归解析表达式结构"""
        node_type = node.get("type")
        
        if node_type == "identifier":
            parts.append(node.get("label", ""))
        
        elif node_type == "fieldExpression":
            for child in node.get("children", []):
                self._resolve_expression(child, parts)
        
        elif node_type == "callExpression":
            if children := node.get("children"):
                self._resolve_expression(children[0], parts)
        
        elif node_type == "parenthesizedExpression":
            if children := node.get("children"):
                self._resolve_expression(children[0], parts)
        
        elif node_type == "subscriptExpression":
            if children := node.get("children"):
                self._resolve_expression(children[0], parts)

    def _record_call(self, target: str):
        """记录调用关系的统一方法"""
        if target and self.context_stack:
            current_ctx = self.context_stack[-1]
            self.call_graph[current_ctx].add(target)

In [25]:
import json
import os

static_ast_dir = "../../dataset/cangjie_ast"
static_cg_dir = "../../dataset/cangjie_cg"
os.makedirs(static_cg_dir, exist_ok=True)
for i in range(300):
    ast_file = f"{static_ast_dir}/{i}.cj.json"
    if not os.path.exists(ast_file):
        continue
    ast = json.load(open(ast_file))
    extractor = CallGraphExtractor()
    call_graph = extractor.build_call_graph(ast)
    with open(f"{static_cg_dir}/{i}.json", "w") as f:
        json.dump(call_graph, f, indent=4)
