In [None]:
from google.colab import userdata
userdata.get('HF_TOKEN')

In [None]:
import json
import os
from pathlib import Path
from typing import List, Dict
from dataclasses import dataclass
from collections import Counter
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import requests
from bs4 import BeautifulSoup
from time import sleep
from tqdm import tqdm
import torch
from sentence_transformers import SentenceTransformer, util
import pandas as pd

RESULTS_DIR = Path("./")
OUTPUT_DIR = Path("./reprocessed_results")
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
@dataclass
class PairResult:
    pair_id: int
    commit_id: str
    ground_truth_cwe: str
    vuln_detected_cwes: List[str]
    fixed_detected_cwes: List[str]

json_file = RESULTS_DIR / "YOUR_RESULTS_FILE.json"

with open(json_file, "r", encoding="utf-8") as f:
    data = json.load(f)

print(f"Number of pairs: {len(data['pairs'])}")


In [None]:
cwe_set = set()

for pair in data["pairs"]:
    gt = pair.get("ground_truth_cwe")
    if gt:
        cwe_set.add(gt)
    vuln_cwes = pair.get("vulnerable", {}).get("detected_cwes", [])
    fixed_cwes = pair.get("fixed", {}).get("detected_cwes", [])
    cwe_set.update(vuln_cwes)
    cwe_set.update(fixed_cwes)

cwe_list = sorted(cwe_set)
print(f"{len(cwe_list)} unique CWEs found.")
print(cwe_list[:20])

In [None]:
BASE_URL = "http://cwe.mitre.org/data/definitions/{}.html"

def fetch_cwe_description(cwe_id):
    url = BASE_URL.format(cwe_id.split("-")[1])
    try:
        r = requests.get(url)
        if r.status_code != 200:
            return None
        soup = BeautifulSoup(r.text, "html.parser")
        title_tag = soup.find("span", {"id": "TitleText"})
        name = title_tag.text.strip() if title_tag else ""
        desc_tag = soup.find("div", {"id": "Description"})
        description = desc_tag.text.strip() if desc_tag else ""
        return {"cwe_id": cwe_id, "name": name, "description": description}
    except Exception as e:
        print(f"Error CWE-{cwe_id}: {e}")
        return None

cwe_data = []
for cwe_id in tqdm(cwe_list):
    data_cwe = fetch_cwe_description(cwe_id)
    if data_cwe:
        cwe_data.append(data_cwe)
    sleep(0.5)

with open(OUTPUT_DIR / "cwe_descriptions.json", "w", encoding="utf-8") as f:
    json.dump(cwe_data, f, ensure_ascii=False, indent=2)

print(f"{len(cwe_data)} CWE descriptions fetched and saved.")

In [None]:
model = SentenceTransformer('all-MiniLM-L6-v2')

with open(OUTPUT_DIR / "cwe_descriptions.json", "r", encoding="utf-8") as f:
    cwe_descriptions = json.load(f)

cwe_embeddings = {}
for item in cwe_descriptions:
    desc = item.get("description", "")
    if desc:
        cwe_embeddings[item["cwe_id"]] = model.encode(desc, convert_to_tensor=True)
    else:
        cwe_embeddings[item["cwe_id"]] = torch.zeros(model.get_sentence_embedding_dimension())

In [None]:
def calculate_metrics(data):
    tp = tn = fp = fn = 0
    pairs_correct = 0
    pairs_reversed = 0
    for pair in data["pairs"]:
        gt = pair["ground_truth_cwe"]
        vuln_detects = pair["vulnerable"]["detected_cwes"]
        fixed_detects = pair["fixed"]["detected_cwes"]

        vuln_tp = gt in vuln_detects
        vuln_fn = not vuln_tp
        fixed_fp = gt in fixed_detects
        fixed_tn = not fixed_fp

        tp += vuln_tp
        fn += vuln_fn
        tn += fixed_tn
        fp += fixed_fp

        if vuln_tp and fixed_tn:
            pairs_correct += 1
        elif vuln_fn and fixed_fp:
            pairs_reversed += 1

    total = tp + tn + fp + fn
    accuracy = (tp + tn) / total if total > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    vps = pairs_correct - pairs_reversed

    return {
        "TP": tp, "TN": tn, "FP": fp, "FN": fn, "Total": total,
        "P-C (Pairs Correct)": pairs_correct,
        "P-R (Pairs Reversed)": pairs_reversed,
        "VPS": vps,
        "Accuracy": round(accuracy,4),
        "Precision": round(precision,4),
        "Recall": round(recall,4),
        "F1-Score": round(f1_score,4)
    }

