In [None]:


import os
import sys
import logging
import warnings
import random
import time
from pathlib import Path
from typing import Dict, List, Any, Optional

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader, TensorDataset
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import (
    roc_auc_score,
    precision_score,
    recall_score,
    f1_score,
)
import optuna
import matplotlib.pyplot as plt
import seaborn as sns

# Use vector backend for high-quality PDFs
plt.rcParams.update({
    "font.size": 12,
    "axes.titlesize": 14,
    "axes.labelsize": 12,
    "legend.fontsize": 11,
    "figure.titlesize": 14,
    "text.usetex": False, 
    "font.family": "serif",
    "pdf.fonttype": 42,
    "ps.fonttype": 42,
})

# ─────────────────────────────────────────────────────────────────────
# Configuration & Global Constants
# ─────────────────────────────────────────────────────────────────────

CONFIG = {
    # Reproducibility
    "SEEDS": [42, 1337, 2024],
    
    # Hardware
    "DEVICE": "cuda" if torch.cuda.is_available() else "cpu",
    
    # Data
    "DATA_DIR": Path("research_data"),  
    "TASK_DEFINITIONS": [
        ["dos"],          
        ["probe"],       
        ["r2l", "u2r"],   
    ],
    "MAX_SAMPLES_PER_TASK": 10_000,
    
    # Training
    "EPOCHS": 5,
    "BATCH_SIZE": 256,
    
    # Hyperparameter tuning
    "TUNING_TRIALS": 10,
    "TUNING_EPOCHS": 3,
}

# Logging setup
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s | %(levelname)-8s | %(message)s",
    datefmt="%Y-%m-%d %H:%M:%S",
)
logger = logging.getLogger(__name__)
optuna.logging.set_verbosity(optuna.logging.WARNING)
warnings.filterwarnings("ignore", category=UserWarning)


def set_seed(seed: int) -> None:
    """Ensure full reproducibility across runs."""
    torch.manual_seed(seed)
    np.random.seed(seed)
    random.seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False


# ─────────────────────────────────────────────────────────────────────
# Model Definitions
# ─────────────────────────────────────────────────────────────────────

