In [1]:
import json
import time
from collections import Counter, defaultdict

import numpy as np
import pandas as pd

from sklearn.compose import ColumnTransformer
from sklearn.pipeline import Pipeline
from sklearn.preprocessing import OneHotEncoder
from sklearn.metrics import (
    accuracy_score,
    precision_recall_fscore_support,
    classification_report,
    confusion_matrix,
)
from sklearn.ensemble import HistGradientBoostingClassifier
from sklearn.utils.class_weight import compute_sample_weight
from sklearn.inspection import permutation_importance

In [3]:
# ============================================================================
# MITRE ATT&CK Mapping for NSL-KDD Attack Categories
# ============================================================================
ATTACK_TO_MITRE = {
    # DoS
    "back": ["T1499.002"],
    "land": ["T1499.004"],
    "neptune": ["T1498.001"],
    "pod": ["T1499.004"],
    "smurf": ["T1498.001"],
    "teardrop": ["T1499.004"],
    "apache2": ["T1499.002"],
    "udpstorm": ["T1498.001"],
    "processtable": ["T1499.003"],
    "mailbomb": ["T1499.002"],
    # Probe
    "ipsweep": ["T1046"],
    "nmap": ["T1046"],
    "portsweep": ["T1046"],
    "satan": ["T1046"],
    "mscan": ["T1046"],
    "saint": ["T1046"],
    # R2L
    "ftp_write": ["T1071.002"],
    "guess_passwd": ["T1110.001"],
    "imap": ["T1078"],
    "multihop": ["T1090"],
    "phf": ["T1190"],
    "spy": ["T1056.001"],
    "warezclient": ["T1071.001"],
    "warezmaster": ["T1071.001"],
    "sendmail": ["T1190"],
    "named": ["T1190"],
    "snmpgetattack": ["T1046"],
    "snmpguess": ["T1110.001"],
    "xlock": ["T1110.001"],
    "xsnoop": ["T1056.001"],
    # U2R
    "buffer_overflow": ["T1068"],
    "loadmodule": ["T1068"],
    "perl": ["T1059.006"],
    "rootkit": ["T1014"],
    "ps": ["T1057"],
    "sqlattack": ["T1190"],
    "xterm": ["T1068"],
    "httptunnel": ["T1572"],
}

ATTACK_STAGES = {
    "Reconnaissance": ["ipsweep", "nmap", "portsweep", "satan", "mscan", "saint"],
    "Initial Access": ["ftp_write", "phf", "sendmail", "named", "snmpgetattack", "sqlattack"],
    "Credential Access": ["guess_passwd", "snmpguess", "xlock"],
    "Privilege Escalation": ["buffer_overflow", "loadmodule", "perl", "rootkit", "ps", "xterm", "httptunnel"],
    "Persistence": ["spy", "xsnoop", "rootkit"],
    "Lateral Movement": ["multihop", "warezmaster"],
    "Impact": ["back", "land", "neptune", "pod", "smurf", "teardrop", "apache2", "udpstorm", "processtable", "mailbomb"],
}


# ============================================================================
# Data loading and labels
# ============================================================================
def load_nsl_kdd(trainPath="./dataset/KDDTrain+.txt", testPath="./dataset/KDDTest+.txt"):
    """Load NSL-KDD dataset with all columns"""
    columns = [
        "duration",
        "protocol_type",
        "service",
        "flag",
        "src_bytes",
        "dst_bytes",
        "land",
        "wrong_fragment",
        "urgent",
        "hot",
        "num_failed_logins",
        "logged_in",
        "num_compromised",
        "root_shell",
        "su_attempted",
        "num_root",
        "num_file_creations",
        "num_shells",
        "num_access_files",
        "num_outbound_cmds",
        "is_host_login",
        "is_guest_login",
        "count",
        "srv_count",
        "serror_rate",
        "srv_serror_rate",
        "rerror_rate",
        "srv_rerror_rate",
        "same_srv_rate",
        "diff_srv_rate",
        "srv_diff_host_rate",
        "dst_host_count",
        "dst_host_srv_count",
        "dst_host_same_srv_rate",
        "dst_host_diff_srv_rate",
        "dst_host_same_src_port_rate",
        "dst_host_srv_diff_host_rate",
        "dst_host_serror_rate",
        "dst_host_srv_serror_rate",
        "dst_host_rerror_rate",
        "dst_host_srv_rerror_rate",
        "label",
        "difficulty",
    ]
    trainDf = pd.read_csv(trainPath, names=columns)
    testDf = pd.read_csv(testPath, names=columns)
    return trainDf, testDf


