In [1]:
import json
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set

import numpy as np
import pandas as pd
from tqdm import tqdm

In [2]:
DATASET_PATHS = {
    "DoS":   "../../Dataset/DoS_dataset_clean.csv",
    "Fuzzy": "../../Dataset/Fuzzy_dataset_clean.csv",
    "Gear":  "../../Dataset/gear_dataset_clean.csv",
    "RPM":   "../../Dataset/RPM_dataset_clean.csv",
}

WINDOW_SIZE = 200
WINDOW_STRIDE = 200
SAMPLE_RATIO = 0.10
MCQS_PER_WINDOW = 3
GLOBAL_SEED = 50

BYTE_COLUMNS = [f"Byte{i}" for i in range(1, 9)]
ATTACK_LABELS = ["DoS", "Fuzzy", "Gear", "RPM"]
SUPPRESSION_THRESHOLD = 1e-3
FLOODING_THRESHOLD = 5e-4


SELECTED_DATASETS = ["DoS", "Fuzzy", "Gear", "RPM"]

In [None]:
def _normalize_flag_series(series: pd.Series) -> pd.Series:
    mapped = series.map({"R": 0, "T": 1})
    numeric = pd.to_numeric(series, errors="coerce")
    combined = mapped.fillna(numeric).fillna(0).astype(int)
    return combined


def load_datasets(paths: Dict[str, str]):
    datasets: Dict[str, pd.DataFrame] = {}
    profiles: Dict[str, dict] = {}

    for name, path in paths.items():
        csv_path = Path(path)
        if not csv_path.exists():
            print(f"[WARN] Dataset {name} not found at {csv_path}, skipping.")
            continue

        df = pd.read_csv(csv_path)
        if "Flag" in df.columns:
            df["Flag"] = _normalize_flag_series(df["Flag"])
        else:
            df["Flag"] = 0

        id_counts = df["ID"].value_counts()
        expected_ids = set(int(x) for x in id_counts.head(5).index.tolist())
        critical_ids = set(int(x) for x in id_counts.head(3).index.tolist())

        datasets[name] = df
        profiles[name] = {
            "expected_ids": expected_ids,
            "critical_ids": critical_ids,
            "attack_label": name,
        }

        print(f"[INFO] Loaded {name}: {len(df)} rows, "
              f"{len(expected_ids)} expected IDs, {len(critical_ids)} critical IDs.")

    return datasets, profiles


datasets, profiles = load_datasets({k: v for k, v in DATASET_PATHS.items() if k in SELECTED_DATASETS})
rng_global = np.random.default_rng(GLOBAL_SEED)


# Cell 3: helpers (window, stats)
def iter_window_starts(num_rows: int) -> List[int]:
    if num_rows < WINDOW_SIZE:
        return []
    return list(range(0, num_rows - WINDOW_SIZE + 1, WINDOW_STRIDE))


def sample_window_indices(starts: List[int], rng: np.random.Generator) -> List[int]:
    if not starts:
        return []
    sample_size = max(1, int(len(starts) * SAMPLE_RATIO))
    sample_size = min(sample_size, len(starts))
    return sorted(rng.choice(starts, size=sample_size, replace=False))


def format_window(df: pd.DataFrame) -> str:
    rows = []
    for _, row in df.iterrows():
        byte_vals = [int(row[col]) for col in BYTE_COLUMNS]
        rows.append(
            f"Timestamp={row['Timestamp']:.6f} | "
            f"ID={int(row['ID'])} | DLC={int(row['DLC'])} | "
            f"bytes={byte_vals} | Flag={int(row['Flag'])} |"
        )
    return "\n".join(rows)


def compute_basic_stats(df: pd.DataFrame) -> dict:
    stats = {}
    total_frames = len(df)
    if total_frames == 0:
        return stats

    id_counts = df["ID"].value_counts()
    stats["id_counts"] = id_counts
    stats["dominant_id"] = int(id_counts.index[0])
    stats["dominant_share"] = id_counts.iloc[0] / total_frames

    # timing
    if total_frames > 1:
        diffs = df["Timestamp"].to_numpy()[1:] - df["Timestamp"].to_numpy()[:-1]
        stats["diffs"] = diffs
        stats["gap_max"] = float(diffs.max())
        stats["gap_min"] = float(diffs.min())
        stats["gap_mean"] = float(diffs.mean())
        stats["gap_std"] = float(diffs.std())
        stats["window_duration"] = float(df["Timestamp"].iloc[-1] - df["Timestamp"].iloc[0])
    else:
        stats["diffs"] = np.array([])
        stats["gap_max"] = 0.0
        stats["gap_min"] = 0.0
        stats["gap_mean"] = 0.0
        stats["gap_std"] = 0.0
        stats["window_duration"] = 0.0

    # payload const for dominant id
    dom_group = df[df["ID"] == stats["dominant_id"]]
    if not dom_group.empty:
        payload = dom_group[BYTE_COLUMNS].to_numpy()
        stats["dominant_payload_var"] = float(payload.var())
    else:
        stats["dominant_payload_var"] = 0.0

    return stats


