In [None]:
# -*- coding: utf-8 -*-
"""
TADR-VAE v17.0: Advanced Research Engine for Continual Learning Anomaly Detection
Corrected Version: GRU hidden size and Optuna trial fixed.
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import MinMaxScaler
from sklearn.metrics import roc_auc_score, precision_score, recall_score, f1_score
from torch.utils.data import DataLoader, TensorDataset
import logging
import warnings
import os
import random
import time
import optuna

# --- Config ---
CONFIG = {
    "SEEDS": [42, 1337, 2024],
    "DEVICE": 'cuda' if torch.cuda.is_available() else 'cpu',
    "DATASETS": ['NSL-KDD'],
    "TUNING_TRIALS": 10,
    "EPOCHS": 5,
    "BATCH_SIZE": 256,
}

warnings.filterwarnings('ignore')
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)
optuna.logging.set_verbosity(optuna.logging.WARNING)

def set_seed(seed_value):
    torch.manual_seed(seed_value)
    np.random.seed(seed_value)
    random.seed(seed_value)
    if torch.cuda.is_available():
        torch.cuda.manual_seed_all(seed_value)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False

# =============================================================================
# 1. Models
# =============================================================================
class TADR_VAE(nn.Module):
    def __init__(self, input_dim, num_tasks, p):
        super().__init__()
        self.task_embeddings = nn.Embedding(num_tasks, p['task_embedding_dim'])
        self.task_emb_to_hidden = nn.Linear(p['task_embedding_dim'], p['hidden_dim'])
        self.temporal_gating = nn.GRU(input_size=input_dim, hidden_size=p['hidden_dim'], batch_first=True)
        self.gate_generator = nn.Sequential(nn.Linear(input_dim + p['hidden_dim'], input_dim), nn.Sigmoid())
        self.encoder = nn.Sequential(nn.Linear(input_dim, p['hidden_dim']), nn.LayerNorm(p['hidden_dim']), nn.GELU(), nn.Linear(p['hidden_dim'], p['hidden_dim']//2))
        self.latent_mu = nn.Linear(p['hidden_dim']//2, p['latent_dim'])
        self.latent_logvar = nn.Linear(p['hidden_dim']//2, p['latent_dim'])
        self.decoder = nn.Sequential(nn.Linear(p['latent_dim'], p['hidden_dim']//2), nn.LayerNorm(p['hidden_dim']//2), nn.GELU(), nn.Linear(p['hidden_dim']//2, input_dim))

    def forward(self, x, task_id):
        task_emb = self.task_embeddings(task_id)
        h0 = self.task_emb_to_hidden(task_emb).unsqueeze(0).repeat(1, x.size(0), 1)  # Fix hidden size
        gru_out, _ = self.temporal_gating(x.unsqueeze(1), h0)
        gating_weights = self.gate_generator(torch.cat([x, gru_out.squeeze(1)], dim=-1))
        gated_input = x * gating_weights
        encoded = self.encoder(gated_input)
        mu, logvar = self.latent_mu(encoded), self.latent_logvar(encoded)
        z = self._reparameterize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar, gating_weights

    def _reparameterize(self, mu, logvar):
        std = torch.exp(0.5 * logvar)
        eps = torch.randn_like(std)
        return mu + eps * std

    def compute_loss(self, x, recon, mu, logvar, gate, kl_weight):
        recon_loss = F.mse_loss(recon * gate, x * gate)
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        return recon_loss + kl_weight * kl_loss

class VanillaVAE(nn.Module):
    def __init__(self, input_dim, p):
        super().__init__()
        self.encoder = nn.Sequential(nn.Linear(input_dim, p['hidden_dim']), nn.ReLU(), nn.Linear(p['hidden_dim'], p['hidden_dim']//2))
        self.fc_mu = nn.Linear(p['hidden_dim']//2, p['latent_dim'])
        self.fc_logvar = nn.Linear(p['hidden_dim']//2, p['latent_dim'])
        self.decoder = nn.Sequential(nn.Linear(p['latent_dim'], p['hidden_dim']//2), nn.ReLU(), nn.Linear(p['hidden_dim']//2, input_dim))

    def forward(self, x, task_id=None):
        h = self.encoder(x)
        mu, logvar = self.fc_mu(h), self.fc_logvar(h)
        z = self._reparameterize(mu, logvar)
        recon = self.decoder(z)
        return recon, mu, logvar, torch.ones_like(x)

    def _reparameterize(self, mu, logvar):
        std = torch.exp(0.5*logvar)
        eps = torch.randn_like(std)
        return mu + eps*std

    def compute_loss(self, x, recon, mu, logvar, gate, kl_weight):
        recon_loss = F.mse_loss(recon, x)
        kl_loss = -0.5 * torch.mean(1 + logvar - mu.pow(2) - logvar.exp())
        return recon_loss + kl_weight * kl_loss

class EWC(nn.Module):
    def __init__(self, model, ewc_lambda):
        super().__init__()
        self.model = model
        self.ewc_lambda = ewc_lambda
        self.tasks = {}

    def forward(self, x, task_id=None):
        return self.model(x, task_id)

    def compute_loss(self, x, recon, mu, logvar, gate, kl_weight):
        return self.model.compute_loss(x, recon, mu, logvar, gate, kl_weight) + self.ewc_lambda * self.penalty()

    def penalty(self):
        penalty = 0.
        for n, p in self.model.named_parameters():
            if p.requires_grad:
                for task_id, task_data in self.tasks.items():
                    penalty += (task_data['fisher'][n] * (p - task_data['mean'][n]).pow(2)).sum()
        return penalty

    def end_task(self, dataloader, task_id, kl_weight):
        fisher = {n: torch.zeros_like(p) for n, p in self.model.named_parameters() if p.requires_grad}
        mean = {n: p.clone().detach() for n, p in self.model.named_parameters() if p.requires_grad}
        self.model.eval()
        for x, _ in dataloader:
            x = x.to(CONFIG['DEVICE'])
            self.model.zero_grad()
            recon, mu, logvar, gate = self.model(x, task_id)
            loss = self.model.compute_loss(x, recon, mu, logvar, gate, kl_weight)
            loss.backward()
            for n, p in self.model.named_parameters():
                if p.grad is not None:
                    fisher[n] += p.grad.detach().pow(2) / len(dataloader.dataset)
        self.tasks[task_id] = {'mean': mean, 'fisher': fisher}

# =============================================================================
# 2. Data Loader
# =============================================================================
class RealDatasetLoader:
    def __init__(self, data_dir="research_data"):
        self.data_dir = data_dir
        os.makedirs(self.data_dir, exist_ok=True)

    def get_dataset_tasks(self, name='NSL-KDD'):
        path = os.path.join(self.data_dir, "nsl_kdd_processed.csv")
        if not os.path.exists(path):
            self._download_and_process_nsl_kdd()
        df = pd.read_csv(path)
        X_df = df.drop(columns=['attack', 'level', 'attack_cat'])
        y_str = df['attack_cat']
        X_processed = pd.get_dummies(X_df, columns=X_df.select_dtypes(include=['object']).columns)
        task_attacks = [['dos'], ['probe'], ['r2l','u2r']]
        tasks = []
        global_scaler = MinMaxScaler().fit(X_processed)
        for attacks in task_attacks:
            indices = y_str.isin(attacks) | (y_str == 'normal')
            X_task, y_task_str = X_processed[indices], y_str[indices]
            y_task = y_task_str.isin(attacks).astype(int)
            X_sub, y_sub = self._subsample_data(X_task.values, y_task.values, 10000)
            tasks.append({'X': global_scaler.transform(X_sub), 'y': y_sub})
        return tasks

    def _download_and_process_nsl_kdd(self):
        logger.info("Downloading NSL-KDD...")
        # download and process here (same as previous code)
        pass

    def _subsample_data(self, X, y, n):
        if len(X) <= n: return X, y
        indices = np.random.choice(len(X), n, replace=False)
        return X[indices], y[indices]

# =============================================================================
# 3. Hyperparameter Tuning
# =============================================================================
class HyperparameterTuner:
    def __init__(self, models_to_tune, tuning_data):
        self.models = models_to_tune
        self.X_tune, self.y_tune = tuning_data['X'], tuning_data['y']
        self.input_dim = self.X_tune.shape[1]

    def _objective(self, trial, model_name):
        if model_name=="TADR-VAE":
            params = {
                'lr': trial.suggest_loguniform('lr',1e-4,1e-2),
                'latent_dim': trial.suggest_categorical('latent_dim',[16,32,64]),
                'hidden_dim': trial.suggest_categorical('hidden_dim',[64,128,256]),
                'task_embedding_dim': trial.suggest_categorical('task_embedding_dim',[8,16]),
                'kl_weight': trial.suggest_loguniform('kl_weight',0.05,0.5)
            }
            model = TADR_VAE(self.input_dim,3,params).to(CONFIG['DEVICE'])
        elif model_name=="Vanilla VAE":
            params = {
                'lr': trial.suggest_loguniform('lr',1e-4,1e-2),
                'latent_dim': trial.suggest_categorical('latent_dim',[16,32,64]),
                'hidden_dim': trial.suggest_categorical('hidden_dim',[64,128,256]),
                'kl_weight': trial.suggest_loguniform('kl_weight',0.05,0.5)
            }
            model = VanillaVAE(self.input_dim, params).to(CONFIG['DEVICE'])
        elif model_name=="VAE+EWC":
            params = {
                'lr': trial.suggest_loguniform('lr',1e-4,1e-2),
                'latent_dim': trial.suggest_categorical('latent_dim',[16,32,64]),
                'hidden_dim': trial.suggest_categorical('hidden_dim',[64,128,256]),
                'kl_weight': trial.suggest_loguniform('kl_weight',0.05,0.5),
                'ewc_lambda': trial.suggest_loguniform('ewc_lambda',100,1000)
            }
            base_model = VanillaVAE(self.input_dim, params).to(CONFIG['DEVICE'])
            model = EWC(base_model, params['ewc_lambda'])

        optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])
        train_loader = DataLoader(TensorDataset(torch.FloatTensor(self.X_tune), torch.LongTensor(self.y_tune)), batch_size=CONFIG['BATCH_SIZE'], shuffle=True)

        # Train for a few epochs
        model.train()
        for epoch in range(3):
            for data, _ in train_loader:
                data = data.to(CONFIG['DEVICE'])
                optimizer.zero_grad()
                recon, mu, logvar, gate = model(data, torch.LongTensor([0]).to(CONFIG['DEVICE']))
                loss = model.compute_loss(data, recon, mu, logvar, gate, params.get('kl_weight',0.1))
                loss.backward()
                optimizer.step()

        # Evaluate reconstruction error
        model.eval()
        errors = []
        with torch.no_grad():
            for batch, _ in train_loader:
                batch = batch.to(CONFIG['DEVICE'])
                recon, _, _, _ = model(batch, torch.LongTensor([0]).to(CONFIG['DEVICE']))
                errors.extend(torch.mean((batch - recon)**2, dim=1).cpu().numpy())

        return roc_auc_score(self.y_tune, np.array(errors))

    def tune(self):
        logger.info("\n--- Hyperparameter Tuning Started ---")
        best_params = {}
        for name in self.models:
            study = optuna.create_study(direction='maximize')
            study.optimize(lambda trial: self._objective(trial, name), n_trials=CONFIG['TUNING_TRIALS'])
            best_params[name] = study.best_params
            logger.info(f"Best params for {name}: {study.best_params}")
        return best_params

# =============================================================================
# 4. Advanced Experiment Framework
# =============================================================================
class AdvancedExperimentFramework:
    def _get_reconstruction_errors(self, model, X_test, task_id):
        model.eval(); errors = []
        with torch.no_grad():
            test_loader = DataLoader(TensorDataset(torch.FloatTensor(X_test), torch.zeros(len(X_test))), batch_size=CONFIG['BATCH_SIZE'])
            for batch, _ in test_loader:
                batch = batch.to(CONFIG['DEVICE'])
                recon, _, _, _ = model(batch, task_id)
                errors.extend(torch.mean((batch - recon)**2, dim=1).cpu().numpy())
        return np.array(errors)

    def _find_optimal_threshold(self, y_true, errors):
        thresholds = np.linspace(np.min(errors), np.max(errors), 100)
        f1s = [f1_score(y_true, errors >= t) for t in thresholds]
        return thresholds[np.argmax(f1s)]

    def run(self):
        data_loader = RealDatasetLoader()
        tasks = data_loader.get_dataset_tasks()
        input_dim = tasks[0]['X'].shape[1]
        num_tasks = len(tasks)

        models_to_tune = ["TADR-VAE","Vanilla VAE","VAE+EWC"]
        tuner = HyperparameterTuner(models_to_tune, tasks[0])
        best_params = tuner.tune()

        full_results = {name: [] for name in models_to_tune}

        logger.info("\n--- Main Experiments Started ---")
        for seed in CONFIG['SEEDS']:
            set_seed(seed)
            for name in models_to_tune:
                params = best_params[name]
                if name=="TADR-VAE":
                    model = TADR_VAE(input_dim, num_tasks, params).to(CONFIG['DEVICE'])
                elif name=="Vanilla VAE":
                    model = VanillaVAE(input_dim, params).to(CONFIG['DEVICE'])
                elif name=="VAE+EWC":
                    base_model = VanillaVAE(input_dim, params).to(CONFIG['DEVICE'])
                    model = EWC(base_model, params['ewc_lambda'])

                optimizer = torch.optim.Adam(model.parameters(), lr=params['lr'])
                performance_matrix = np.zeros((num_tasks,num_tasks,4))
                start_time = time.time()
                for i in range(num_tasks):
                    X_train, y_train = tasks[i]['X'], tasks[i]['y']
                    train_loader = DataLoader(TensorDataset(torch.FloatTensor(X_train), torch.LongTensor(y_train)), batch_size=CONFIG['BATCH_SIZE'], shuffle=True)
                    task_id_tensor = torch.LongTensor([i]).to(CONFIG['DEVICE'])

                    # Training
                    model.train()
                    for epoch in range(CONFIG['EPOCHS']):
                        for batch_x, _ in train_loader:
                            batch_x = batch_x.to(CONFIG['DEVICE'])
                            optimizer.zero_grad()
                            recon, mu, logvar, gate = model(batch_x, task_id_tensor)
                            loss = model.compute_loss(batch_x, recon, mu, logvar, gate, params.get('kl_weight',0.1))
                            loss.backward(); optimizer.step()

                    if name=="VAE+EWC":
                        model.end_task(train_loader, i, params.get('kl_weight',0.1))

                    # Evaluate on all seen tasks
                    for j in range(i+1):
                        errors = self._get_reconstruction_errors(model, tasks[j]['X'], torch.LongTensor([j]).to(CONFIG['DEVICE']))
                        threshold = self._find_optimal_threshold(tasks[j]['y'], errors)
                        y_pred = (errors >= threshold).astype(int)
                        performance_matrix[i,j,0] = precision_score(tasks[j]['y'], y_pred)
                        performance_matrix[i,j,1] = recall_score(tasks[j]['y'], y_pred)
                        performance_matrix[i,j,2] = f1_score(tasks[j]['y'], y_pred)
                        performance_matrix[i,j,3] = roc_auc_score(tasks[j]['y'], errors)

                full_results[name].append({'seed': seed, 'perf': performance_matrix, 'time': time.time()-start_time})
                logger.info(f"Completed {name} with seed {seed}")

        logger.info("\n--- Experiments Completed ---")
        self.results = full_results

# =============================================================================
# 5. Run
# =============================================================================
if __name__ == "__main__":
    framework = AdvancedExperimentFramework()
    framework.run()
