In [1]:
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import os
import json
import sys
from functools import partial
from concurrent.futures import as_completed, ThreadPoolExecutor
from multiprocessing import cpu_count
from typing import List, Dict, Any

# =========================================================
#  假设在 llm.py 中已有的函数，如下为占位示例
#  实际请用你自己的实现替换
# =========================================================
from llm import get_llm_answers

# =========================================================
#  第一步：从源文件生成 CFG（带行号），并保存到 llm_whole_cfg/
# =========================================================
def get_whole_code_cfg_prompt(code_text: str, program_language: str) -> str:
    """
    生成整体CFG的Prompt，直接处理整个代码
    """
    code_lines = code_text.splitlines()
    line_array = [{"lineno": i+1, "line": ln} for i, ln in enumerate(code_lines)]
    
    prompt = f"""
You will be given a {program_language} code file. Your task is to generate a complete Control Flow Graph (CFG) for the ENTIRE code.

Output Requirements:
1. Identify ALL basic blocks in the code
2. Capture control flow between ALL blocks
3. Handle nested structures (if/for/while/try/etc) properly
4. Use EXACT line numbers from the original code
5. Output format must be:

{{
    "blocks": [
        {{
            "id": 1,
            "start_line": 1,
            "end_line": 1,
            "label": "code snippet",
            "successors": [2, 3]
        }},
        // ... more blocks
    ]
}}

The code is:
{json.dumps(line_array, indent=2)}

IMPORTANT:
- Use original line numbers
- Include ALL code lines in blocks
- Maintain execution order
- Handle function/method calls as linear flow
- For loops, show back edges
- For conditionals, show both branches
"""
    return prompt.strip()

def process_whole_code_cfg(code: str, program_language: str) -> Dict[str, Any]:
    """
    调用LLM，基于给定源代码生成CFG（JSON格式）
    """
    prompt = get_whole_code_cfg_prompt(code, program_language)
    response = get_llm_answers(prompt, model_name="gpt-4o", require_json=True)
    try:
        return json.loads(response)
    except json.JSONDecodeError:
        # 如果LLM返回并非严格JSON，做一次简单修复或返回空结构
        return {
            "blocks": []
        }

def generate_llm_cfgs(start_idx: int, end_idx: int):
    """
    读取 [start_idx, end_idx] 范围内的源文件，调用LLM生成CFG。
    结果保存到 llm_whole_cfg/ 目录下。
    """
    source_code_dir = "../dataset/python"
    target_dir = "llm_whole_cfg"
    os.makedirs(target_dir, exist_ok=True)

    # 构造要处理的文件列表
    files = []
    for i in range(start_idx, end_idx+1):
        py_file = os.path.join(source_code_dir, f"{i}.py")
        out_file = os.path.join(target_dir, f"{i}.json")
        files.append((py_file, out_file))

    def process_single_file(source_file, target_file):
        if os.path.exists(target_file):
            return  # 已经生成过CFG，跳过

        if not os.path.exists(source_file):
            return  # 源文件不存在，跳过

        print(f"[GenerateCFG] Processing {source_file}")
        with open(source_file, 'r', encoding='utf-8') as f:
            code = f.read()

        # 调用LLM生成CFG
        cfg_result = process_whole_code_cfg(code, "python")

        # 包装后保存
        result = {
            "source_file": os.path.basename(source_file),
            "cfg": cfg_result,
            "code_length": len(code.splitlines())
        }

        with open(target_file, "w", encoding="utf-8") as fout:
            json.dump(result, fout, indent=2, ensure_ascii=False)

    # 并行处理
    with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
        futures = [executor.submit(process_single_file, src, tgt) for src, tgt in files]
        for _ in as_completed(futures):
            pass


# =========================================================
#  第二步：对生成的 CFG 做后处理（去除不可达节点、合并线性块等）
#  该过程将结果保存到 merged_llm_cfg/ 目录
# =========================================================