# Cell 4: MCQ generation (a few template types)
def generate_mcq_attack_type(stats: dict, profile: dict, rng: np.random.Generator) -> Optional[dict]:
    """
    Attack-type MCQ:
    Question: which attack type best fits this window?
    Options: 4 labels from ATTACK_LABELS.
    Ground truth: profile['attack_label'].
    """
    gt_label = profile.get("attack_label", "DoS")
    distractors = [x for x in ATTACK_LABELS if x != gt_label]
    if len(distractors) < 3:
        return None
    chosen_distractors = list(rng.choice(distractors, size=3, replace=False))
    labels = [gt_label] + chosen_distractors
    rng.shuffle(labels)

    label_to_text = {
        "DoS":   "Flooding / DoS-like behavior",
        "Fuzzy": "Fuzzy ID injection with many unseen IDs",
        "Gear":  "Gear spoofing affecting transmission or drive state",
        "RPM":   "RPM spoofing affecting engine speed readings",
    }
    options_text = [label_to_text[l] for l in labels]
    correct_index = labels.index(gt_label)
    letters = ["A", "B", "C", "D"]
    options = {letters[i]: options_text[i] for i in range(4)}
    answer = letters[correct_index]

    return {
        "type": "attack_type",
        "question": "Which attack type best fits the behavior in this window?",
        "options": options,
        "answer": answer,
    }


def generate_mcq_dominant_id(df: pd.DataFrame, stats: dict, rng: np.random.Generator) -> Optional[dict]:
    """
    ID frequency MCQ:
    Question: which CAN ID appears most frequently in this window?
    """
    id_counts = stats.get("id_counts")
    if id_counts is None or id_counts.empty:
        return None

    top_ids = list(id_counts.index[:4])
    if len(top_ids) < 2:
        return None

    if len(top_ids) < 4:
        other_ids = [int(i) for i in df["ID"].unique() if int(i) not in top_ids]
        rng.shuffle(other_ids)
        top_ids = top_ids + other_ids[: max(0, 4 - len(top_ids))]
        top_ids = top_ids[:4]

    dom_id = stats["dominant_id"]
    if dom_id not in top_ids:
        top_ids[0] = dom_id

    rng.shuffle(top_ids)
    letters = ["A", "B", "C", "D"]
    options = {}
    correct_index = None
    for i, id_val in enumerate(top_ids):
        options[letters[i]] = f"ID 0x{int(id_val):03X}"
        if int(id_val) == dom_id:
            correct_index = i
    if correct_index is None:
        return None
    answer = letters[correct_index]

    return {
        "type": "dominant_id",
        "question": "Which CAN ID appears most frequently in this window?",
        "options": options,
        "answer": answer,
    }


def generate_mcq_timing(stats: dict) -> Optional[dict]:
    """
    Timing behavior MCQ:
    """
    diffs = stats.get("diffs", np.array([]))
    if diffs.size == 0:
        return None

    gap_max = stats["gap_max"]
    gap_std = stats["gap_std"]

    # A: suppression (large gaps)
    # B: flooding (extremely small gaps & low variance)
    # C: uniform timing
    # D: random timing
    if gap_max > SUPPRESSION_THRESHOLD * 5:
        correct = "A"
    elif stats["gap_mean"] < FLOODING_THRESHOLD and gap_std < FLOODING_THRESHOLD:
        correct = "B"
    elif gap_std < stats["gap_mean"] * 0.1:
        correct = "C"
    else:
        correct = "D"

    options = {
        "A": "Several large gaps suggest suppression behavior.",
        "B": "Extremely small gaps suggest flooding.",
        "C": "Timing is mostly uniform with minor jitter.",
        "D": "Timing is random with no clear pattern.",
    }

    return {
        "type": "timing",
        "question": "Which description best matches the timing behavior in this window?",
        "options": options,
        "answer": correct,
    }



def generate_mcq_abnormal_rate_id(stats: dict, rng: np.random.Generator) -> Optional[dict]:
    """
    Q2/Q18 Which CAN ID shows an abnormal increase in transmission rate?
    """
    id_counts = stats.get("id_counts")
    if id_counts is None or id_counts.empty:
        return None
    total = id_counts.sum()
    dom_id = stats["dominant_id"]
    dom_share = stats["dominant_share"]

    threshold = 0.5
    if dom_share > threshold:
        gt = "abnormal"
    else:
        gt = "none"

    letters = ["A", "B", "C", "D"]
    candidate_ids = list(id_counts.index[:4])
    if dom_id not in candidate_ids:
        candidate_ids.insert(0, dom_id)
    candidate_ids = [int(x) for x in candidate_ids[:3]]

    options = {}
    correct_letter = None
    idx = 0
    if gt == "abnormal":
        options[letters[idx]] = f"ID 0x{dom_id:03X}"
        correct_letter = letters[idx]
        idx += 1
        for cid in candidate_ids:
            if cid == dom_id:
                continue
            if idx >= 3:
                break
            options[letters[idx]] = f"ID 0x{cid:03X}"
            idx += 1
        options[letters[idx]] = "No ID shows abnormal frequency"
    else:
        options[letters[idx]] = "No ID shows abnormal frequency"
        correct_letter = letters[idx]
        idx += 1
        for cid in candidate_ids:
            if idx >= 4:
                break
            options[letters[idx]] = f"ID 0x{cid:03X}"
            idx += 1

    if correct_letter is None or len(options) < 2:
        return None

    return {
        "type": "id_abnormal_rate",
        "question": "Which CAN ID shows an abnormal increase in transmission rate?",
        "options": options,
        "answer": correct_letter,
    }


def generate_mcq_missing_expected_id(df: pd.DataFrame, profile: dict,
                                     rng: np.random.Generator) -> Optional[dict]:
    """
    Q3/Q21 Which ID is missing compared to expected control IDs?
    """
    expected_ids: Set[int] = profile.get("expected_ids", set())
    if not expected_ids:
        return None

    present_ids = set(int(x) for x in df["ID"].unique())
    missing = [eid for eid in expected_ids if eid not in present_ids]

    letters = ["A", "B", "C", "D"]
    options = {}
    correct_letter = None

    if missing:

        correct_id = rng.choice(missing)
        options[letters[0]] = f"ID 0x{correct_id:03X}"
        correct_letter = letters[0]

        present_expected = [eid for eid in expected_ids if eid in present_ids and eid != correct_id]
        rng.shuffle(present_expected)
        idx = 1
        for eid in present_expected[:2]:
            options[letters[idx]] = f"ID 0x{eid:03X}"
            idx += 1
        options[letters[idx]] = "None is missing"
    else:
        options[letters[0]] = "None is missing"
        correct_letter = letters[0]
        present_expected = list(expected_ids)
        rng.shuffle(present_expected)
        idx = 1
        for eid in present_expected[:3]:
            options[letters[idx]] = f"ID 0x{eid:03X}"
            idx += 1

    return {
        "type": "expected_id_missing",
        "question": "Which ID is missing compared to expected control IDs?",
        "options": options,
        "answer": correct_letter,
    }


