In [2]:
import pandas as pd
df = pd.read_csv('devign.csv')
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 14032 entries, 0 to 14031
Data columns (total 4 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   func          14032 non-null  object
 1   func_cleaned  14032 non-null  object
 2   project       14032 non-null  object
 3   target        14032 non-null  bool  
dtypes: bool(1), object(3)
memory usage: 342.7+ KB


In [3]:
df['target'].astype(int)
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 14032 entries, 0 to 14031
Data columns (total 4 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   func          14032 non-null  object
 1   func_cleaned  14032 non-null  object
 2   project       14032 non-null  object
 3   target        14032 non-null  bool  
dtypes: bool(1), object(3)
memory usage: 342.7+ KB


In [5]:
import subprocess
import tempfile
import os
import json

def semgrep_trace(code, rule_path="p/cwe-top-25"):

    with tempfile.NamedTemporaryFile(delete=False, suffix=".c", mode="w") as tmp:
        tmp.write(code)
        tmp_filename = tmp.name
    try:
        result = subprocess.run(
            ["semgrep", "--config", rule_path, tmp_filename, "--json"],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
            timeout=10,
        )
        findings = json.loads(result.stdout).get("results", [])
        if not findings:
            return "No_issue_found"
        traces = []
        for f in findings:
            rule = f.get("check_id", "unknown_rule")
            start = f.get("start", {}).get("line", "NA")
            traces.append(f"{rule}@L{start}")
        return "; ".join(traces)
    except Exception as e:
        return "Error_or_Timeout"
    finally:
        os.remove(tmp_filename)


In [6]:
df["event_trace"] = df["func"].apply(semgrep_trace)


In [7]:
df.to_csv("event_trace.csv",index=False,quoting=1)

In [8]:
DANGEROUS_CALLS = {
    # Buffer overflows, format strings, etc.
    "strcpy": "buffer_overflow",
    "strcat": "buffer_overflow",
    "sprintf": "format_string",
    "vsprintf": "format_string",
    "gets": "buffer_overflow",
    "scanf": "input_unsanitized",
    "sscanf": "input_unsanitized",
    "fscanf": "input_unsanitized",
    "memcpy": "buffer_overflow",
    "strncpy": "buffer_overflow",
    "strncat": "buffer_overflow",
    "memset": "buffer_operation",
    "memmove": "buffer_overflow",
    "strlen": "buffer_operation",
    "read": "buffer_overflow",
    "write": "buffer_overflow",
    # Dangerous system interaction
    "system": "command_injection",
    "popen": "command_injection",
    "exec": "command_injection",
    "execl": "command_injection",
    "execlp": "command_injection",
    "execle": "command_injection",
    "execv": "command_injection",
    "execvp": "command_injection",
    "execve": "command_injection",
    # Memory management
    "malloc": "memory_alloc",
    "calloc": "memory_alloc",
    "realloc": "memory_alloc",
    "free": "memory_free",
    "delete": "memory_free",
    "delete[]": "memory_free",
    "new": "memory_alloc",
    # Pointer arithmetic, dangerous casts
    "reinterpret_cast": "dangerous_cast",
    "const_cast": "dangerous_cast",
    # Network functions
    "recv": "network_input",
    "send": "network_output",
    "accept": "network_input",
    "connect": "network_output",
    "bind": "network_input",
    "listen": "network_input",
    "inet_addr": "network_input",
    # Misc
    "chown": "privilege_escalation",
    "chmod": "privilege_escalation",
    "fopen": "file_operation",
    "fclose": "file_operation",
    "fread": "file_operation",
    "fwrite": "file_operation",
    # Add more as needed
}


In [10]:
import json
import os

def extract_event_trace_from_ast(ast_path):
    # Load the AST JSON file
    with open(ast_path, "r") as f:
        ast_json = json.load(f)
    event_trace = []
    # --- Track pointer names and free usage for UAF/double-free ---
    freed_pointers = set()
    pointer_names = set()
    free_lines = {}
    for node in ast_json.get('nodes', []):
        label = node.get('label', '')
        code = node.get('code', '')
        line = node.get('lineNumber', '?')
        # --- Dangerous API calls ---
        if label == 'CALL':
            for func, vuln_type in DANGEROUS_CALLS.items():
                if func in code:
                    event_trace.append(f"{vuln_type}:{func}@L{line}")
            # Track free() calls for UAF/Double-Free
            if 'free' in code:
                # Try to extract pointer var name (basic heuristic)
                tokens = code.split('free')
                if len(tokens) > 1:
                    ptr_name = tokens[1].replace('(', '').replace(')', '').replace(';','').strip()
                    if ptr_name:
                        if ptr_name in freed_pointers:
                            event_trace.append(f"double_free:{ptr_name}@L{line}")
                        freed_pointers.add(ptr_name)
                        free_lines[ptr_name] = line
        # --- Pointer usage (look for use-after-free) ---
        # Very basic: look for use of pointers (dereference or member access) after free
        if label == 'IDENTIFIER' and '*' in code:
            pointer_names.add(code.replace('*','').strip())
        if label in ('CALL', 'FIELD_IDENTIFIER', 'IDENTIFIER'):
            for ptr in freed_pointers:
                if ptr and ptr in code and 'free' not in code:
                    event_trace.append(f"use_after_free:{ptr}@L{line}")
        # --- Checks ---
        if label == 'CONTROL_STRUCTURE':
            cond = code.lower()
            if 'if' in cond:
                if any(op in cond for op in ['<', '>', '!=', '==', '>=', '<=']):
                    event_trace.append(f"check@L{line}")
                if 'null' in cond or 'nullptr' in cond or '0' in cond:
                    event_trace.append(f"null_check@L{line}")
                if 'strlen' in cond or 'sizeof' in cond:
                    event_trace.append(f"length_check@L{line}")
    if not event_trace:
        return "No_issue_found"
    # Remove duplicates while keeping order
    seen = set()
    unique = []
    for ev in event_trace:
        if ev not in seen:
            seen.add(ev)
            unique.append(ev)
    return "; ".join(unique)


In [17]:
DANGEROUS_CALLS = {
    "strcpy": "buffer_overflow",
    "strcat": "buffer_overflow",
    "sprintf": "format_string",
    "vsprintf": "format_string",
    "gets": "buffer_overflow",
    "scanf": "input_unsanitized",
    "sscanf": "input_unsanitized",
    "fscanf": "input_unsanitized",
    "memcpy": "buffer_overflow",
    "strncpy": "buffer_overflow",
    "strncat": "buffer_overflow",
    "memset": "buffer_operation",
    "memmove": "buffer_overflow",
    "strlen": "buffer_operation",
    "read": "buffer_overflow",
    "write": "buffer_overflow",
    "system": "command_injection",
    "popen": "command_injection",
    "exec": "command_injection",
    "execl": "command_injection",
    "execlp": "command_injection",
    "execle": "command_injection",
    "execv": "command_injection",
    "execvp": "command_injection",
    "execve": "command_injection",
    "malloc": "memory_alloc",
    "calloc": "memory_alloc",
    "realloc": "memory_alloc",
    "free": "memory_free",
    "delete": "memory_free",
    "delete[]": "memory_free",
    "new": "memory_alloc",
    "reinterpret_cast": "dangerous_cast",
    "const_cast": "dangerous_cast",
    "recv": "network_input",
    "send": "network_output",
    "accept": "network_input",
    "connect": "network_output",
    "bind": "network_input",
    "listen": "network_input",
    "inet_addr": "network_input",
    "chown": "privilege_escalation",
    "chmod": "privilege_escalation",
    "fopen": "file_operation",
    "fclose": "file_operation",
    "fread": "file_operation",
    "fwrite": "file_operation",
}


In [18]:
import json
import re

def extract_event_trace_from_edges(ast_json):
    event_trace = []
    calls = []
    frees = set()
    pointer_uses = []
    freed_ptrs = {}
    call_lines = {}
    # Compile regex for call node extraction
    call_re = re.compile(r'call\(\"([^\"]+)\"\)')
    identifier_re = re.compile(r'identifier\(\"([^\"]+)\"\)')

    # 1. Find all call nodes and process them
    for edge in ast_json.get("edges", []):
        # Check if src or dst is a function call
        for direction in ["src", "dst"]:
            node = edge[direction]
            # Extract call info
            call_match = call_re.search(node)
            if call_match:
                call_content = call_match.group(1)
                # Dangerous function call detection
                for func, vuln in DANGEROUS_CALLS.items():
                    # Match whole word function (avoid substrings)
                    if re.search(r"\b" + re.escape(func) + r"\b", call_content):
                        event_trace.append(f"{vuln}:{func}")
                # Memory management (track frees)
                if re.match(r"free\s*\(([^)]+)\)", call_content):
                    freed_var = re.match(r"free\s*\(([^)]+)\)", call_content).group(1).strip()
                    if freed_var in frees:
                        event_trace.append(f"double_free:{freed_var}")
                    frees.add(freed_var)
                    freed_ptrs[freed_var] = True
                # Detect basic pointer use (very basic)
                ptr_use = re.findall(r"\*([a-zA-Z_][a-zA-Z0-9_]*)", call_content)
                pointer_uses.extend(ptr_use)
                # Detect if-NULL or similar
                if "if" in call_content and ("null" in call_content.lower() or "nullptr" in call_content.lower() or "!=" in call_content):
                    event_trace.append("null_check")
            # Pointer identifier detection (use-after-free)
            id_match = identifier_re.search(node)
            if id_match:
                var_name = id_match.group(1)
                if var_name in freed_ptrs:
                    event_trace.append(f"use_after_free:{var_name}")

    # Remove duplicates, keep order
    seen = set()
    unique_events = []
    for evt in event_trace:
        if evt not in seen:
            seen.add(evt)
            unique_events.append(evt)
    return "; ".join(unique_events) if unique_events else "No_issue_found"


In [19]:
import pandas as pd

df = pd.read_csv("devign.csv")
df = df.iloc[:-1]


def get_ast_path(idx):
    return f"/mnt/c/Users/user01/PycharmProjects/PythonProject2/joern_outputs/sample_{idx}/ast.json"

df["ast_path"] = df.index.map(get_ast_path)

df["event_trace"] = df["ast_path"].apply(lambda p: extract_event_trace_from_edges(json.load(open(p))))


df.to_csv("devign_with_event_trace.csv", index=False)


In [25]:
import json

def extract_event_trace(graph_json):
    nodes = {str(n['id']): n for n in graph_json['nodes']}
    edges = graph_json['edges']

    # Find METHOD node (the function itself)
    method_nodes = [n for n in nodes.values() if n['label'].strip('"') == 'METHOD']
    if not method_nodes:
        return []
    method = method_nodes[0]

    # Get AST root for the function body
    ast_body_edges = [e for e in edges if e['src'] == method['id'] and e['label'] == '"AST"']
    body_id = None
    for e in ast_body_edges:
        dst = nodes.get(str(e['dst']))
        if dst and dst['label'].strip('"') == 'BLOCK':
            body_id = dst['id']
            break
    if not body_id:
        return []

    # Recursively traverse AST for this body
    trace = []
    def walk_ast(node_id):
        node = nodes[str(node_id)]
        label = node['label'].strip('"')
        code = node.get('CODE', '').strip('"')
        if label == "LOCAL":
            t_full = node['TYPE_FULL_NAME'].strip('"')
            name = node['NAME'].strip('"')
            trace.append(f"declare {t_full} {name}")
        elif label == "CALL":
            mfn = node.get('METHOD_FULL_NAME', '').strip('"')
            code_short = code.split('(')[0]
            # Assignment
            if mfn == "<operator>.assignment":
                trace.append(f"assign: {code}")
            # Function call
            elif not mfn.startswith("<operator>"):
                trace.append(f"call: {mfn} ({code})")
            # Conditional or logical operators
            elif mfn in ["<operator>.conditional", "<operator>.logicalNot", "<operator>.notEquals"]:
                trace.append(f"condition: {code}")
            # Field/pointer access
            elif mfn == "<operator>.indirectFieldAccess":
                trace.append(f"field_access: {code}")
        elif label == "IDENTIFIER":
            tname = node.get("TYPE_FULL_NAME", "")
            type_clean = tname.strip('"')
            code_clean = code
            trace.append(f"use {code_clean} as {type_clean}")
        elif label == "LITERAL":
            trace.append(f"literal: {code}")
        elif label == "FIELD_IDENTIFIER":
            field_name = node['CANONICAL_NAME'].strip('"')
            trace.append(f"field: {field_name}")

        # Add more node type cases as needed.

        # Recurse on AST children
        for e in edges:
            if e['src'] == node_id and e['label'] == '"AST"':
                walk_ast(e['dst'])

    walk_ast(body_id)
    return trace


In [36]:
import json

def extract_event_trace_from_joern(graph_json):
    """
    Extracts a detailed, stepwise event trace from a Joern graph.json file
    for a C/C++ function. Designed for conflict-driven learning (CDL),
    contrastive neuro-symbolic research, or explainable code analysis.

    Args:
        graph_json (dict): Loaded Joern graph.json object.

    Returns:
        List[str]: Stepwise event trace, formatted for symbolic learning.
    """
    nodes = {str(n['id']): n for n in graph_json['nodes']}
    edges = graph_json['edges']

    # Find the main METHOD node (function)
    method_nodes = [n for n in nodes.values() if n['label'].strip('"') == 'METHOD']
    if not method_nodes:
        raise ValueError("No METHOD node found in the graph.")
    method = method_nodes[0]

    # Find the AST root for the function body (BLOCK node)
    ast_body_edges = [e for e in edges if e['src'] == method['id'] and e['label'] == '"AST"']
    body_id = None
    for e in ast_body_edges:
        dst = nodes.get(e['dst'])
        if dst and dst['label'].strip('"') == 'BLOCK':
            body_id = dst['id']
            break
    if not body_id:
        raise ValueError("No BLOCK node (function body) found in the graph.")

    # Helper to clean field values
    def clean(field):
        return field.strip('"') if isinstance(field, str) else str(field)

    # Helper to get AST children
    def get_ast_children(node_id):
        return [e['dst'] for e in edges if e['src'] == node_id and e['label'] == '"AST"']

    trace = []

    def walk_ast(node_id, indent=0):
        node = nodes[str(node_id)]
        label = node['label'].strip('"')
        code = clean(node.get('CODE', ''))

        # --- Variable declarations ---
        if label == "LOCAL":
            trace.append("    " * indent + f"declare {clean(node['TYPE_FULL_NAME'])} {clean(node['NAME'])}")

        # --- Assignment, calls, field access, operators ---
        elif label == "CALL":
            mfn = clean(node.get('METHOD_FULL_NAME', ''))
            if mfn == "<operator>.assignment":
                trace.append("    " * indent + f"assign: {code}")
            elif mfn == "<operator>.indirectFieldAccess":
                trace.append("    " * indent + f"field_access: {code}")
            elif mfn in ["<operator>.conditional", "<operator>.logicalNot", "<operator>.notEquals", "<operator>.equals", "<operator>.and", "<operator>.or"]:
                trace.append("    " * indent + f"condition: {code}")
            elif mfn.startswith("<operator>"):
                trace.append("    " * indent + f"operator: {mfn} ({code})")
            else:  # Regular function call
                trace.append("    " * indent + f"call: {mfn} ({code})")

        # --- Variable usage ---
        elif label == "IDENTIFIER":
            tname = clean(node.get("TYPE_FULL_NAME", ""))
            trace.append("    " * indent + f"use {code} as {tname}")

        # --- Constants/literals ---
        elif label == "LITERAL":
            trace.append("    " * indent + f"literal: {code}")

        # --- Field identifiers ---
        elif label == "FIELD_IDENTIFIER":
            field_name = clean(node.get('CANONICAL_NAME', ''))
            trace.append("    " * indent + f"field: {field_name}")

        # --- Control structures: if, else, loops, etc ---
        elif label == "CONTROL_STRUCTURE":
            cs_code = code.lower()
            children = get_ast_children(node_id)
            if cs_code.startswith('if'):
                # Try to extract the condition node and then/else blocks
                if len(children) >= 2:
                    cond_id = children[0]
                    then_id = children[1]
                    cond_code = nodes[str(cond_id)].get('CODE', '').strip()
                    trace.append("    " * indent + f"if: {cond_code}")
                    walk_ast(then_id, indent + 1)
                    # Check for 'else' block
                    if len(children) > 2:
                        else_id = children[2]
                        trace.append("    " * indent + "else")
                        walk_ast(else_id, indent + 1)
                else:
                    trace.append("    " * indent + f"if: {code}")
            elif cs_code.startswith('else if'):
                trace.append("    " * indent + f"else if: {code}")
            elif cs_code.startswith('else'):
                trace.append("    " * indent + "else")
                if children:
                    walk_ast(children[0], indent + 1)
            elif cs_code.startswith('for'):
                trace.append("    " * indent + f"for: {code}")
                if children:
                    walk_ast(children[-1], indent + 1)
            elif cs_code.startswith('while'):
                trace.append("    " * indent + f"while: {code}")
                if children:
                    walk_ast(children[-1], indent + 1)
            elif cs_code.startswith('switch'):
                trace.append("    " * indent + f"switch: {code}")
                if children:
                    for cid in children[1:]:
                        walk_ast(cid, indent + 1)
            elif cs_code.startswith('return'):
                trace.append("    " * indent + f"return: {code}")
            else:
                trace.append("    " * indent + f"control: {code}")

        # --- Return statements ---
        elif label == "RETURN":
            trace.append("    " * indent + f"return: {code}")

        # --- Recursively process children (except CONTROL_STRUCTURE, which we handled above) ---
        if label not in ["CONTROL_STRUCTURE"]:
            for child_id in get_ast_children(node_id):
                walk_ast(child_id, indent)

    walk_ast(body_id)
    return trace


In [38]:
import pandas as pd

df = pd.read_csv("devign.csv")
df = df.iloc[:-1]


def get_graph_path(idx):
    return f"/mnt/c/Users/user01/PycharmProjects/PythonProject2/joern_outputs/sample_{idx}/json/graph.json"

df["graph_path"] = df.index.map(get_graph_path)

df["event_trace"] = df["graph_path"].apply(lambda p: extract_event_trace_from_joern(json.load(open(p))))


df.to_csv("devign_with_event_trace_graph.csv", index=False,quoting=1)

In [39]:
df['event_trace'].iloc[14000]

['declare uint32_t sctlr',
 'assign: sctlr = A32_BANKED_CURRENT_REG_GET(env, sctlr)',
 'use sctlr as uint32_t',
 'call: A32_BANKED_CURRENT_REG_GET (A32_BANKED_CURRENT_REG_GET(env, sctlr))',
 'use env as CPUARMState*',
 'use sctlr as uint32_t',
 'if: "address < 0x02000000"',
 '    operator: <operator>.assignmentPlus (address += env->cp15.c13_fcse)',
 '    use address as target_ulong',
 '    operator: <operator>.fieldAccess (env->cp15.c13_fcse)',
 '    field_access: env->cp15',
 '    use env as CPUARMState*',
 '    field: cp15',
 '    field: c13_fcse',
 'if: "(sctlr & SCTLR_M) == 0"',
 '    assign: *phys_ptr = address',
 '    operator: <operator>.indirection (*phys_ptr)',
 '    use phys_ptr as hwaddr*',
 '    use address as target_ulong',
 '    assign: *prot = PAGE_READ | PAGE_WRITE | PAGE_EXEC',
 '    operator: <operator>.indirection (*prot)',
 '    use prot as int*',
 '    condition: PAGE_READ | PAGE_WRITE | PAGE_EXEC',
 '    condition: PAGE_READ | PAGE_WRITE',
 '    use <unknown> PAGE

Fine-tune

In [41]:
import pandas as pd


def build_input(row):
    trace_str = "; ".join(row['event_trace']) if isinstance(row['event_trace'], list) else str(row['event_trace'])
    return row['func'] + " [SEP] " + trace_str

df['input'] = df.apply(build_input, axis=1)


In [42]:
from sklearn.model_selection import train_test_split

train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['target'])


train_df.to_csv("cdl_train.csv", index=False)
test_df.to_csv("cdl_test.csv", index=False)



In [43]:
import pandas as pd
from sklearn.model_selection import train_test_split


df = df.rename(columns={"target": "label"})
# Split
train_df, test_df = train_test_split(df, test_size=0.2, random_state=42, stratify=df['label'])

# Save splits
train_df.to_csv("cdl_train.csv", index=False)
test_df.to_csv("cdl_test.csv", index=False)




In [52]:
train_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 11224 entries, 0 to 11223
Data columns (total 7 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   func          11224 non-null  object
 1   func_cleaned  11224 non-null  object
 2   project       11224 non-null  object
 3   label         11224 non-null  int64 
 4   graph_path    11224 non-null  object
 5   event_trace   11224 non-null  object
 6   input         11224 non-null  object
dtypes: int64(1), object(6)
memory usage: 613.9+ KB


In [51]:


test_df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 11224 entries, 0 to 11223
Data columns (total 7 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   func          11224 non-null  object
 1   func_cleaned  11224 non-null  object
 2   project       11224 non-null  object
 3   label         11224 non-null  int64 
 4   graph_path    11224 non-null  object
 5   event_trace   11224 non-null  object
 6   input         11224 non-null  object
dtypes: int64(1), object(6)
memory usage: 613.9+ KB
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 2807 entries, 0 to 2806
Data columns (total 7 columns):
 #   Column        Non-Null Count  Dtype 
---  ------        --------------  ----- 
 0   func          2807 non-null   object
 1   func_cleaned  2807 non-null   object
 2   project       2807 non-null   object
 3   label         2807 non-null   int64 
 4   graph_path    2807 non-null   object
 5   event_trace   2807 non-null   object
 6   input         2

In [50]:
import pandas as pd
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, confusion_matrix, matthews_corrcoef,
    cohen_kappa_score, mean_squared_error,
    mean_absolute_error, roc_auc_score
)
import torch
import numpy as np
# Load splits
train_df = pd.read_csv("cdl_train.csv")
test_df = pd.read_csv("cdl_test.csv")
# Before creating the HuggingFace Dataset
train_df['label'] = train_df['label'].astype(int)
test_df['label'] = test_df['label'].astype(int)

# Use input and label columns
hf_train = Dataset.from_pandas(train_df[['input', 'label']])
hf_test = Dataset.from_pandas(test_df[['input', 'label']])

# Use CodeBERT (recommended for code classification)
model_name = "microsoft/codebert-base"
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForSequenceClassification.from_pretrained(model_name, num_labels=2)

def tokenize_function(example):
    return tokenizer(example["input"], truncation=True, padding="max_length", max_length=256)

tokenized_train = hf_train.map(tokenize_function, batched=True)
tokenized_test = hf_test.map(tokenize_function, batched=True)

tokenized_train.set_format("torch", columns=["input_ids", "attention_mask", "label"])
tokenized_test.set_format("torch", columns=["input_ids", "attention_mask", "label"])


def compute_metrics(eval_pred):
    logits, labels = eval_pred
    preds = np.argmax(logits, axis=-1)
    tn, fp, fn, tp = confusion_matrix(labels, preds).ravel()
    probs = torch.softmax(torch.tensor(logits), dim=1).numpy()[:, 1]
    try:
        auc = roc_auc_score(labels, probs)
    except ValueError:
        auc = float('nan')
    return {
        "accuracy":    accuracy_score(labels, preds),
        "precision":   precision_score(labels, preds, zero_division=0),
        "recall":      recall_score(labels, preds, zero_division=0),
        "specificity": tn / (tn + fp) if (tn + fp) > 0 else 0,
        "fpr":         fp / (fp + tn) if (fp + tn) > 0 else 0,
        "f1":          f1_score(labels, preds, zero_division=0),
        "mcc":         matthews_corrcoef(labels, preds),
        "kappa":       cohen_kappa_score(labels, preds),
        "mse":         mean_squared_error(labels, preds),
        "mae":         mean_absolute_error(labels, preds),
        "auc":         auc
    }


training_args = TrainingArguments(
    output_dir="./cdl_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    learning_rate=2e-5,
    num_train_epochs=3,
    per_device_train_batch_size=4,
    per_device_eval_batch_size=4,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",  # <-- must match key in your compute_metrics dict!
    logging_steps=10,
)

from transformers import Trainer

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_train,
    eval_dataset=tokenized_test,
    tokenizer=tokenizer,
    compute_metrics=compute_metrics,   # <-- this is the key part!
)
trainer.train()


Some weights of RobertaForSequenceClassification were not initialized from the model checkpoint at microsoft/codebert-base and are newly initialized: ['classifier.dense.bias', 'classifier.dense.weight', 'classifier.out_proj.bias', 'classifier.out_proj.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 11224/11224 [00:08<00:00, 1313.96 examples/s]
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2807/2807 [00:01<00:00, 1450.38 examples/s]
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
  trainer = Trainer(


Epoch,Training Loss,Validation Loss,Accuracy,Precision,Recall,Specificity,Fpr,F1,Mcc,Kappa,Mse,Mae,Auc
1,0.7216,0.68471,0.572497,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.427503,0.427503,0.482588
2,0.6954,0.682603,0.572497,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.427503,0.427503,0.487289
3,0.6885,0.682873,0.572497,0.0,0.0,1.0,0.0,0.0,0.0,0.0,0.427503,0.427503,0.48498


TrainOutput(global_step=8418, training_loss=0.6882378701558728, metrics={'train_runtime': 1124.8673, 'train_samples_per_second': 29.934, 'train_steps_per_second': 7.484, 'total_flos': 4429737728040960.0, 'train_loss': 0.6882378701558728, 'epoch': 3.0})

In [5]:
import pandas as pd
from datasets import Dataset
from transformers import AutoTokenizer, AutoModelForSequenceClassification, Trainer, TrainingArguments
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score,
    f1_score, confusion_matrix, matthews_corrcoef,
    cohen_kappa_score, mean_squared_error,
    mean_absolute_error, roc_auc_score
)
import torch
import numpy as np
# Load splits
train_df = pd.read_csv("cdl_train.csv")
test_df = pd.read_csv("cdl_test.csv")
# Before creating the HuggingFace Dataset
train_df['label'] = train_df['label'].astype(int)
test_df['label'] = test_df['label'].astype(int)
train_df['label'] = train_df['label'].astype(str)
test_df['label'] = test_df['label'].astype(str)

# Optionally reduce dataset for quick test
# train_df = train_df.sample(500)
# test_df = test_df.sample(100)

train_dataset = Dataset.from_pandas(train_df[['input', 'label']])
test_dataset = Dataset.from_pandas(test_df[['input', 'label']])

In [7]:
def format_example(example):
    # Format as prompt and response, for each example
    return {
        "prompt": f"<s>Code:\n{example['input']}\nLabel:",
        "completion": f" {example['label']}</s>"
    }

train_dataset = train_dataset.map(format_example)
test_dataset = test_dataset.map(format_example)


Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 11224/11224 [00:00<00:00, 11772.21 examples/s]
Map: 100%|â–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆâ–ˆ| 2807/2807 [00:00<00:00, 15612.53 examples/s]


Conflict

In [40]:
def identify_conflicts(event_trace, predicted_label):
    """
    Given an event trace (list of symbolic actions) and the model's predicted label,
    return (conflict: bool, reasons: list of explanations).
    """
    reasons = []
    conflict = False

    # Unsafe pointer assignment with no prior check
    for i, event in enumerate(event_trace):
        if event.startswith("assign: *"):
            # Look back for any if/condition/guard in the last few events
            guarded = any(event_trace[j].startswith('if:') or 'check' in event_trace[j]
                          for j in range(max(0, i-5), i))
            if not guarded and predicted_label == 0:
                conflict = True
                reasons.append(f"Unsafe pointer assignment '{event}' not preceded by a safety check, but model predicted not vulnerable.")

    # Use of <unknown> types/constants
    for event in event_trace:
        if 'use <unknown>' in event and predicted_label == 0:
            conflict = True
            reasons.append(f"Use of unknown/untrusted type in '{event}' but model predicted not vulnerable.")

    # Direct field access without checks
    for event in event_trace:
        if "field_access" in event and predicted_label == 0:
            conflict = True
            reasons.append(f"Direct field access '{event}' could be unsafe, but model predicted not vulnerable.")

    # Return after unsafe op
    for i, event in enumerate(event_trace):
        if event.startswith("return:") and predicted_label == 0:
            # Check for unsafe ops in the previous few events
            unsafe_ops = any("assign: *" in e for e in event_trace[max(0, i-5):i])
            if unsafe_ops:
                conflict = True
                reasons.append(f"Returning after unsafe op before '{event}', but model predicted not vulnerable.")

    # Unsafe operation inside an else branch
    for i, event in enumerate(event_trace):
        if event == "else" and predicted_label == 0:
            possible_unsafe = any("assign: *" in e for e in event_trace[i:i+5])
            if possible_unsafe:
                conflict = True
                reasons.append("Possible unsafe write in else-branch, but model predicted not vulnerable.")

    return conflict, reasons




In [None]:
def conflict_wrapper(row):
    conflict, reasons = identify_conflicts(row['event_trace'], row['predicted_label'])
    return pd.Series({
        "conflict": conflict,
        "conflict_reasons": reasons
    })

df[['conflict', 'conflict_reasons']] = df.apply(conflict_wrapper, axis=1)
