In [1]:
llm_cg_dir = "llm_cg"
import os
import json
for file in os.listdir(llm_cg_dir):
    if file.endswith(".json"):
        llm_cg = json.load(open(os.path.join(llm_cg_dir, file)))
        # 过滤掉重复的值,只保留一个
        llm_cg = {k: list(set(v)) for k, v in llm_cg.items()}
        json.dump(llm_cg, open(os.path.join(llm_cg_dir, file), "w"), indent=4, ensure_ascii=False)
        

In [4]:
import json
from difflib import SequenceMatcher


def normalize_name(call):
    """
    Normalize a function/method name for comparison.
    - Removes redundant module prefixes if possible.
    - Converts the name to a comparable base representation.
    """
    return call.split('.')[-1]  # Use only the last part (e.g., "backend.data.graph.Node" -> "Node")


def compare_call_graphs(generated, ground_truth, similarity_threshold=0.8):
    """
    Compares the generated call graph with the ground truth and calculates metrics.

    Parameters:
        generated (dict): The call graph generated by the analyzer.
        ground_truth (dict): The static ground truth call graph.
        similarity_threshold (float): Threshold for considering two function names as similar.

    Returns:
        dict: A detailed comparison report including metrics and mismatches.
    """
    report = {
        "missing_keys": [],
        "extra_keys": [],
        "mismatched_calls": [],
        "metrics": {
            "precision": 0.0,
            "recall": 0.0,
            "f1_score": 0.0
        }
    }

    # Helper function for similarity matching
    def is_similar(call1, call2):
        call1 = call1.split(".")[-1]
        call2 = call2.split(".")[-1]
        return SequenceMatcher(None, normalize_name(call1), normalize_name(call2)).ratio() >= similarity_threshold

    # 过滤掉builtin函数
    generated = {k: [c for c in v if 'builtin' not in c] for k, v in generated.items()}
    ground_truth = {k: [c for c in v if 'builtin' not in c] for k, v in ground_truth.items()}

    # Convert keys to sets
    generated_keys = set(generated.keys())
    ground_truth_keys = set(ground_truth.keys())

    # Detect missing and extra keys
    report["missing_keys"] = list(ground_truth_keys - generated_keys)
    report["extra_keys"] = list(generated_keys - ground_truth_keys)

    # Check mismatched calls for common keys
    total_matches = 0
    total_ground_truth_calls = 0
    total_generated_calls = 0

    common_keys = ground_truth_keys.intersection(generated_keys)
    for key in common_keys:
        ground_truth_calls = ground_truth[key]
        generated_calls = generated[key]
        total_ground_truth_calls += len(ground_truth_calls)
        total_generated_calls += len(generated_calls)

        matched_calls = set()
        unmatched_generated = []
        unmatched_ground_truth = []

        for gt_call in ground_truth_calls:
            found = False
            for gen_call in generated_calls:
                if is_similar(gt_call, gen_call):
                    matched_calls.add(gt_call)
                    found = True
                    break
            if not found:
                unmatched_ground_truth.append(gt_call)

        for gen_call in generated_calls:
            if not any(is_similar(gen_call, gt_call) for gt_call in ground_truth_calls):
                unmatched_generated.append(gen_call)

        total_matches += len(matched_calls)
        if unmatched_generated or unmatched_ground_truth:
            report["mismatched_calls"].append({
                "key": key,
                "missing_calls": unmatched_ground_truth,
                "extra_calls": unmatched_generated
            })

    # Calculate precision, recall, and F1-score
    if total_generated_calls > 0:
        precision = total_matches / total_generated_calls
    else:
        if total_ground_truth_calls > 0:
            precision = 0.0
        else:
            precision = 1.0

    if total_ground_truth_calls > 0:
        recall = total_matches / total_ground_truth_calls
    else:
        recall = 1.0

    if precision + recall > 0:
        f1_score = 2 * (precision * recall) / (precision + recall)
    else:
        f1_score = 0.0

    report["metrics"]["precision"] = round(precision, 4)
    report["metrics"]["recall"] = round(recall, 4)
    report["metrics"]["f1_score"] = round(f1_score, 4)
    report["metrics"]["total_matches"] = total_matches
    report["metrics"]["total_ground_truth_calls"] = total_ground_truth_calls
    report["metrics"]["total_generated_calls"] = total_generated_calls

    return report