def generate_mcq_constant_payload_id(df: pd.DataFrame,
                                     rng: np.random.Generator) -> Optional[dict]:
    """
    Q10 Which ID shows an unusually constant payload?
    """
    letters = ["A", "B", "C", "D"]
    options = {}
    candidate_ids = []

    for id_val, group in df.groupby("ID"):
        if len(group) < 3:
            continue
        payload = group[BYTE_COLUMNS].to_numpy()
        var = payload.var()
        if var < 1e-3:
            candidate_ids.append(int(id_val))

    if not candidate_ids:
        options[letters[0]] = "None shows constant payload"
        correct_letter = letters[0]
        ids = list(int(x) for x in df["ID"].unique())
        rng.shuffle(ids)
        idx = 1
        for id_val in ids[:3]:
            options[letters[idx]] = f"ID 0x{id_val:03X}"
            idx += 1
    else:
        correct_id = rng.choice(candidate_ids)
        options[letters[0]] = f"ID 0x{correct_id:03X}"
        correct_letter = letters[0]
        other_ids = [int(x) for x in df["ID"].unique() if int(x) != correct_id]
        rng.shuffle(other_ids)
        idx = 1
        for id_val in other_ids[:2]:
            options[letters[idx]] = f"ID 0x{id_val:03X}"
            idx += 1
        options[letters[idx]] = "None shows constant payload"

    return {
        "type": "constant_payload_id",
        "question": "Which ID shows an unusually constant payload?",
        "options": options,
        "answer": correct_letter,
    }


def generate_mcq_payload_pattern(stats: dict) -> Optional[dict]:
    """
    Q11 Which phenomenon best describes the payload pattern?
    """
    var = stats.get("dominant_payload_var", 0.0)

    if var < 1e-3:
        correct = "A"  # Several IDs transmit identical payloads repeatedly
    elif var < 1.0:
        correct = "B"  # Payload values increase steadily
    else:
        correct = "C"  # Payload varies unpredictably

    options = {
        "A": "Several IDs transmit identical payloads repeatedly.",
        "B": "Payload values increase steadily.",
        "C": "Payload varies unpredictably.",
        "D": "Most payloads are near zero.",
    }

    return {
        "type": "payload_pattern",
        "question": "Which phenomenon best describes the payload pattern?",
        "options": options,
        "answer": correct,
    }


def generate_mcq_dlc_distribution(df: pd.DataFrame) -> Optional[dict]:
    """
    Q14 Which statement best describes DLC distribution?
    """
    dlc = df["DLC"].to_numpy()
    if dlc.size == 0:
        return None
    share_8 = (dlc == 8).mean()
    share_low = (dlc <= 4).mean()

    if share_8 > 0.6:
        correct = "A"
    elif share_low > 0.6:
        correct = "C"
    else:
        correct = "B"

    options = {
        "A": "Majority of frames have DLC = 8.",
        "B": "DLC values vary evenly.",
        "C": "DLC is consistently low.",
        "D": "DLC appears corrupted.",
    }

    return {
        "type": "dlc_distribution",
        "question": "Which statement best describes DLC distribution?",
        "options": options,
        "answer": correct,
    }


def generate_mcq_flag_behavior(df: pd.DataFrame) -> Optional[dict]:
    """
    Q24 What best describes flag behavior in this window?
    """
    flags = df["Flag"].to_numpy()
    if flags.size == 0:
        return None
    unique_flags = np.unique(flags)

    if len(unique_flags) == 1:
        correct = "A"  # All flags are identical
    elif len(unique_flags) == 2 and all(v in [0, 1] for v in unique_flags):
        correct = "B"  # Both flag values (0/1) appear frequently
    else:
        correct = "C"

    options = {
        "A": "All flags are identical.",
        "B": "Both flag values (0/1) appear frequently.",
        "C": "Flags are inconsistent and likely corrupted.",
        "D": "Several flags appear missing.",
    }

    return {
        "type": "flag_behavior",
        "question": "What best describes flag behavior in this window?",
        "options": options,
        "answer": correct,
    }


def generate_mcq_overall_window(stats: dict) -> Optional[dict]:
    """
    Q30 How would you best characterize this window overall?
    """
    dom_share = stats.get("dominant_share", 0.0)
    gap_std = stats.get("gap_std", 0.0)

    if dom_share > 0.7 and gap_std < FLOODING_THRESHOLD:
        correct = "B"  # Highly irregular and unsafe
    elif dom_share < 0.4 and gap_std < stats.get("gap_mean", 1.0) * 0.1:
        correct = "C"  # Uniform and typical
    else:
        correct = "A"  # Mostly stable with minor anomalies

    options = {
        "A": "Mostly stable with minor anomalies.",
        "B": "Highly irregular and unsafe.",
        "C": "Uniform and typical.",
        "D": "Largely empty or incomplete.",
    }

    return {
        "type": "overall_window",
        "question": "How would you best characterize this window overall?",
        "options": options,
        "answer": correct,
    }


