### Easy Generate CFG

#### LLM

效果不太好，还是需要我们一步步进行处理！

step1 先将文件的嵌套类，方法给找到

In [1]:
from multiprocessing import cpu_count
from llm import get_llm_answers
import json

def get_step1_prompt(code_text: str, program_language: str):
    """
    生成第一步的Prompt
    """
    code_lines = code_text.splitlines()
    code_lines_json = [{
        "line": i + 1,
        "code": line
    } for i, line in enumerate(code_lines)]

    prompt = """
You are given a piece of """ +  program_language + """ code. Your goal is to find all the nested classes and methods in the code.

Please return the result in JSON format, your output should be the following format:

```json
{
    "name": "example_script",  // Name of the script or function
    "type": "CFG",
    "start_line": number,
    "end_line": number,
    "functions": [
      {
        "name": "function_name",
        "type": "function",
        "start_line": number,
        "end_line": number,
        "functions": [],         // Nested functions
        "classes": []            // Nested classes
      }
    ],
    "classes": [
      {
        "name": "class_name",
        "type": "class",
        "start_line": number,
        "end_line": number,
        "functions": [           // Methods of the class
          {
            "name": "method_name",
            "type": "function",
            "start_line": number,
            "end_line": number,
            "functions": [],     // Nested functions
            "classes": []        // Nested classes
          }
        ]
      }
    ]
}
```
The code lines are:
""" + json.dumps(code_lines_json, indent=2) + """
IMPORTANT: Make sure that the nested classes and methods are in the correct level. For example, if a function is nested in another class, the function should be in the nested class's functions list. 
Besides, if a class is nested in another class, the class should be in the nested class's classes list.
"""
    
    return prompt

def find_nested_classes_and_methods(code_text: str, program_language):
    """
    找到文件中的嵌套类，方法
    """
    prompt = get_step1_prompt(code_text, program_language)
    response = get_llm_answers(prompt, model_name="gpt-4o", require_json=True)
    nested_classes_and_methods = json.loads(response)
    return nested_classes_and_methods

def process_file_with_chain_of_thought(input_file: str, program_language: str):
    """
    读取 Python 文件 -> 生成’思维链’式Prompt -> 调用大模型 -> 写入结果JSON
    """
    with open(input_file, "r", encoding="utf-8") as f:
        code_text = f.read()

    # 找到文件中的嵌套类，方法
    step1_result = find_nested_classes_and_methods(code_text, program_language)
    # print(json.dumps(step1_result, indent=2))

    return step1_result

from difflib import SequenceMatcher


def get_code_by_line_range(code_block, code):
    code_lines = code.splitlines()
    start_line = code_block["start_line"]
    end_line = code_block["end_line"] + 1

    ## start_line 到 end_line 之间的代码， 但是要减去自身class和function的代码
    line_set = set(range(start_line, end_line))
    for func in code_block.get("functions", []):
        func_start_line = func.get("start_line", 0)
        func_end_line = func.get("end_line", 0)
        line_set.difference_update(range(func_start_line, func_end_line))

    for cls in code_block.get("classes", []):
        cls_start_line = cls.get("start_line", 0)
        cls_end_line = cls.get("end_line", 0)
        line_set.difference_update(range(cls_start_line, cls_end_line))

    # 将line_set转换为有序列表并排序,确保按行号顺序
    ordered_lines = sorted(list(line_set))
    sum_code = "\n".join([code_lines[i-1] for i in ordered_lines])

    code_block["simplified_code"] = sum_code

def recursive_get_code_by_line_range(code_block, code):
    get_code_by_line_range(code_block, code)
    for func in code_block.get("functions", []):
        recursive_get_code_by_line_range(func, code)
    for cls in code_block.get("classes", []):
        recursive_get_code_by_line_range(cls, code)


def print_simplified_code(code_block: dict, indent=0):
    """
    递归遍历并打印 simplified_code
    """
    print(" " * indent + "简化后的代码:")
    print(" " * indent + code_block.get("simplified_code", "").strip())

    # 递归处理嵌套的类
    for class_block in code_block.get("classes", []):
        print(" " * indent + f"\n类 {class_block.get('name', '')}:")
        print_simplified_code(class_block, indent + 2)

    # 递归处理嵌套的函数
    for function_block in code_block.get("functions", []):
        print(" " * indent + f"\n函数 {function_block.get('name', '')}:")
        print_simplified_code(function_block, indent + 2)

def get_code_cfg_prompt(code, program_language):
    """
    生成代码的CFG的Prompt
    """
    prompt = f"""
You are given a piece of {program_language} code. Your goal is to generate a CFG for the code. You should find the basic blocks of the code and find the successors of each block.

Attention to the following structure containing the branch:
1. if-else
2. for-while
3. try-except-finally
4. with-as
5. match-case
6. break-continue-return

You should identify the basic blocks and the successors of each block(which means the blocks that may be executed after this block).

Your output should follow the following json format:
""" + """
```json
{  
  "blocks": [
    {
      "id": 1,
      "label": "if a > 2:",
      "successors": [
        2,3
      ], # which means that there are two successors of this block which may be executed after this block
    },
    {
      "id": 2,
      "label": "print(a)",
      "successors": [
        3
      ]
    },
    {
      "id": 3,
      "label": "print(1)",
      "successors": [
        
      ]
    }
  ]
}
```

+ id: 1, 2, 3, ...
+ label: the entire code of the block (don't remove any code)
+ successors: the ids of the blocks that may be executed after this block

Make sure that the successors blocks exist in the blocks list before you finally output.
!!!IMPORTANT: Each blocks represent a basic block, which is a single statement or a group of statements that can be executed as a unit without any branch. 

Following is the given code:
""" + code
    return prompt


def get_single_block_cfg(code_block, program_languge):
    """
    获取每个代码块的CFG
    """
    prompt = get_code_cfg_prompt(code_block["simplified_code"], program_languge)
    response = get_llm_answers(prompt, model_name="gpt-4o", require_json=True)
    blocks = json.loads(response)["blocks"]
    code_block["blocks"] = blocks
    
def recursive_get_each_block_cfg(code_block, program_language):
    """
    递归获取每个代码块的CFG
    """
    get_single_block_cfg(code_block, program_language)
    for block in code_block.get("classes", []):
        recursive_get_each_block_cfg(block, program_language)
    for block in code_block.get("functions", []):
        recursive_get_each_block_cfg(block, program_language)

import os
def main():
    from concurrent.futures import ThreadPoolExecutor
    from functools import partial
    os.makedirs("llm_cfg_gpt-4o_50", exist_ok=True)
    def process_single_file(file):
        target_file = "llm_cfg_gpt-4o_50/" + file.replace(".py", ".json")
        if os.path.exists(target_file):
            return
        print("Processing " + file)
        step1_result = process_file_with_chain_of_thought("source_code/" + file, "python")
        with open("source_code/" + file, 'r') as f:
            code = f.read()
        recursive_get_code_by_line_range(step1_result, code)
        recursive_get_each_block_cfg(step1_result, "python")

        def remove_duplicate_blocks(code_block):
            """
            删除同一层级中start_line和end_line相同的重复块,仅保留最前面的一个
            """
            if "blocks" in code_block:
                # 用于记录已经出现过的(start_line, end_line)组合
                seen = set()
                # 用于存储不重复的blocks
                unique_blocks = []
                
                for block in code_block["blocks"]:
                    # 如果block有start_line和end_line属性
                    if "start_line" in block and "end_line" in block:
                        key = (block["start_line"], block["end_line"])
                        if key not in seen:
                            seen.add(key)
                            unique_blocks.append(block)
                    else:
                        # 如果没有这些属性,保留该block
                        unique_blocks.append(block)
                        
                code_block["blocks"] = unique_blocks
            
            # 递归处理子块
            for sub_block in code_block.get("classes", []):
                remove_duplicate_blocks(sub_block)
            for sub_block in code_block.get("functions", []):
                remove_duplicate_blocks(sub_block)
                
        # 处理整个CFG
        remove_duplicate_blocks(step1_result)
        
        with open(target_file, "w") as f:
            json.dump(step1_result, f, indent=2)

    # 使用线程池并行处理文件
    with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
        files = os.listdir("source_code")[:50]
        executor.map(process_single_file, files)

