In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 1: Setup (Colab-aware)
# ═══════════════════════════════════════════════════════════
GIT_REPO_URL = "https://github.com/Arif-Foysal/FAA-Net.git"
REPO_DIR = "FAA-Net"

import os, sys

# Clone repo if on Colab
if not os.path.exists(REPO_DIR):
    !git clone {GIT_REPO_URL}
    !git checkout edt
if os.path.exists(REPO_DIR):
    os.chdir(REPO_DIR)
    print(f"Working directory: {os.getcwd()}")

# Mount Drive for saving artifacts
try:
    from google.colab import drive
    drive.mount('/content/drive')
    SAVE_DIR = '/content/drive/MyDrive/EDANet_Models'
    os.makedirs(SAVE_DIR, exist_ok=True)
    IN_COLAB = True
except ImportError:
    SAVE_DIR = '.'
    IN_COLAB = False
    print("Not in Colab — saving artifacts locally.")

!pip install -q -r requirements.txt
print(f"\nSave directory: {os.path.abspath(SAVE_DIR)}")

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 2: Imports & Configuration
# ═══════════════════════════════════════════════════════════
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import time
import copy
import joblib
from sklearn.metrics import (
    precision_recall_curve, roc_curve, auc,
    confusion_matrix, classification_report,
    accuracy_score, precision_score, recall_score, f1_score,
    roc_auc_score, average_precision_score
)
import xgboost as xgb
import lightgbm as lgb

sys.path.insert(0, os.getcwd())

from core.config import EDA_CONFIG, ABLATION_CONFIGS, RANDOM_STATE
from core.data_loader import load_and_preprocess_data, create_dataloaders, get_data_paths
from core.model import EDANet, MinorityPrototypeGenerator
from core.ablation import (
    VanillaDNN_Ablation, FixedTempNet_Ablation,
    HeuristicEDTNet_Ablation, EDANet_Ablation
)
from core.loss import (
    ImbalanceAwareFocalLoss, ImbalanceAwareFocalLoss_Logits,
    EntropyRegularization, EDANetLoss
)
from core.trainer import train_model
from core.utils import (
    set_all_seeds, evaluate_model, print_metrics,
    save_training_history, save_predictions,
    collect_edt_analysis, save_edt_analysis
)

# Plot style
plt.rcParams.update({
    'font.size': 12, 'font.family': 'serif',
    'axes.labelsize': 13, 'axes.titlesize': 14,
    'legend.fontsize': 11, 'xtick.labelsize': 11, 'ytick.labelsize': 11,
    'figure.dpi': 150, 'savefig.dpi': 300, 'savefig.bbox': 'tight',
})

FIG_DIR = os.path.join(SAVE_DIR, 'paper_figures')
os.makedirs(FIG_DIR, exist_ok=True)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Device: {device}")
print(f"Figures directory: {os.path.abspath(FIG_DIR)}")

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 3: Load & Preprocess Dataset
# ═══════════════════════════════════════════════════════════
data_dir = '/content' if os.path.exists('/content') else '.'
X_train_scaled, X_test_scaled, y_train, y_test, y_train_cat, y_test_cat = \
    load_and_preprocess_data(data_dir=data_dir)

# DataLoaders
train_loader, val_loader, test_loader, X_test_tensor = create_dataloaders(
    X_train_scaled, y_train, X_test_scaled, y_test,
    batch_size=EDA_CONFIG['batch_size']
)

input_dim = X_train_scaled.shape[1]

# Class statistics
minority_mask = y_train.values == 1
X_minority = X_train_scaled[minority_mask]
X_majority = X_train_scaled[~minority_mask]
class_counts = [len(X_majority), len(X_minority)]
pos_weight = torch.tensor([class_counts[0] / class_counts[1]], device=device, dtype=torch.float32)

# Minority prototypes (shared across all attention-based models)
proto_gen = MinorityPrototypeGenerator(
    n_prototypes=EDA_CONFIG['n_prototypes'], random_state=RANDOM_STATE
)
minority_prototypes = proto_gen.fit(X_minority)

print(f"\nInput dim: {input_dim}")
print(f"Train: {len(y_train):,}  |  Test: {len(y_test):,}")
print(f"Minority (attack): {len(X_minority):,}  |  Majority (normal): {len(X_majority):,}")
print(f"Class ratio: 1:{class_counts[0]/class_counts[1]:.2f}")
print(f"Prototypes extracted: {EDA_CONFIG['n_prototypes']}")

---
## Table 1: Dataset Statistics

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 4: Table 1 — Dataset Statistics
# ═══════════════════════════════════════════════════════════
d1 = pd.DataFrame({
    'Split': ['Training', 'Testing', 'Total'],
    'Samples': [len(y_train), len(y_test), len(y_train)+len(y_test)],
    'Attack': [int(y_train.sum()), int(y_test.sum()), int(y_train.sum()+y_test.sum())],
    'Normal': [int(len(y_train)-y_train.sum()), int(len(y_test)-y_test.sum()),
               int((len(y_train)-y_train.sum())+(len(y_test)-y_test.sum()))],
    'Features': [input_dim, input_dim, input_dim],
})
d1['Imbalance Ratio'] = (d1['Normal'] / d1['Attack']).round(2)
print('Table 1: UNSW-NB15 Dataset Statistics')
display(d1)

