In [21]:
import json
from pathlib import Path
from typing import Dict, Iterable, List, Optional, Set
from tqdm import tqdm
import numpy as np
import pandas as pd

In [2]:
DATASET_PATHS = {
    "DoS":   "../../Dataset/DoS_dataset_clean.csv",
    "Fuzzy": "../../Dataset/Fuzzy_dataset_clean.csv",
    "Gear":  "../../Dataset/gear_dataset_clean.csv",
    "RPM":   "../../Dataset/RPM_dataset_clean.csv",
}

WINDOW_SIZE = 200
WINDOW_STRIDE = 200
SAMPLE_RATIO = 0.1
QUESTIONS_PER_WINDOW = 3
GLOBAL_SEED = 42

BYTE_COLUMNS = [f"Byte{i}" for i in range(1, 9)]
ATTACK_LABELS = ["DoS", "Fuzzy", "Gear", "RPM"]
HIGH_DLC_RATIO_THRESHOLD = 0.5
WINDOW_DURATION_THRESHOLD = 2.0  # seconds
RARE_ID_RATIO_THRESHOLD = 0.3

SELECTED_DATASETS = ["DoS", "Fuzzy", "Gear", "RPM"]

In [3]:
# Define helper function

def _normalize_flag_series(series: pd.Series) -> pd.Series:
    mapped = series.map({"R": 0, "T": 1})
    numeric = pd.to_numeric(series, errors="coerce")
    combined = mapped.fillna(numeric).fillna(0).astype(int)
    return combined


def load_datasets(paths: Dict[str, str]):
    datasets: Dict[str, pd.DataFrame] = {}
    profiles: Dict[str, dict] = {}

    for name, path in paths.items():
        csv_path = Path(path)
        if not csv_path.exists():
            print(f"[WARN] Dataset {name} not found at {csv_path}, skipping.")
            continue

        df = pd.read_csv(csv_path)
        if "Flag" in df.columns:
            df["Flag"] = _normalize_flag_series(df["Flag"])
        else:
            df["Flag"] = 0

        id_counts = df["ID"].value_counts()
        expected_ids = set(int(x) for x in id_counts.head(5).index.tolist())
        critical_ids = set(int(x) for x in id_counts.head(3).index.tolist())

        datasets[name] = df
        profiles[name] = {
            "expected_ids": expected_ids,
            "critical_ids": critical_ids,
            "attack_label": name,  
            "unique_id_threshold": max(10, len(expected_ids)),
        }

        print(f"[INFO] Loaded {name}: {len(df)} rows, "
              f"{len(expected_ids)} expected IDs, {len(critical_ids)} critical IDs.")

    return datasets, profiles


def format_window(df: pd.DataFrame, hide_indices: Optional[Iterable[int]] = None) -> str:
    hide_set = set(hide_indices or [])
    formatted_rows = []
    for idx, row in df.iterrows():
        byte_vals = [int(row[col]) for col in BYTE_COLUMNS]
        flag_repr = "Flag=? (hidden)" if idx in hide_set else f"Flag={int(row['Flag'])}"
        formatted_rows.append(
            f"Timestamp={row['Timestamp']:.6f} | "
            f"ID={int(row['ID'])} | DLC={int(row['DLC'])} | "
            f"bytes={byte_vals} | {flag_repr} |"
        )
    return "\n".join(formatted_rows)


def _has_constant_payload(df: pd.DataFrame) -> bool:
    for _, group in df.groupby("ID"):
        if len(group) < 2:
            continue
        if all(group[col].nunique(dropna=False) == 1 for col in BYTE_COLUMNS):
            return True
    return False