main()


Processing 11.py
Processing 60.py
Processing 180.py
Processing 195.py
Processing 202.py
Processing 208.py
Processing 151.py
Processing 176.py
Processing 98.py
Processing 13.py
Processing 120.py
Processing 71.py
Processing 12.py
Processing 167.py
Processing 186.py
Processing 62.py
Processing 14.py
Processing 6.py
Processing 59.py
Processing 107.py
Processing 9.py
Processing 163.py
Processing 129.py
Processing 139.py
Processing 54.py
Processing 184.py
Processing 174.py
Processing 138.py
Processing 148.py
Processing 201.py
Processing 75.py
Processing 159.py
Processing 55.py
Processing 123.py
Processing 160.py
Processing 29.py
Processing 100.py
Processing 45.py
Processing 116.py
Processing 25.py
Processing 89.py
Processing 28.py
Processing 86.py
Processing 99.py
Processing 19.py
Processing 90.py
Processing 132.py
Processing 77.py
Processing 144.py
Processing 93.py


### LLM生成的代码可能可以合并

In [2]:
def process_cfg(cfg):
    """
    Process a CFG to separate loop headers and bodies, merge non-loop blocks, and remove unreachable blocks.

    Args:
        cfg (dict): The CFG data structure.

    Returns:
        dict: Processed CFG.
    """
    def filter_connected_blocks(blocks):
        """Keep only connected blocks reachable from the first block (ID = 1)."""
        # 如果block中没有successors键,则设为空列表
        for block in blocks:
            if "successors" not in block:
                block["successors"] = []
        adjacency_list = {block["id"]: block["successors"] for block in blocks}

        # Perform DFS to find all reachable blocks
        connected = set()
        stack = [1]  # Start from block with ID = 1
        while stack:
            current = stack.pop()
            if current not in connected:
                connected.add(current)
                stack.extend(adjacency_list.get(current, []))

        return [block for block in blocks if block["id"] in connected]

    def find_loop_body(header_id, block_map):
        """Find all nodes in the loop body starting from the loop header."""
        loop_body = set()
        stack = [header_id]
        while stack:
            current = stack.pop()
            if current not in loop_body:
                loop_body.add(current)
                for succ in block_map[current]["successors"]:
                    # Avoid re-adding the loop header itself
                    if succ != header_id:
                        stack.append(succ)
        return loop_body

    def merge_blocks(blocks):
        """Merge blocks by separating loop headers and combining loop bodies."""
        merged_blocks = []
        visited = set()

        block_map = {block["id"]: block for block in blocks}
        predecessors = {block["id"]: set() for block in blocks}
        for block in blocks:
            for succ in block["successors"]:
                if succ in predecessors:
                    predecessors[succ].add(block["id"])

        def is_loop_header(block):
            """Check if a block is a loop header (e.g., 'for', 'while')."""
            return len(block["successors"]) == 1 and block["successors"][0] == block["id"]

        for block in blocks:
            if block["id"] in visited:
                continue

            if is_loop_header(block):
                # Loop header: keep as a separate block
                merged_blocks.append(block)
                visited.add(block["id"])

                # Find and merge the loop body
                loop_body = find_loop_body(block["id"], block_map)
                loop_body_nodes = [block_map[node_id] for node_id in loop_body if node_id not in visited]
                if loop_body_nodes:
                    merged_label = "\n".join(node["label"] for node in loop_body_nodes)
                    merged_blocks.append({
                        "id": min(loop_body),  # Use the smallest ID in the loop body
                        "label": merged_label,
                        "successors": list(loop_body_nodes[-1]["successors"]),  # Use the successors of the last node
                    })
                    visited.update(loop_body)

            else:
                # Attempt to merge blocks
                current_block = block
                merged_label = current_block["label"]
                visited.add(current_block["id"])

                while current_block["successors"]:
                    successor_id = current_block["successors"][0]
                    if (
                        successor_id in visited or
                        successor_id not in block_map or
                        is_loop_header(block_map[successor_id])
                    ):
                        break

                    next_block = block_map[successor_id]
                    if len(predecessors[successor_id]) > 1:
                        break  # Cannot merge due to multiple predecessors

                    # Merge the block
                    merged_label += f"\n{next_block['label']}"
                    visited.add(successor_id)
                    current_block = next_block

                merged_blocks.append({
                    "id": block["id"],
                    "label": merged_label,
                    "successors": current_block["successors"]
                })

        return merged_blocks

    # Process top-level blocks
    if "blocks" in cfg:
        cfg["blocks"] = filter_connected_blocks(cfg["blocks"])
        cfg["blocks"] = merge_blocks(cfg["blocks"])

    # Recursively process functions
    if "functions" in cfg:
        for func in cfg["functions"]:
            process_cfg(func)

    # Recursively process classes
    if "classes" in cfg:
        for cls in cfg["classes"]:
            process_cfg(cls)

    return cfg


for i, file in enumerate(os.listdir("llm_cfg_gpt-4o_50")):
    with open("llm_cfg_gpt-4o_50/" + file, "r") as f:
        # if file != '98.json':
        #     continue
        print("Merging " + file)
        llm_cfg = json.load(f)
        process_cfg(llm_cfg)
        os.makedirs("merged_llm_cfg_50", exist_ok=True)
        with open("merged_llm_cfg_50/" + file, "w") as f:
            json.dump(llm_cfg, f, indent=2)


Merging 160.json
Merging 132.json
Merging 116.json
Merging 14.json
Merging 167.json
Merging 107.json
Merging 89.json
Merging 176.json
Merging 98.json
Merging 123.json
Merging 28.json
Merging 45.json
Merging 13.json
Merging 90.json
Merging 138.json
Merging 163.json
Merging 54.json
Merging 59.json
Merging 71.json
Merging 195.json
Merging 201.json
Merging 93.json
Merging 184.json
Merging 120.json
Merging 151.json
Merging 25.json
Merging 144.json
Merging 208.json
Merging 75.json
Merging 180.json
Merging 86.json
Merging 148.json
Merging 11.json
Merging 174.json
Merging 29.json
Merging 159.json
Merging 19.json
Merging 55.json
Merging 6.json
Merging 129.json
Merging 60.json
Merging 77.json
Merging 202.json
Merging 9.json
Merging 99.json
Merging 100.json
Merging 62.json
Merging 12.json


#### static analysis

In [1]:
from scalpel.cfg import CFGBuilder