---
## Figure 2: EDT Mechanism Visualisation (Synthetic)

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 5: Figure 2 — EDT Mechanism (no training needed)
# ═══════════════════════════════════════════════════════════
np.random.seed(42)
n_proto, d_k = 8, 32

logits_easy = np.array([5.0, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1])
logits_hard = np.array([1.2, 1.0, 1.1, 0.9, 1.0, 1.1, 0.8, 1.0])
logits_med  = np.array([3.0, 1.5, 0.5, 0.3, 0.2, 0.1, 0.1, 0.1])

def edt_demo(logits, tau=1.0):
    scale = d_k ** -0.5
    p = np.exp(logits * scale / tau) / np.exp(logits * scale / tau).sum()
    p_base = np.exp(logits * scale) / np.exp(logits * scale).sum()
    entropy = -np.sum(p_base * np.log(p_base + 1e-8))
    entropy_norm = entropy / np.log(len(logits))
    return entropy_norm, p_base, p

fig, axes = plt.subplots(2, 3, figsize=(15, 8))
samples = [
    ('Easy (Clear DoS)', logits_easy, '#2ecc71'),
    ('Medium (Brute Force)', logits_med, '#f39c12'),
    ('Hard (Reconnaissance)', logits_hard, '#e74c3c'),
]

for i, (name, logits, color) in enumerate(samples):
    ent_n, p_fixed, _ = edt_demo(logits, tau=1.0)
    tau_dyn = 0.1 + (5.0 - 0.1) * (1.0 - ent_n)
    _, _, p_edt = edt_demo(logits, tau=tau_dyn)

    axes[0, i].bar(range(n_proto), p_fixed, color='steelblue', alpha=0.8, edgecolor='navy')
    axes[0, i].set_title(f'{name}\nFixed \u03c4=1.0 | H_norm={ent_n:.2f}')
    axes[0, i].set_ylim(0, 1.05); axes[0, i].set_xlabel('Prototype')
    if i == 0: axes[0, i].set_ylabel('Attention Weight')
    axes[0, i].grid(axis='y', alpha=0.3)

    axes[1, i].bar(range(n_proto), p_edt, color=color, alpha=0.8, edgecolor='black')
    axes[1, i].set_title(f'EDT \u03c4={tau_dyn:.2f}')
    axes[1, i].set_ylim(0, 1.05); axes[1, i].set_xlabel('Prototype')
    if i == 0: axes[1, i].set_ylabel('Attention Weight')
    axes[1, i].grid(axis='y', alpha=0.3)

fig.suptitle('Figure 2: EDT Attention Mechanism\n(Top: Fixed Temperature | Bottom: Entropy-Dynamic Temperature)',
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, 'fig2_edt_mechanism.pdf'), format='pdf')
plt.savefig(os.path.join(FIG_DIR, 'fig2_edt_mechanism.png'))
plt.show()

---
## Train Main EDA-Net Model

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 6: Train Main EDA-Net Model
# ═══════════════════════════════════════════════════════════
print("=" * 60)
print("  Training: EDA-Net (Full Model)")
print("=" * 60)

set_all_seeds(RANDOM_STATE)

edanet_model = EDANet(
    input_dim=input_dim,
    num_heads=EDA_CONFIG['num_heads'],
    attention_dim=EDA_CONFIG['attention_dim'],
    n_prototypes=EDA_CONFIG['n_prototypes'],
    hidden_units=EDA_CONFIG['hidden_units'],
    dropout_rate=EDA_CONFIG['dropout_rate'],
    attention_dropout=EDA_CONFIG['attention_dropout'],
    tau_min=EDA_CONFIG['tau_min'],
    tau_max=EDA_CONFIG['tau_max'],
    tau_hidden_dim=EDA_CONFIG['tau_hidden_dim'],
    edt_mode=EDA_CONFIG['edt_mode'],
    normalize_entropy=EDA_CONFIG['normalize_entropy'],
    num_classes=1,
    output_logits=True
).to(device)

edanet_model.edt_attention.initialize_all_prototypes(minority_prototypes, device)
print(f"Parameters: {edanet_model.count_parameters():,}")
print(f"EDT mode: {EDA_CONFIG['edt_mode']} | \u03c4 range: [{EDA_CONFIG['tau_min']}, {EDA_CONFIG['tau_max']}]")

edanet_criterion = EDANetLoss(
    gamma=EDA_CONFIG['focal_gamma'],
    class_counts=class_counts,
    entropy_reg_weight=EDA_CONFIG['entropy_reg_weight']
)

edanet_model, edanet_history = train_model(
    edanet_model, train_loader, val_loader, EDA_CONFIG,
    edanet_criterion, device, use_edt_loss=True
)

# Evaluate
edanet_metrics, edanet_probs, edanet_preds = evaluate_model(
    edanet_model, X_test_tensor, y_test, device
)
print_metrics(edanet_metrics, "\nEDA-Net Test Results")

# Collect EDT analysis
print("\nCollecting per-sample EDT analysis...")
edt_analysis_df = collect_edt_analysis(edanet_model, X_test_tensor, y_test, device)

