In [1]:
import json
from pathlib import Path
from typing import List, Dict
from tqdm import tqdm

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

In [None]:
MCQ_QUESTION_FILES = {
    "DoS":   Path("DoS_mcq_qa/dos_mcq_questions.json"),
    "Fuzzy": Path("Fuzzy_mcq_qa/fuzzy_mcq_questions.json"),
    "Gear":  Path("Gear_mcq_qa/gear_mcq_questions.json"),
    "RPM":   Path("RPM_mcq_qa/rpm_mcq_questions.json"),
}

SELECTED_DATASETS = ["DoS", "Fuzzy", "Gear", "RPM"]


In [3]:
MODEL_ID = "meta-llama/Llama-3.1-8B-Instruct"
DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
MODEL_TAG = MODEL_ID.split("/")[-1].replace(".", "_").replace("-", "_")

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True)
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token
tokenizer.padding_side = "left"

model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    torch_dtype=DTYPE,
    device_map="auto",
    pad_token_id=tokenizer.pad_token_id,
)

In [4]:
# # Example: Qwen 3-4B Thinking
# MODEL_ID = "Qwen/Qwen3-4B-Thinking-2507"  # or "meta-llama/Meta-Llama-3.1-8B-Instruct"
# DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# MODEL_TAG = MODEL_ID.split("/")[-1].replace(".", "_").replace("-", "_")

# tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, use_fast=True, trust_remote_code=True)
# if tokenizer.pad_token is None:
#     tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "left"

# model = AutoModelForCausalLM.from_pretrained(
#     MODEL_ID,
#     torch_dtype=DTYPE,
#     device_map="auto",
#     pad_token_id=tokenizer.pad_token_id,
#     trust_remote_code=True,
# )


In [5]:
# MODEL_ID = "deepseek-ai/DeepSeek-R1-Distill-Llama-8B"
# DTYPE = torch.bfloat16 if torch.cuda.is_available() else torch.float32
# MODEL_TAG = MODEL_ID.split("/")[-1].replace(".", "_").replace("-", "_")

# tokenizer = AutoTokenizer.from_pretrained(
#     MODEL_ID,
#     use_fast=True,
#     trust_remote_code=True,
# )
# if tokenizer.pad_token is None:
#     tokenizer.pad_token = tokenizer.eos_token
# tokenizer.padding_side = "left"

# model = AutoModelForCausalLM.from_pretrained(
#     MODEL_ID,
#     torch_dtype=DTYPE,
#     device_map="auto",
#     pad_token_id=tokenizer.pad_token_id,
#     trust_remote_code=True,
# )


Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [6]:
def build_mcq_prompt_text(context: str,
                          question: str,
                          options: Dict[str, str]) -> str:
    """
    Build a chat-style prompt for MCQ answering.
    Model is expected to reply with a single letter: A/B/C/D.
    """
    # Ensure options are ordered A-D
    ordered_letters = ["A", "B", "C", "D"]
    lines = []
    for letter in ordered_letters:
        if letter in options:
            lines.append(f"{letter}. {options[letter]}")
    options_block = "\n".join(lines)

    system_prompt = (
        "You are a CAN bus intrusion-detection analyst. "
        "Given a CAN bus time window and a multiple-choice question about attack behavior, "
        "ID frequency, timing, payload, or flag patterns, carefully reason and choose the single best answer. "
        "Respond with only the option letter: A, B, C, or D."
    )

    user_prompt = (
        "Below is a CAN bus time window. Review the sequence carefully, note anomalies or missing IDs, "
        "and then answer the multiple-choice question.\n\n"
        f"{context}\n\n"
        f"Question: {question}\n\n"
        f"Options:\n{options_block}\n\n"
        "Answer with only a single letter: A, B, C, or D."
    )

    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_prompt},
    ]
    return tokenizer.apply_chat_template(
        messages,
        add_generation_prompt=True,
        tokenize=False,
    )


def normalize_mcq_answer(text: str) -> str:
    """
    Normalize raw model text to one of {'A','B','C','D'}.
    """
    if not text:
        return "A"
    s = text.strip()
    # First non-whitespace character
    first = s[0].upper()
    if first in {"A", "B", "C", "D"}:
        return first
    # Fallback: scan the string for A-D
    for ch in s.upper():
        if ch in {"A", "B", "C", "D"}:
            return ch
    # Default fallback
    return "A"