def cfg_to_dict(cfg):
    """
    遍历 CFG 和所有子 CFG，生成嵌套的 Python 字典
    """
    visited = set()  # 防止重复访问

    def traverse(block):
        """
        遍历单个 CFG 块并生成节点和边的数据结构
        """
        if block.id in visited:
            return None
        visited.add(block.id)
    
        # 获取当前块的源代码
        block_label = block.get_source().strip() if not block.is_empty() else "Empty Block"
    
        # 按行拆分代码
        lines = block_label.split('\n')
    
        # 去掉包含 `...` 的行
        filtered_lines = [line for line in lines if '...' not in line]
    
        # 将过滤后的行重新合并为块标签
        block_label = '\n'.join(filtered_lines).strip()
    
        block_dict = {
            "id": block.id,
            "label": block_label,
            "successors": []
        }
    
        # 遍历后继节点
        for exit in block.exits:
            successor = traverse(exit.target)
            if successor:
                block_dict["successors"].append(successor)
    
        return block_dict


    def process_cfg(cfg, prefix="Main"):
        """
        处理当前 CFG，包括其子 CFG（函数和类），并返回嵌套字典
        """
        cfg_dict = {
            "name": cfg.name,
            "type": "CFG",
            "blocks": []
        }

        # 处理当前 CFG 的入口块
        if cfg.entryblock:
            entry_block = traverse(cfg.entryblock)
            if entry_block:
                cfg_dict["blocks"].append(entry_block)

        # 递归处理函数的子 CFG
        cfg_dict["functions"] = []
        for func_name, func_cfg in cfg.functioncfgs.items():
            func_dict = process_cfg(func_cfg, prefix=f"{prefix}_func_{func_name}")
            cfg_dict["functions"].append(func_dict)

        # 递归处理类的子 CFG
        cfg_dict["classes"] = []
        for class_name, class_cfg in cfg.class_cfgs.items():
            class_dict = process_cfg(class_cfg, prefix=f"{prefix}_class_{class_name}")
            cfg_dict["classes"].append(class_dict)

        return cfg_dict

    return process_cfg(cfg)


import json
import os
os.makedirs("static_cfg", exist_ok=True)

for i, file in enumerate(os.listdir("source_code")):
    with open("source_code/" + file, "r") as f:
        source_code = f.read()

    # 构建控制流图
    cfg = CFGBuilder().build_from_src(file, source_code)

    # 将 CFG 转换为字典数据结构
    cfg_dict = cfg_to_dict(cfg)

    # 保存到文件
    with open("static_cfg/" + file.replace(".py", ".json"), "w") as f:
        json.dump(cfg_dict, f, indent=2)

### Compare CFG

In [1]:
import json
import os
from concurrent.futures import ThreadPoolExecutor, as_completed
from multiprocessing import cpu_count
from typing import Dict, List, Tuple
from dataclasses import dataclass
from sklearn.feature_extraction.text import CountVectorizer
from sklearn.metrics.pairwise import cosine_similarity
import numpy as np
from tqdm import tqdm
from llm import get_llm_answers

def compare_cfg_similarity(llm_cfg, static_cfg):
    max_retries = 3
    retry_count = 0
    
    while retry_count < max_retries:
        try:
            prompt = f"""
You are a CFG evaluator to evaluate whether the generated CFG is correct based on the static CFG.

You should first compare the structure of the CFG, then compare the content of the CFG. Focus on the flow of the CFG and ignore the details such as content and block_id, block_name.

Your output should be a json with the following format:
{{
    "reasonable": true/false,
    "structure_similarity": 0.8,
    "content_similarity": 0.9,
    "total_similarity": 0.85,
    "reason": ""
}}

Ground truth:
{static_cfg}

Generated CFG:
{llm_cfg}
"""
            similarity = json.loads(get_llm_answers(prompt, model_name="gpt-4o", require_json=True))
            return similarity
            
        except Exception as e:
            retry_count += 1
            if retry_count == max_retries:
                raise e
            print(f"重试第{retry_count}次,错误信息:{str(e)}")

from dataclasses import dataclass
from typing import Dict, List, Optional, Union
import json
import numpy as np
from pathlib import Path

@dataclass
class CFGSimilarityResult:
    """存储CFG比较结果的数据类"""
    filename: str
    edge_coverage: float
    content_similarity: float
    structure_similarity: float
    matched_edges: int
    gt_edges: int
    llm_edges: int
    nested_results: Optional[Dict[str, 'CFGSimilarityResult']] = None
    llm_similarity: Optional[Dict[str, Union[float, bool]]] = None

class CFGComparator:
    def __init__(self):
        """初始化CFG比较器"""
        pass
    
    @staticmethod
    def count_edges(cfg: Dict) -> int:
        """递归计算CFG中的边数量"""
        edge_count = sum(
            len(block.get("successors", []))
            for block in cfg.get("blocks", [])
        )
        
        # 递归计算嵌套CFG的边
        for func in cfg.get("functions", []):  # functions是列表
            edge_count += CFGComparator.count_edges(func)
        for cls in cfg.get("classes", []):     # classes是列表
            edge_count += CFGComparator.count_edges(cls)
            
        return edge_count

    def structure_similarity(self, llm_cfg: Dict, static_cfg: Dict) -> float:
        """计算两个CFG的结构相似度"""
        # 获取两个CFG的blocks
        llm_blocks = llm_cfg.get("blocks", [])
        static_blocks = static_cfg.get("blocks", [])
        
        # 如果两个CFG都没有blocks，返回1.0
        if not llm_blocks and not static_blocks:
            return 1.0
        # 如果其中一个没有blocks，返回0.0
        if not llm_blocks or not static_blocks:
            return 0.0
            
        # 计算边的匹配度
        llm_edges = sum(len(block.get("successors", [])) for block in llm_blocks)
        static_edges = sum(len(block.get("successors", [])) for block in static_blocks)
        
        if llm_edges == 0 and static_edges == 0:
            return 1.0
        if llm_edges == 0 or static_edges == 0:
            return 0.0
            
        # 使用边数的比例计算相似度
        return min(llm_edges, static_edges) / max(llm_edges, static_edges)

    def content_similarity(self, llm_cfg: Dict, static_cfg: Dict) -> float:
        """计算两个CFG的内容相似度"""
        # 获取简化的代码内容
        llm_code = llm_cfg.get("simplified_code", "")
        static_code = static_cfg.get("simplified_code", "")
        
        # 如果两个都为空，返回1.0
        if not llm_code and not static_code:
            return 1.0
        # 如果其中一个为空，返回0.0
        if not llm_code or not static_code:
            return 0.0
        
        # 将代码分割成行并去除空白行
        llm_lines = [line.strip() for line in llm_code.split("\n") if line.strip()]
        static_lines = [line.strip() for line in static_code.split("\n") if line.strip()]
        
        # 计算行的匹配度
        common_lines = set(llm_lines) & set(static_lines)
        return len(common_lines) / max(len(llm_lines), len(static_lines))

    def compare_cfgs(self, llm_cfg: Dict, static_cfg: Dict, name: str) -> CFGSimilarityResult:
        """递归比较两个CFG并返回相似度结果"""
        # 计算当前层级的相似度
        structure_sim = self.structure_similarity(llm_cfg, static_cfg)
        content_sim = self.content_similarity(llm_cfg, static_cfg)
        
        # 计算边的统计信息
        gt_edges = self.count_edges(static_cfg)
        llm_edges = self.count_edges(llm_cfg)
        matched_edges = int(structure_sim * min(gt_edges, llm_edges))
        edge_coverage = matched_edges / gt_edges if gt_edges > 0 else 0
        
        # 递归比较嵌套的CFG
        nested_results = {}
        
        # 比较函数CFG
        llm_functions = {f["name"]: f for f in llm_cfg.get("functions", [])}
        static_functions = {f["name"]: f for f in static_cfg.get("functions", [])}
        common_functions = set(llm_functions.keys()) & set(static_functions.keys())
        
        for func_name in common_functions:
            nested_results[f"function_{func_name}"] = self.compare_cfgs(
                llm_functions[func_name],
                static_functions[func_name],
                func_name
            )
        
        # 比较类CFG
        llm_classes = {c["name"]: c for c in llm_cfg.get("classes", [])}
        static_classes = {c["name"]: c for c in static_cfg.get("classes", [])}
        common_classes = set(llm_classes.keys()) & set(static_classes.keys())
        
        for class_name in common_classes:
            nested_results[f"class_{class_name}"] = self.compare_cfgs(
                llm_classes[class_name],
                static_classes[class_name],
                class_name
            )
        
        return CFGSimilarityResult(
            filename=name,
            edge_coverage=edge_coverage,
            content_similarity=content_sim,
            structure_similarity=structure_sim,
            matched_edges=matched_edges,
            gt_edges=gt_edges,
            llm_edges=llm_edges,
            nested_results=nested_results if nested_results else None,
            llm_similarity=None  # 将在process_file中设置
        )