# Save all artifacts
pd.DataFrame([edanet_metrics]).to_csv(os.path.join(SAVE_DIR, 'edanet_metrics.csv'), index=False)
torch.save(edanet_model.state_dict(), os.path.join(SAVE_DIR, 'edanet_main.pt'))
save_training_history(edanet_history, os.path.join(SAVE_DIR, 'edanet_history.csv'))
save_predictions(y_test, edanet_probs, os.path.join(SAVE_DIR, 'edanet_predictions.npz'))
save_edt_analysis(edt_analysis_df, os.path.join(SAVE_DIR, 'edanet_edt_analysis.csv'))
print(f"\n\u2713 All EDA-Net artifacts saved to {SAVE_DIR}")

---
## Run Full Ablation Study (8 Experiments)

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 7: Ablation Study — Helper
# ═══════════════════════════════════════════════════════════

def run_ablation_experiment(name, model, config, criterion, use_edt_loss=False):
    """Run one ablation experiment: train, evaluate, save artifacts, return metrics."""
    print(f"\n{'='*60}")
    print(f"  Experiment: {name}")
    print(f"  Parameters: {model.count_parameters():,}")
    print(f"{'='*60}")

    model, history = train_model(
        model, train_loader, val_loader, config, criterion, device,
        use_edt_loss=use_edt_loss
    )
    metrics, y_probs, y_pred = evaluate_model(model, X_test_tensor, y_test, device)
    print_metrics(metrics, f"{name} Results")

    safe_name = name.replace(' ', '_').replace('(', '').replace(')', '').replace('\u03c4', 'tau').lower()

    # Save model, predictions, history
    torch.save(model.state_dict(), os.path.join(SAVE_DIR, f"{safe_name}.pt"))
    save_predictions(y_test, y_probs, os.path.join(SAVE_DIR, f"{safe_name}_predictions.npz"))
    save_training_history(history, os.path.join(SAVE_DIR, f"{safe_name}_history.csv"))

    # Save EDT analysis if applicable
    if hasattr(model, 'last_edt_info') and model.last_edt_info is not None:
        try:
            adf = collect_edt_analysis(model, X_test_tensor, y_test, device)
            save_edt_analysis(adf, os.path.join(SAVE_DIR, f"{safe_name}_edt_analysis.csv"))
        except:
            pass

    return metrics, history


def make_edanet_ablation(config):
    """Create an EDANet ablation variant from config dict."""
    m = EDANet_Ablation(
        input_dim=input_dim,
        num_heads=config['num_heads'],
        attention_dim=config['attention_dim'],
        n_prototypes=config['n_prototypes'],
        tau_min=config['tau_min'],
        tau_max=config['tau_max'],
        tau_hidden_dim=config['tau_hidden_dim'],
        edt_mode=config['edt_mode'],
        normalize_entropy=config.get('normalize_entropy', True),
        hidden_units=config['hidden_units'],
        dropout_rate=config['dropout_rate'],
        attention_dropout=config['attention_dropout'],
    ).to(device)
    m.edt_attention.initialize_all_prototypes(minority_prototypes, device)
    return m

print("\u2713 Ablation helpers defined.")

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 8: Ablation Experiments 1-4 (Baselines → Heuristic)
# ═══════════════════════════════════════════════════════════
ablation_results = {}
ablation_histories = {}

# --- Exp 1: Vanilla DNN + BCE ---
set_all_seeds(RANDOM_STATE)
m1 = VanillaDNN_Ablation(input_dim=input_dim).to(device)
c1 = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
ablation_results['Vanilla DNN + BCE'], ablation_histories['Vanilla DNN + BCE'] = \
    run_ablation_experiment('Vanilla DNN + BCE', m1, EDA_CONFIG, c1)

# --- Exp 2: Vanilla DNN + Focal ---
set_all_seeds(RANDOM_STATE)
m2 = VanillaDNN_Ablation(input_dim=input_dim).to(device)
c2 = ImbalanceAwareFocalLoss_Logits(class_counts=class_counts, gamma=2.0)
ablation_results['Vanilla DNN + Focal'], ablation_histories['Vanilla DNN + Focal'] = \
    run_ablation_experiment('Vanilla DNN + Focal', m2, EDA_CONFIG, c2)

# --- Exp 3: Fixed-Temp Attention + Focal ---
set_all_seeds(RANDOM_STATE)
m3 = FixedTempNet_Ablation(
    input_dim=input_dim, num_heads=EDA_CONFIG['num_heads'],
    attention_dim=EDA_CONFIG['attention_dim'], n_prototypes=EDA_CONFIG['n_prototypes']
).to(device)
m3.edt_attention.initialize_all_prototypes(minority_prototypes, device)
c3 = ImbalanceAwareFocalLoss_Logits(class_counts=class_counts, gamma=2.0)
ablation_results['Fixed-Temp Attn + Focal'], ablation_histories['Fixed-Temp Attn + Focal'] = \
    run_ablation_experiment('Fixed-Temp Attn + Focal', m3, EDA_CONFIG, c3)

# --- Exp 4: Heuristic EDT + Focal ---
set_all_seeds(RANDOM_STATE)
m4 = HeuristicEDTNet_Ablation(
    input_dim=input_dim, num_heads=EDA_CONFIG['num_heads'],
    attention_dim=EDA_CONFIG['attention_dim'], n_prototypes=EDA_CONFIG['n_prototypes']
).to(device)
m4.edt_attention.initialize_all_prototypes(minority_prototypes, device)
c4 = ImbalanceAwareFocalLoss_Logits(class_counts=class_counts, gamma=2.0)
ablation_results['Heuristic EDT + Focal'], ablation_histories['Heuristic EDT + Focal'] = \
    run_ablation_experiment('Heuristic EDT + Focal', m4, EDA_CONFIG, c4)