def create_attack_category_labels(df):
    """
    Create multi-class labels based on NSL-KDD categories
    0: Normal, 1: DoS, 2: Probe, 3: R2L, 4: U2R
    Unknown attacks are mapped to -1 so they can be excluded from evaluation
    """
    labelSeries = df["label"].astype(str).str.strip().str.rstrip(".")

    dos = {
        "back",
        "land",
        "neptune",
        "pod",
        "smurf",
        "teardrop",
        "apache2",
        "udpstorm",
        "processtable",
        "mailbomb",
    }
    probe = {"ipsweep", "nmap", "portsweep", "satan", "mscan", "saint"}
    r2l = {
        "ftp_write",
        "guess_passwd",
        "imap",
        "multihop",
        "phf",
        "spy",
        "warezclient",
        "warezmaster",
        "sendmail",
        "named",
        "snmpgetattack",
        "snmpguess",
        "xlock",
        "xsnoop",
    }
    u2r = {"buffer_overflow", "loadmodule", "perl", "rootkit", "ps", "sqlattack", "xterm", "httptunnel"}

    attackToCategory = {"normal": 0}
    attackToCategory.update({a: 1 for a in dos})
    attackToCategory.update({a: 2 for a in probe})
    attackToCategory.update({a: 3 for a in r2l})
    attackToCategory.update({a: 4 for a in u2r})

    y = labelSeries.map(attackToCategory).fillna(-1).astype(int).to_numpy()
    return y


# ============================================================================
# Temporal features (corrected: rolling is per connGroup, not global)
# ============================================================================
def extract_temporal_features(df, windowSizes=(2, 5, 10, 30, 100)):
    """Extract temporal features using per-connGroup rolling statistics"""
    dfCopy = df.copy()

    connGroup = (
        dfCopy["protocol_type"].astype(str)
        + "_"
        + dfCopy["service"].astype(str)
        + "_"
        + dfCopy["flag"].astype(str)
    )

    grouped = dfCopy.groupby(connGroup, sort=False)
    temporal = pd.DataFrame(index=dfCopy.index)

    for window in windowSizes:
        temporal[f"conn_count_{window}"] = grouped["duration"].transform(
            lambda s: s.rolling(window=window, min_periods=1).count()
        )
        temporal[f"avg_duration_{window}"] = grouped["duration"].transform(
            lambda s: s.rolling(window=window, min_periods=1).mean()
        )
        temporal[f"total_src_bytes_{window}"] = grouped["src_bytes"].transform(
            lambda s: s.rolling(window=window, min_periods=1).sum()
        )
        temporal[f"total_dst_bytes_{window}"] = grouped["dst_bytes"].transform(
            lambda s: s.rolling(window=window, min_periods=1).sum()
        )
        temporal[f"failed_login_rate_{window}"] = grouped["num_failed_logins"].transform(
            lambda s: s.rolling(window=window, min_periods=1).mean()
        )
        temporal[f"serror_rate_avg_{window}"] = grouped["serror_rate"].transform(
            lambda s: s.rolling(window=window, min_periods=1).mean()
        )

    temporal["protocol_change"] = (dfCopy["protocol_type"] != dfCopy["protocol_type"].shift(1)).astype(int)
    temporal["service_change"] = (dfCopy["service"] != dfCopy["service"].shift(1)).astype(int)
    temporal["flag_change"] = (dfCopy["flag"] != dfCopy["flag"].shift(1)).astype(int)

    # Vectorized time since last failed login (global sequence)
    failedMask = (dfCopy["num_failed_logins"].to_numpy() > 0).astype(int)
    idxArr = np.arange(len(dfCopy), dtype=int)
    lastFailedIdx = np.where(failedMask == 1, idxArr, -1)
    lastFailedIdx = pd.Series(lastFailedIdx).replace(-1, np.nan).ffill().fillna(-1).to_numpy(dtype=int)
    timeSince = np.where(lastFailedIdx == -1, 999, np.clip(idxArr - lastFailedIdx, 0, 999))
    temporal["time_since_failed_login"] = timeSince

    # Burst and spike indicators (based on within-group rolling signals already created)
    if "conn_count_10" in temporal.columns:
        prevConn = temporal["conn_count_10"].shift(10).fillna(0)
        temporal["conn_burst"] = (temporal["conn_count_10"] > (prevConn * 2)).astype(int)
    else:
        temporal["conn_burst"] = 0

    if "total_src_bytes_10" in temporal.columns and "total_dst_bytes_10" in temporal.columns:
        trafficNow = temporal["total_src_bytes_10"] + temporal["total_dst_bytes_10"]
        trafficPrev = (temporal["total_src_bytes_10"].shift(10).fillna(0) +
                       temporal["total_dst_bytes_10"].shift(10).fillna(0))
        temporal["traffic_spike"] = (trafficNow > (trafficPrev * 3)).astype(int)
    else:
        temporal["traffic_spike"] = 0

    temporal = temporal.fillna(0)
    return temporal