# Create T/F questions
def generate_tf_questions(df: pd.DataFrame, profile: dict,
                          rng: np.random.Generator) -> List[dict]:
    qas: List[dict] = []
    total_frames = len(df)
    if total_frames == 0:
        return qas

    id_counts = df["ID"].value_counts()
    unique_ids = set(df["ID"].tolist())
    byte_matrix = df[BYTE_COLUMNS]
    expected_ids: Set[int] = profile.get("expected_ids", set())
    critical_ids: Set[int] = profile.get("critical_ids", set())
    unique_id_threshold = profile.get("unique_id_threshold", 20)

    def make_entry(question: str, condition, context_type: str = "full"):
        answer = condition if isinstance(condition, str) else ("True" if condition else "False")
        return {"question": question, "answer": answer, "context_type": context_type}

    FLOOD_THRESHOLD = 3
    FLOOD_SIGNATURE_SHARE = 0.6

    qas.append(make_entry("All CAN frames in this window have Flag equal to 0.",
                          bool((df["Flag"] == 0).all())))
    qas.append(make_entry("At least one CAN frame in this window carries a non-zero Flag.",
                          bool((df["Flag"] != 0).any())))
    max_id_share = (id_counts.max() / total_frames) if total_frames > 0 else 0
    qas.append(make_entry("A single CAN ID accounts for more than half of the frames in this window.",
                          bool(max_id_share > 0.5)))
    qas.append(make_entry(f"Some CAN ID appears more than {FLOOD_THRESHOLD} times in this window.",
                          bool((id_counts > FLOOD_THRESHOLD).any())))

    qas.append(make_entry("Timestamps are strictly non-decreasing across consecutive frames.",
                          bool(df["Timestamp"].is_monotonic_increasing)))
    suppression_threshold = 1e-3
    flooding_threshold = 5e-4
    if total_frames > 1:
        diffs = df["Timestamp"].to_numpy()[1:] - df["Timestamp"].to_numpy()[:-1]
        has_large_gap = bool((diffs > suppression_threshold).any())
        all_small_gaps = bool((diffs < flooding_threshold).all())
    else:
        has_large_gap = False
        all_small_gaps = False
    qas.append(make_entry(
        f"Some inter-frame gap exceeds {suppression_threshold} seconds (possible suppression).",
        has_large_gap,
    ))
    qas.append(make_entry(
        f"All inter-frame gaps stay below {flooding_threshold} seconds (possible flooding).",
        all_small_gaps,
    ))

    qas.append(make_entry(
        "Some CAN ID transmits an identical payload across the entire window.",
        _has_constant_payload(df),
    ))
    qas.append(make_entry(
        "At least one expected control ID is missing from this window.",
        bool(expected_ids) and any(eid not in unique_ids for eid in expected_ids),
    ))
    qas.append(make_entry(
        "At least one data byte in this window is greater than 200.",
        bool((byte_matrix > 200).any().any()),
    ))
    unique_id_count = len(unique_ids)
    qas.append(make_entry(
        f"The number of distinct CAN IDs exceeds {unique_id_threshold} in this window.",
        unique_id_count > unique_id_threshold,
    ))
    high_dlc_ratio = (df["DLC"] >= 8).mean()
    qas.append(make_entry(
        f"At least {int(HIGH_DLC_RATIO_THRESHOLD * 100)}% of frames use DLC >= 8.",
        high_dlc_ratio >= HIGH_DLC_RATIO_THRESHOLD,
    ))
    window_duration = float(df["Timestamp"].iloc[-1] - df["Timestamp"].iloc[0]) if total_frames > 1 else 0.0
    qas.append(make_entry(
        f"The time span of this window exceeds {WINDOW_DURATION_THRESHOLD} seconds.",
        window_duration > WINDOW_DURATION_THRESHOLD,
    ))
    rare_id_ratio = (id_counts == 1).sum() / total_frames
    qas.append(make_entry(
        f"More than {int(RARE_ID_RATIO_THRESHOLD * 100)}% of frames are unique IDs (possible fuzzing).",
        rare_id_ratio > RARE_ID_RATIO_THRESHOLD,
    ))

    dominant_id = int(id_counts.index[0])
    dominant_group = df[df["ID"] == dominant_id]
    if not dominant_group.empty:
        byte1_series = dominant_group["Byte1"].tolist()
        if byte1_series:
            mode_value = int(pd.Series(byte1_series).mode().iloc[0])
            run = 0
            has_run = False
            for val in byte1_series:
                if val == mode_value:
                    run += 1
                    if run >= 3:
                        has_run = True
                        break
                else:
                    run = 0
            qas.append(make_entry(
                f"ID {dominant_id} produces Byte1 = {mode_value} for three consecutive transmissions.",
                has_run,
            ))
        all_zero_payload = bool((dominant_group[BYTE_COLUMNS] == 0).all(axis=None))
        qas.append(make_entry(
            f"ID {dominant_id} maintains an all-zero payload across the window.",
            all_zero_payload,
        ))


    duplicate_timestamps = False
    if total_frames > 1:
        for _, grp in df.groupby("Timestamp"):
            if len(grp["ID"].unique()) > 1:
                duplicate_timestamps = True
                break
    qas.append(make_entry(
        "Multiple different CAN IDs share the exact same timestamp in this window.",
        duplicate_timestamps,
    ))
    flood_signature = bool((max_id_share > FLOOD_SIGNATURE_SHARE) and all_small_gaps)
    qas.append(make_entry(
        "The window matches the heuristic signature of a flooding attack (high share + dense timing).",
        flood_signature,
    ))
    critical_disruption = bool(critical_ids) and any(
        (id_counts.get(cid, 0) > FLOOD_THRESHOLD) or
        bool(((df["ID"] == cid) & (df["Flag"] != 0)).any())
        for cid in critical_ids
    )
    qas.append(make_entry(
        "A critical-control ID exhibits anomalous activity in this window.",
        critical_disruption,
    ))

    if total_frames > 0:
        last_flag_zero = bool(df["Flag"].iloc[-1] == 0)
        qas.append(make_entry(
            "Given the window where the last frame's Flag is hidden, does the final frame have Flag = 0?",
            last_flag_zero,
            context_type="hidden_last_flag",
        ))
    if total_frames > 2:
        mid_index = total_frames // 2
        mid_flag_zero = bool(df["Flag"].iloc[mid_index] == 0)
        qas.append(make_entry(
            "Given the window where the middle frame's Flag is hidden, does that frame have Flag = 0?",
            mid_flag_zero,
            context_type="hidden_middle_flag",
        ))

    attack_label = profile.get("attack_label", "DoS")
    predicted_label = rng.choice(ATTACK_LABELS)
    qas.append(make_entry(
        f"This window most plausibly corresponds to the {predicted_label} attack type.",
        attack_label == predicted_label,
    ))

    return qas