class CFGEvaluator:
    def __init__(self, llm_cfg_dir: str, static_cfg_dir: str, result_file: str):
        """初始化评估器
        
        Args:
            llm_cfg_dir: LLM生成的CFG文件目录
            static_cfg_dir: 静态分析生成的CFG文件目录
            result_file: 结果保存文件路径
        """
        self.llm_cfg_dir = Path(llm_cfg_dir)
        self.static_cfg_dir = Path(static_cfg_dir)
        self.result_file = Path(result_file)
        self.comparator = CFGComparator()
        self.results = []  # 存储所有结果
    
    def process_file(self, llm_cfg_path: Path) -> Optional[CFGSimilarityResult]:
        """处理单个CFG文件对的比较
        
        Args:
            llm_cfg_path: LLM生成的CFG文件路径
            
        Returns:
            CFGSimilarityResult 或 None（如果没有对应的静态CFG文件）
        """
        # 获取对应的静态CFG文件路径
        static_cfg_path = self.static_cfg_dir / llm_cfg_path.name
        if not static_cfg_path.exists():
            return None
            
        # 读取CFG文件
        with open(llm_cfg_path) as f:
            llm_cfg = json.load(f)
        with open(static_cfg_path) as f:
            static_cfg = json.load(f)
            
        # 比较CFG
        result = self.comparator.compare_cfgs(llm_cfg, static_cfg, llm_cfg_path.name)
        llm_sim = compare_cfg_similarity(llm_cfg, static_cfg)
        result.llm_similarity = llm_sim
        
        # 将结果添加到列表并保存
        self.results.append(result)
        self.save_results()
        
        return result
    
    def save_results(self):
        """保存当前所有结果到文件"""
        with open(self.result_file, "w") as f:
            json.dump(
                [self._result_to_dict(r) for r in self.results],
                f,
                indent=2
            )
    
    def evaluate_all(self) -> List[CFGSimilarityResult]:
        """评估所有CFG文件对
        
        Returns:
            所有比较结果的列表
        """
        # 处理每个LLM生成的CFG文件
        llm_cfg_paths = list(self.llm_cfg_dir.glob("*.json"))
        
        # 使用多线程并行处理
        with ThreadPoolExecutor() as executor:
            futures = []
            for llm_cfg_path in llm_cfg_paths:
                future = executor.submit(self.process_file, llm_cfg_path)
                futures.append(future)
            
            # 使用tqdm显示进度
            for future in tqdm(as_completed(futures), total=len(futures), desc="处理CFG文件"):
                future.result()
                
        return self.results
    
    @staticmethod
    def _result_to_dict(result: CFGSimilarityResult) -> Dict:
        """将CFGSimilarityResult转换为可JSON序列化的字典"""
        return {
            "filename": result.filename,
            "edge_coverage": result.edge_coverage,
            "content_similarity": result.content_similarity,
            "structure_similarity": result.structure_similarity,
            "matched_edges": result.matched_edges,
            "gt_edges": result.gt_edges,
            "llm_edges": result.llm_edges,
            "nested_results": {
                k: CFGEvaluator._result_to_dict(v)
                for k, v in result.nested_results.items()
            } if result.nested_results else None,
            "llm_similarity": result.llm_similarity
        }

def calculate_aggregate_metrics(results: List[CFGSimilarityResult]) -> Dict:
    """计算聚合指标
    
    Args:
        results: CFGSimilarityResult列表
        
    Returns:
        包含聚合指标的字典
    """
    metrics = {
        "total_cfgs_compared": len(results),
        "average_edge_coverage": np.mean([r.edge_coverage for r in results]),
        "average_content_similarity": np.mean([r.content_similarity for r in results]),
        "average_structure_similarity": np.mean([r.structure_similarity for r in results]),
        "total_gt_edges": sum(r.gt_edges for r in results),
        "total_llm_edges": sum(r.llm_edges for r in results),
        "total_matched_edges": sum(r.matched_edges for r in results)
    }
    return metrics

def main():
    evaluator = CFGEvaluator(
        llm_cfg_dir="merged_llm_cfg_50",
        static_cfg_dir="static_cfg",
        result_file="evaluation_results_50.json"
    )
    
    # 评估所有CFG
    results = evaluator.evaluate_all()
    
    # 计算统计指标
    metrics = calculate_aggregate_metrics(results)
    
    # 计算LLM评估的平均值
    llm_metrics = {
        "average_llm_structure_similarity": np.mean([
            r.llm_similarity["structure_similarity"] 
            for r in results if r.llm_similarity
        ]),
        "average_llm_content_similarity": np.mean([
            r.llm_similarity["content_similarity"] 
            for r in results if r.llm_similarity
        ]),
        "average_llm_total_similarity": np.mean([
            r.llm_similarity["total_similarity"] 
            for r in results if r.llm_similarity
        ]),
        "reasonable_percentage": np.mean([
            float(r.llm_similarity["reasonable"]) 
            for r in results if r.llm_similarity
        ]) * 100
    }
    
    # 输出评估结果
    print("\nAutomatic Evaluation Summary:")
    print(f"Total CFGs compared: {metrics['total_cfgs_compared']}")
    print(f"Average Edge Coverage: {metrics['average_edge_coverage']:.2f}")
    print(f"Average Content Similarity: {metrics['average_content_similarity']:.2f}")
    print(f"Average Structure Similarity: {metrics['average_structure_similarity']:.2f}")
    
    print("\nLLM Evaluation Summary:")
    print(f"Average Structure Similarity: {llm_metrics['average_llm_structure_similarity']:.2f}")
    print(f"Average Content Similarity: {llm_metrics['average_llm_content_similarity']:.2f}")
    print(f"Average Total Similarity: {llm_metrics['average_llm_total_similarity']:.2f}")
    print(f"Reasonable Percentage: {llm_metrics['reasonable_percentage']:.1f}%")
    
    # 保存完整的评估指标
    metrics.update(llm_metrics)
    with open("evaluation_metrics_50.json", "w") as f:
        json.dump(metrics, f, indent=2)

if __name__ == "__main__":
    main()


处理CFG文件: 100%|██████████| 48/48 [00:18<00:00,  2.60it/s]


Automatic Evaluation Summary:
Total CFGs compared: 48
Average Edge Coverage: 0.52
Average Content Similarity: 0.00
Average Structure Similarity: 0.83

LLM Evaluation Summary:
Average Structure Similarity: 0.84
Average Content Similarity: 0.86
Average Total Similarity: 0.85
Reasonable Percentage: 77.1%





### Refine

In [5]:
from dataclasses import dataclass
from typing import List, Dict, Optional
import json
from pathlib import Path