def generate_mcq_id_most_irregular_timing(df: pd.DataFrame,
                                          rng: np.random.Generator) -> Optional[dict]:
    """
    Q4 / Q4 
    Which ID shows the most irregular timing pattern?
    """
    gaps_std = {}
    for id_val, group in df.groupby("ID"):
        ts = group["Timestamp"].to_numpy()
        if ts.size < 3:
            continue
        diffs = np.diff(ts)
        if diffs.size == 0:
            continue
        gaps_std[int(id_val)] = float(diffs.std())

    if not gaps_std:
        return None

    sorted_ids = sorted(gaps_std.items(), key=lambda x: x[1], reverse=True)
    correct_id = sorted_ids[0][0]

    other_ids = [id_ for id_, _ in sorted_ids[1:]]
    ids_for_options = [correct_id]
    rng.shuffle(other_ids)
    ids_for_options.extend(other_ids[:3])
    ids_for_options = ids_for_options[:4]
    rng.shuffle(ids_for_options)

    letters = ["A", "B", "C", "D"]
    options = {}
    correct_letter = None
    for i, id_val in enumerate(ids_for_options):
        options[letters[i]] = f"ID 0x{id_val:03X}"
        if id_val == correct_id:
            correct_letter = letters[i]

    if correct_letter is None:
        return None

    return {
        "type": "id_irregular_timing",
        "question": "Which CAN ID shows the most irregular timing pattern?",
        "options": options,
        "answer": correct_letter,
    }


def generate_mcq_id_shortest_gap(df: pd.DataFrame,
                                 rng: np.random.Generator) -> Optional[dict]:
    """
    Q7 
    Which ID shows the shortest average inter-frame gap?
    """
    gaps_mean = {}
    for id_val, group in df.groupby("ID"):
        ts = group["Timestamp"].to_numpy()
        if ts.size < 3:
            continue
        diffs = np.diff(ts)
        if diffs.size == 0:
            continue
        gaps_mean[int(id_val)] = float(diffs.mean())

    if not gaps_mean:
        return None

    sorted_ids = sorted(gaps_mean.items(), key=lambda x: x[1])
    correct_id = sorted_ids[0][0]

    other_ids = [id_ for id_, _ in sorted_ids[1:]]
    ids_for_options = [correct_id]
    rng.shuffle(other_ids)
    ids_for_options.extend(other_ids[:3])
    ids_for_options = ids_for_options[:4]
    rng.shuffle(ids_for_options)

    letters = ["A", "B", "C", "D"]
    options = {}
    correct_letter = None
    for i, id_val in enumerate(ids_for_options):
        options[letters[i]] = f"ID 0x{id_val:03X}"
        if id_val == correct_id:
            correct_letter = letters[i]

    if correct_letter is None:
        return None

    return {
        "type": "id_shortest_gap",
        "question": "Which CAN ID shows the shortest average inter-frame gap?",
        "options": options,
        "answer": correct_letter,
    }


def generate_mcq_window_duration(stats: dict) -> Optional[dict]:
    """
    Q8 
    What best describes the time coverage of this window?
    """
    duration = stats.get("window_duration", 0.0)
    if duration <= 0:
        return None

    if duration > 0.5:
        correct = "A"  # unusually long
    elif duration < 0.05:
        correct = "C"  # too short
    else:
        correct = "B"  # typical duration

    options = {
        "A": "The window spans an unusually long duration.",
        "B": "The window spans a typical duration.",
        "C": "The window is too short to analyze.",
        "D": "Time information is inconsistent.",
    }
    return {
        "type": "window_duration",
        "question": "What best describes the time coverage of this window?",
        "options": options,
        "answer": correct,
    }


def generate_mcq_burst_explanation(stats: dict) -> Optional[dict]:
    """
    Q9/Q19/Q20
    What best explains the burst behavior observed?
    """
    dom_share = stats.get("dominant_share", 0.0)
    gap_mean = stats.get("gap_mean", 0.0)
    gap_std = stats.get("gap_std", 0.0)

    if gap_mean < FLOODING_THRESHOLD and dom_share > 0.5:
        correct = "B"  # Flooding from a compromised ECU
    elif gap_mean > SUPPRESSION_THRESHOLD and gap_std > gap_mean * 0.5:
        correct = "C"  # Diagnostic traffic or recovery
    else:
        correct = "A"  # Normal periodic or mild burst

    options = {
        "A": "Normal periodic behavior with minor bursts.",
        "B": "Flooding from a compromised ECU.",
        "C": "Diagnostic or recovery traffic causing bursts.",
        "D": "Random logging artifact with no pattern.",
    }
    return {
        "type": "burst_explanation",
        "question": "What best explains the burst behavior observed in this window?",
        "options": options,
        "answer": correct,
    }


def generate_mcq_high_dlc_id(df: pd.DataFrame,
                             rng: np.random.Generator) -> Optional[dict]:
    """
    Q15 
    Which ID shows the highest average DLC usage?
    """
    if "DLC" not in df.columns:
        return None

    dlc_mean_by_id = df.groupby("ID")["DLC"].mean()
    if dlc_mean_by_id.empty:
        return None

    correct_id = int(dlc_mean_by_id.idxmax())

    other_ids = [int(x) for x in dlc_mean_by_id.index if int(x) != correct_id]
    rng.shuffle(other_ids)
    ids_for_options = [correct_id] + other_ids[:3]
    ids_for_options = ids_for_options[:4]
    rng.shuffle(ids_for_options)

    letters = ["A", "B", "C", "D"]
    options = {}
    correct_letter = None
    for i, id_val in enumerate(ids_for_options):
        options[letters[i]] = f"0x{id_val:03X}"
        if id_val == correct_id:
            correct_letter = letters[i]

    if correct_letter is None:
        return None

    return {
        "type": "high_dlc_id",
        "question": "Which ID shows the highest DLC usage?",
        "options": options,
        "answer": correct_letter,
    }