# Example usage
import os
if __name__ == "__main__":
    total_matches = 0
    total_ground_truth_calls = 0
    total_generated_calls = 0
    
    for i in range(200):
        if not os.path.exists(f"llm_cg/{i}.json"):
            continue
        if not os.path.exists(f"../../dataset/cangjie_cg/{i}.json"):
            continue
            
        generated = json.load(open(f"llm_cg/{i}.json"))
        ground_truth = json.load(open(f"../../dataset/cangjie_cg/{i}.json"))
        
        differences = compare_call_graphs(generated, ground_truth)
        metrics = differences["metrics"]
        
        total_matches += metrics["total_matches"]
        total_ground_truth_calls += metrics["total_ground_truth_calls"]
        total_generated_calls += metrics["total_generated_calls"]
            
        
        # print(f"file {i} metrics:")
        # print(json.dumps(metrics, indent=4))
        # 如果precision低于0.5,删除该文件
        # if metrics["precision"] < 0.5:
        #     print(f"file {i} metrics:")
        #     print(json.dumps(metrics, indent=4))
            # os.remove(f"llm_cg/{i}.json")
            # print(f"已删除文件 llm_cg/{i}.json (precision={metrics['precision']})")
    
    # 使用总体数据计算指标
    precision = round(total_matches / total_generated_calls, 4) if total_generated_calls > 0 else 0
    recall = round(total_matches / total_ground_truth_calls, 4) if total_ground_truth_calls > 0 else 0
    f1_score = round(2 * (precision * recall) / (precision + recall), 4) if precision + recall > 0 else 0
    
    print("\n平均指标:")
    print(f"Precision: {precision}")
    print(f"Recall: {recall}")  
    print(f"F1 Score: {f1_score}")



平均指标:
Precision: 0.8003
Recall: 0.6488
F1 Score: 0.7166


In [6]:
llm_cg_dir = "llm_cg"
source_dir = "../../dataset/cangjie"

import os
import json

datas = []

for i in range(200):
    llm_cg_path = os.path.join(llm_cg_dir, f"{i}.json")
    # static_cg_path = os.path.join(static_cg_dir, f"{i}.json")
    source_path = os.path.join(source_dir, f"{i}.cj")

    if not os.path.exists(llm_cg_path) or not os.path.exists(source_path):
        continue

    with open(source_path, "r", encoding="utf-8") as f:
        source_code = f.read()

    with open(llm_cg_path, "r", encoding="utf-8") as f:
        llm_cg = json.load(f)  

    static_cg = {}

    data = {    
        "source_code": source_code,
        "llm_cg": llm_cg,
        "static_cg": static_cg
    }
    datas.append(data)

with open("cg_task.jsonl", "w", encoding="utf-8") as f:
    for data in datas:
        f.write(json.dumps(data, ensure_ascii=False) + "\n")



In [8]:
from llm import get_llm_answers
import json
import concurrent.futures
import os

def process_line(line):
    return{
        "sum_call_from_static": 0,
        "correct_call_from_llm": 0,
        "missing_call_from_llm": 0,
        "extra_call_from_llm": 0
    }

jsonl_file = "cg_task.jsonl"
results = []

# 创建输出目录
os.makedirs("results_local", exist_ok=True)

with open(jsonl_file, "r", encoding="utf-8") as f:
    lines = f.readlines()
    # 解析每一行获取文件名
    files = [f"{i}.cj" for i in range(200)]
    
with concurrent.futures.ThreadPoolExecutor(max_workers=4) as executor:
    future_to_line = {executor.submit(process_line, line): i for i, line in enumerate(lines)}
    
    for future in concurrent.futures.as_completed(future_to_line):
        line_index = future_to_line[future]
        try:
            result = future.result()
            # 添加文件名到结果中
            result["file_name"] = files[line_index]
            results.append(result)
            # 保存单个结果
            with open(f"results_local/result_{line_index}.json", "w") as f:
                json.dump(result, f, indent=4)
        except Exception as e:
            print(f'处理第 {line_index} 行时发生错误: {str(e)}')

# 按照file_name排序
results.sort(key=lambda x: x["file_name"])

# 保存所有结果
with open("results_local/all_results.json", "w") as f:
    json.dump(results, f, indent=4)

# 计算总体指标
total_correct = sum(r["correct_call_from_llm"] for r in results)
total_static = sum(r["sum_call_from_static"] for r in results)
total_llm = sum(r["correct_call_from_llm"] + r["extra_call_from_llm"] for r in results)

precision = round(total_correct / total_llm, 4) if total_llm > 0 else 0
recall = round(total_correct / total_static, 4) if total_static > 0 else 0
f1_score = round(2 * (precision * recall) / (precision + recall), 4) if precision + recall > 0 else 0

print("\n总体指标:")
print(f"Precision: {precision}")
print(f"Recall: {recall}")
print(f"F1 Score: {f1_score}")


总体指标:
Precision: 0
Recall: 0
F1 Score: 0
