In [1]:
import json
from pathlib import Path
from typing import List
from tqdm import tqdm
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [2]:
QUESTION_FILES = {
    "DoS":   Path("DoS_tf_qa/dos_questions.json"),
    "Fuzzy": Path("Fuzzy_tf_qa/fuzzy_questions.json"),
    "Gear":  Path("Gear_tf_qa/gear_questions.json"),
    "RPM":   Path("RPM_tf_qa/rpm_questions.json"),
}
SELECTED_DATASETS = ["DoS", "Fuzzy", "Gear", "RPM"]


In [None]:
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
MODEL_TAG = MODEL_ID.split("/")[-1].replace(".", "_").replace("-", "_")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    device_map="auto",
    pad_token_id=tokenizer.pad_token_id,
)

In [4]:
# MODEL_ID = "Qwen/Qwen3-4B-Thinking-2507"  # OR "Qwen/Qwen3-4B-Instruct-2507"
# DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# MODEL_TAG = MODEL_ID.split("/")[-1].replace(".", "_").replace("-", "_")

# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
# if tokenizer.pad_token is None:
#     tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "left"

# model = AutoModelForCausalLM.from_pretrained(
#     MODEL_ID,
#     torch_dtype=DTYPE,
#     device_map="auto",
#     pad_token_id=tokenizer.pad_token_id,
#     trust_remote_code=True,
# )

In [None]:
# MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
# DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# MODEL_TAG = MODEL_ID.split("/")[-1].replace(".", "_").replace("-", "_")

# tokenizer = AutoTokenizer.from_pretrained(
#     MODEL_ID,
#     use_fast=True,
#     trust_remote_code=True,
# )
# if tokenizer.pad_token is None:
#     tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "left"

# model = AutoModelForCausalLM.from_pretrained(
#     MODEL_ID,
#     torch_dtype=DTYPE,
#     device_map="auto",
#     pad_token_id=tokenizer.pad_token_id,
#     trust_remote_code=True,
# )


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
def build_prompt_text(context: str, question: str, context_type: str = "full") -> str:
    extra_instruction = ""
    if context_type == "hidden_last_flag":
        extra_instruction = "Use the preceding frames to infer the hidden flag of the final frame.\n"
    elif context_type == "hidden_middle_flag":
        extra_instruction = "Use the frames before and after the hidden middle frame to infer its flag.\n"

    system_prompt = (
        "You are a CAN bus intrusion-detection analyst. Study timestamp ordering, ID frequency, "
        "payload stability, byte ranges, and gaps between frames. Use those characteristics plus the "
        "statement to reach a True/False conclusion. Respond with True or False only."
    )
    user_prompt = (
        "Below is a CAN bus time window. Review the sequence carefully, note anomalies or missing IDs, "
        "and reason about the claim.\n"
        f"{context}\n\n{extra_instruction}"
        f"Statement: {question}\nAnswer True or False only."
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    return tokenizer.apply_chat_template(messages, add_generation_prompt=True, tokenize=False)


def normalize_tf_answer(text: str) -> str:
    if not text:
        return "False"
    first = text.strip().split()[0].replace(".", "").lower()
    if first.startswith("t"):
        return "True"
    if first.startswith("f"):
        return "False"
    return "False"


def query_llm(context: str, question: str, context_type: str = "full",
              max_new_tokens: int = 16) -> str:
    prompt_text = build_prompt_text(context, question, context_type=context_type)
    inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=None,
            top_p=None,
            top_k=None,
            pad_token_id=tokenizer.eos_token_id,
        )
    completion = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True,
    ).strip()
    return normalize_tf_answer(completion)


def load_questions(path: Path) -> List[dict]:
    if not path.exists():
        return []
    if path.suffix == ".json":
        with path.open("r", encoding="utf-8") as f:
            return json.load(f)  

    records = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if line:
                records.append(json.loads(line))
    return records


In [7]:
for ds_name in SELECTED_DATASETS:
    q_path = Path(QUESTION_FILES[ds_name])
    if not q_path.exists():
        print(f"[WARN] Questions for {ds_name} not found at {q_path}, skip.")
        continue

    questions = load_questions(q_path)
    print(f"[INFO] {ds_name}: loaded {len(questions)} questions.")

    out_dir = q_path.parent
    ans_path = out_dir / f"{ds_name.lower()}_answers_{MODEL_TAG}.jsonl"

    ans_path.write_text("", encoding="utf-8")  # truncate
    with ans_path.open("a", encoding="utf-8") as f:
        for rec in tqdm(questions, desc=f"{ds_name} answering"):
            qa_id = rec["qa_id"]
            dataset = rec["metadata"]["dataset"]
            ctx_type = rec.get("context_type", "full")
            ctx = rec["context"]
            q_text = rec["question"]
            gt = rec["ground_truth"]

            pred = query_llm(ctx, q_text, context_type=ctx_type)
            answer_rec = {
                "qa_id": qa_id,
                "dataset": dataset,
                "model": MODEL_ID,
                "llm_answer": pred,
                "ground_truth": gt,
                "is_correct": pred == gt,
                "answer_valid": bool(pred),
            }
            f.write(json.dumps(answer_rec, ensure_ascii=False) + "\n")

    print(f"[INFO] {ds_name}: answers saved to {ans_path}")


[INFO] DoS: loaded 5496 questions.


DoS answering: 100%|██████████| 5496/5496 [1:59:37<00:00,  1.31s/it]


[INFO] DoS: answers saved to DoS_tf_qa/dos_answers_DeepSeek_R1_Distill_Llama_8B.jsonl
[INFO] Fuzzy: loaded 5757 questions.


Fuzzy answering: 100%|██████████| 5757/5757 [2:05:01<00:00,  1.30s/it]  


[INFO] Fuzzy: answers saved to Fuzzy_tf_qa/fuzzy_answers_DeepSeek_R1_Distill_Llama_8B.jsonl
[INFO] Gear: loaded 6663 questions.


Gear answering: 100%|██████████| 6663/6663 [2:24:58<00:00,  1.31s/it]  


[INFO] Gear: answers saved to Gear_tf_qa/gear_answers_DeepSeek_R1_Distill_Llama_8B.jsonl
[INFO] RPM: loaded 6930 questions.


RPM answering: 100%|██████████| 6930/6930 [2:30:37<00:00,  1.30s/it]  

[INFO] RPM: answers saved to RPM_tf_qa/rpm_answers_DeepSeek_R1_Distill_Llama_8B.jsonl