print(f"\n\u2713 Experiments 1-4 complete.")

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 9: Ablation Experiments 5-8 (Full EDA-Net + Sensitivity)
# ═══════════════════════════════════════════════════════════

# --- Exp 5: Full EDA-Net (Learned EDT + Focal + Entropy Reg) ---
set_all_seeds(RANDOM_STATE)
m5 = make_edanet_ablation(EDA_CONFIG)
c5 = EDANetLoss(gamma=EDA_CONFIG['focal_gamma'], class_counts=class_counts,
                entropy_reg_weight=EDA_CONFIG['entropy_reg_weight'])
ablation_results['EDA-Net (Full)'], ablation_histories['EDA-Net (Full)'] = \
    run_ablation_experiment('EDA-Net (Full)', m5, EDA_CONFIG, c5, use_edt_loss=True)

# --- Exp 6: Narrow \u03c4 [0.5, 2.0] ---
set_all_seeds(RANDOM_STATE)
cfg_narrow = ABLATION_CONFIGS['narrow_tau']
m6 = make_edanet_ablation(cfg_narrow)
c6 = EDANetLoss(gamma=cfg_narrow['focal_gamma'], class_counts=class_counts,
                entropy_reg_weight=cfg_narrow['entropy_reg_weight'])
ablation_results['EDA-Net (Narrow \u03c4)'], ablation_histories['EDA-Net (Narrow \u03c4)'] = \
    run_ablation_experiment('EDA-Net (Narrow \u03c4)', m6, cfg_narrow, c6, use_edt_loss=True)

# --- Exp 7: Wide \u03c4 [0.01, 10.0] ---
set_all_seeds(RANDOM_STATE)
cfg_wide = ABLATION_CONFIGS['wide_tau']
m7 = make_edanet_ablation(cfg_wide)
c7 = EDANetLoss(gamma=cfg_wide['focal_gamma'], class_counts=class_counts,
                entropy_reg_weight=cfg_wide['entropy_reg_weight'])
ablation_results['EDA-Net (Wide \u03c4)'], ablation_histories['EDA-Net (Wide \u03c4)'] = \
    run_ablation_experiment('EDA-Net (Wide \u03c4)', m7, cfg_wide, c7, use_edt_loss=True)

# --- Exp 8: No Entropy Normalisation ---
set_all_seeds(RANDOM_STATE)
cfg_nonorm = ABLATION_CONFIGS['no_entropy_norm']
m8 = make_edanet_ablation(cfg_nonorm)
c8 = EDANetLoss(gamma=cfg_nonorm['focal_gamma'], class_counts=class_counts,
                entropy_reg_weight=cfg_nonorm['entropy_reg_weight'])
ablation_results['EDA-Net (No Norm)'], ablation_histories['EDA-Net (No Norm)'] = \
    run_ablation_experiment('EDA-Net (No Norm)', m8, cfg_nonorm, c8, use_edt_loss=True)

# --- Save ablation summary ---
abl_df = pd.DataFrame(ablation_results).T
abl_df.to_csv(os.path.join(SAVE_DIR, 'ablation_summary.csv'))
print("\n" + "=" * 80)
print("ABLATION STUDY SUMMARY")
print("=" * 80)
display(abl_df)
print(f"\n\u2713 All 8 ablation experiments complete. Summary saved.")

---
## Train ML Baselines (XGBoost & LightGBM)

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 10: Train XGBoost & LightGBM Baselines
# ═══════════════════════════════════════════════════════════
os.environ['OMP_NUM_THREADS'] = '4'
set_all_seeds(RANDOM_STATE)
baseline_results = {}

def eval_sklearn(model, X_test, y_test):
    yp = model.predict_proba(X_test)[:, 1]
    ypd = (yp > 0.5).astype(int)
    return {
        'Accuracy': accuracy_score(y_test, ypd),
        'Precision': precision_score(y_test, ypd, zero_division=0),
        'Recall': recall_score(y_test, ypd, zero_division=0),
        'F1-Score': f1_score(y_test, ypd, zero_division=0),
        'AUC-ROC': roc_auc_score(y_test, yp),
        'Avg Precision': average_precision_score(y_test, yp)
    }, yp

# --- XGBoost ---
print("\n--- Training XGBoost ---")
xgb_model = xgb.XGBClassifier(
    n_estimators=100, max_depth=6, learning_rate=0.1,
    subsample=0.8, colsample_bytree=0.8,
    scale_pos_weight=class_counts[0] / class_counts[1],
    random_state=RANDOM_STATE, n_jobs=-1, eval_metric='logloss'
)
xgb_model.fit(X_train_scaled, y_train)
baseline_results['XGBoost'], xgb_probs = eval_sklearn(xgb_model, X_test_scaled, y_test)
print_metrics(baseline_results['XGBoost'], 'XGBoost Results')
joblib.dump(xgb_model, os.path.join(SAVE_DIR, 'xgboost_baseline.joblib'))
save_predictions(y_test, xgb_probs, os.path.join(SAVE_DIR, 'xgboost_predictions.npz'))