def generate_mcq_high_dlc_increase(df: pd.DataFrame) -> Optional[dict]:
    """
    Q16 
    What does a sudden increase in high-DLC frames suggest?
    """
    if "DLC" not in df.columns:
        return None
    dlc = df["DLC"].to_numpy()
    if dlc.size == 0:
        return None

    high_share = (dlc >= 7).mean()
    if high_share > 0.5:
        correct = "B"  # Injection of crafted frames
    else:
        correct = "A"  # Onboard diagnostics OR normal

    options = {
        "A": "Onboard diagnostics or normal high-payload traffic.",
        "B": "Injection of crafted frames.",
        "C": "Normal low-rate sensor updates.",
        "D": "Bus-off recovery sequence.",
    }
    return {
        "type": "high_dlc_increase",
        "question": "What does a sudden increase in high-DLC frames suggest?",
        "options": options,
        "answer": correct,
    }


def generate_mcq_critical_id_abnormal(df: pd.DataFrame,
                                      profile: dict,
                                      rng: np.random.Generator) -> Optional[dict]:
    """
    Q22 
    Which critical ID behaves abnormally?
    """
    critical_ids: Set[int] = profile.get("critical_ids", set())
    if not critical_ids:
        return None

    df_crit = df[df["ID"].isin(critical_ids)]
    if df_crit.empty:
        return None

    id_counts = df_crit["ID"].value_counts()
    total = id_counts.sum()
    abnormal_scores = {}

    for id_val, count in id_counts.items():
        freq_score = count / total
        flags = df_crit[df_crit["ID"] == id_val]["Flag"].to_numpy()
        flag_var = float(np.var(flags)) if flags.size > 0 else 0.0
        abnormal_scores[int(id_val)] = freq_score + flag_var

    if not abnormal_scores:
        return None

    sorted_ids = sorted(abnormal_scores.items(), key=lambda x: x[1], reverse=True)
    correct_id = sorted_ids[0][0]

    other_ids = [id_ for id_, _ in sorted_ids[1:]]
    ids_for_options = [correct_id]
    rng.shuffle(other_ids)
    ids_for_options.extend(other_ids[:3])
    ids_for_options = ids_for_options[:4]
    rng.shuffle(ids_for_options)

    letters = ["A", "B", "C", "D"]
    options = {}
    correct_letter = None
    for i, id_val in enumerate(ids_for_options):
        options[letters[i]] = f"ID 0x{id_val:03X}"
        if id_val == correct_id:
            correct_letter = letters[i]

    if correct_letter is None:
        return None

    return {
        "type": "critical_id_abnormal",
        "question": "Which critical ID behaves abnormally?",
        "options": options,
        "answer": correct_letter,
    }


def generate_mcq_flag_suspicious_id(df: pd.DataFrame,
                                    rng: np.random.Generator) -> Optional[dict]:
    """
    Q25 
    Which ID's flag pattern appears most suspicious?
    """
    suspicious_scores = {}
    for id_val, group in df.groupby("ID"):
        flags = group["Flag"].to_numpy()
        if flags.size < 3:
            continue
        switches = np.sum(flags[1:] != flags[:-1])
        unique = np.unique(flags)
        if len(unique) == 1:
            score = 0.0 
        else:
            score = switches / (flags.size - 1)
        suspicious_scores[int(id_val)] = score

    if not suspicious_scores:
        return None

    sorted_ids = sorted(suspicious_scores.items(), key=lambda x: x[1], reverse=True)
    correct_id = sorted_ids[0][0]

    other_ids = [id_ for id_, _ in sorted_ids[1:]]
    ids_for_options = [correct_id]
    rng.shuffle(other_ids)
    ids_for_options.extend(other_ids[:3])
    ids_for_options = ids_for_options[:4]
    rng.shuffle(ids_for_options)

    letters = ["A", "B", "C", "D"]
    options = {}
    correct_letter = None
    for i, id_val in enumerate(ids_for_options):
        options[letters[i]] = f"ID with flag pattern on 0x{id_val:03X}"
        if id_val == correct_id:
            correct_letter = letters[i]

    if correct_letter is None:
        return None

    return {
        "type": "flag_suspicious_id",
        "question": "Which ID's flag pattern appears most suspicious?",
        "options": options,
        "answer": correct_letter,
    }


def generate_mcq_attack_explanation(stats: dict) -> Optional[dict]:
    """
    Q26/Q29
    Which explanation best fits the irregular behavior in this window?
    """
    dom_share = stats.get("dominant_share", 0.0)
    gap_mean = stats.get("gap_mean", 0.0)
    gap_std = stats.get("gap_std", 0.0)
    var = stats.get("dominant_payload_var", 0.0)

    if dom_share > 0.7 and gap_mean < FLOODING_THRESHOLD:
        correct = "A"  # Malicious flooding
    elif dom_share > 0.7 and var < 1e-3:
        correct = "A"  # malicious flooding / fabricated
    elif gap_mean > SUPPRESSION_THRESHOLD and gap_std > gap_mean * 0.5:
        correct = "B"  # Overloaded network or suppression
    elif var < 1e-3:
        correct = "C"  # Legitimate sensor mode change
    else:
        correct = "D"  # Logging artifact / minor anomalies

    options = {
        "A": "Malicious flooding or fabricated high-rate data.",
        "B": "Overloaded network or suppression-like behavior.",
        "C": "Legitimate mode transition or sensor update burst.",
        "D": "Logging artifact or mild, non-critical anomaly.",
    }
    return {
        "type": "attack_explanation",
        "question": "Which explanation best fits the irregular behavior in this window?",
        "options": options,
        "answer": correct,
    }