metrics = calculate_metrics(data)
print(metrics)

In [None]:
def plot_confusion_matrix(metrics):
    cm = np.array([[metrics["TP"], metrics["FN"]],
                   [metrics["FP"], metrics["TN"]]])
    plt.figure(figsize=(8,6))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=["Pred Vulnerable","Pred Safe"],
                yticklabels=["Actual Vulnerable","Actual Safe"])
    plt.xlabel("Predicted")
    plt.ylabel("Actual")
    plt.title("Confusion Matrix")
    plt.show()

plot_confusion_matrix(metrics)

In [None]:
def compute_semantic_fn_tp(pair, cwe_embeddings):
    gt_cwe = pair["ground_truth_cwe"]
    vuln_detected = pair["vulnerable"]["detected_cwes"]
    fixed_detected = pair["fixed"]["detected_cwes"]

    is_vuln_tp = gt_cwe in vuln_detected
    is_vuln_fn = not is_vuln_tp
    is_fixed_fp = gt_cwe in fixed_detected
    is_fixed_tn = not is_fixed_fp

    vuln_sem_score = 0.0
    if is_vuln_fn and vuln_detected:
        gt_emb = cwe_embeddings.get(gt_cwe)
        scores = [util.cos_sim(gt_emb, cwe_embeddings.get(det)).item()
                  for det in vuln_detected if cwe_embeddings.get(det) is not None]
        if scores:
            vuln_sem_score = max(scores)

    fixed_sem_score = 0.0
    if is_fixed_fp and fixed_detected:
        gt_emb = cwe_embeddings.get(gt_cwe)
        scores = [util.cos_sim(gt_emb, cwe_embeddings.get(det)).item()
                  for det in fixed_detected if cwe_embeddings.get(det) is not None]
        if scores:
            fixed_sem_score = max(scores)

    return {
        "pair_id": pair["pair_id"],
        "is_vuln_tp": is_vuln_tp,
        "is_vuln_fn": is_vuln_fn,
        "vuln_semantic_score": vuln_sem_score,
        "is_fixed_tn": is_fixed_tn,
        "is_fixed_fp": is_fixed_fp,
        "fixed_semantic_score": fixed_sem_score
    }

semantic_results = [compute_semantic_fn_tp(pair, cwe_embeddings) for pair in data["pairs"]]

for r in semantic_results[:5]:
    print(r)

In [None]:
semantic_recall = np.mean([r["vuln_semantic_score"] for r in semantic_results])
print(f"Semantic recall approx.: {semantic_recall:.3f}")

In [None]:
vuln_scores = [r["vuln_semantic_score"] for r in semantic_results if r["is_vuln_fn"]]
fixed_scores = [r["fixed_semantic_score"] for r in semantic_results if r["is_fixed_fp"]]

plt.figure(figsize=(12,5))

plt.subplot(1,2,1)
plt.hist(vuln_scores, bins=20, color='tomato', alpha=0.7)
plt.title("Vulnerable FN — Semantic Similarity Scores")
plt.xlabel("Max Semantic Score with GT CWE")
plt.ylabel("Count")

plt.subplot(1,2,2)
plt.hist(fixed_scores, bins=20, color='skyblue', alpha=0.7)
plt.title("Fixed FP — Semantic Similarity Scores")
plt.xlabel("Max Semantic Score with GT CWE")
plt.ylabel("Count")