def process_cfg(cfg: Dict[str, Any]) -> Dict[str, Any]:
    """
    对CFG进行后处理：
      1. 移除不可达节点（假设 blocks[0] 为 root）
      2. 合并线性块
      3. 递归处理嵌套的functions/classes
    """
    def filter_connected_blocks(blocks: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
        """
        过滤不可达块
        """
        visited_ids = set()
        id_to_block = {}

        def collect_all_blocks(block_list):
            for b in block_list:
                id_to_block[b["id"]] = b
                if "successors" in b:
                    collect_all_blocks(b["successors"])

        collect_all_blocks(blocks)

        def dfs(block):
            if block["id"] in visited_ids:
                return
            visited_ids.add(block["id"])
            for succ_block in block.get("successors", []):
                dfs(succ_block)

        # 假设 blocks[0] 是root
        if blocks:
            root_block = blocks[0]
            dfs(root_block)

        def filter_nested(block_list):
            filtered = []
            for b in block_list:
                if b["id"] in visited_ids:
                    new_successors = filter_nested(b.get("successors", []))
                    filtered.append({
                        "id": b["id"],
                        "label": b["label"],
                        "successors": new_successors
                    })
            return filtered

        return filter_nested(blocks)

    def is_loop_header(block: Dict[str, Any]) -> bool:
        """
        简单判断块是否为循环头（以for或while开头）
        """
        code_str = block["label"].strip() if "label" in block else ""
        return code_str.startswith("for ") or code_str.startswith("while ")

    def merge_blocks_in_place(block: Dict[str, Any]) -> Dict[str, Any]:
        """
        递归合并线性块
        """
        successors = block.get("successors", [])
        if not successors:
            return block

        # 多个后继，说明是分支点，不合并
        if len(successors) > 1:
            for i, succ in enumerate(successors):
                successors[i] = merge_blocks_in_place(succ)
            block["successors"] = successors
            return block

        # 只有1个后继
        single_succ = successors[0]

        # 遇到循环头，不合并
        if is_loop_header(block) or is_loop_header(single_succ):
            block["successors"][0] = merge_blocks_in_place(single_succ)
            return block

        # 可以合并
        block["label"] = block["label"] + "\n" + single_succ["label"]
        block["successors"] = single_succ.get("successors", [])

        # 递归继续尝试合并
        if block["successors"]:
            new_succ_list = []
            for succ in block["successors"]:
                new_succ_list.append(merge_blocks_in_place(succ))
            block["successors"] = new_succ_list

        return block

    # 开始处理
    if "blocks" in cfg:
        # 1. 移除不可达块
        cfg["blocks"] = filter_connected_blocks(cfg["blocks"])

        # 2. 合并线性块
        if cfg["blocks"]:
            merged = []
            for b in cfg["blocks"]:
                merged_block = merge_blocks_in_place(b)
                merged.append(merged_block)
            cfg["blocks"] = merged

    # 递归处理 functions/classes
    if "functions" in cfg:
        for func in cfg["functions"]:
            process_cfg(func)

    if "classes" in cfg:
        for cls in cfg["classes"]:
            process_cfg(cls)

    return cfg


def postprocess_all_llm_cfgs():
    """
    批量读取 llm_whole_cfg/ 下的文件，进行后处理并保存到 merged_llm_cfg/
    """
    input_dir = "llm_whole_cfg"
    output_dir = "merged_llm_cfg"
    os.makedirs(output_dir, exist_ok=True)

    for filename in os.listdir(input_dir):
        if not filename.endswith(".json"):
            continue

        input_path = os.path.join(input_dir, filename)
        output_path = os.path.join(output_dir, filename)

        # 如果已经存在后处理结果，可选择跳过
        # if os.path.exists(output_path):
        #     continue

        print(f"[PostProcessCFG] Processing {filename}")
        try:
            with open(input_path, "r", encoding="utf-8") as f:
                llm_result = json.load(f)
        except Exception as e:
            print(f"Error loading {filename}: {str(e)}")
            continue

        # 注意：llm_result 的结构：{"source_file":..., "cfg":..., "code_length":...}
        # 我们只后处理其中的 cfg
        cfg_data = llm_result.get("cfg", {})
        try:
            cfg_data = process_cfg(cfg_data)
            # 更新回 llm_result
            llm_result["cfg"] = cfg_data
        except Exception as e:
            print(f"Error processing {filename}: {str(e)}")
            continue

        with open(output_path, "w", encoding="utf-8") as f:
            json.dump(llm_result, f, indent=2, ensure_ascii=False)


# =========================================================
# 第三步：将后处理过的 CFG 与静态分析结果做比较
#         并将对比结果输出到 results/ 目录
# =========================================================

def convert_cfg_json_to_edges(cfg_json: Dict[str, Any]) -> List[str]:
    """
    将 CFG 转换为 edges 列表，以便更方便地做文本层面的对比。
    返回形如：
      [
        "Edge 0: [Source] labelA => [Target] labelB",
        ...
      ]
    """
    edges = []
    edge_counter = 0

    def get_block_label(block: Dict[str, Any]) -> str:
        label = block.get("label", "")
        if not isinstance(label, str):
            label = str(label)
        # 去除多余空白
        label = ' '.join(label.replace('\n', '\\n').strip().split())
        return label

    def process_entity(entity: Dict[str, Any]):
        # 处理自身 blocks
        for block in entity.get("blocks", []):
            process_block(block)
        # 递归处理 functions/classes
        for key in ["functions", "classes"]:
            for sub_ent in entity.get(key, []):
                process_entity(sub_ent)

    def process_block(block: Dict[str, Any]):
        nonlocal edge_counter
        source_label = get_block_label(block)
        if not source_label:
            return

        # 处理后继
        for succ in block.get("successors", []):
            target_label = get_block_label(succ)
            if target_label:
                edges.append(f"Edge {edge_counter}: [Source] {source_label} => [Target] {target_label}")
                edge_counter += 1

        # 递归处理内嵌的 blocks 与 successors
        for sub_key in ["blocks", "successors"]:
            for sub_block in block.get(sub_key, []):
                process_block(sub_block)

    process_entity(cfg_json)
    return edges


def get_compare_prompt(
    code: str,
    llm_cfg: Dict[str, Any],
    static_cfg: Dict[str, Any]
) -> str:
    """
    根据给定的 code, LLM CFG, 静态 CFG 生成对比 prompt，
    让 LLM 给出 JSON 格式的对比结果。
    """
    llm_edges = convert_cfg_json_to_edges(llm_cfg)
    static_edges = convert_cfg_json_to_edges(static_cfg)

    prompt = f"""
Role: Control Flow Graph Validation Specialist
Objective: Accurately compare CFG structures between static analysis (ground truth) and LLM generation

### JSON Structure Definition
Ground Truth (static_cfg) & LLM Output (llm_cfg) follow:
[
    "Edge 0: [Source] node_A -> [Target] node_B",
    "Edge 1: [Source] node_C -> [Target] node_D",
    ...
]

### Comparison Criteria
1. Structure Matching:
   Match edges when:
   - Same branching pattern (sequential/conditional/loop)
   - Equivalent depth in nested structure
   - Matching control flow order

2. Mismatch Conditions:
   - Different number of successors in equivalent blocks
   - Inconsistent branch types (e.g., true/false vs multiple)
   - Missing/extra exception handling flows

### Analysis Task
1. For static_cfg:
   - Count total edges (ground truth)
   - Map control flow patterns

2. For llm_cfg:
   - Count total generated edges
   - Identify structurally matched edges
   - Traverse each edge and check if it corresponds to static analysis

3. Output (JSON):
{{
  "edge_analysis": {{
    "static_total": "Number of edges from static analysis",
    "llm_total": "Number of edges generated by LLM",
    "matched_edges": {{
      "exact_matches": "Number of exactly matched edges (type + position)",
      "partial_matches": "Number of type-matched edges with different positions"
    }},
    "accuracy_metrics": {{
      "precision": "exact_matches / llm_total",
      "recall": "exact_matches / static_total",
      "f1_score": "2*(precision*recall)/(precision+recall)"
    }}
  }},
  "structure_validation": {{
    "missing_blocks": ["Unmatched static block IDs"],
    "extra_blocks": ["Extra LLM block IDs"]
  }}
}}

### Input Data
Python Code:
{code}

Static Analysis CFG (Ground Truth):
{json.dumps(static_edges, indent=2)}

LLM Generated CFG:
{json.dumps(llm_edges, indent=2)}

Output JSON analysis ONLY.
"""
    return prompt.strip()


def compare_single_cfg(i: int):
    """
    对单个文件的 CFG 做比较，结果保存到 results/{i}.json
    """
    SOURCE_CODE_DIR = "../dataset/python"
    LLM_CFG_DIR = "merged_llm_cfg"
    STATIC_CFG_DIR = "../dataset/python_cfg"
    RESULTS_DIR = "cfg_results"

    code_path = os.path.join(SOURCE_CODE_DIR, f"{i}.py")
    llm_cfg_path = os.path.join(LLM_CFG_DIR, f"{i}.json")
    static_cfg_path = os.path.join(STATIC_CFG_DIR, f"{i}.json")
    result_path = os.path.join(RESULTS_DIR, f"{i}.json")

    # 如果结果已经存在可跳过
    if os.path.exists(result_path):
        return

    missing_files = [
        p for p in [code_path, llm_cfg_path, static_cfg_path]
        if not os.path.exists(p)
    ]
    if missing_files:
        print(f"[CompareCFG] Missing files for index {i}: {missing_files}")
        return

    try:
        code = open(code_path, "r", encoding="utf-8").read()
        with open(llm_cfg_path, "r", encoding="utf-8") as f:
            llm_data = json.load(f)
        with open(static_cfg_path, "r", encoding="utf-8") as f:
            static_cfg = json.load(f)

        # 注意：llm_data 的核心CFG在 llm_data["cfg"] 中
        llm_cfg = llm_data.get("cfg", {})

        # 构造 prompt
        prompt = get_compare_prompt(code, llm_cfg, static_cfg)
        # 调用 LLM 得到对比结果
        response = get_llm_answers(prompt, model_name="gpt-4o", require_json=True)
        # 解析对比结果为 JSON
        result_json = json.loads(response)

        # 保存
        os.makedirs("results", exist_ok=True)
        with open(result_path, "w", encoding="utf-8") as f:
            json.dump(result_json, f, indent=2, ensure_ascii=False)

    except Exception as e:
        print(f"[CompareCFG] Error processing file {i}: {str(e)}")


def compare_cfgs(start_idx: int, end_idx: int):
    """
    并行对比 [start_idx, end_idx] 范围内的文件CFG
    """
    os.makedirs("cfg_results", exist_ok=True)
    file_indices = range(start_idx, end_idx+1)

    with ThreadPoolExecutor(max_workers=cpu_count()) as executor:
        futures = [executor.submit(compare_single_cfg, i) for i in file_indices]
        # 这里可以用进度条，也可直接等待
        for _ in as_completed(futures):
            pass


# =========================================================
# 第四步：全局指标统计
# =========================================================
def calculate_global_metrics(results_dir: str) -> Dict[str, Any]:
    """
    遍历 results_dir 下所有JSON文件，统计全局的精确率、召回率、F1 等
    预期每个 results/xxx.json 都包含:
    {
      "edge_analysis": {
        "static_total": int,
        "llm_total": int,
        "matched_edges": {
          "exact_matches": int,
          "partial_matches": int
        },
        "accuracy_metrics": {
          "precision": float,
          "recall": float,
          "f1_score": float
        }
      },
      "structure_validation": {
        "missing_blocks": [...],
        "extra_blocks": [...]
      }
    }
    """
    metrics = {
        "total_files": 0,
        "gt_edges": 0,
        "llm_edges": 0,
        "exact_matches": 0,
        "partial_matches": 0,
        "missing_blocks": set(),
        "extra_blocks": set(),
        "file_errors": []
    }

    for filename in os.listdir(results_dir):
        if not filename.endswith(".json"):
            continue

        filepath = os.path.join(results_dir, filename)
        try:
            with open(filepath, "r", encoding="utf-8") as f:
                data = json.load(f)

            edge_analysis = data["edge_analysis"]
            metrics["gt_edges"] += edge_analysis["static_total"]
            metrics["llm_edges"] += edge_analysis["llm_total"]
            metrics["exact_matches"] += edge_analysis["matched_edges"]["exact_matches"]
            metrics["partial_matches"] += edge_analysis["matched_edges"]["partial_matches"]

            structure_val = data["structure_validation"]
            # missing_blocks/extra_blocks 在不同文件可能是数字或字符串
            # 这里统一转为字符串集合
            missing_b = structure_val["missing_blocks"]
            extra_b = structure_val["extra_blocks"]
            metrics["missing_blocks"].update(map(str, missing_b))
            metrics["extra_blocks"].update(map(str, extra_b))

            metrics["total_files"] += 1

        except Exception as e:
            metrics["file_errors"].append(f"{filename}: {str(e)}")

    total_matched = metrics["exact_matches"] + metrics["partial_matches"]
    precision = (total_matched / metrics["llm_edges"]) if metrics["llm_edges"] > 0 else 0
    recall = (total_matched / metrics["gt_edges"]) if metrics["gt_edges"] > 0 else 0
    f1 = 0
    if (precision + recall) > 0:
        f1 = 2 * (precision * recall) / (precision + recall)

    return {
        "file_count": metrics["total_files"],
        "edge_metrics": {
            "static_total": metrics["gt_edges"],
            "llm_total": metrics["llm_edges"],
            "exact_matches": metrics["exact_matches"],
            "partial_matches": metrics["partial_matches"],
            "total_matched": total_matched,
            "precision": round(precision, 4),
            "recall": round(recall, 4),
            "f1_score": round(f1, 4)
        },
        "structure_metrics": {
            "missing_blocks_count": len(metrics["missing_blocks"]),
            "extra_blocks_count": len(metrics["extra_blocks"]),
            "missing_block_samples": list(metrics["missing_blocks"])[:5],
            "extra_block_samples": list(metrics["extra_blocks"])[:5]
        },
        "data_quality": {
            "error_files": len(metrics["file_errors"]),
            "error_samples": metrics["file_errors"][:3]  # 只示例输出前3个错误
        }
    }

def print_report(metrics: Dict[str, Any]):
    """
    控制台打印统计报告
    """
    print("CFG 评估报告".center(60, "="))
    print(f"分析文件总数: {metrics['file_count']}")

    em = metrics["edge_metrics"]
    print("\n[边匹配分析]")
    print(f"  静态分析总边数: {em['static_total']}")
    print(f"  LLM生成总边数: {em['llm_total']}")
    print(f"  精确匹配边数: {em['exact_matches']} (占静态 {em['exact_matches']/em['static_total']:.2%} )")
    print(f"  部分匹配边数: {em['partial_matches']} (占静态 {em['partial_matches']/em['static_total']:.2%} )")
    print(f"  总匹配边数: {em['total_matched']} (占静态 {em['total_matched']/em['static_total']:.2%} )")
    print(f"  Precision: {em['precision']:.4f}")
    print(f"  Recall: {em['recall']:.4f}")
    print(f"  F1 Score: {em['f1_score']:.4f}")

    sm = metrics["structure_metrics"]
    print("\n[结构验证]")
    print(f"  缺失块总数: {sm['missing_blocks_count']} (示例: {sm['missing_block_samples']})")
    print(f"  多余块总数: {sm['extra_blocks_count']} (示例: {sm['extra_block_samples']})")

    dq = metrics["data_quality"]
    if dq["error_files"] > 0:
        print("\n[数据质量问题]")
        print(f"  错误文件数: {dq['error_files']}")
        print("  示例错误:")
        for err in dq["error_samples"]:
            print(f"    - {err}")

    print("="*60)


# =========================================================
# 主流程示例：可按需拆分 0-50, 51-200, 201-9999 等
# =========================================================
def main():
    """
    主函数示例：
      1) 解析命令行参数决定处理范围
      2) 生成 LLM CFG
      3) 后处理
      4) 对比静态分析
      5) 汇总统计
    用法：python script.py start_idx end_idx
      如果不指定，则默认 range(0, 50) 做示例。
    """
    if len(sys.argv) == 3:
        start_idx = int(sys.argv[1])
        end_idx = int(sys.argv[2])
    else:
        # 默认只示例处理 0~50
        start_idx, end_idx = 0, 200

    print(f"==> 全流程处理，文件范围: [{start_idx}, {end_idx}]")

    # 1) 生成 CFG
    generate_llm_cfgs(start_idx, end_idx)

    # 2) 后处理
    postprocess_all_llm_cfgs()

    # 3) 对比静态分析
    compare_cfgs(start_idx, end_idx)

    # 4) 汇总统计
    final_metrics = calculate_global_metrics("results")
    print_report(final_metrics)


if __name__ == "__main__":
    main()

==> 全流程处理，文件范围: [0, 200]
[GenerateCFG] Processing ../dataset/python/0.py
[GenerateCFG] Processing ../dataset/python/3.py
[GenerateCFG] Processing ../dataset/python/5.py
[GenerateCFG] Processing ../dataset/python/8.py
[GenerateCFG] Processing ../dataset/python/11.py
[GenerateCFG] Processing ../dataset/python/9.py
[GenerateCFG] Processing ../dataset/python/14.py
[GenerateCFG] Processing ../dataset/python/15.py
[GenerateCFG] Processing ../dataset/python/17.py
[GenerateCFG] Processing ../dataset/python/18.py
[GenerateCFG] Processing ../dataset/python/20.py
[GenerateCFG] Processing ../dataset/python/21.py
[GenerateCFG] Processing ../dataset/python/23.py
[GenerateCFG] Processing ../dataset/python/25.py
[GenerateCFG] Processing ../dataset/python/26.py
[GenerateCFG] Processing ../dataset/python/28.py
[GenerateCFG] Processing ../dataset/python/29.py
[GenerateCFG] Processing ../dataset/python/31.py
[GenerateCFG] Processing ../dataset/python/33.py
[GenerateCFG] Processing ../dataset/python/30.py


ZeroDivisionError: division by zero

In [None]:
generate_llm_cfgs(3, 3)
