In [6]:
# %% [markdown] ---------------------------------------------------------------------------------
# # 🚦 JLR Hackathon – Static Scan Demo
# Upload (or point me at) a C/C++ file; I’ll flag potential vulnerabilities,
# attach nearest CWE examples, and rank them by severity.

# %% [code] 0 Imports & paths -------------------------------------------------------------------
import pathlib, json, re, time, warnings
import pandas as pd, numpy as np, torch, onnxruntime as ort, faiss
from transformers import AutoTokenizer, AutoModel, AutoModelForSequenceClassification
import safetensors.torch as st
import torch.nn as nn
from src.eval_demo import identify_cwe# ←❶ needed for CodeBERTRegressor
from IPython.display import display

warnings.filterwarnings("ignore")

# artefacts
ONNX_PATH  = pathlib.Path("../models/quantised/codebert_int8.onnx")
INDEX_PATH = pathlib.Path("../data/embeddings/embeddings.faiss")
OUT_PATH   = pathlib.Path("../data/results/result.json")
META_PATH  = pathlib.Path("../data/embeddings/embeddings.jsonl")
REG_DIR    = pathlib.Path("../models/cvss_regressor")          # regressor folder
THRESHOLD  = 0                                            # p_vuln cut-off
DEVICE = "cpu"
# SRC_PATH = pathlib.Path("../data/test/codebert_safe_test_file.c")
SRC_PATH = pathlib.Path("../data/test/vuln_7_CWE-129.c")

tok    = AutoTokenizer.from_pretrained("microsoft/codebert-base")

# %% [code] 1 Helper – slice C/C++ file into functions ---------------------------------------
FUNC_RE = re.compile(r"^[A-Za-z_][\w\s\*]*\s+([A-Za-z_]\w*)\s*\([^;]*\)\s*\{", re.M)

def extract_functions(text: str):
    stack, start, name = [], None, None
    for m in re.finditer(r"[\{\}]", text):
        if m.group() == "{":
            if not stack:
                header = text.rfind("\n", 0, m.start())
                mfunc  = FUNC_RE.search(text[header:m.start()])
                if mfunc:
                    start, name = header, mfunc.group(1)
            stack.append("{")
        else:
            if stack: stack.pop()
            if not stack and start is not None:
                yield name or "anon", text[start:m.end()]
                start = name = None

# %% [code] 2 Models – INT8 classifier + optional regressor -------------------------------
print("Loading INT-8 classifier …")
sess = ort.InferenceSession(str(ONNX_PATH), providers=["CPUExecutionProvider"])

# fp32 classifier for embedding → FAISS
hf = AutoModelForSequenceClassification.from_pretrained(
        "../models/quantised/fp32", num_labels=2
     ).to(DEVICE).eval()                    # ←❶ new: used for CLS embedding

# optional CVSS regressor --------------------------------------------------
class CodeBERTRegressor(nn.Module):
    def __init__(self, base="microsoft/codebert-base"):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(base)
        for p in self.encoder.parameters():
            p.requires_grad = False
        self.mlp = nn.Sequential(
            nn.Linear(768, 256),
            nn.ReLU(),
            nn.Linear(256, 1)
        )
    def forward(self, input_ids=None, attention_mask=None):
        cls = self.encoder(input_ids=input_ids,
                           attention_mask=attention_mask).last_hidden_state[:, 0, :]
        return {"logits": self.mlp(cls).squeeze(-1)}

regressor = None
if (REG_DIR / "model.safetensors").exists():
    regressor = CodeBERTRegressor()
    state = st.load_file(REG_DIR / "model.safetensors", device="cpu")
    regressor.load_state_dict(state); regressor.eval()

# %% [code] 3 FAISS index --------------------------------------------------------------
index = faiss.read_index(str(INDEX_PATH))
with META_PATH.open() as f:
    meta = [json.loads(l) for l in f]


code_text = ""
with open(SRC_PATH, 'r', encoding='utf-8') as f:
    code_text = f.read()# change as needed
results = identify_cwe(code_text)

code_text = SRC_PATH.read_text(encoding="utf-8", errors="ignore")
functions = list(extract_functions(code_text))
# print(f"Found {len(functions)} functions in {SRC_PATH.name}")
# %% [code] 5 Run scan ---------------------------------------------------
records = []
for func_name, snippet in functions:
    ids_np = tok(snippet, return_tensors="np", max_length=512,
                 truncation=True, padding="max_length")
    feed = {
        "input_ids":      ids_np["input_ids"],
        "attention_mask": ids_np["attention_mask"],
        "token_type_ids": np.zeros_like(ids_np["input_ids"]),
    }
    logits = sess.run(None, feed)[0]                      # (1,2)
    p_vuln = float(torch.softmax(torch.tensor(logits), -1)[0, 1])

    if p_vuln < THRESHOLD:
        continue

    # severity ----------------------------------------------------------
    cvss_pred = None
    if regressor:
        ids_pt = tok(snippet, return_tensors="pt", max_length=512,
                     truncation=True, padding="max_length")
        with torch.no_grad():
            cvss_pred = float(regressor(**ids_pt)["logits"][0])

    # nearest CWE via FAISS --------------------------------------------
    with torch.no_grad():
        cls_vec = hf.base_model(
            input_ids=torch.tensor(ids_np["input_ids"]),
            attention_mask=torch.tensor(ids_np["attention_mask"])
        ).last_hidden_state[:, 0, :].cpu().numpy().astype("float32")
    faiss.normalize_L2(cls_vec)
    D, I = index.search(cls_vec, k=3)
    examples = [meta[i]["cwe_id"] for i in I[0]]

    records.append({
        "Function":   func_name,
        "p_vuln":     round(p_vuln, 3),
        "CVSS_pred":  round(cvss_pred, 1) if cvss_pred else None,
        "Nearest_CWE": ", ".join(examples),
        "Snippet":    snippet[:120] + "…"
    })

# %% [code] 6 Display ----------------------------------------------------
df = pd.DataFrame([results])

# Display results
if df.empty or not df.at[0, 'is_vulnerable']:
    print("🎉  No vulnerabilities detected in the provided code.")
else:
    df = df.sort_values("is_vulnerable", ascending=False)
    display(df.style.bar(subset=["is_vulnerable"], color= "#ffa07a"))

# Save report
STAMP = time.strftime("%Y%m%d_%H%M%S")
report_path = pathlib.Path(f"scan_report_{STAMP}.json")
df.to_json(OUT_PATH, orient="records", indent=2)
print("✅  Report written →", report_path)


Loading INT-8 classifier …
🎉  No vulnerabilities detected in the provided code.
✅  Report written → scan_report_20250612_201222.json