# ============================================================================
# Preprocessing and model
# ============================================================================
def build_preprocessor(trainX):
    """Build preprocessing pipeline (dense output for HistGradientBoosting)"""
    categoricalCols = ["protocol_type", "service", "flag"]
    numericCols = [c for c in trainX.columns if c not in categoricalCols]

    preprocessor = ColumnTransformer(
        transformers=[
            ("cat", OneHotEncoder(handle_unknown="ignore", sparse_output=False), categoricalCols),
            ("num", "passthrough", numericCols),
        ],
        remainder="drop",
        verbose_feature_names_out=False,
    )
    return preprocessor


def train_multiclass_model(trainX, trainY):
    """Train HistGradientBoosting with balanced sample weights"""
    preprocessor = build_preprocessor(trainX)

    model = HistGradientBoostingClassifier(
        max_iter=400,
        learning_rate=0.05,
        max_leaf_nodes=63,
        min_samples_leaf=20,
        l2_regularization=0.1,
        random_state=42,
        early_stopping=True,
        validation_fraction=0.1,
        verbose=0,
    )

    pipeline = Pipeline([("preprocess", preprocessor), ("model", model)])

    sampleWeight = compute_sample_weight(class_weight="balanced", y=trainY)
    startTime = time.time()
    pipeline.fit(trainX, trainY, model__sample_weight=sampleWeight)
    trainTime = time.time() - startTime

    print(f"Training completed in {trainTime:.2f}s")
    return pipeline