plt.tight_layout()
plt.show()

print(f"Vulnerable FN count: {len(vuln_scores)}, mean score: {np.mean(vuln_scores):.3f}")
print(f"Fixed FP count: {len(fixed_scores)}, mean score: {np.mean(fixed_scores):.3f}")

In [None]:
def calculate_prediction_metrics(data):

    tp = tn = fp = fn = 0
    pairs_correct = 0
    pairs_reversed = 0

    for pair in data["pairs"]:
        vuln_pred = pair["vulnerable"]["prediction"]
        vuln_correct = pair["vulnerable"]["correct"]

        fixed_pred = pair["fixed"]["prediction"]
        fixed_correct = pair["fixed"]["correct"]

        if vuln_pred and vuln_correct:
            tp += 1
            vuln_fn = False
        else:
            vuln_fn = True
            fn += 1

        if not fixed_pred and fixed_correct:
            tn += 1
            fixed_fp = False
        else:
            fixed_fp = True
            fp += 1

        if not vuln_fn and not fixed_fp:
            pairs_correct += 1
        elif vuln_fn and fixed_fp:
            pairs_reversed += 1

    total = tp + tn + fp + fn
    accuracy = (tp + tn) / total if total > 0 else 0
    precision = tp / (tp + fp) if (tp + fp) > 0 else 0
    recall = tp / (tp + fn) if (tp + fn) > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if (precision + recall) > 0 else 0
    vps = pairs_correct - pairs_reversed

    return {
        "TP": tp, "TN": tn, "FP": fp, "FN": fn, "Total": total,
        "P-C (Pairs Correct)": pairs_correct,
        "P-R (Pairs Reversed)": pairs_reversed,
        "VPS": vps,
        "Accuracy": round(accuracy,4),
        "Precision": round(precision,4),
        "Recall": round(recall,4),
        "F1-Score": round(f1_score,4)
    }

prediction_metrics = calculate_prediction_metrics(data)
print("=== Metrics based on predictions ===")
for k,v in prediction_metrics.items():
    print(f"{k}: {v}")

In [None]:
vuln_fn_results = [r for r in semantic_results if r.get("is_vuln_fn", False)]
vuln_fn_count = len(vuln_fn_results)
vuln_fn_scores = []
for r in vuln_fn_results:
    scores = r.get("vulnerable_semantic_scores", [])
    if scores:
        vuln_fn_scores.append(max(s["semantic_score"] for s in scores))
    else:
        vuln_fn_scores.append(0)
vuln_fn_mean_score = np.mean(vuln_fn_scores) if vuln_fn_scores else 0

fixed_fp_results = [r for r in semantic_results if r.get("is_fixed_fn", False) or r.get("is_fixed_fp", False)]
fixed_fp_count = len(fixed_fp_results)
fixed_fp_scores = []
for r in fixed_fp_results:
    scores = r.get("fixed_semantic_scores", [])
    if scores:
        fixed_fp_scores.append(max(s["semantic_score"] for s in scores))
    else:
        fixed_fp_scores.append(0)
fixed_fp_mean_score = np.mean(fixed_fp_scores) if fixed_fp_scores else 0

In [None]:
def plot_confusion(metrics, save_path=None):
    cm = np.array([[metrics['TP'], metrics['FN']],
                   [metrics['FP'], metrics['TN']]])

    plt.figure(figsize=(6,5))
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
                xticklabels=['Predicted Vulnerable', 'Predicted Safe'],
                yticklabels=['Actually Vulnerable', 'Actually Safe'])
    plt.title("Confusion Matrix", fontsize=14)
    plt.ylabel("Actual")
    plt.xlabel("Predicted")

    if save_path:
        plt.savefig(save_path, dpi=300, bbox_inches='tight')
    plt.show()

plot_confusion(prediction_metrics, save_path="confusion_matrix.png")

In [None]:
metrics_summary = pd.DataFrame({
    'Metric': ['Precision', 'Recall', 'F1-Score'],
    'Value': [metrics['Precision'], metrics['Recall'], metrics['F1-Score']]
})