@dataclass
class CFGNode:
    id: str
    code: str
    order: int  # 添加order来保持顺序
    start_line: int = 0
    end_line: int = 0

@dataclass
class CFGEdge:
    from_node: str
    to_node: str

@dataclass
class CFGData:
    nodes: List[CFGNode]
    edges: List[CFGEdge]

@dataclass
class CodeBlock:
    decl_name: str
    start_line: int
    end_line: int
    children: List['CodeBlock']
    code: str
    cfg: Optional[CFGData] = None

def parse_json_to_cfg(json_data: dict) -> CodeBlock:
    """将JSON数据解析为CodeBlock对象"""
    cfg = None
    if 'cfg' in json_data:
        cfg_data = json_data['cfg']
        # 创建节点，添加order属性
        nodes = [
            CFGNode(
                id=node['id'], 
                code=node['code'],
                order=idx  # 使用索引作为顺序
            )
            for idx, node in enumerate(cfg_data['nodes'])
        ]
        # 创建边
        edges = [
            CFGEdge(
                from_node=edge['from'], 
                to_node=edge['to']
            )
            for edge in cfg_data['edges']
        ]
        cfg = CFGData(nodes=nodes, edges=edges)

    children = [
        parse_json_to_cfg(child) 
        for child in json_data.get('children', [])
    ]

    return CodeBlock(
        decl_name=json_data['decl_name'],
        start_line=json_data['start_line'],
        end_line=json_data['end_line'],
        code=json_data['code'],
        children=children,
        cfg=cfg
    )

def can_merge_global_nodes(node1: CFGNode, node2: CFGNode, edges: List[CFGEdge]) -> bool:
    """判断全局作用域中的节点是否可以合并"""
    # 检查是否都是全局作用域的节点
    if not (node1.id.startswith("GlobalBlock_") and node2.id.startswith("GlobalBlock_")):
        return False
        
    # 检查节点间是否有其他控制流（如if/else, try/except等）
    node1_code = node1.code.strip()
    node2_code = node2.code.strip()
    
    # 简单检查是否都是导入语句或简单的赋值语句
    def is_simple_statement(code: str) -> bool:
        lines = code.split('\n')
        for line in lines:
            line = line.strip()
            if not line:
                continue
            if not (line.startswith('from ') or 
                   line.startswith('import ') or 
                   '=' in line or 
                   line.startswith('#')):
                return False
        return True
        
    return is_simple_statement(node1_code) and is_simple_statement(node2_code)

def merge_nodes(node1_id: str, node2_id: str, cfg_data: CFGData) -> CFGData:
    """合并两个节点，返回新的CFG数据"""
    nodes_dict = {node.id: node for node in cfg_data.nodes}
    
    # 获取原始节点
    node1 = nodes_dict[node1_id]
    node2 = nodes_dict[node2_id]
    
    # 合并代码
    merged_code = f"{node1.code}\n{node2.code}"
    
    # 创建新节点，保持较小的order
    merged_node = CFGNode(
        id=node1_id,
        code=merged_code,
        order=min(node1.order, node2.order)
    )
    
    # 更新节点列表，保持顺序
    new_nodes = [node for node in cfg_data.nodes if node.id not in {node1_id, node2_id}]
    new_nodes.append(merged_node)
    new_nodes.sort(key=lambda x: x.order)
    
    # 更新边
    new_edges = []
    for edge in cfg_data.edges:
        if edge.from_node == node2_id:
            new_edges.append(CFGEdge(from_node=node1_id, to_node=edge.to_node))
        elif edge.to_node == node2_id:
            continue
        elif edge.from_node != node1_id or edge.to_node != node2_id:
            new_edges.append(edge)
    
    return CFGData(nodes=new_nodes, edges=new_edges)

def can_merge_nodes(from_node_id: str, to_node_id: str, edges: List[CFGEdge], nodes_dict: Dict[str, CFGNode]) -> bool:
    """判断两个节点是否可以合并
    条件：
    1. from_node 必须直接连接到 to_node
    2. to_node 只能有一个入边（来自 from_node）
    3. from_node 只能有一个出边（到 to_node）
    """
    # 检查 from_node -> to_node 的直接连接
    is_directly_connected = any(
        edge.from_node == from_node_id and edge.to_node == to_node_id
        for edge in edges
    )
    if not is_directly_connected:
        return False

    # 检查 to_node 的入边数量
    incoming_edges_to_node = sum(1 for edge in edges if edge.to_node == to_node_id)
    if incoming_edges_to_node > 1:
        return False

    # 检查 from_node 的出边数量
    outgoing_edges_from_node = sum(1 for edge in edges if edge.from_node == from_node_id)
    if outgoing_edges_from_node > 1:
        return False

    return True

def optimize_cfg(cfg_data: CFGData) -> CFGData:
    """优化CFG，合并可以合并的节点"""
    if not cfg_data or len(cfg_data.nodes) <= 1:
        return cfg_data
    
    nodes_dict = {node.id: node for node in cfg_data.nodes}
    changed = True
    
    while changed:
        changed = False
        nodes = sorted(cfg_data.nodes, key=lambda x: x.order)
        
        # 首先尝试合并全局节点
        for i in range(len(nodes) - 1):
            node1 = nodes[i]
            node2 = nodes[i + 1]
            if can_merge_global_nodes(node1, node2, cfg_data.edges):
                cfg_data = merge_nodes(node1.id, node2.id, cfg_data)
                nodes_dict = {node.id: node for node in cfg_data.nodes}
                changed = True
                break
        
        # 如果没有全局节点可以合并，再尝试常规的边合并
        if not changed:
            edges = sorted(cfg_data.edges, 
                         key=lambda e: (nodes_dict[e.from_node].order, nodes_dict[e.to_node].order))
            for edge in edges:
                if can_merge_nodes(edge.from_node, edge.to_node, cfg_data.edges, nodes_dict):
                    cfg_data = merge_nodes(edge.from_node, edge.to_node, cfg_data)
                    nodes_dict = {node.id: node for node in cfg_data.nodes}
                    changed = True
                    break
    
    return cfg_data

def optimize_code_block(block: CodeBlock) -> CodeBlock:
    """优化单个代码块"""
    if block.cfg:
        block.cfg = optimize_cfg(block.cfg)
    
    # 递归优化子块
    block.children = [optimize_code_block(child) for child in block.children]
    return block

def parse_and_optimize_json_to_cfg(json_data: List[dict]) -> List[CodeBlock]:
    """解析JSON并优化所有代码块"""
    blocks = [parse_json_to_cfg(block_data) for block_data in json_data]
    return [optimize_code_block(block) for block in blocks]

def print_code_block_hierarchy(block: CodeBlock, indent: int = 0):
    """打印代码块层次结构"""
    indent_str = "  " * indent
    print(f"{indent_str}Block: {block.decl_name} (lines {block.start_line}-{block.end_line})")
    if block.cfg:
        print(f"{indent_str}CFG Nodes: {len(block.cfg.nodes)}")
        print(f"{indent_str}CFG Edges: {len(block.cfg.edges)}")
        print(f"{indent_str}Nodes (in order):")
        for node in sorted(block.cfg.nodes, key=lambda x: x.order):
            print(f"{indent_str}  Node {node.id} (order {node.order}):")
            print(f"{indent_str}    {node.code.strip()}")
        
        # 打印边的关系
        print(f"{indent_str}Edges:")
        for edge in block.cfg.edges:
            print(f"{indent_str}  {edge.from_node} -> {edge.to_node}")
        print()  # 空行分隔
        
    for child in block.children:
        print_code_block_hierarchy(child, indent + 1)