def iter_window_starts(num_rows: int) -> List[int]:
    if num_rows < WINDOW_SIZE:
        return []
    return list(range(0, num_rows - WINDOW_SIZE + 1, WINDOW_STRIDE))


def sample_window_indices(starts: List[int], rng: np.random.Generator) -> List[int]:
    if not starts:
        return []
    sample_size = max(1, int(len(starts) * SAMPLE_RATIO))
    sample_size = min(sample_size, len(starts))
    return sorted(rng.choice(starts, size=sample_size, replace=False))


def select_questions(qas: List[dict], rng: np.random.Generator) -> List[dict]:
    if len(qas) <= QUESTIONS_PER_WINDOW:
        return qas
    indices = rng.choice(len(qas), size=QUESTIONS_PER_WINDOW, replace=False)
    return [qas[i] for i in indices]

In [4]:
# Create QA T/F dataset

datasets, profiles = load_datasets({k: v for k, v in DATASET_PATHS.items() if k in SELECTED_DATASETS})
rng = np.random.default_rng(GLOBAL_SEED)

for ds_idx, (name, df) in enumerate(datasets.items()):
    print(f"[INFO] Generating questions for {name}")
    starts = iter_window_starts(len(df))
    sampled_starts = sample_window_indices(starts, rng)
    output_dir = Path(f"{name}_tf_qa")
    output_dir.mkdir(parents=True, exist_ok=True)
    q_path = output_dir / f"{name.lower()}_questions.jsonl"
    q_path.write_text("", encoding="utf-8")  # truncate

    qa_id_counter = 0
    with q_path.open("a", encoding="utf-8") as f:
        for window_idx, start in enumerate(tqdm(sampled_starts, desc=f"{name} windows")):
            window = df.iloc[start:start + WINDOW_SIZE].copy().reset_index(drop=True)
            qa_items = generate_tf_questions(window, profiles[name], rng)
            selected_qas = select_questions(qa_items, rng)

            context_full = format_window(window)
            contexts = {"full": context_full}
            if len(window) > 0:
                contexts["hidden_last_flag"] = format_window(window, hide_indices={len(window) - 1})
            if len(window) > 2:
                mid_index = len(window) // 2
                contexts["hidden_middle_flag"] = format_window(window, hide_indices={mid_index})

            for local_q_idx, qa in enumerate(selected_qas):
                qa_id = f"{name}_{window_idx:06d}_{local_q_idx:02d}"
                context_type = qa.get("context_type", "full")
                context = contexts.get(context_type, context_full)
                record = {
                    "qa_id": qa_id,
                    "metadata": {
                        "dataset": name,
                        "window_index": int(window_idx),
                        "window_start": int(start),
                        "window_size": int(WINDOW_SIZE),
                    },
                    "context": context,
                    "context_type": context_type,
                    "question": qa["question"],
                    "ground_truth": qa["answer"],
                }
                f.write(json.dumps(record, ensure_ascii=False) + "\n")
                qa_id_counter += 1

    print(f"[INFO] {name}: saved {qa_id_counter} questions -> {q_path}")

    json_path = q_path.with_suffix(".json")
    records = []
    with q_path.open("r", encoding="utf-8") as f_in:
        for line in f_in:
            line = line.strip()
            if line:
                records.append(json.loads(line))

    with json_path.open("w", encoding="utf-8") as f_out:
        json.dump(records, f_out, ensure_ascii=False, indent=2)

    print(f"[INFO] {name}: saved JSON array -> {json_path}")