plt.figure(figsize=(6,4))
sns.barplot(x='Metric', y='Value', data=metrics_summary, palette='coolwarm')
plt.ylim(0,1)
plt.title("Key Performance Metrics")
for i, v in enumerate(metrics_summary['Value']):
    plt.text(i, v + 0.02, f"{v:.2%}", ha='center', fontweight='bold')
plt.tight_layout()
plt.savefig("performance_metrics.png", dpi=300)
plt.show()

In [None]:
metrics = ["PC", "PR", "VPS", "FPR", "F1-score", "ACC"]

# VulAgent baseline (paper)
vulagent = [17.7, 8.74, 8.96, 19.95, 41.59, 54.73]

# VulTrial baseline
vultrial = [18.6, 11.4, 7.13, 52.6, 56.1, 53.4]

# VulPrune
vulprune = [
    5.0,   # PC (%)
    0.0,   # PR (%)
    5.0,   # VPS (%)
    90.0,   # FPR (%)
    68.97, # F1-score (%)
    55.0   # Accuracy (%)
]

df = pd.DataFrame({
    "Metric": metrics,
    "VulAgent": vulagent,
    "VulPrune": vulprune,
    "VulTrial": vultrial
})

arrows = []
for i, row in df.iterrows():
    if row["Metric"] in ["FPR", "PR"]:
        arrows.append('↑' if row["VulPrune"] < row["VulAgent"] else '↓')
    else:
        arrows.append('↑' if row["VulPrune"] > row["VulAgent"] else '↓')

colors = ['green' if a == '↑' else 'red' for a in arrows]

fig, ax = plt.subplots(figsize=(10, 4))
ax.axis('tight')
ax.axis('off')

table_data = []
for i, row in df.iterrows():
    table_data.append([
        row["Metric"],
        f"{row['VulAgent']:.2f}",
        f"{row['VulPrune']:.2f} {arrows[i]}",
        f"{row['VulTrial']:.2f}"
    ])

table = ax.table(
    cellText=table_data,
    colLabels=["Metric", "VulAgent", "VulPrune", "VulTrial"],
    cellLoc='center',
    loc='center'
)

for i, color in enumerate(colors):
    table[i+1, 2].set_text_props(color=color, weight='bold')

plt.title("Metrics Comparison with Arrows (↑ better, ↓ worse)", fontsize=14)
plt.show()

In [None]:
pairs = data["pairs"]
gt_list = [p["ground_truth_cwe"] for p in pairs]

dist = pd.DataFrame(Counter(gt_list).most_common(), columns=["vulnerability", "count"])
display(dist)

plt.figure(figsize=(10, 5))

colors = plt.cm.tab20(np.linspace(0, 1, len(dist)))

bars = plt.bar(dist["vulnerability"], dist["count"], color=colors)

for bar in bars:
    h = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, h, str(int(h)), ha="center", va="bottom")

plt.xlabel("vulnerability")
plt.ylabel("count")
plt.title("Distribution of unique ground-truth vulnerabilities")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

In [None]:
pairs = data["pairs"]
detected = []

for p in pairs:
    vuln = p.get("vulnerable", {}).get("detected_cwes", [])
    fixed = p.get("fixed", {}).get("detected_cwes", [])
    detected.extend(vuln)
    detected.extend(fixed)

dist_det = pd.DataFrame(Counter(detected).most_common(), columns=["vulnerability", "count"])
display(dist_det)

plt.figure(figsize=(10, 6))

colors = plt.cm.tab20(np.linspace(0, 1, len(dist_det)))
bars = plt.bar(dist_det["vulnerability"], dist_det["count"], color=colors)

for bar in bars:
    h = bar.get_height()
    plt.text(bar.get_x() + bar.get_width()/2, h, str(int(h)), ha="center", va="bottom")

plt.xlabel("vulnerability")
plt.ylabel("count")
plt.title("Distribution of detected CWE")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()