def generate_mcq_analysis_method(stats: dict) -> Optional[dict]:
    """
    Q31
    Which type of analysis would be most appropriate for this window?
    """
    dom_share = stats.get("dominant_share", 0.0)
    gap_std = stats.get("gap_std", 0.0)
    gap_mean = stats.get("gap_mean", 0.0)
    var = stats.get("dominant_payload_var", 0.0)

    has_timing_issue = gap_std > gap_mean * 0.5 and gap_mean > 0
    has_payload_issue = var > 1.0 or var < 1e-3
    has_id_issue = dom_share > 0.5

    if has_timing_issue and has_payload_issue and has_id_issue:
        correct = "D"  # All of the above
    elif has_timing_issue:
        correct = "A"
    elif has_payload_issue:
        correct = "B"
    elif has_id_issue:
        correct = "C"
    else:
        correct = "D" 

    options = {
        "A": "Temporal anomaly detection.",
        "B": "Payload entropy or pattern analysis.",
        "C": "ID frequency and distribution monitoring.",
        "D": "All of the above.",
    }
    return {
        "type": "analysis_method",
        "question": "Which type of analysis would be most appropriate for this window?",
        "options": options,
        "answer": correct,
    }


[INFO] Loaded DoS: 3665771 rows, 5 expected IDs, 3 critical IDs.
[INFO] Loaded Fuzzy: 3838860 rows, 5 expected IDs, 3 critical IDs.
[INFO] Loaded Gear: 4443142 rows, 5 expected IDs, 3 critical IDs.
[INFO] Loaded RPM: 4621702 rows, 5 expected IDs, 3 critical IDs.


In [None]:
# ==== Optional: enrich profiles with global baseline ID rates ====
# Run once after datasets, profiles are created.
for name, df_full in datasets.items():
    id_counts_full = df_full["ID"].value_counts()
    total_full = len(df_full)
    baseline = {int(i): float(c / total_full) for i, c in id_counts_full.items()}
    profiles[name]["baseline_id_rate"] = baseline


# ==== Extra MCQ templates ====
def generate_mcq_spoofing_suspect_id(df: pd.DataFrame,
                                     profile: dict,
                                     rng: np.random.Generator) -> Optional[dict]:
    expected_ids: Set[int] = profile.get("expected_ids", set())
    id_counts = df["ID"].value_counts()
    if id_counts.empty:
        return None

    novel_ids = [int(i) for i in id_counts.index if int(i) not in expected_ids]
    if not novel_ids:
        # No clear spoofing candidate
        letters = ["A", "B", "C", "D"]
        options = {
            "A": "The ID with the most irregular timing.",
            "B": "The ID with the largest DLC.",
            "C": "The ID that appears only once.",
            "D": "No ID shows clear spoofing behavior.",
        }
        return {
            "type": "spoofing_suspect_id",
            "question": "Which ID most likely indicates a spoofing attempt?",
            "options": options,
            "answer": "D",
        }

    # Pick the most frequent novel ID as spoofing candidate
    novel_counts = id_counts.loc[[i for i in id_counts.index if int(i) in novel_ids]]
    correct_id = int(novel_counts.idxmax())

    other_novel = [i for i in novel_ids if i != correct_id]
    rng.shuffle(other_novel)
    ids_for_options = [correct_id] + other_novel[:2]
    ids_for_options = ids_for_options[:3]

    letters = ["A", "B", "C", "D"]
    options = {}
    correct_letter = None
    for i, id_val in enumerate(ids_for_options):
        options[letters[i]] = f"ID 0x{id_val:03X}"
        if id_val == correct_id:
            correct_letter = letters[i]
    options[letters[3]] = "No ID shows clear spoofing behavior."

    if correct_letter is None:
        return None

    return {
        "type": "spoofing_suspect_id",
        "question": "Which ID most likely indicates a spoofing attempt?",
        "options": options,
        "answer": correct_letter,
    }


def generate_mcq_fabricated_payload_id(df: pd.DataFrame,
                                       rng: np.random.Generator) -> Optional[dict]:
    candidate_ids = []
    for id_val, group in df.groupby("ID"):
        if len(group) < 3:
            continue
        payload = group[BYTE_COLUMNS].to_numpy()
        var = payload.var()
        if var < 1e-3:
            candidate_ids.append(int(id_val))

    letters = ["A", "B", "C", "D"]
    options = {}

    if not candidate_ids:
        options["A"] = "The ID with repetitive identical frames."
        options["B"] = "The ID with the largest DLC."
        options["C"] = "The ID that appears very rarely."
        options["D"] = "None of the IDs shows fabricated data."
        return {
            "type": "fabricated_payload_id",
            "question": "Which ID's payload most likely suggests fabricated sensor data?",
            "options": options,
            "answer": "D",
        }

    correct_id = rng.choice(candidate_ids)
    options["A"] = f"ID 0x{int(correct_id):03X}"
    other_ids = [int(x) for x in df["ID"].unique() if int(x) != correct_id]
    rng.shuffle(other_ids)
    idx = 1
    for id_val in other_ids[:2]:
        options[letters[idx]] = f"ID 0x{id_val:03X}"
        idx += 1
    options[letters[idx]] = "None of the IDs shows fabricated data."

    return {
        "type": "fabricated_payload_id",
        "question": "Which ID's payload most likely suggests fabricated sensor data?",
        "options": options,
        "answer": "A",
    }