# --- LightGBM ---
print("\n--- Training LightGBM ---")
lgb_model = lgb.LGBMClassifier(
    n_estimators=100, num_leaves=31, learning_rate=0.1,
    class_weight='balanced', random_state=RANDOM_STATE, n_jobs=-1, verbose=-1
)
lgb_model.fit(X_train_scaled, y_train)
baseline_results['LightGBM'], lgb_probs = eval_sklearn(lgb_model, X_test_scaled, y_test)
print_metrics(baseline_results['LightGBM'], 'LightGBM Results')
joblib.dump(lgb_model, os.path.join(SAVE_DIR, 'lightgbm_baseline.joblib'))
save_predictions(y_test, lgb_probs, os.path.join(SAVE_DIR, 'lightgbm_predictions.npz'))

# Summary
bl_df = pd.DataFrame(baseline_results).T
bl_df.to_csv(os.path.join(SAVE_DIR, 'baseline_summary.csv'))
print("\n=== Baseline Summary ===")
display(bl_df)
print(f"\n\u2713 Baselines trained and saved.")

---
# Paper Figures (from trained artifacts)

All models are now trained. The cells below generate every figure and table.

## Figure 3: Training Convergence

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 11: Figure 3 — Training Convergence
# ═══════════════════════════════════════════════════════════
histories = {
    'EDA-Net': pd.DataFrame(edanet_history),
}
# Load ablation histories that we saved
for name, fname in [('Vanilla DNN', 'vanilla_dnn_+_focal_history.csv'),
                     ('Fixed-Temp', 'fixed-temp_attn_+_focal_history.csv')]:
    p = os.path.join(SAVE_DIR, fname)
    if os.path.exists(p):
        histories[name] = pd.read_csv(p)
    elif name == 'Vanilla DNN' and 'Vanilla DNN + Focal' in ablation_histories:
        histories[name] = pd.DataFrame(ablation_histories['Vanilla DNN + Focal'])
    elif name == 'Fixed-Temp' and 'Fixed-Temp Attn + Focal' in ablation_histories:
        histories[name] = pd.DataFrame(ablation_histories['Fixed-Temp Attn + Focal'])

fig, axes = plt.subplots(1, 3, figsize=(18, 5))
colors = {'EDA-Net': '#e74c3c', 'Vanilla DNN': '#95a5a6', 'Fixed-Temp': '#3498db'}

for n, hist in histories.items():
    ep = range(1, len(hist)+1)
    c = colors.get(n, 'gray')
    lw = 2.5 if n == 'EDA-Net' else 1.5
    axes[0].plot(ep, hist['train_loss'], label=f'{n} train', color=c, lw=lw)
    axes[0].plot(ep, hist['val_loss'], label=f'{n} val', color=c, ls='--', alpha=0.7, lw=lw*0.8)
    axes[1].plot(ep, hist['train_f1'], label=f'{n} train', color=c, lw=lw)
    axes[1].plot(ep, hist['val_f1'], label=f'{n} val', color=c, ls='--', alpha=0.7, lw=lw*0.8)
    axes[2].plot(ep, hist['train_recall'], label=f'{n} train', color=c, lw=lw)
    axes[2].plot(ep, hist['val_recall'], label=f'{n} val', color=c, ls='--', alpha=0.7, lw=lw*0.8)

for ax, t, yl in zip(axes, ['Loss', 'F1-Score', 'Recall'], ['Loss', 'F1', 'Recall']):
    ax.set_title(t); ax.set_xlabel('Epoch'); ax.set_ylabel(yl)
    ax.legend(fontsize=9); ax.grid(True, alpha=0.3)

fig.suptitle('Figure 3: Training Convergence', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, 'fig3_convergence.pdf'), format='pdf')
plt.savefig(os.path.join(FIG_DIR, 'fig3_convergence.png'))
plt.show()

## Figure 4: Temperature Evolution During Training

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 12: Figure 4 — Temperature & Entropy Evolution
# ═══════════════════════════════════════════════════════════
hist_eda = pd.DataFrame(edanet_history)
ep = range(1, len(hist_eda)+1)

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 5))

# Temperature
ax1.plot(ep, hist_eda['mean_tau'], color='#e74c3c', lw=2, label='Mean \u03c4')
if 'tau_std' in hist_eda.columns:
    t = np.array(hist_eda['mean_tau'])
    s = np.array(hist_eda['tau_std'])
    ax1.fill_between(ep, t-s, t+s, alpha=0.2, color='#e74c3c', label='\u00b11 std')
ax1.axhline(EDA_CONFIG['tau_min'], color='gray', ls=':', alpha=0.5, label=f"\u03c4_min={EDA_CONFIG['tau_min']}")
ax1.axhline(EDA_CONFIG['tau_max'], color='gray', ls=':', alpha=0.5, label=f"\u03c4_max={EDA_CONFIG['tau_max']}")
ax1.set_xlabel('Epoch'); ax1.set_ylabel('Temperature')
ax1.set_title('Mean Temperature per Epoch')
ax1.legend(); ax1.grid(True, alpha=0.3)