llm_cfg_path = Path("llm_cfg/0.json")
with open(llm_cfg_path, 'r', encoding='utf-8') as f:
    llm_cfg = json.load(f)

# 解析和优化
llm_blocks = parse_and_optimize_json_to_cfg(llm_cfg)

# 打印结果
print("\nOptimized Code Block Hierarchy:")
for block in llm_blocks:
    print_code_block_hierarchy(block)



Optimized Code Block Hierarchy:
Block: GlobalBlock (lines 0-9)
CFG Nodes: 1
CFG Edges: 0
Nodes (in order):
  Node GlobalBlock_1 (order 0):
    from prisma.models import User

from backend.blocks.basic import AgentInputBlock, PrintToConsoleBlock
from backend.blocks.text import FillTextTemplateBlock
from backend.data import graph
from backend.data.graph import create_graph
from backend.data.user import get_or_create_user
from backend.util.test import SpinTestServer, wait_execution
Edges:

Block: create_test_user (lines 10-17)
CFG Nodes: 1
CFG Edges: 0
Nodes (in order):
  Node create_test_user_1 (order 0):
    async def create_test_user() -> User:

    test_user_data = {
        "sub": "ef3b97d7-1161-4eb4-92b2-10c24fb154c1",
        "email": "testuser#example.com",
        "name": "Test User",
    }

    user = await get_or_create_user(test_user_data)

    return user
Edges:

Block: create_test_graph (lines 20-72)
CFG Nodes: 1
CFG Edges: 0
Nodes (in order):
  Node create_test_graph_1 (or

In [6]:
from typing import List, Tuple, Optional
from dataclasses import dataclass

@dataclass
class CFGNode:
    id: str
    code: str
    order: int
    start_line: int = 0
    end_line: int = 0

@dataclass
class CFGEdge:
    from_node: str
    to_node: str

@dataclass
class CFGData:
    nodes: List[CFGNode]
    edges: List[CFGEdge]

@dataclass
class CodeBlock:
    decl_name: str
    start_line: int
    end_line: int
    children: List['CodeBlock']
    code: str
    cfg: Optional[CFGData] = None

def parse_block(block):
    """解析基本块"""
    # 获取块中语句的行号范围
    start_line = min([stmt.lineno for stmt in block.statements if hasattr(stmt, 'lineno')] or [0])
    end_line = max([getattr(stmt, 'end_lineno', stmt.lineno) for stmt in block.statements if hasattr(stmt, 'lineno')] or [0])
    
    statements = []
    for stmt in block.statements:
        if hasattr(stmt, 'lineno'):
            try:
                import ast
                if hasattr(ast, 'unparse'):
                    code = ast.unparse(stmt)
                else:
                    code = str(stmt)
            except:
                code = str(stmt)
            statements.append(code)
    
    return {
        "id": block.id,
        "code": "\n".join(statements),
        "start_line": start_line,
        "end_line": end_line,
        "order": getattr(block, 'order', 0)
    }

def process_block(name, block_cfg):
    """处理单个代码块（可以是函数或类方法）"""
    if not hasattr(block_cfg, 'entryblock'):
        return [], []
        
    nodes = []
    edges = []
    visited = set()
    block_order = 0
    
    # 使用广度优先搜索来保持正确的顺序
    queue = [(block_cfg.entryblock, block_order)]
    while queue:
        current_block, order = queue.pop(0)
        if current_block in visited:
            continue
            
        visited.add(current_block)
        
        block_data = parse_block(current_block)
        block_data["order"] = order
        nodes.append(block_data)
        
        # 添加所有出边
        for link in current_block.exits:
            if hasattr(link, "target"):
                edges.append({
                    "from": current_block.id,
                    "to": link.target.id
                })
                if link.target not in visited:
                    block_order += 1
                    queue.append((link.target, block_order))
    
    # 按order排序节点
    nodes.sort(key=lambda x: x["order"])
    return nodes, edges

def process_cfg_recursively(cfg_obj, prefix="", processed=None, context=None):
    """递归处理CFG对象，返回所有代码块"""
    if processed is None:
        processed = set()
    if context is None:
        context = {"in_class": False}
    
    blocks = []
    
    # 使用cfg_obj的id作为唯一标识
    cfg_id = id(cfg_obj)
    if cfg_id in processed:
        return blocks
    processed.add(cfg_id)
    
    # 处理当前层级的主体代码
    if hasattr(cfg_obj, 'entryblock'):
        nodes, edges = process_block(f"{prefix}", cfg_obj)
        if nodes:  # 只有当有实际内容时才添加
            block_type = "GlobalBlock"
            if context.get("in_class"):
                if prefix.endswith(".__init__"):
                    block_type = "Constructor"
                elif "." in prefix:
                    block_type = "Method"
                else:
                    block_type = "ClassBody"
            elif prefix:
                block_type = "Function"
                
            # 处理嵌套类名称
            display_name = prefix
            if "." in prefix and not context.get("in_class"):
                parts = prefix.split(".")
                display_name = ".".join([p if i == 0 else f"Input" if p == "Input" else p 
                                       for i, p in enumerate(parts)])
            
            blocks.append({
                "name": f"{block_type}: {display_name}" if prefix else block_type,
                "nodes": nodes,
                "edges": edges,
                "type": block_type,
                "original_name": prefix,  # 保存原始名称用于排序
                "line_info": (min([n.get("start_line", 0) for n in nodes] or [0]),
                            max([n.get("end_line", 0) for n in nodes] or [0]))
            })
    
    # 获取所有需要处理的项
    all_items = []
    
    # 添加类
    if hasattr(cfg_obj, 'class_cfgs'):
        for class_name, class_cfg in cfg_obj.class_cfgs.items():
            all_items.append(('class', class_name, class_cfg))
    
    # 添加方法
    if hasattr(cfg_obj, 'methodcfgs'):
        for method_name, method_cfg in cfg_obj.methodcfgs.items():
            all_items.append(('method', method_name, method_cfg))
    
    # 添加函数
    if hasattr(cfg_obj, 'functioncfgs'):
        for func_name, func_cfg in cfg_obj.functioncfgs.items():
            all_items.append(('function', func_name, func_cfg))
    
    # 按照源代码中的顺序排序
    def get_first_line(item):
        _, _, cfg = item
        if hasattr(cfg, 'entryblock') and hasattr(cfg.entryblock, 'statements'):
            statements = cfg.entryblock.statements
            if statements and hasattr(statements[0], 'lineno'):
                return statements[0].lineno
        return float('inf')
    
    all_items.sort(key=get_first_line)
    
    # 按顺序处理所有项
    for item_type, name, sub_cfg in all_items:
        if id(sub_cfg) in processed:
            continue
        
        new_prefix = f"{prefix}.{name}" if prefix else name
        new_context = {"in_class": item_type == 'class'}
        
        # 递归处理
        sub_blocks = process_cfg_recursively(
            sub_cfg,
            new_prefix,
            processed,
            new_context
        )
        blocks.extend(sub_blocks)
    
    return blocks

def parse_cfg(cfg):
    """解析CFG为结构化数据"""
    return process_cfg_recursively(cfg)

def dedent_code(code: str) -> str:
    """处理代码缩进"""
    lines = code.split('\n')
    if not lines:
        return ''
    
    # 找到最小的非空行缩进
    min_indent = float('inf')
    for line in lines:
        if line.strip():
            indent = len(line) - len(line.lstrip())
            min_indent = min(min_indent, indent)
    
    if min_indent == float('inf'):
        return code
    
    # 删除多余的缩进
    result = []
    for line in lines:
        if line.strip():
            result.append(line[min_indent:])
        else:
            result.append(line)
    
    return '\n'.join(result)

def convert_parsed_cfg_to_codeblock(cfg_blocks) -> List[CodeBlock]:
    """将解析后的CFG转换为CodeBlock结构"""
    def create_cfg_data(block) -> CFGData:
        nodes = []
        edges = []
        
        # 转换节点
        for node in block["nodes"]:
            nodes.append(CFGNode(
                id=node["id"],
                code=node["code"],
                order=node["order"],
                start_line=node.get("start_line", 0),
                end_line=node.get("end_line", 0)
            ))
        
        # 转换边
        for edge in block["edges"]:
            edges.append(CFGEdge(
                from_node=edge["from"],
                to_node=edge["to"]
            ))
        
        return CFGData(nodes=nodes, edges=edges)
    
    def get_block_code(nodes: List[dict]) -> str:
        """从节点列表中提取完整的代码"""
        sorted_nodes = sorted(nodes, key=lambda x: x["order"])
        code_parts = []
        for node in sorted_nodes:
            if node["code"]:
                code_parts.append(dedent_code(node["code"]))
        return "\n".join(code_parts)
    
    def build_block_hierarchy(blocks: List[dict]) -> List[CodeBlock]:
        """构建代码块层级结构"""
        # 按照行号排序所有块
        sorted_blocks = sorted(blocks, key=lambda x: x["line_info"][0])
        
        # 创建一个映射来存储所有块
        block_map = {}  # original_name -> CodeBlock
        
        # 第一遍：创建所有 CodeBlock 对象
        for block in sorted_blocks:
            name = block["original_name"]
            cfg_data = create_cfg_data(block)
            code = get_block_code(block["nodes"])
            start_line, end_line = block["line_info"]
            
            code_block = CodeBlock(
                decl_name=name,
                start_line=start_line,
                end_line=end_line,
                children=[],
                code=code,
                cfg=cfg_data
            )
            block_map[name] = code_block
        
        # 第二遍：构建层级关系
        result = []
        for block in sorted_blocks:
            name = block["original_name"]
            code_block = block_map[name]
            
            # 检查是否是某个类的成员
            if "." in name:
                parent_name = name.split(".")[0]
                if parent_name in block_map:
                    block_map[parent_name].children.append(code_block)
                    continue
            
            # 如果不是成员，添加到结果列表
            result.append(code_block)
        
        return result
    
    return build_block_hierarchy(cfg_blocks)

def visualize_cfg(code_blocks: List[CodeBlock], indent: str = "") -> str:
    """生成CFG的文本可视化"""
    result = ["Optimized Code Block Hierarchy:"]
    
    def visualize_block(block: CodeBlock, level: int = 0):
        """递归可视化代码块及其子块"""
        indent = "  " * level
        lines = []
        
        # 添加块名称和行号范围
        block_info = f"{indent}Block: {block.decl_name}"
        if block.start_line > 0 and block.end_line > 0:
            block_info += f" (lines {block.start_line}-{block.end_line})"
        lines.append(block_info)
        
        # 添加代码内容
        lines.append(f"{indent}Code:")
        for line in block.code.split('\n'):
            if line.strip():
                lines.append(f"{indent}  {line}")
        
        # 添加CFG信息（如果存在）
        if block.cfg:
            lines.append(f"{indent}CFG Nodes: {len(block.cfg.nodes)}")
            lines.append(f"{indent}CFG Edges: {len(block.cfg.edges)}")
            
            # 按顺序显示节点
            lines.append(f"{indent}Nodes (in order):")
            sorted_nodes = sorted(block.cfg.nodes, key=lambda x: x.order)
            for node in sorted_nodes:
                lines.append(f"{indent}  Node {node.id} (order {node.order})")
            
            # 显示边
            lines.append(f"{indent}Edges:")
            for edge in block.cfg.edges:
                lines.append(f"{indent}  {edge.from_node} -> {edge.to_node}")
        
        # 递归处理子块
        if block.children:
            lines.append(f"{indent}Children:")
            for child in block.children:
                lines.extend(visualize_block(child, level + 1))
        
        lines.append("")  # 添加空行分隔不同的块
        return lines
    
    # 处理所有顶层块
    for block in code_blocks:
        result.extend(visualize_block(block))
    
    return '\n'.join(result)

# 使用示例：
if __name__ == "__main__":
    from scalpel.cfg import CFGBuilder

    file_path = "source_code/0.py"
    src = open(file_path, 'r', encoding='utf-8').read()

    cfg = CFGBuilder().build_from_src("1", src)
    parsed_cfg = parse_cfg(cfg)
    static_blocks = convert_parsed_cfg_to_codeblock(parsed_cfg)
    print(visualize_cfg(static_blocks))


Optimized Code Block Hierarchy:
Block:  (lines 1-92)
Code:
  from prisma.models import User
  from backend.blocks.basic import AgentInputBlock, PrintToConsoleBlock
  from backend.blocks.text import FillTextTemplateBlock
  from backend.data import graph
  from backend.data.graph import create_graph
  from backend.data.user import get_or_create_user
  from backend.util.test import SpinTestServer, wait_execution
  async def create_test_user() -> User:
      test_user_data = {'sub': 'ef3b97d7-1161-4eb4-92b2-10c24fb154c1', 'email': 'testuser#example.com', 'name': 'Test User'}
      user = await get_or_create_user(test_user_data)
      return user
  def create_test_graph() -> graph.Graph:
      """
      InputBlock
                                  ---- FillTextTemplateBlock ---- PrintToConsoleBlock
                 /
      InputBlock
      """
      nodes = [graph.Node(block_id=AgentInputBlock().id, input_default={'name': 'input_1'}), graph.Node(block_id=AgentInputBlock().id, input_default=

In [8]:
from collections import defaultdict, deque


def extract_calls_from_cfg(cfg_node: CFGNode) -> List[str]:
    """从CFG节点的代码中提取函数调用"""
    import ast
    
    class CallVisitor(ast.NodeVisitor):
        def __init__(self):
            self.calls = []
            
        def visit_Call(self, node):
            if isinstance(node.func, ast.Name):
                self.calls.append(node.func.id)
            elif isinstance(node.func, ast.Attribute):
                self.calls.append(node.func.attr)
            self.generic_visit(node)
    
    calls = []
    try:
        tree = ast.parse(cfg_node.code)
        visitor = CallVisitor()
        visitor.visit(tree)
        calls = visitor.calls
    except:
        pass
    
    return calls

@dataclass
class BlockWithCalls:
    """包含调用信息的代码块"""
    decl_name: str
    code: str
    calls: List[str]

def convert_to_blocks_with_calls(code_blocks: List[CodeBlock]) -> List[BlockWithCalls]:
    """将CodeBlock转换为包含调用信息的BlockWithCalls"""
    blocks_with_calls = []
    
    for block in code_blocks:
        calls = []
        if block.cfg:
            # 从所有CFG节点中收集调用
            for node in block.cfg.nodes:
                calls.extend(extract_calls_from_cfg(node))
        
        # 去重调用
        calls = list(set(calls))
        
        blocks_with_calls.append(BlockWithCalls(
            decl_name=block.decl_name,
            code=block.code,
            calls=calls
        ))
    
    return blocks_with_calls

class TopologicalBlockMatcher:
    def __init__(self):
        self.similarity_threshold = 0.7

    def blocks_to_graph(self, blocks: List[BlockWithCalls]) -> Dict[str, List[str]]:
        """将代码块列表转换为邻接表表示的图"""
        graph = defaultdict(list)
        for block in blocks:
            graph[block.decl_name].extend(block.calls)
        return dict(graph)

    def get_block_by_name(self, blocks: List[BlockWithCalls], name: str) -> BlockWithCalls:
        """通过函数名获取代码块"""
        for block in blocks:
            if block.decl_name == name:
                return block
        return None

    def topological_sort(self, blocks: List[BlockWithCalls]) -> List[str]:
        """
        对代码块进行拓扑排序
        Returns:
            排序后的函数名列表
        """
        if not blocks:
            return []
            
        # 构建图
        graph = self.blocks_to_graph(blocks)
        
        # 计算入度
        in_degree = defaultdict(int)
        for node in graph:
            for successor in graph[node]:
                if successor:  # 确保successor不是None
                    in_degree[successor] += 1
            if node not in in_degree:
                in_degree[node] = 0
    
        # 初始化队列（入度为0的节点）
        queue = deque([node for node, degree in in_degree.items() if degree == 0 and node])
        result = []
    
        # BFS进行拓扑排序
        while queue:
            node = queue.popleft()
            if node:  # 确保node不是None
                result.append(node)
                for successor in graph.get(node, []):
                    if successor:  # 确保successor不是None
                        in_degree[successor] -= 1
                        if in_degree[successor] == 0:
                            queue.append(successor)
    
        # 添加可能未被引用的节点
        all_nodes = {block.decl_name for block in blocks if block.decl_name}
        for node in all_nodes:
            if node and node not in result:
                result.append(node)
    
        return result


    def calculate_code_similarity(self, code1: str, code2: str) -> float:
        """计算代码相似度"""
        # 简化代码（移除空白字符和注释）
        def clean_code(code: str) -> str:
            import re
            # 移除注释
            code = re.sub(r'#.*$', '', code, flags=re.MULTILINE)
            # 移除多余空白字符
            code = ' '.join(code.split())
            return code
        
        code1 = clean_code(code1)
        code2 = clean_code(code2)
        
        if not code1 or not code2:
            return 0.0
        
        if code1 == code2:
            return 1.0
            
        # 使用最长公共子序列计算相似度
        return self.lcs_similarity(code1, code2)

    def lcs_similarity(self, s1: str, s2: str) -> float:
        """使用最长公共子序列计算相似度"""
        m, n = len(s1), len(s2)
        dp = [[0] * (n + 1) for _ in range(m + 1)]
        
        for i in range(1, m + 1):
            for j in range(1, n + 1):
                if s1[i-1] == s2[j-1]:
                    dp[i][j] = dp[i-1][j-1] + 1
                else:
                    dp[i][j] = max(dp[i-1][j], dp[i][j-1])
        
        lcs_length = dp[m][n]
        return 2.0 * lcs_length / (len(s1) + len(s2))

    def calculate_structural_similarity(self, 
                                     block1: BlockWithCalls, 
                                     block2: BlockWithCalls, 
                                     position1: int, 
                                     position2: int,
                                     max_pos: int) -> float:
        """计算结构相似度"""
        # 位置相似度
        pos_sim = 1.0 - abs(position1 - position2) / max_pos if max_pos > 0 else 1.0
        
        # 调用关系相似度
        calls1 = set(block1.calls)
        calls2 = set(block2.calls)
        calls_sim = len(calls1 & calls2) / len(calls1 | calls2) if calls1 or calls2 else 1.0
        
        # 计算加权平均
        return 0.4 * pos_sim + 0.6 * calls_sim

    def match_blocks(self, 
                    llm_blocks: List[BlockWithCalls], 
                    static_blocks: List[BlockWithCalls]) -> Dict:
        """匹配两组代码块"""
        # 1. 对两组代码块进行拓扑排序
        sorted_llm = self.topological_sort(llm_blocks)
        sorted_static = self.topological_sort(static_blocks)

        # 2. 计算匹配
        matches = []
        max_pos = max(len(sorted_llm), len(sorted_static)) - 1
        if max_pos < 0:
            max_pos = 0

        # 构建快速查找字典
        llm_blocks_dict = {b.decl_name: b for b in llm_blocks}
        static_blocks_dict = {b.decl_name: b for b in static_blocks}

        for i, name1 in enumerate(sorted_llm):
            block1 = llm_blocks_dict.get(name1)
            if not block1:
                continue

            for j, name2 in enumerate(sorted_static):
                block2 = static_blocks_dict.get(name2)
                if not block2:
                    continue

                # 计算代码相似度
                code_sim = self.calculate_code_similarity(block1.code, block2.code)
                
                # 计算结构相似度
                struct_sim = self.calculate_structural_similarity(
                    block1, block2, i, j, max_pos
                )
                
                # 综合相似度
                similarity = 0.6 * code_sim + 0.4 * struct_sim
                
                if similarity >= self.similarity_threshold:
                    matches.append((name1, name2, similarity))

        return {
            'matches': sorted(matches, key=lambda x: x[2], reverse=True),
            'topological_order_llm': sorted_llm,
            'topological_order_static': sorted_static,
            'coverage': len(set(m[0] for m in matches)) / len(llm_blocks) if llm_blocks else 0.0
        }

def format_match_line(block1: str, block2: str, similarity: float, width: int = 30) -> str:
    """格式化匹配结果行"""
    block1_str = str(block1).ljust(width)
    block2_str = str(block2).ljust(width)
    return f"{block1_str} <-> {block2_str} (相似度: {similarity:.2%})"

def print_match_results(result: Dict):
    """打印匹配结果"""
    print("\n" + "="*50)
    print("代码块匹配结果")
    print("="*50)
    
    print("\n1. 匹配的代码块对:")
    if result['matches']:
        for block1, block2, sim in result['matches']:
            print(format_match_line(block1, block2, sim))
    else:
        print("未找到匹配的代码块")
    
    print("\n2. LLM代码块的拓扑排序:")
    llm_order = [str(x) for x in result['topological_order_llm'] if x is not None]
    print(" -> ".join(llm_order) if llm_order else "空")
    
    print("\n3. 静态分析代码块的拓扑排序:")
    static_order = [str(x) for x in result['topological_order_static'] if x is not None]
    print(" -> ".join(static_order) if static_order else "空")
    
    print(f"\n4. 覆盖率: {result['coverage']:.2%}")
    print("="*50 + "\n")


def main():
    # 转换为带调用信息的代码块
    llm_blocks_with_calls = convert_to_blocks_with_calls(llm_blocks)
    static_blocks_with_calls = convert_to_blocks_with_calls(static_blocks)
    
    # 进行匹配
    matcher = TopologicalBlockMatcher()
    result = matcher.match_blocks(llm_blocks_with_calls, static_blocks_with_calls)
    print_match_results(result)

if __name__ == "__main__":
    main()



代码块匹配结果

1. 匹配的代码块对:
create_test_user               <-> (1, 'create_test_user')        (相似度: 80.90%)
sample_agent                   <-> (1, 'sample_agent')            (相似度: 80.13%)
create_test_graph              <-> (1, 'create_test_graph')       (相似度: 77.09%)

2. LLM代码块的拓扑排序:
GlobalBlock -> run -> sample_agent -> SpinTestServer -> print -> test_execute_graph -> create_test_user -> create_graph -> create_test_graph -> wait_execution -> get_or_create_user -> AgentInputBlock -> Node -> Link -> FillTextTemplateBlock -> PrintToConsoleBlock -> Graph

3. 静态分析代码块的拓扑排序:
(1, 'create_test_user') -> (1, 'create_test_graph') -> (1, 'sample_agent')

4. 覆盖率: 60.00%