[INFO] Loaded DoS: 3665771 rows, 5 expected IDs, 3 critical IDs.
[INFO] Loaded Fuzzy: 3838860 rows, 5 expected IDs, 3 critical IDs.
[INFO] Loaded Gear: 4443142 rows, 5 expected IDs, 3 critical IDs.
[INFO] Loaded RPM: 4621702 rows, 5 expected IDs, 3 critical IDs.
[INFO] Generating questions for DoS


DoS windows:   0%|          | 0/1832 [00:00<?, ?it/s]

[INFO] DoS: saved 5496 questions -> DoS_tf_qa/dos_questions.jsonl
[INFO] DoS: saved JSON array -> DoS_tf_qa/dos_questions.json
[INFO] Generating questions for Fuzzy


Fuzzy windows:   0%|          | 0/1919 [00:00<?, ?it/s]

[INFO] Fuzzy: saved 5757 questions -> Fuzzy_tf_qa/fuzzy_questions.jsonl
[INFO] Fuzzy: saved JSON array -> Fuzzy_tf_qa/fuzzy_questions.json
[INFO] Generating questions for Gear


Gear windows:   0%|          | 0/2221 [00:00<?, ?it/s]

[INFO] Gear: saved 6663 questions -> Gear_tf_qa/gear_questions.jsonl
[INFO] Gear: saved JSON array -> Gear_tf_qa/gear_questions.json
[INFO] Generating questions for RPM


RPM windows:   0%|          | 0/2310 [00:00<?, ?it/s]

[INFO] RPM: saved 6930 questions -> RPM_tf_qa/rpm_questions.jsonl
[INFO] RPM: saved JSON array -> RPM_tf_qa/rpm_questions.json