def query_mcq_llm(context: str,
                  question: str,
                  options: Dict[str, str],
                  max_new_tokens: int = 8) -> str:
    prompt_text = build_mcq_prompt_text(context, question, options)
    inputs = tokenizer(prompt_text, return_tensors="pt").to(model.device)
    with torch.no_grad():
        outputs = model.generate(
            **inputs,
            max_new_tokens=max_new_tokens,
            do_sample=False,
            temperature=None,
            top_p=None,
            top_k=None,
            pad_token_id=tokenizer.eos_token_id,
        )
    completion = tokenizer.decode(
        outputs[0][inputs["input_ids"].shape[-1]:],
        skip_special_tokens=True,
    ).strip()
    return normalize_mcq_answer(completion)


def load_questions(path: Path) -> List[dict]:
    """
    Load MCQ questions from .json (list) or .jsonl.
    """
    if not path.exists():
        return []
    if path.suffix == ".json":
        with path.open("r", encoding="utf-8") as f:
            return json.load(f)
    # jsonl fallback
    records = []
    with path.open("r", encoding="utf-8") as f:
        for line in f:
            line = line.strip()
            if not line:
                continue
            records.append(json.loads(line))
    return records


In [7]:
for ds_name in SELECTED_DATASETS:
    q_path = MCQ_QUESTION_FILES[ds_name]
    if not q_path.exists():
        print(f"[WARN] MCQ questions for {ds_name} not found at {q_path}, skip.")
        continue

    questions = load_questions(q_path)
    print(f"[INFO] {ds_name}: loaded {len(questions)} MCQ questions.")

    out_dir = q_path.parent
    ans_path = out_dir / f"{ds_name.lower()}_mcq_answers_{MODEL_TAG}.jsonl"

    # Truncate old answer file
    ans_path.write_text("", encoding="utf-8")

    with ans_path.open("a", encoding="utf-8") as f:
        for rec in tqdm(questions, desc=f"{ds_name} MCQ answering"):
            qa_id = rec.get("qa_id")
            metadata = rec.get("metadata", {})
            dataset = metadata.get("dataset", ds_name)
            context = rec["context"]
            question = rec["question"]
            options = rec["options"]
            gt = rec["answer"]  # correct option letter, e.g. "B"

            pred = query_mcq_llm(context, question, options)
            answer_rec = {
                "qa_id": qa_id,
                "dataset": dataset,
                "mcq_type": rec.get("mcq_type"),
                "model": MODEL_ID,
                "llm_answer": pred,
                "ground_truth": gt,
                "is_correct": pred == gt,
                "answer_valid": pred in {"A", "B", "C", "D"},
            }
            f.write(json.dumps(answer_rec, ensure_ascii=False) + "\n")

    print(f"[INFO] {ds_name}: MCQ answers saved to {ans_path}")


[INFO] DoS: loaded 5496 MCQ questions.


DoS MCQ answering: 100%|██████████| 5496/5496 [1:36:07<00:00,  1.05s/it]


[INFO] DoS: MCQ answers saved to DoS_mcq_qa/dos_mcq_answers_DeepSeek_R1_Distill_Llama_8B.jsonl
[INFO] Fuzzy: loaded 5757 MCQ questions.


Fuzzy MCQ answering: 100%|██████████| 5757/5757 [1:40:45<00:00,  1.05s/it]


[INFO] Fuzzy: MCQ answers saved to Fuzzy_mcq_qa/fuzzy_mcq_answers_DeepSeek_R1_Distill_Llama_8B.jsonl
[INFO] Gear: loaded 6663 MCQ questions.


Gear MCQ answering: 100%|██████████| 6663/6663 [1:56:42<00:00,  1.05s/it]


[INFO] Gear: MCQ answers saved to Gear_mcq_qa/gear_mcq_answers_DeepSeek_R1_Distill_Llama_8B.jsonl
[INFO] RPM: loaded 6930 MCQ questions.


RPM MCQ answering: 100%|██████████| 6930/6930 [2:27:44<00:00,  1.28s/it]     

[INFO] RPM: MCQ answers saved to RPM_mcq_qa/rpm_mcq_answers_DeepSeek_R1_Distill_Llama_8B.jsonl