class TADR_VAE(nn.Module):
    """Task-Aware Dynamic Reconstruction Variational Autoencoder."""
    
    def __init__(self, input_dim: int, num_tasks: int, hparams: Dict[str, Any]):
        super().__init__()
        self.input_dim = input_dim
        self.num_tasks = num_tasks
        self.hparams = hparams

        self.task_embeddings = nn.Embedding(num_tasks, hparams["task_embedding_dim"])
        self.task_emb_to_hidden = nn.Linear(hparams["task_embedding_dim"], hparams["hidden_dim"])
        self.temporal_gating = nn.GRU(
            input_size=input_dim,
            hidden_size=hparams["hidden_dim"],
            batch_first=True
        )
        self.gate_generator = nn.Sequential(
            nn.Linear(input_dim + hparams["hidden_dim"], input_dim),
            nn.Sigmoid()
        )
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hparams["hidden_dim"]),
            nn.LayerNorm(hparams["hidden_dim"]),
            nn.GELU(),
            nn.Linear(hparams["hidden_dim"], hparams["hidden_dim"] // 2),
        )
        self.latent_mu = nn.Linear(hparams["hidden_dim"] // 2, hparams["latent_dim"])
        self.latent_logvar = nn.Linear(hparams["hidden_dim"] // 2, hparams["latent_dim"])
        self.decoder = nn.Sequential(
            nn.Linear(hparams["latent_dim"], hparams["hidden_dim"] // 2),
            nn.LayerNorm(hparams["hidden_dim"] // 2),
            nn.GELU(),
            nn.Linear(hparams["hidden_dim"] // 2, input_dim),
        )

    def forward(self, x: torch.Tensor, task_id: torch.Tensor):
        # task_id: [B] — same batch size as x
        task_emb = self.task_embeddings(task_id)  # [B, emb_dim]
        h0 = self.task_emb_to_hidden(task_emb).unsqueeze(0)  # [1, B, hidden_dim]
        gru_out, _ = self.temporal_gating(x.unsqueeze(1), h0)  # x: [B, 1, input_dim]
        gating_weights = self.gate_generator(torch.cat([x, gru_out.squeeze(1)], dim=-1))
        gated_input = x * gating_weights
        encoded = self.encoder(gated_input)
        mu, logvar = self.latent_mu(encoded), self.latent_logvar(encoded)
        z = self._reparameterize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar, gating_weights

    def _reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def compute_loss(
        self,
        x: torch.Tensor,
        recon: torch.Tensor,
        mu: torch.Tensor,
        logvar: torch.Tensor,
        gate: torch.Tensor,
        kl_weight: float,
    ) -> torch.Tensor:
        recon_loss = F.mse_loss(recon * gate, x * gate, reduction="mean")
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        return recon_loss + kl_weight * kl_loss


class VanillaVAE(nn.Module):
    """Baseline Variational Autoencoder (no task awareness)."""
    
    def __init__(self, input_dim: int, hparams: Dict[str, Any]):
        super().__init__()
        self.encoder = nn.Sequential(
            nn.Linear(input_dim, hparams["hidden_dim"]),
            nn.ReLU(),
            nn.Linear(hparams["hidden_dim"], hparams["hidden_dim"] // 2),
        )
        self.fc_mu = nn.Linear(hparams["hidden_dim"] // 2, hparams["latent_dim"])
        self.fc_logvar = nn.Linear(hparams["hidden_dim"] // 2, hparams["latent_dim"])
        self.decoder = nn.Sequential(
            nn.Linear(hparams["latent_dim"], hparams["hidden_dim"] // 2),
            nn.ReLU(),
            nn.Linear(hparams["hidden_dim"] // 2, input_dim),
        )

    def forward(self, x: torch.Tensor, task_id: Optional[torch.Tensor] = None):
        h = self.encoder(x)
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        z = self._reparameterize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar, torch.ones_like(x)

    def _reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor:
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def compute_loss(
        self,
        x: torch.Tensor,
        recon: torch.Tensor,
        mu: torch.Tensor,
        logvar: torch.Tensor,
        gate: torch.Tensor,
        kl_weight: float,
    ) -> torch.Tensor:
        recon_loss = F.mse_loss(recon, x, reduction="mean")
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        return recon_loss + kl_weight * kl_loss


class EWC(nn.Module):
    """Elastic Weight Consolidation wrapper for continual learning."""
    
    def __init__(self, model: nn.Module, ewc_lambda: float):
        super().__init__()
        self.model = model
        self.ewc_lambda = ewc_lambda
        self.tasks: Dict[int, Dict[str, Dict[str, torch.Tensor]]] = {}

    def forward(self, x: torch.Tensor, task_id: Optional[torch.Tensor] = None):
        return self.model(x, task_id)

    def compute_loss(
        self,
        x: torch.Tensor,
        recon: torch.Tensor,
        mu: torch.Tensor,
        logvar: torch.Tensor,
        gate: torch.Tensor,
        kl_weight: float,
    ) -> torch.Tensor:
        base_loss = self.model.compute_loss(x, recon, mu, logvar, gate, kl_weight)
        return base_loss + self.ewc_lambda * self.penalty()

    def penalty(self) -> torch.Tensor:
        if not self.tasks:
            return torch.tensor(0.0, device=next(self.parameters()).device)
        penalty = torch.tensor(0.0, device=next(self.parameters()).device)
        for n, p in self.model.named_parameters():
            if not p.requires_grad:
                continue
            for task_data in self.tasks.values():
                _mu = task_data["mean"][n]
                _fisher = task_data["fisher"][n]
                penalty += (_fisher * (p - _mu).pow(2)).sum()
        return penalty

    def end_task(self, dataloader: DataLoader, task_id: int, kl_weight: float):
        device = next(self.parameters()).device
        fisher = {}
        mean = {}
        for n, p in self.model.named_parameters():
            if p.requires_grad:
                fisher[n] = torch.zeros_like(p.data, device=device)
                mean[n] = p.data.clone().detach()

        self.model.eval()
        total_samples = 0
        for x, _ in dataloader:
            x = x.to(device)
            batch_size = x.size(0)
            self.model.zero_grad()
            # Pass task_id with correct batch size
            task_tensor = torch.full((batch_size,), task_id, dtype=torch.long, device=device)
            recon, mu, logvar, gate = self.model(x, task_tensor)
            loss = self.model.compute_loss(x, recon, mu, logvar, gate, kl_weight)
            loss.backward()
            for n, p in self.model.named_parameters():
                if p.grad is not None:
                    fisher[n] += p.grad.data.pow(2) * batch_size
            total_samples += batch_size

        for n in fisher:
            fisher[n] /= total_samples
        self.tasks[task_id] = {"mean": mean, "fisher": fisher}


# ─────────────────────────────────────────────────────────────────────
# Data Loader: Handles Numeric Attack Labels
# ─────────────────────────────────────────────────────────────────────

class RealDatasetLoader:
    """Loads NSL-KDD from .txt files with numeric attack labels."""
    
    def __init__(self, data_dir: Path = CONFIG["DATA_DIR"]):
        self.data_dir = Path(data_dir)
        train_path = self.data_dir / "KDDTrain+.txt"
        test_path = self.data_dir / "KDDTest+.txt"
        
        if not train_path.exists():
            raise FileNotFoundError(f"Missing: {train_path}")
        if not test_path.exists():
            raise FileNotFoundError(f"Missing: {test_path}")
        logger.info(f"✅ Raw NSL-KDD files found in: {self.data_dir}")

    def get_dataset_tasks(self) -> List[Dict[str, np.ndarray]]:
        """Load and partition NSL-KDD into continual learning tasks."""
        logger.info("Loading NSL-KDD from .txt files...")
        
        # Column schema (41 features + 2 labels)
        columns = [
            'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes',
            'land', 'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in',
            'num_compromised', 'root_shell', 'su_attempted', 'num_root', 'num_file_creations',
            'num_shells', 'num_access_files', 'num_outbound_cmds', 'is_host_login',
            'is_guest_login', 'count', 'srv_count', 'serror_rate', 'srv_serror_rate',
            'rerror_rate', 'srv_rerror_rate', 'same_srv_rate', 'diff_srv_rate',
            'same_srv_rate_2', 'diff_srv_rate_2', 'dst_host_count', 'dst_host_srv_count',
            'dst_host_same_srv_rate', 'dst_host_diff_srv_rate', 'dst_host_same_src_port_rate',
            'dst_host_srv_diff_host_rate', 'dst_host_serror_rate', 'dst_host_srv_serror_rate',
            'dst_host_rerror_rate', 'dst_host_srv_rerror_rate', 'attack', 'level'
        ]
        
        # Load raw data
        train_df = pd.read_csv(self.data_dir / "KDDTrain+.txt", names=columns)
        test_df = pd.read_csv(self.data_dir / "KDDTest+.txt", names=columns)
        df = pd.concat([train_df, test_df], ignore_index=True)
        logger.info(f"Combined dataset: {len(df)} samples")

        # Numeric attack code to category mapping (NSL-KDD standard)
        numeric_attack_mapping = {
            0: 'normal',
            1: 'dos',      # back
            2: 'u2r',      # buffer_overflow
            3: 'r2l',      # ftp_write
            4: 'r2l',      # guess_passwd
            5: 'r2l',      # imap
            6: 'probe',    # ipsweep
            7: 'dos',      # land
            8: 'u2r',      # loadmodule
            9: 'r2l',      # multihop
            10: 'dos',     # neptune
            11: 'probe',   # nmap
            12: 'u2r',     # perl
            13: 'r2l',     # phf
            14: 'dos',     # pod
            15: 'probe',   # portsweep
            16: 'u2r',     # rootkit
            17: 'probe',   # satan
            18: 'dos',     # smurf
            19: 'r2l',     # spy
            20: 'dos',     # teardrop
            21: 'r2l',     # warezclient
            22: 'r2l',     # warezmaster
            23: 'dos',     # apache2
            24: 'dos',     # udpstorm
            25: 'dos',     # procestable
            26: 'dos',     # mailbomb
            27: 'probe',   # saint
            28: 'probe',   # mscan
            29: 'u2r',     # xterm
            30: 'u2r',     # ps
            31: 'u2r',     # sqlattack
            32: 'r2l',     # httptunnel
            33: 'r2l',     # sendmail
            34: 'r2l',     # named
            35: 'r2l',     # snmpgetattack
            36: 'r2l',     # snmpguess
            37: 'r2l',     # xlock
            38: 'r2l',     # xsnoop
        }

        # Ensure 'attack' column is numeric
        df['attack'] = pd.to_numeric(df['attack'], errors='coerce')
        
        # Map to categories
        df['attack_cat'] = df['attack'].map(numeric_attack_mapping)
        df = df.dropna(subset=['attack_cat'])  # Remove unmapped entries
        logger.info(f"After filtering: {len(df)} samples")

        # Prepare features
        X_raw = df.drop(columns=['attack', 'level', 'attack_cat'])
        y_str = df['attack_cat']

        # One-hot encode categorical features
        X_processed = pd.get_dummies(
            X_raw,
            columns=X_raw.select_dtypes(include=['object']).columns,
            drop_first=False
        )

        # Global scaling
        scaler = MinMaxScaler()
        X_scaled = scaler.fit_transform(X_processed)

        # Create tasks
        tasks = []
        for task_idx, attack_types in enumerate(CONFIG["TASK_DEFINITIONS"]):
            mask = y_str.isin(attack_types) | (y_str == 'normal')
            X_task = X_scaled[mask]
            y_task = (y_str[mask].isin(attack_types)).astype(int).values

            if len(X_task) > CONFIG["MAX_SAMPLES_PER_TASK"]:
                idx = np.random.choice(len(X_task), CONFIG["MAX_SAMPLES_PER_TASK"], replace=False)
                X_task, y_task = X_task[idx], y_task[idx]

            tasks.append({"X": X_task.astype(np.float32), "y": y_task.astype(np.int64)})
            logger.info(f"Task {task_idx}: {len(X_task)} samples ({y_task.sum()} anomalies)")

        return tasks


# ─────────────────────────────────────────────────────────────────────
# Hyperparameter Tuning
# ─────────────────────────────────────────────────────────────────────

class HyperparameterTuner:
    """Optuna-based hyperparameter optimization for reconstruction-based anomaly detection."""
    
    def __init__(self, model_names: List[str], tuning_data: Dict[str, np.ndarray]):
        self.model_names = model_names
        self.X_tune = torch.from_numpy(tuning_data["X"])
        self.y_tune = tuning_data["y"]
        self.input_dim = self.X_tune.shape[1]
        self.device = CONFIG["DEVICE"]

    def _objective(self, trial: optuna.Trial, model_name: str) -> float:
        lr = trial.suggest_float("lr", 1e-4, 1e-2, log=True)
        latent_dim = trial.suggest_categorical("latent_dim", [16, 32, 64])
        hidden_dim = trial.suggest_categorical("hidden_dim", [64, 128, 256])
        kl_weight = trial.suggest_float("kl_weight", 0.05, 0.5, log=True)

        hparams = {
            "latent_dim": latent_dim,
            "hidden_dim": hidden_dim,
            "kl_weight": kl_weight,
        }

        if model_name == "TADR-VAE":
            task_emb_dim = trial.suggest_categorical("task_embedding_dim", [8, 16])
            hparams["task_embedding_dim"] = task_emb_dim
            model = TADR_VAE(self.input_dim, num_tasks=3, hparams=hparams).to(self.device)
        elif model_name == "Vanilla VAE":
            model = VanillaVAE(self.input_dim, hparams=hparams).to(self.device)
        elif model_name == "VAE+EWC":
            ewc_lambda = trial.suggest_float("ewc_lambda", 100, 1000, log=True)
            base_model = VanillaVAE(self.input_dim, hparams=hparams).to(self.device)
            model = EWC(base_model, ewc_lambda).to(self.device)
        else:
            raise ValueError(f"Unknown model: {model_name}")

        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        dataset = TensorDataset(self.X_tune, torch.from_numpy(self.y_tune))
        loader = DataLoader(dataset, batch_size=CONFIG["BATCH_SIZE"], shuffle=True)

        model.train()
        for _ in range(CONFIG["TUNING_EPOCHS"]):
            for batch_x, _ in loader:
                batch_x = batch_x.to(self.device)
                optimizer.zero_grad()
                # ✅ FIXED: task_id matches batch size
                task_id = torch.full((batch_x.size(0),), 0, dtype=torch.long, device=self.device)
                recon, mu, logvar, gate = model(batch_x, task_id)
                loss = model.compute_loss(batch_x, recon, mu, logvar, gate, kl_weight)
                loss.backward()
                optimizer.step()

        model.eval()
        errors = []
        with torch.no_grad():
            for batch_x, _ in loader:
                batch_x = batch_x.to(self.device)
                # ✅ FIXED: task_id matches batch size
                task_id = torch.full((batch_x.size(0),), 0, dtype=torch.long, device=self.device)
                recon, _, _, _ = model(batch_x, task_id)
                err = torch.mean((batch_x - recon) ** 2, dim=1).cpu().numpy()
                errors.extend(err)
        errors = np.array(errors)
        return roc_auc_score(self.y_tune, errors)

    def tune(self) -> Dict[str, Dict[str, Any]]:
        logger.info("Starting hyperparameter tuning...")
        best_params = {}
        for name in self.model_names:
            study = optuna.create_study(direction="maximize")
            study.optimize(
                lambda trial: self._objective(trial, name),
                n_trials=CONFIG["TUNING_TRIALS"],
                show_progress_bar=False
            )
            best_params[name] = study.best_params
            logger.info(f"Best params for {name}: {study.best_params}")
        return best_params


# ─────────────────────────────────────────────────────────────────────
# Experiment Framework with Visualization
# ─────────────────────────────────────────────────────────────────────

class AdvancedExperimentFramework:
    def _get_reconstruction_errors(
        self, model: nn.Module, X_test: np.ndarray, task_id: int
    ) -> np.ndarray:
        model.eval()
        device = next(model.parameters()).device
        dataset = TensorDataset(torch.from_numpy(X_test), torch.zeros(len(X_test)))
        loader = DataLoader(dataset, batch_size=CONFIG["BATCH_SIZE"])
        errors = []
        with torch.no_grad():
            for batch_x, _ in loader:
                batch_x = batch_x.to(device)
                # ✅ FIXED: task_id matches batch size
                task_tensor = torch.full((batch_x.size(0),), task_id, dtype=torch.long, device=device)
                recon, _, _, _ = model(batch_x, task_tensor)
                err = torch.mean((batch_x - recon) ** 2, dim=1).cpu().numpy()
                errors.extend(err)
        return np.array(errors)

    def _find_optimal_threshold(self, y_true: np.ndarray, errors: np.ndarray) -> float:
        thresholds = np.linspace(errors.min(), errors.max(), 100)
        f1_scores = [f1_score(y_true, errors >= t) for t in thresholds]
        return thresholds[np.argmax(f1_scores)]

    def run(self):
        logger.info("🚀 Starting Advanced Experiment Framework")
        data_loader = RealDatasetLoader()
        tasks = data_loader.get_dataset_tasks()
        input_dim = tasks[0]["X"].shape[1]
        num_tasks = len(tasks)

        tuner = HyperparameterTuner(
            model_names=["TADR-VAE", "Vanilla VAE", "VAE+EWC"],
            tuning_data=tasks[0]
        )
        best_params = tuner.tune()

        full_results = {name: [] for name in best_params}

        for seed in CONFIG["SEEDS"]:
            set_seed(seed)
            logger.info(f"Running experiments with seed: {seed}")
            for model_name in best_params:
                params = best_params[model_name]
                logger.info(f"Training {model_name} (seed={seed})")

                if model_name == "TADR-VAE":
                    model = TADR_VAE(input_dim, num_tasks, params).to(CONFIG["DEVICE"])
                elif model_name == "Vanilla VAE":
                    model = VanillaVAE(input_dim, params).to(CONFIG["DEVICE"])
                elif model_name == "VAE+EWC":
                    base = VanillaVAE(input_dim, params).to(CONFIG["DEVICE"])
                    model = EWC(base, params["ewc_lambda"]).to(CONFIG["DEVICE"])
                else:
                    raise ValueError(f"Unknown model: {model_name}")

                optimizer = torch.optim.Adam(model.parameters(), lr=params["lr"])
                perf_matrix = np.zeros((num_tasks, num_tasks, 4))
                start_time = time.time()

                for task_id in range(num_tasks):
                    X_train, y_train = tasks[task_id]["X"], tasks[task_id]["y"]
                    train_loader = DataLoader(
                        TensorDataset(torch.from_numpy(X_train), torch.from_numpy(y_train)),
                        batch_size=CONFIG["BATCH_SIZE"],
                        shuffle=True
                    )
                    model.train()
                    for epoch in range(CONFIG["EPOCHS"]):
                        for batch_x, _ in train_loader:
                            batch_x = batch_x.to(CONFIG["DEVICE"])
                            optimizer.zero_grad()
                            # ✅ FIXED: task_id matches batch size
                            tid_tensor = torch.full((batch_x.size(0),), task_id, dtype=torch.long, device=CONFIG["DEVICE"])
                            recon, mu, logvar, gate = model(batch_x, tid_tensor)
                            loss = model.compute_loss(
                                batch_x, recon, mu, logvar, gate, params["kl_weight"]
                            )
                            loss.backward()
                            optimizer.step()

                    if isinstance(model, EWC):
                        model.end_task(train_loader, task_id, params["kl_weight"])

                    for eval_task in range(task_id + 1):
                        errors = self._get_reconstruction_errors(
                            model, tasks[eval_task]["X"], eval_task
                        )
                        thresh = self._find_optimal_threshold(tasks[eval_task]["y"], errors)
                        y_pred = (errors >= thresh).astype(int)
                        y_true = tasks[eval_task]["y"]
                        perf_matrix[task_id, eval_task, 0] = precision_score(y_true, y_pred)
                        perf_matrix[task_id, eval_task, 1] = recall_score(y_true, y_pred)
                        perf_matrix[task_id, eval_task, 2] = f1_score(y_true, y_pred)
                        perf_matrix[task_id, eval_task, 3] = roc_auc_score(y_true, errors)

                elapsed = time.time() - start_time
                full_results[model_name].append({
                    "seed": seed,
                    "performance_matrix": perf_matrix,
                    "training_time_sec": elapsed
                })
                logger.info(f"✅ Completed {model_name} (seed={seed}) in {elapsed:.2f}s")

        self.results = full_results
        logger.info(" All experiments completed successfully.")
        
        # ✅ Generate Springer-ready output
        self.save_results_and_figures()

    def save_results_and_figures(self, output_dir: str = "results"):
        """Generate publication-ready figures and tables (Springer format)."""
        from pathlib import Path

        output_path = Path(output_dir)
        output_path.mkdir(exist_ok=True)

        # Aggregate results: mean and std across seeds
        summary = {}
        for model_name, runs in self.results.items():
            perf = np.array([run["performance_matrix"] for run in runs])  # [seeds, T, T, 4]
            mean_perf = np.mean(perf, axis=0)
            std_perf = np.std(perf, axis=0)
            summary[model_name] = {"mean": mean_perf, "std": std_perf}

        # 1. Final Average Metrics Table (LaTeX-ready)
        with open(output_path / "results_table.tex", "w") as f:
            f.write("\\begin{table}[ht]\n\\centering\n\\caption{Final Performance (Mean $\\pm$ Std)}\n")
            f.write("\\begin{tabular}{lcccc}\n\\toprule\nModel & Precision & Recall & F1-Score & AUC \\\\\n\\midrule\n")
            for model_name in summary:
                mean = summary[model_name]["mean"]
                prec = np.mean(mean[-1, :, 0])
                rec = np.mean(mean[-1, :, 1])
                f1 = np.mean(mean[-1, :, 2])
                auc = np.mean(mean[-1, :, 3])
                f.write(f"{model_name} & {prec:.3f} & {rec:.3f} & {f1:.3f} & {auc:.3f} \\\\\n")
            f.write("\\bottomrule\n\\end{tabular}\n\\end{table}\n")

        # 2. F1-Score Over Tasks (Continual Learning Plot)
        plt.figure(figsize=(8, 5))
        for model_name in summary:
            mean_f1 = summary[model_name]["mean"][:, :, 2]
            diag_f1 = np.diag(mean_f1)
            plt.plot(range(len(diag_f1)), diag_f1, marker='o', label=model_name, linewidth=2)
        plt.xlabel("Task Index")
        plt.ylabel("F1-Score")
        plt.title("Continual Learning Performance")
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.tight_layout()
        plt.savefig(output_path / "f1_over_tasks.pdf", dpi=300, bbox_inches='tight')
        plt.savefig(output_path / "f1_over_tasks.png", dpi=300, bbox_inches='tight')
        plt.close()

        # 3. Forgetting Metric
        plt.figure(figsize=(8, 5))
        for model_name in summary:
            mean_f1 = summary[model_name]["mean"][:, :, 2]
            num_tasks = mean_f1.shape[0]
            forgetting = []
            for t in range(1, num_tasks):
                max_perf = np.max(mean_f1[:t+1, t])
                final_perf = mean_f1[t, t]
                forget = max_perf - final_perf
                forgetting.append(forget)
            if forgetting:
                plt.plot(range(1, num_tasks), forgetting, marker='s', label=model_name, linewidth=2)
        plt.xlabel("Task Index")
        plt.ylabel("Forgetting (Δ F1)")
        plt.title("Catastrophic Forgetting Analysis")
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.6)
        plt.tight_layout()
        plt.savefig(output_path / "forgetting.pdf", dpi=300, bbox_inches='tight')
        plt.savefig(output_path / "forgetting.png", dpi=300, bbox_inches='tight')
        plt.close()

        # 4. Final Performance Heatmap
        for model_name in summary:
            plt.figure(figsize=(6, 2))
            final_f1 = summary[model_name]["mean"][-1, :, 2]
            sns.heatmap(
                final_f1.reshape(1, -1),
                annot=True,
                fmt=".3f",
                cmap="viridis",
                xticklabels=[f"Task {i}" for i in range(len(final_f1))],
                yticklabels=["Final"],
                cbar_kws={'label': 'F1-Score'}
            )
            plt.title(f"Final Task Performance: {model_name}")
            plt.tight_layout()
            plt.savefig(output_path / f"heatmap_{model_name.replace(' ', '_')}.pdf", dpi=300, bbox_inches='tight')
            plt.close()

        logger.info(f" figures and tables saved to: {output_path.resolve()}")


# ─────────────────────────────────────────────────────────────────────
# Entry Point
# ─────────────────────────────────────────────────────────────────────

if __name__ == "__main__":
    try:
        framework = AdvancedExperimentFramework()
        framework.run()
    except Exception as e:
        logger.exception("Fatal error during execution:")
        sys.exit(1)

2025-10-02 23:26:14 | INFO     | 🚀 Starting Advanced Experiment Framework
2025-10-02 23:26:14 | INFO     | ✅ Raw NSL-KDD files found in: research_data
2025-10-02 23:26:14 | INFO     | Loading NSL-KDD from .txt files...
2025-10-02 23:26:18 | INFO     | Combined dataset: 148517 samples
2025-10-02 23:26:18 | INFO     | After filtering: 148517 samples
2025-10-02 23:26:24 | INFO     | Task 0: 10000 samples (9963 anomalies)
2025-10-02 23:26:24 | INFO     | Task 1: 10000 samples (9830 anomalies)
2025-10-02 23:26:25 | INFO     | Task 2: 10000 samples (9985 anomalies)
2025-10-02 23:26:25 | INFO     | Starting hyperparameter tuning...
2025-10-02 23:28:04 | INFO     | Best params for TADR-VAE: {'lr': 0.0008983801068631165, 'latent_dim': 16, 'hidden_dim': 256, 'kl_weight': 0.2933216713441632, 'task_embedding_dim': 8}
2025-10-02 23:28:46 | INFO     | Best params for Vanilla VAE: {'lr': 0.008889988444746321, 'latent_dim': 32, 'hidden_dim': 128, 'kl_weight': 0.07761425986216605}
2025-10-02 23:29:42 |