# Entropy
if 'mean_entropy' in hist_eda.columns:
    ax2.plot(ep, hist_eda['mean_entropy'], color='#3498db', lw=2)
    ax2.set_xlabel('Epoch'); ax2.set_ylabel('Normalised Entropy')
    ax2.set_title('Mean Attention Entropy per Epoch')
    ax2.set_ylim(0, 1); ax2.grid(True, alpha=0.3)

fig.suptitle('Figure 4: EDT Dynamics During Training', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, 'fig4_tau_evolution.pdf'), format='pdf')
plt.savefig(os.path.join(FIG_DIR, 'fig4_tau_evolution.png'))
plt.show()

## Figures 5–6: ROC and PR Curves

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 13: Figures 5–6 — ROC and PR Curves
# ═══════════════════════════════════════════════════════════
# Gather all prediction arrays (from memory + saved files)
all_preds = {}

# EDA-Net (from memory)
all_preds['EDA-Net'] = (y_test.values if hasattr(y_test, 'values') else y_test, edanet_probs)

# Baselines (from memory)
all_preds['XGBoost'] = (y_test.values if hasattr(y_test, 'values') else y_test, xgb_probs)
all_preds['LightGBM'] = (y_test.values if hasattr(y_test, 'values') else y_test, lgb_probs)

# Ablation models (from saved .npz files)
for label, fname in [
    ('Vanilla DNN', 'vanilla_dnn_+_focal_predictions.npz'),
    ('Fixed-Temp', 'fixed-temp_attn_+_focal_predictions.npz'),
    ('Heuristic EDT', 'heuristic_edt_+_focal_predictions.npz'),
]:
    p = os.path.join(SAVE_DIR, fname)
    if os.path.exists(p):
        d = np.load(p)
        all_preds[label] = (d['y_true'], d['y_probs'])

clr = {'EDA-Net': '#e74c3c', 'Fixed-Temp': '#3498db', 'Vanilla DNN': '#95a5a6',
       'Heuristic EDT': '#f39c12', 'XGBoost': '#2ecc71', 'LightGBM': '#9b59b6'}

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(14, 6))

for name, (yt, yp) in all_preds.items():
    c = clr.get(name, 'gray')
    lw = 2.5 if 'EDA' in name and 'Net' in name else 1.5
    fpr, tpr, _ = roc_curve(yt, yp)
    ra = auc(fpr, tpr)
    ax1.plot(fpr, tpr, label=f'{name} (AUC={ra:.4f})', color=c, lw=lw)
    prec, rec, _ = precision_recall_curve(yt, yp)
    pa = auc(rec, prec)
    ax2.plot(rec, prec, label=f'{name} (AP={pa:.4f})', color=c, lw=lw)

ax1.plot([0,1],[0,1],'k--',alpha=0.3)
ax1.set_xlabel('FPR'); ax1.set_ylabel('TPR'); ax1.set_title('Figure 5: ROC Curves')
ax1.legend(loc='lower right'); ax1.grid(True, alpha=0.3)

ax2.set_xlabel('Recall'); ax2.set_ylabel('Precision'); ax2.set_title('Figure 6: Precision-Recall Curves')
ax2.legend(loc='lower left'); ax2.grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, 'fig5_6_roc_pr.pdf'), format='pdf')
plt.savefig(os.path.join(FIG_DIR, 'fig5_6_roc_pr.png'))
plt.show()

## Figures 7–8: EDT Per-Sample Analysis

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 14: Figures 7–8 — EDT Per-Sample Analysis
# ═══════════════════════════════════════════════════════════
edt_df = edt_analysis_df.copy()
edt_df['class'] = edt_df['y_true'].map({0: 'Normal', 1: 'Attack'})

fig, axes = plt.subplots(1, 3, figsize=(18, 5))

# (a) Temperature distribution by class
if 'tau' in edt_df.columns:
    for cls, c in [('Normal','#3498db'),('Attack','#e74c3c')]:
        sub = edt_df[edt_df['class']==cls]
        axes[0].hist(sub['tau'], bins=50, alpha=0.6, color=c,
                     label=f'{cls} (\u03bc={sub["tau"].mean():.2f})', edgecolor='white')
    axes[0].set_xlabel('Temperature (\u03c4)'); axes[0].set_ylabel('Count')
    axes[0].set_title('(a) Temperature by Class'); axes[0].legend(); axes[0].grid(True, alpha=0.3)

# (b) Entropy distribution by class
if 'entropy' in edt_df.columns:
    for cls, c in [('Normal','#3498db'),('Attack','#e74c3c')]:
        sub = edt_df[edt_df['class']==cls]
        axes[1].hist(sub['entropy'], bins=50, alpha=0.6, color=c,
                     label=f'{cls} (\u03bc={sub["entropy"].mean():.2f})', edgecolor='white')
    axes[1].set_xlabel('Normalised Entropy'); axes[1].set_ylabel('Count')
    axes[1].set_title('(b) Entropy by Class'); axes[1].legend(); axes[1].grid(True, alpha=0.3)