def generate_mcq_all_zero_payload_anomaly(df: pd.DataFrame) -> Optional[dict]:
    if not BYTE_COLUMNS:
        return None
    bytes_mat = df[BYTE_COLUMNS].to_numpy()
    if bytes_mat.size == 0:
        return None

    all_zero_mask = (bytes_mat == 0).all(axis=1)
    zero_share = all_zero_mask.mean()

    if zero_share > 0.3:
        correct = "A"
    else:
        correct = "D"

    options = {
        "A": "Uninitialized or placeholder sensor data.",
        "B": "Normal control frames under light load.",
        "C": "Overloaded network with random drops.",
        "D": "Legitimate control frames with diverse payloads.",
    }
    return {
        "type": "all_zero_payload_anomaly",
        "question": "Which type of anomaly is most consistent with constant all-zero payloads?",
        "options": options,
        "answer": correct,
    }

for name, df_full in datasets.items():
    id_counts_full = df_full["ID"].value_counts()
    total_full = len(df_full)
    baseline = {int(i): float(c / total_full) for i, c in id_counts_full.items()}
    profiles[name]["baseline_id_rate"] = baseline

def generate_mcq_expected_id_lower_rate(df: pd.DataFrame,
                                        profile: dict,
                                        rng: np.random.Generator) -> Optional[dict]:
    expected_ids: Set[int] = profile.get("expected_ids", set())
    baseline: Dict[int, float] = profile.get("baseline_id_rate", {})
    if not expected_ids or not baseline:
        return None

    id_counts = df["ID"].value_counts()
    total = len(df)
    if total == 0:
        return None

    ratios = {}
    for eid in expected_ids:
        base = baseline.get(int(eid), 0.0)
        window_count = float(id_counts.get(int(eid), 0.0))
        window_rate = window_count / total
        if base <= 0:
            continue
        ratio = window_rate / base
        ratios[int(eid)] = ratio

    if not ratios:
        return None

    low_ids = [eid for eid, r in ratios.items() if r < 0.5]
    letters = ["A", "B", "C", "D"]

    if not low_ids:
        options = {
            "A": "ID 0x000",
            "B": "ID 0x001",
            "C": "ID 0x002",
            "D": "No deviation observed.",
        }
        return {
            "type": "expected_id_lower_rate",
            "question": "Which expected ID appears at a lower rate than normal?",
            "options": options,
            "answer": "D",
        }

    sorted_ids = sorted(ratios.items(), key=lambda x: x[1])
    correct_id = sorted_ids[0][0]

    other_ids = [eid for eid, _ in sorted_ids[1:]]
    ids_for_options = [correct_id]
    rng.shuffle(other_ids)
    ids_for_options.extend(other_ids[:2])
    ids_for_options = ids_for_options[:3]

    options = {}
    options["A"] = f"ID 0x{int(correct_id):03X}"
    idx = 1
    for eid in ids_for_options[1:]:
        options[letters[idx]] = f"ID 0x{int(eid):03X}"
        idx += 1
    options[letters[idx]] = "No deviation observed."

    return {
        "type": "expected_id_lower_rate",
        "question": "Which expected ID appears at a lower rate than normal?",
        "options": options,
        "answer": "A",
    }





In [None]:
def shuffle_mcq_options_inplace(mcq: dict,
                                rng: np.random.Generator) -> dict:
    """
    mcq framework
      {
        "question": ...,
        "options": {"A": "...", "B": "...", "C": "...", "D": "..."},
        "answer": "B",
        ...
      }
    """
    options = mcq.get("options")
    answer = mcq.get("answer")

    if not options or answer not in options:
        return mcq

    correct_text = options[answer]
    texts = list(options.values())
    if len(texts) <= 1:
        return mcq

    rng.shuffle(texts)

    letters = ["A", "B", "C", "D"]
    new_options = {}
    new_answer = None
    for i, text in enumerate(texts):
        letter = letters[i]
        new_options[letter] = text
        if text == correct_text:
            new_answer = letter

    mcq["options"] = new_options
    if new_answer is not None:
        mcq["answer"] = new_answer

    return mcq