def evaluate_multiclass(pipeline, testX, testY, validMask=None):
    """Evaluate multi-class model performance (optionally ignoring unknown labels -1)"""
    predY = pipeline.predict(testX)

    if validMask is None:
        validMask = np.ones_like(testY, dtype=bool)

    testYValid = testY[validMask]
    predYValid = predY[validMask]

    classNames = ["Normal", "DoS", "Probe", "R2L", "U2R"]

    acc = accuracy_score(testYValid, predYValid)
    prec, rec, f1, _ = precision_recall_fscore_support(
        testYValid, predYValid, average="weighted", zero_division=0
    )
    macroF1 = precision_recall_fscore_support(
        testYValid, predYValid, average="macro", zero_division=0
    )[2]

    print("=" * 80)
    print("MULTI-CLASS ATTACK DETECTION RESULTS")
    print("=" * 80)
    print(f"Overall Accuracy (valid labels): {acc:.4f}")
    print(f"Weighted F1 (valid labels):      {f1:.4f}")
    print(f"Macro F1 (valid labels):         {macroF1:.4f}")
    print()
    print(classification_report(testYValid, predYValid, target_names=classNames, digits=4, zero_division=0))

    cm = confusion_matrix(testYValid, predYValid, labels=[0, 1, 2, 3, 4])
    cmDf = pd.DataFrame(cm, index=classNames, columns=classNames)
    print("Confusion Matrix (valid labels):")
    print(cmDf)

    metrics = {
        "accuracy": float(acc),
        "precision": float(prec),
        "recall": float(rec),
        "f1score": float(f1),
        "macroF1": float(macroF1),
    }
    return predY, metrics