# (c) Entropy vs Temperature scatter
if 'entropy' in edt_df.columns and 'tau' in edt_df.columns:
    samp = edt_df.sample(min(5000, len(edt_df)), random_state=42)
    sc = axes[2].scatter(samp['entropy'], samp['tau'], c=samp['y_true'],
                         cmap='coolwarm', alpha=0.3, s=10, edgecolors='none')
    axes[2].set_xlabel('Normalised Entropy'); axes[2].set_ylabel('Temperature (\u03c4)')
    axes[2].set_title('(c) Entropy vs Temperature')
    cb = plt.colorbar(sc, ax=axes[2], ticks=[0,1])
    cb.set_ticklabels(['Normal','Attack'])
    axes[2].grid(True, alpha=0.3)

fig.suptitle('Figures 7\u20138: EDT Per-Sample Analysis', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, 'fig7_8_edt_analysis.pdf'), format='pdf')
plt.savefig(os.path.join(FIG_DIR, 'fig7_8_edt_analysis.png'))
plt.show()

print('\nEDT Statistics by Class:')
cols = [c for c in ['entropy','tau'] if c in edt_df.columns]
display(edt_df.groupby('class')[cols].describe().round(3))

## Figure 9: Confusion Matrix

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 15: Figure 9 — Confusion Matrix
# ═══════════════════════════════════════════════════════════
yt_arr = y_test.values if hasattr(y_test, 'values') else y_test
yp_bin = (edanet_probs > 0.5).astype(int)
cm = confusion_matrix(yt_arr, yp_bin)

fig, ax = plt.subplots(figsize=(6, 5))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
            xticklabels=['Normal','Attack'], yticklabels=['Normal','Attack'], ax=ax)
ax.set_xlabel('Predicted'); ax.set_ylabel('Actual')
ax.set_title('Figure 9: EDA-Net Confusion Matrix')
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, 'fig9_confusion_matrix.pdf'), format='pdf')
plt.savefig(os.path.join(FIG_DIR, 'fig9_confusion_matrix.png'))
plt.show()

print(classification_report(yt_arr, yp_bin, target_names=['Normal','Attack']))

## Figure 10: Per-Attack Detection Rates + Mean \u03c4

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 16: Figure 10 — Per-Attack Detection & Mean \u03c4
# ═══════════════════════════════════════════════════════════
_, test_path = get_data_paths(data_dir=data_dir)
df_raw = pd.read_csv(test_path)

yt_arr = y_test.values if hasattr(y_test, 'values') else np.array(y_test)

analysis = pd.DataFrame({
    'True': yt_arr,
    'Prob': edanet_probs,
    'Pred': (edanet_probs > 0.5).astype(int),
    'Category': df_raw['attack_cat'].fillna('Normal').values[:len(yt_arr)]
})
analysis['Category'] = analysis['Category'].replace({'Backdoors': 'Backdoor'})

# Merge EDT analysis
if 'tau' in edt_analysis_df.columns and len(edt_analysis_df) == len(analysis):
    analysis['tau'] = edt_analysis_df['tau'].values

# Per-attack metrics
mlist = []
for cat in analysis['Category'].unique():
    s = analysis[analysis['Category'] == cat]
    m = {
        'Attack': cat,
        'Samples': len(s),
        'Detection Rate': s['Pred'].sum() / max(len(s), 1),
        'Type': 'Minority' if len(s) < 5000 else 'Majority'
    }
    if 'tau' in s.columns:
        m['Mean \u03c4'] = s['tau'].mean()
    mlist.append(m)

adf = pd.DataFrame(mlist).sort_values('Samples', ascending=False)
print('Table: Per-Attack Metrics')
display(adf)

fig, axes = plt.subplots(1, 2, figsize=(16, 6))

# (a) Detection rate
asrt = adf.sort_values('Detection Rate')
clrs = ['#e74c3c' if t == 'Minority' else '#3498db' for t in asrt['Type']]
axes[0].barh(asrt['Attack'], asrt['Detection Rate'], color=clrs, edgecolor='white')
axes[0].set_xlabel('Detection Rate')
axes[0].set_title('(a) Per-Attack Detection Rate')
axes[0].set_xlim(0, 1.05)
axes[0].grid(axis='x', alpha=0.3)
for i, (_, row) in enumerate(asrt.iterrows()):
    axes[0].text(row['Detection Rate'] + 0.01, i, f"{row['Detection Rate']:.3f}", va='center', fontsize=9)

# (b) Mean \u03c4
if 'Mean \u03c4' in adf.columns:
    tsrt = adf.sort_values('Mean \u03c4')
    axes[1].barh(tsrt['Attack'], tsrt['Mean \u03c4'], color='#f39c12', edgecolor='white')
    axes[1].set_xlabel('Mean \u03c4')
    axes[1].set_title('(b) Mean EDT Temperature by Attack')
    axes[1].grid(axis='x', alpha=0.3)
    for i, (_, row) in enumerate(tsrt.iterrows()):
        axes[1].text(row['Mean \u03c4'] + 0.01, i, f"{row['Mean \u03c4']:.3f}", va='center', fontsize=9)

fig.suptitle('Figure 10: Per-Attack Analysis', fontsize=14, fontweight='bold')
plt.tight_layout()
plt.savefig(os.path.join(FIG_DIR, 'fig10_per_attack.pdf'), format='pdf')
plt.savefig(os.path.join(FIG_DIR, 'fig10_per_attack.png'))
plt.show()