In [None]:
def generate_mcq_for_window(df: pd.DataFrame, profile: dict, rng: np.random.Generator) -> List[dict]:
    stats = compute_basic_stats(df)
    mcqs: List[dict] = []

    # Existing templates (already defined earlier in the notebook)
    mcq_attack = generate_mcq_attack_type(stats, profile, rng)
    if mcq_attack:
        mcqs.append(mcq_attack)

    mcq_dom = generate_mcq_dominant_id(df, stats, rng)
    if mcq_dom:
        mcqs.append(mcq_dom)

    mcq_abn = generate_mcq_abnormal_rate_id(stats, rng)
    if mcq_abn:
        mcqs.append(mcq_abn)

    mcq_missing = generate_mcq_missing_expected_id(df, profile, rng)
    if mcq_missing:
        mcqs.append(mcq_missing)

    mcq_time = generate_mcq_timing(stats)
    if mcq_time:
        mcqs.append(mcq_time)

    mcq_const = generate_mcq_constant_payload_id(df, rng)
    if mcq_const:
        mcqs.append(mcq_const)

    mcq_payload = generate_mcq_payload_pattern(stats)
    if mcq_payload:
        mcqs.append(mcq_payload)

    mcq_dlc = generate_mcq_dlc_distribution(df)
    if mcq_dlc:
        mcqs.append(mcq_dlc)

    mcq_flag = generate_mcq_flag_behavior(df)
    if mcq_flag:
        mcqs.append(mcq_flag)

    mcq_overall = generate_mcq_overall_window(stats)
    if mcq_overall:
        mcqs.append(mcq_overall)

    mcq_irreg_id = generate_mcq_id_most_irregular_timing(df, rng)
    if mcq_irreg_id:
        mcqs.append(mcq_irreg_id)

    mcq_shortest_gap = generate_mcq_id_shortest_gap(df, rng)
    if mcq_shortest_gap:
        mcqs.append(mcq_shortest_gap)

    mcq_win_dur = generate_mcq_window_duration(stats)
    if mcq_win_dur:
        mcqs.append(mcq_win_dur)

    mcq_burst = generate_mcq_burst_explanation(stats)
    if mcq_burst:
        mcqs.append(mcq_burst)

    mcq_high_dlc_id = generate_mcq_high_dlc_id(df, rng)
    if mcq_high_dlc_id:
        mcqs.append(mcq_high_dlc_id)

    mcq_high_dlc_inc = generate_mcq_high_dlc_increase(df)
    if mcq_high_dlc_inc:
        mcqs.append(mcq_high_dlc_inc)

    mcq_crit_abn = generate_mcq_critical_id_abnormal(df, profile, rng)
    if mcq_crit_abn:
        mcqs.append(mcq_crit_abn)

    mcq_flag_susp = generate_mcq_flag_suspicious_id(df, rng)
    if mcq_flag_susp:
        mcqs.append(mcq_flag_susp)

    mcq_att_expl = generate_mcq_attack_explanation(stats)
    if mcq_att_expl:
        mcqs.append(mcq_att_expl)

    mcq_analysis = generate_mcq_analysis_method(stats)
    if mcq_analysis:
        mcqs.append(mcq_analysis)

    mcq_spoof_id = generate_mcq_spoofing_suspect_id(df, profile, rng)
    if mcq_spoof_id:
        mcqs.append(mcq_spoof_id)

    mcq_fab_id = generate_mcq_fabricated_payload_id(df, rng)
    if mcq_fab_id:
        mcqs.append(mcq_fab_id)

    mcq_zero_anom = generate_mcq_all_zero_payload_anomaly(df)
    if mcq_zero_anom:
        mcqs.append(mcq_zero_anom)

    mcq_exp_low = generate_mcq_expected_id_lower_rate(df, profile, rng)
    if mcq_exp_low:
        mcqs.append(mcq_exp_low)

    # Limit per-window MCQs
    if len(mcqs) > MCQS_PER_WINDOW:
        idxs = rng.choice(len(mcqs), size=MCQS_PER_WINDOW, replace=False)
        mcqs = [mcqs[i] for i in idxs]

    for mcq in mcqs:
        shuffle_mcq_options_inplace(mcq, rng)
        
    return mcqs


In [None]:
# Cell 5: main loop – per-dataset question generation
for ds_idx, (name, df) in enumerate(datasets.items()):
    print(f"[INFO] Generating MCQ questions for {name}")
    starts = iter_window_starts(len(df))
    sampled_starts = sample_window_indices(starts, rng_global)

    out_dir = Path(f"{name}_mcq_qa")
    out_dir.mkdir(parents=True, exist_ok=True)
    ql_path = out_dir / f"{name.lower()}_mcq_questions.jsonl"
    qj_path = out_dir / f"{name.lower()}_mcq_questions.json"

    ql_path.write_text("", encoding="utf-8")

    qa_id_counter = 0
    records_all = []

    with ql_path.open("a", encoding="utf-8") as f:
        for window_idx, start in enumerate(tqdm(sampled_starts, desc=f"{name} windows")):
            window = df.iloc[start:start + WINDOW_SIZE].copy().reset_index(drop=True)
            mcq_items = generate_mcq_for_window(window, profiles[name], rng_global)
            if not mcq_items:
                continue

            context = format_window(window)
            for local_q_idx, mcq in enumerate(mcq_items):
                qa_id = f"{name}_MCQ_{window_idx:06d}_{local_q_idx:02d}"
                record = {
                    "qa_id": qa_id,
                    "metadata": {
                        "dataset": name,
                        "window_index": int(window_idx),
                        "window_start": int(start),
                        "window_size": int(WINDOW_SIZE),
                    },
                    "context": context,
                    "mcq_type": mcq["type"],
                    "question": mcq["question"],
                    "options": mcq["options"],
                    "answer": mcq["answer"], 
                }
                f.write(json.dumps(record, ensure_ascii=False) + "\n")
                records_all.append(record)
                qa_id_counter += 1

    with qj_path.open("w", encoding="utf-8") as jf:
        json.dump(records_all, jf, ensure_ascii=False, indent=2)

    print(f"[INFO] {name}: saved {qa_id_counter} MCQ questions -> {ql_path}, {qj_path}")


[INFO] Generating MCQ questions for DoS


DoS windows: 100%|██████████| 1832/1832 [00:50<00:00, 36.09it/s]


[INFO] DoS: saved 5496 MCQ questions -> DoS_tf_qa/dos_mcq_questions.jsonl, DoS_tf_qa/dos_mcq_questions.json
[INFO] Generating MCQ questions for Fuzzy


Fuzzy windows: 100%|██████████| 1919/1919 [00:59<00:00, 32.47it/s]


[INFO] Fuzzy: saved 5757 MCQ questions -> Fuzzy_tf_qa/fuzzy_mcq_questions.jsonl, Fuzzy_tf_qa/fuzzy_mcq_questions.json
[INFO] Generating MCQ questions for Gear


Gear windows: 100%|██████████| 2221/2221 [01:02<00:00, 35.35it/s]


[INFO] Gear: saved 6663 MCQ questions -> Gear_tf_qa/gear_mcq_questions.jsonl, Gear_tf_qa/gear_mcq_questions.json
[INFO] Generating MCQ questions for RPM


RPM windows: 100%|██████████| 2310/2310 [01:05<00:00, 35.34it/s]


[INFO] RPM: saved 6930 MCQ questions -> RPM_tf_qa/rpm_mcq_questions.jsonl, RPM_tf_qa/rpm_mcq_questions.json