# ============================================================================
# Analysis helpers
# ============================================================================
def detect_attack_sequences(dfWithTemporal, predictions, windowSize=20):
    """Detect multi-stage attack patterns and export explicit severity reasoning"""
    classNames = ["Normal", "DoS", "Probe", "R2L", "U2R"]
    attackSequences = []

    stepSize = max(1, windowSize // 2)
    n = len(predictions)

    # Tunable thresholds for explainability
    burstHighThreshold = 8
    spikeHighThreshold = 8

    for startIdx in range(0, n - windowSize + 1, stepSize):
        endIdx = startIdx + windowSize
        window = predictions[startIdx:endIdx]

        # Count classes in the window
        counts = np.bincount(window, minlength=5)  # indices 0..4
        attackTypeCounts = {
            "Normal": int(counts[0]),
            "DoS": int(counts[1]),
            "Probe": int(counts[2]),
            "R2L": int(counts[3]),
            "U2R": int(counts[4]),
        }

        uniqueAttacks = set(np.unique(window[window > 0]))
        if len(uniqueAttacks) < 2:
            continue

        # Temporal indicators (keep original fields)
        temporalIndicators = {}
        burstCount = 0
        spikeCount = 0

        if "conn_burst" in dfWithTemporal.columns:
            burstCount = int(dfWithTemporal.iloc[startIdx:endIdx]["conn_burst"].sum())
            temporalIndicators["burst_count"] = burstCount

        if "traffic_spike" in dfWithTemporal.columns:
            spikeCount = int(dfWithTemporal.iloc[startIdx:endIdx]["traffic_spike"].sum())
            temporalIndicators["traffic_spikes"] = spikeCount

        # Base outputs
        severity = "MEDIUM"
        pattern = "Multi-vector Attack"

        # Pattern logic (same as before, but we also capture order evidence)
        attackOrder = [int(p) for p in window if p > 0]
        orderEvidence = {}

        hasDoS = 1 in uniqueAttacks
        hasProbe = 2 in uniqueAttacks
        hasR2L = 3 in uniqueAttacks
        hasU2R = 4 in uniqueAttacks

        if hasProbe and (hasR2L or hasU2R):
            probeIdx = next((i for i, a in enumerate(attackOrder) if a == 2), None)
            exploitIdx = next((i for i, a in enumerate(attackOrder) if a in [3, 4]), None)
            orderEvidence = {"probe_idx": probeIdx, "exploit_idx": exploitIdx}

            if probeIdx is not None and exploitIdx is not None and probeIdx < exploitIdx:
                pattern = "Reconnaissance -> Exploitation (Kill Chain)"
                severity = "CRITICAL"
            else:
                pattern = "Reconnaissance + Exploitation (Parallel)"
                severity = "CRITICAL"
        elif hasR2L and hasU2R:
            pattern = "Initial Access -> Privilege Escalation"
            severity = "CRITICAL"
        elif hasDoS and len(uniqueAttacks) >= 2:
            pattern = "DoS + Other Attack (Distraction Tactic)"
            severity = "HIGH"
        elif hasU2R:
            # If U2R appears with something else, treat as high even if no explicit chain detected
            severity = "HIGH"

        # Explainability: compute a score + reasons
        reasonCodes = []
        reasonText = []
        score = 0

        # Pattern-based weight
        if "Kill Chain" in pattern:
            score += 55
            reasonCodes.append("KILL_CHAIN_ORDERED")
            reasonText.append("Probe detected before exploitation in the same window (kill-chain progression)")

        if "Parallel" in pattern:
            score += 45
            reasonCodes.append("RECON_EXPLOIT_PARALLEL")
            reasonText.append("Reconnaissance and exploitation co-occur in the same window")

        if "Privilege Escalation" in pattern:
            score += 55
            reasonCodes.append("ACCESS_TO_ROOT_CHAIN")
            reasonText.append("R2L and U2R both present (access followed by privilege escalation)")

        if "Distraction Tactic" in pattern:
            score += 25
            reasonCodes.append("DOS_DISTRACTION")
            reasonText.append("DoS occurs alongside other attack types (possible distraction)")

        # Presence-based weight
        if hasU2R:
            score += 35
            reasonCodes.append("HAS_U2R")
            reasonText.append("U2R present (privilege escalation / high impact)")

        if hasR2L:
            score += 20
            reasonCodes.append("HAS_R2L")
            reasonText.append("R2L present (remote-to-local compromise)")

        if hasProbe:
            score += 10
            reasonCodes.append("HAS_PROBE")
            reasonText.append("Probe present (reconnaissance / scanning)")

        # Volume/temporal-based weight
        if burstCount >= burstHighThreshold:
            score += 10
            reasonCodes.append("HIGH_BURST")
            reasonText.append(f"High connection burst activity (burst_count={burstCount})")
        elif burstCount > 0:
            reasonCodes.append("BURST_PRESENT")
            reasonText.append(f"Connection burst activity present (burst_count={burstCount})")

        if spikeCount >= spikeHighThreshold:
            score += 10
            reasonCodes.append("HIGH_TRAFFIC_SPIKE")
            reasonText.append(f"High traffic spike activity (traffic_spikes={spikeCount})")
        elif spikeCount > 0:
            reasonCodes.append("TRAFFIC_SPIKE_PRESENT")
            reasonText.append(f"Traffic spikes present (traffic_spikes={spikeCount})")

        # Cap score
        score = int(max(0, min(100, score)))

        # Optional: make severity consistent with score (so label always matches numeric evidence)
        # If you prefer to keep your old severity logic, comment this block out.
        if score >= 70:
            severity = "CRITICAL"
        elif score >= 45:
            severity = "HIGH"
        else:
            severity = "MEDIUM"

        attackSequences.append(
            {
                "window_start": int(startIdx),
                "window_end": int(endIdx),
                "attack_types": [classNames[a] for a in sorted(uniqueAttacks)],
                "attack_type_counts": attackTypeCounts,
                "severity": severity,
                "severity_score": score,
                "pattern": pattern,
                "reason_codes": reasonCodes,
                "reason_text": reasonText,
                "order_evidence": orderEvidence,
                "temporal_indicators": temporalIndicators,
            }
        )

    return attackSequences



def map_to_mitre_attack(df, predictions):
    """Map detected attacks to MITRE ATT&CK techniques and kill chain stages"""
    attackLabels = df["label"].astype(str).str.strip().str.rstrip(".").to_numpy()

    ttpCounts = defaultdict(int)
    stageCounts = defaultdict(int)

    for label, pred in zip(attackLabels, predictions):
        if pred <= 0:
            continue
        if label in ATTACK_TO_MITRE:
            for ttp in ATTACK_TO_MITRE[label]:
                ttpCounts[ttp] += 1
            for stage, attacks in ATTACK_STAGES.items():
                if label in attacks:
                    stageCounts[stage] += 1

    return ttpCounts, stageCounts


def compute_temporal_feature_importance(pipeline, testX, testY, temporalCols, validMask, maxSamples=5000):
    """
    Compute permutation importance for temporal features only
    Returns DataFrame with columns: feature, importance
    """
    if len(temporalCols) == 0:
        return None

    validIdx = np.where(validMask)[0]
    if len(validIdx) == 0:
        return None

    rng = np.random.default_rng(42)
    takeN = min(maxSamples, len(validIdx))
    sampleIdx = rng.choice(validIdx, size=takeN, replace=False)

    Xsub = testX.iloc[sampleIdx].copy()
    ysub = testY[sampleIdx].copy()

    # Compute permutation importance over input columns
    result = permutation_importance(
        pipeline,
        Xsub,
        ysub,
        scoring="f1_weighted",
        n_repeats=5,
        random_state=42,
        n_jobs=-1,
    )

    featureNames = np.array(Xsub.columns)
    importancesMean = result.importances_mean
    importanceDf = pd.DataFrame({"feature": featureNames, "importance": importancesMean})
    temporalMask = importanceDf["feature"].isin(temporalCols)
    temporalImportanceDf = importanceDf[temporalMask].sort_values("importance", ascending=False).reset_index(drop=True)

    return temporalImportanceDf


# ============================================================================
# Dashboard export
# ============================================================================
class NumpyEncoder(json.JSONEncoder):
    def default(self, obj):
        if isinstance(obj, (np.integer, np.int64, np.int32)):
            return int(obj)
        if isinstance(obj, (np.floating, np.float64, np.float32)):
            return float(obj)
        if isinstance(obj, np.ndarray):
            return obj.tolist()
        return super().default(obj)


print("=" * 80)
print("MULTI-STAGE ATTACK DETECTION PIPELINE")
print("NSL-KDD with HistGradientBoosting + Corrected Temporal Features")
print("=" * 80)
print()

print("Loading NSL-KDD dataset...")
trainDf, testDf = load_nsl_kdd()
print(f"Train samples: {len(trainDf)}, Test samples: {len(testDf)}")
print()

trainYAll = create_attack_category_labels(trainDf)
testYAll = create_attack_category_labels(testDf)

# Filter unknown labels (-1) from training if any appear
trainValidMask = trainYAll != -1
trainDf = trainDf.loc[trainValidMask].reset_index(drop=True)
trainY = trainYAll[trainValidMask]

# Baseline features
trainXBase = trainDf.drop(columns=["label", "difficulty"])
testXBase = testDf.drop(columns=["label", "difficulty"])

# Temporal features
print("Extracting temporal features...")
trainTemporal = extract_temporal_features(trainDf, windowSizes=(2, 5, 10, 30, 100))
testTemporal = extract_temporal_features(testDf, windowSizes=(2, 5, 10, 30, 100))

trainXTemporal = pd.concat([trainXBase, trainTemporal], axis=1)
testXTemporal = pd.concat([testXBase, testTemporal], axis=1)

# For evaluation: ignore unknown test labels (-1)
testValidMask = testYAll != -1
testY = testYAll

# Train baseline
print()
print("=" * 80)
print("PHASE 1: TRAIN BASELINE MODEL")
print("=" * 80)
pipelineBaseline = train_multiclass_model(trainXBase, trainY)

baselinePred = pipelineBaseline.predict(testXBase)
baselineAcc = accuracy_score(testY[testValidMask], baselinePred[testValidMask])
baselinePrec, baselineRec, baselineF1, _ = precision_recall_fscore_support(
    testY[testValidMask], baselinePred[testValidMask], average="weighted", zero_division=0
)

# Train temporal
print()
print("=" * 80)
print("PHASE 2: TRAIN TEMPORAL MODEL")
print("=" * 80)
pipelineTemporal = train_multiclass_model(trainXTemporal, trainY)

print()
print("=" * 80)
print("PHASE 3: EVALUATION (TEMPORAL MODEL)")
print("=" * 80)
predictions, temporalMetrics = evaluate_multiclass(pipelineTemporal, testXTemporal, testY, validMask=testValidMask)

temporalAcc = temporalMetrics["accuracy"]
temporalPrec = temporalMetrics["precision"]
temporalRec = temporalMetrics["recall"]
temporalF1 = temporalMetrics["f1score"]

# Compare summary
comparisonDf = pd.DataFrame(
    [
        {
            "Model": "Baseline",
            "Accuracy": baselineAcc,
            "Precision": baselinePrec,
            "Recall": baselineRec,
            "F1-Score": baselineF1,
        },
        {
            "Model": "With Temporal Features",
            "Accuracy": temporalAcc,
            "Precision": temporalPrec,
            "Recall": temporalRec,
            "F1-Score": temporalF1,
        },
    ]
)
print()
print("Performance Comparison (valid labels only):")
print(comparisonDf.to_string(index=False))

# Temporal feature importance (permutation)
print()
print("=" * 80)
print("PHASE 4: TEMPORAL FEATURE IMPORTANCE (PERMUTATION)")
print("=" * 80)
temporalCols = list(trainTemporal.columns)
temporalImportanceDf = compute_temporal_feature_importance(
    pipelineTemporal, testXTemporal, testY, temporalCols, testValidMask, maxSamples=5000
)
if temporalImportanceDf is not None and len(temporalImportanceDf) > 0:
    print("Top 10 temporal features (permutation importance):")
    print(temporalImportanceDf.head(10).to_string(index=False))
else:
    print("Temporal importance not computed or no temporal columns found")

# Multi-stage sequences (use temporal-augmented test df)
print()
print("=" * 80)
print("PHASE 5: MULTI-STAGE ATTACK SEQUENCES")
print("=" * 80)
testDfWithTemporal = pd.concat([testDf, testTemporal], axis=1)
attackSequences = detect_attack_sequences(testDfWithTemporal, predictions, windowSize=20)
print(f"Detected sequences: {len(attackSequences)}")

# MITRE mapping
print()
print("=" * 80)
print("PHASE 6: MITRE ATT&CK MAPPING")
print("=" * 80)
ttpCounts, stageCounts = map_to_mitre_attack(testDf, predictions)
print(f"Unique TTPs detected: {len(ttpCounts)}")

# Save CSV outputs
comparisonDf.to_csv("baseline_vs_temporal_comparison.csv", index=False)
if temporalImportanceDf is not None:
    temporalImportanceDf.to_csv("temporal_feature_importance.csv", index=False)

# ------------------------------------------------------------------------
# EXPORT DATA FOR DASHBOARD
# ------------------------------------------------------------------------
classCounts = Counter(predictions)
classCounts = {int(k): int(v) for k, v in classCounts.items()}

# Timeline sample (up to 500 points)
timelineData = []
sampleIdx = np.linspace(0, len(predictions) - 1, min(500, len(predictions)), dtype=int)
for i in sampleIdx:
    timelineData.append(
        {
            "index": int(i),
            "prediction": int(predictions[i]),
            "true_label": int(testY[i]),
        }
    )

# Attack detection over time (windows of 100)
attackTimeline = []
for i in range(0, len(predictions), 100):
    window = predictions[i : i + 100]
    attackTimeline.append({"window": int(i), "attacks": int(np.sum(window > 0))})

# Top temporal features for dashboard
topTemporalFeatures = []
if temporalImportanceDf is not None and len(temporalImportanceDf) > 0:
    topDf = temporalImportanceDf.head(10)
    topTemporalFeatures = [
        {"feature": str(row["feature"]), "importance": float(row["importance"])} for _, row in topDf.iterrows()
    ]

# Format sequences for dashboard
formattedSequences = []
for seq in attackSequences[:10]:
    tempIndicators = seq.get("temporal_indicators", {})
    tempIndicators = {str(k): int(v) for k, v in tempIndicators.items()}
    formattedSequences.append(
        {
            "pattern": str(seq["pattern"]),
            "severity": str(seq["severity"]),
            "attack_types": ", ".join(seq["attack_types"]),
            "window_start": int(seq["window_start"]),
            "window_end": int(seq["window_end"]),
            "temporal_indicators": tempIndicators,
        }
    )

# Top TTPs
topTtps = sorted(ttpCounts.items(), key=lambda x: x[1], reverse=True)[:10]
mitreData = [{"ttp_id": str(ttp), "count": int(count)} for ttp, count in topTtps]

# Kill chain stages
killChainData = [{"stage": str(stage), "count": int(count)} for stage, count in stageCounts.items()]

dashboardData = {
    "metadata": {
        "timestamp": pd.Timestamp.now().isoformat(),
        "total_samples": int(len(predictions)),
        "dataset": "NSL-KDD",
        "note": "Metrics computed on valid labels only (unknown attacks mapped to -1 are excluded)",
    },
    "stats": {
        "totalConnections": int(len(predictions)),
        "attacksDetected": int(np.sum(predictions > 0)),
        "attackRate": float(np.mean(predictions > 0) * 100),
        "multiStageAttacks": int(len(attackSequences)),
        "modelAccuracy": float(temporalAcc * 100),
    },
    "performance": {
        "baseline": {
            "accuracy": float(baselineAcc * 100),
            "precision": float(baselinePrec * 100),
            "recall": float(baselineRec * 100),
            "f1score": float(baselineF1 * 100),
        },
        "temporal": {
            "accuracy": float(temporalAcc * 100),
            "precision": float(temporalPrec * 100),
            "recall": float(temporalRec * 100),
            "f1score": float(temporalF1 * 100),
        },
    },
    "attackDistribution": {
        "Normal": int(classCounts.get(0, 0)),
        "DoS": int(classCounts.get(1, 0)),
        "Probe": int(classCounts.get(2, 0)),
        "R2L": int(classCounts.get(3, 0)),
        "U2R": int(classCounts.get(4, 0)),
    },
    "attackTimeline": attackTimeline,
    "timelineData": timelineData,
    "temporalFeatures": topTemporalFeatures,
    "attackSequences": formattedSequences,
    "mitreAttack": mitreData,
    "killChain": killChainData,
}

with open("dashboard_data.json", "w") as f:
    json.dump(dashboardData, f, indent=2, cls=NumpyEncoder)

print()
print("=" * 80)
print("EXPORT COMPLETE")
print("=" * 80)
print("Saved:")
print("  baseline_vs_temporal_comparison.csv")
print("  temporal_feature_importance.csv (if computed)")
print("  dashboard_data.json")
print()
print("Open your HTML dashboard via http.server and load the page in the browser")

MULTI-STAGE ATTACK DETECTION PIPELINE
NSL-KDD with HistGradientBoosting + Corrected Temporal Features

Loading NSL-KDD dataset...
Train samples: 125973, Test samples: 22544

Extracting temporal features...

PHASE 1: TRAIN BASELINE MODEL
Training completed in 5.40s

PHASE 2: TRAIN TEMPORAL MODEL
Training completed in 6.47s

PHASE 3: EVALUATION (TEMPORAL MODEL)
MULTI-CLASS ATTACK DETECTION RESULTS
Overall Accuracy (valid labels): 0.8097
Weighted F1 (valid labels):      0.7880
Macro F1 (valid labels):         0.5989

              precision    recall  f1-score   support

      Normal     0.7384    0.9717    0.8391      9711
         DoS     0.9644    0.8174    0.8848      7458
       Probe     0.7726    0.8406    0.8051      2421
         R2L     0.8514    0.2456    0.3813      2752
         U2R     0.6429    0.0450    0.0841       200

    accuracy                         0.8097     22542
   macro avg     0.7939    0.5841    0.5989     22542
weighted avg     0.8298    0.8097    0.7880   