## Tables 2–4 & Figure 11: Ablation Results

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 17: Tables 2–4 & Figure 11 — Ablation Summary
# ═══════════════════════════════════════════════════════════
abl = pd.DataFrame(ablation_results).T

print('Table 2: Main Comparison (Attention Impact)')
display(abl.loc[abl.index.isin(['Vanilla DNN + BCE', 'Vanilla DNN + Focal',
                                 'Fixed-Temp Attn + Focal', 'EDA-Net (Full)'])])

print('\nTable 3: EDT Component Ablation')
display(abl.loc[abl.index.isin(['Fixed-Temp Attn + Focal', 'Heuristic EDT + Focal', 'EDA-Net (Full)'])])

print('\nTable 4: Sensitivity Analysis (\u03c4 range & normalisation)')
display(abl.loc[abl.index.isin(['EDA-Net (Full)', 'EDA-Net (Narrow \u03c4)',
                                 'EDA-Net (Wide \u03c4)', 'EDA-Net (No Norm)'])])

# Figure 11: F1 bar chart
if 'F1-Score' in abl.columns:
    fig, ax = plt.subplots(figsize=(12, 6))
    sa = abl.sort_values('F1-Score')
    colors = [
        '#e74c3c' if 'Full' in idx else
        '#f39c12' if 'EDA' in idx else
        '#3498db' if any(x in idx for x in ['Fixed', 'Heuristic']) else
        '#95a5a6'
        for idx in sa.index
    ]
    bars = ax.barh(range(len(sa)), sa['F1-Score'], color=colors, edgecolor='white', height=0.6)
    ax.set_yticks(range(len(sa)))
    ax.set_yticklabels(sa.index)
    ax.set_xlabel('F1-Score')
    ax.set_title('Figure 11: Ablation Study — F1-Score Comparison')
    ax.set_xlim(max(0, sa['F1-Score'].min() - 0.05), 1.0)
    ax.grid(axis='x', alpha=0.3)
    for bar, val in zip(bars, sa['F1-Score']):
        ax.text(bar.get_width() + 0.005, bar.get_y() + bar.get_height()/2,
                f'{val:.4f}', va='center', fontsize=10)
    plt.tight_layout()
    plt.savefig(os.path.join(FIG_DIR, 'fig11_ablation.pdf'), format='pdf')
    plt.savefig(os.path.join(FIG_DIR, 'fig11_ablation.png'))
    plt.show()

## Table 5: Model Complexity & Inference Speed

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 18: Table 5 — Model Complexity
# ═══════════════════════════════════════════════════════════
models_bench = {
    'Vanilla DNN': VanillaDNN_Ablation(input_dim=input_dim),
    'Fixed-Temp Attn': FixedTempNet_Ablation(
        input_dim=input_dim, num_heads=EDA_CONFIG['num_heads'],
        attention_dim=EDA_CONFIG['attention_dim'], n_prototypes=EDA_CONFIG['n_prototypes']
    ),
    'EDA-Net (Full)': EDANet(
        input_dim=input_dim, num_heads=EDA_CONFIG['num_heads'],
        attention_dim=EDA_CONFIG['attention_dim'], n_prototypes=EDA_CONFIG['n_prototypes'],
        tau_min=EDA_CONFIG['tau_min'], tau_max=EDA_CONFIG['tau_max'],
        tau_hidden_dim=EDA_CONFIG['tau_hidden_dim']
    ),
}

x_bench = torch.randn(256, input_dim)
rows = []
for name, mdl in models_bench.items():
    mdl.eval()
    with torch.no_grad():
        for _ in range(5): _ = mdl(x_bench)  # warmup
        t0 = time.time()
        for _ in range(100): _ = mdl(x_bench)
        t1 = time.time()
    ms = (t1-t0)/100*1000
    rows.append({
        'Model': name,
        'Parameters': f'{mdl.count_parameters():,}',
        'ms/batch (256)': f'{ms:.2f}',
        'Samples/sec': f'{256/((t1-t0)/100):,.0f}'
    })

print('\nTable 5: Model Complexity & Inference Speed')
display(pd.DataFrame(rows))

## Final Summary

In [None]:
# ═══════════════════════════════════════════════════════════
# Cell 19: Final Summary
# ═══════════════════════════════════════════════════════════
print("=" * 60)
print("  \u2713 ALL TRAINING AND ARTIFACT GENERATION COMPLETE")
print("=" * 60)

print(f"\nSave directory: {os.path.abspath(SAVE_DIR)}")
print(f"Figures directory: {os.path.abspath(FIG_DIR)}")

# List saved artifacts
print("\nSaved artifacts:")
for f in sorted(os.listdir(SAVE_DIR)):
    fpath = os.path.join(SAVE_DIR, f)
    if os.path.isfile(fpath):
        size = os.path.getsize(fpath)
        unit = 'KB' if size < 1e6 else 'MB'
        sz = size/1024 if size < 1e6 else size/1e6
        print(f"  {f:<50} {sz:>8.1f} {unit}")

print("\nFigures:")
if os.path.exists(FIG_DIR):
    for f in sorted(os.listdir(FIG_DIR)):
        print(f"  {f}")

print("\n\u2713 EDA-Net Main Model Metrics:")
for k, v in edanet_metrics.items():
    print(f"  {k:<15}: {v:.4f}")

print("\n\u2713 Ablation Summary:")
display(abl_df)