In [None]:
# Model-Agnostic Meta-Learning (MAML) with Transformer Architecture for Network Intrusion Detection

import os
import time
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.model_selection import train_test_split
from sklearn.metrics import confusion_matrix, classification_report, roc_curve, auc
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import random
import warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
np.random.seed(42)
tf.random.set_seed(42)
random.seed(42)

# Configure GPU if available
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
    except RuntimeError as e:
        print(e)

#############################################################
# Data Loading and Preprocessing
#############################################################

class NetworkDataProcessor:
    def __init__(self, data_path=None):
        self.data_path = data_path
        self.label_encoder = LabelEncoder()
        self.feature_scaler = StandardScaler()
        self.max_seq_length = 30  # For sequential flows

    def load_data(self, dataset_name="UNSW-NB15"):
        """
        Load network intrusion detection dataset
        Supports: UNSW-NB15, NSL-KDD, CICIDS2017
        """
        if self.data_path is None:
            # For demo purposes, we'll use a sample from UNSW-NB15
            print(f"No data path provided. Using synthetic data for demonstration.")
            return self._generate_synthetic_data()

        if dataset_name == "UNSW-NB15":
            df = pd.read_csv(self.data_path)
            # Map 'attack_cat' to labels, 'normal' remains as is
            df['binary_label'] = df['label'].apply(lambda x: 0 if x == 0 else 1)
            return df
        elif dataset_name == "NSL-KDD":
            df = pd.read_csv(self.data_path)
            # Map attack types to binary labels
            df['binary_label'] = df['label'].apply(lambda x: 0 if x == 'normal' else 1)
            return df
        elif dataset_name == "CICIDS2017":
            df = pd.read_csv(self.data_path)
            # Map 'Label' to binary (0 for benign, 1 for attacks)
            df['binary_label'] = df['Label'].apply(lambda x: 0 if x == 'BENIGN' else 1)
            return df
        else:
            raise ValueError(f"Dataset {dataset_name} not supported")

    def _generate_synthetic_data(self, n_samples=10000):
        """Generate synthetic network flow data for demonstration"""
        # Create synthetic features resembling network flow data
        data = {
            'duration': np.random.exponential(scale=30, size=n_samples),
            'protocol_type': np.random.choice(['tcp', 'udp', 'icmp'], size=n_samples),
            'service': np.random.choice(['http', 'ftp', 'smtp', 'ssh', 'dns'], size=n_samples),
            'src_bytes': np.random.exponential(scale=1000, size=n_samples),
            'dst_bytes': np.random.exponential(scale=800, size=n_samples),
            'flag': np.random.choice(['SF', 'REJ', 'S0', 'RSTO'], size=n_samples),
            'land': np.random.choice([0, 1], size=n_samples, p=[0.99, 0.01]),
            'wrong_fragment': np.random.choice([0, 1, 2, 3], size=n_samples, p=[0.95, 0.02, 0.02, 0.01]),
            'urgent': np.random.choice([0, 1, 2], size=n_samples, p=[0.98, 0.01, 0.01]),
            'hot': np.random.poisson(lam=0.1, size=n_samples),
            'num_failed_logins': np.random.poisson(lam=0.05, size=n_samples),
            'logged_in': np.random.choice([0, 1], size=n_samples, p=[0.4, 0.6]),
            'num_compromised': np.random.poisson(lam=0.01, size=n_samples),
            'root_shell': np.random.choice([0, 1], size=n_samples, p=[0.99, 0.01]),
            'su_attempted': np.random.choice([0, 1], size=n_samples, p=[0.99, 0.01]),
            'num_root': np.random.poisson(lam=0.01, size=n_samples),
            'num_file_creations': np.random.poisson(lam=0.1, size=n_samples),
            'num_shells': np.random.poisson(lam=0.01, size=n_samples),
            'num_access_files': np.random.poisson(lam=0.05, size=n_samples),
            'is_host_login': np.random.choice([0, 1], size=n_samples, p=[0.99, 0.01]),
            'is_guest_login': np.random.choice([0, 1], size=n_samples, p=[0.95, 0.05]),
        }

        df = pd.DataFrame(data)

        # Generate labels: 5 attack types + normal
        attack_types = ['normal', 'dos', 'probe', 'r2l', 'u2r', 'backdoor']

        # Create imbalanced dataset (realistic for network security)
        # 80% normal, 20% attacks with different distributions
        attack_probs = [0.8, 0.1, 0.04, 0.03, 0.02, 0.01]
        df['attack_cat'] = np.random.choice(attack_types, size=n_samples, p=attack_probs)

        # Add binary label (0 for normal, 1 for attack)
        df['binary_label'] = df['attack_cat'].apply(lambda x: 0 if x == 'normal' else 1)

        # Add a few more network-specific features
        df['pkt_count'] = np.random.poisson(lam=15, size=n_samples)
        df['byte_count'] = df['pkt_count'] * np.random.lognormal(4, 1, size=n_samples)
        df['tcp_flags'] = np.random.choice(['000', '001', '010', '011', '100'], size=n_samples)

        return df

    def preprocess_data(self, df):
        """Preprocess the data for network intrusion detection"""
        # Make a copy to avoid modifying the original
        processed_df = df.copy()

        # Handle categorical features
        categorical_columns = processed_df.select_dtypes(include=['object']).columns
        for col in categorical_columns:
            if col != 'attack_cat':  # Don't encode the target yet
                processed_df[col] = self.label_encoder.fit_transform(processed_df[col])

        # Convert attack categories to numeric labels
        if 'attack_cat' in processed_df.columns:
            processed_df['attack_cat_encoded'] = self.label_encoder.fit_transform(processed_df['attack_cat'])
            self.attack_mapping = dict(zip(self.label_encoder.classes_, self.label_encoder.transform(self.label_encoder.classes_)))
            print("Attack mapping:", self.attack_mapping)

        # Extract features and labels
        if 'attack_cat' in processed_df.columns:
            X = processed_df.drop(['attack_cat', 'attack_cat_encoded', 'binary_label'], axis=1, errors='ignore')
            y_multiclass = processed_df['attack_cat_encoded'] if 'attack_cat_encoded' in processed_df.columns else None
            y_binary = processed_df['binary_label']
        else:
            X = processed_df.drop(['binary_label'], axis=1, errors='ignore')
            y_multiclass = None
            y_binary = processed_df['binary_label']

        # Handle missing values
        X = X.fillna(0)

        # Scale numerical features
        X_scaled = self.feature_scaler.fit_transform(X)

        return X_scaled, y_binary, y_multiclass

    def create_tasks(self, X, y_multiclass, num_tasks=100, k_shot=5, query_size=15):
        """
        Create tasks for meta-learning
        Each task contains:
        - support set: k examples of N classes (k-shot, N-way)
        - query set: query_size examples of the same N classes
        """
        if y_multiclass is None:
            raise ValueError("Multiclass labels are required for creating few-shot tasks")

        # Get unique classes
        classes = np.unique(y_multiclass)
        n_classes = len(classes)
        n_way = min(5, n_classes)  # Default to 5-way classification if possible

        tasks = []
        for _ in range(num_tasks):
            # Randomly select N classes for this task
            task_classes = np.random.choice(classes, n_way, replace=False)

            support_X, support_y = [], []
            query_X, query_y = [], []

            for class_idx, cls in enumerate(task_classes):
                # Find examples of this class
                cls_indices = np.where(y_multiclass == cls)[0]

                # Ensure we have enough examples
                if len(cls_indices) < k_shot + query_size:
                    # If not enough examples, sample with replacement
                    selected_indices = np.random.choice(cls_indices, k_shot + query_size, replace=True)
                else:
                    # Otherwise, sample without replacement
                    selected_indices = np.random.choice(cls_indices, k_shot + query_size, replace=False)

                # Split into support and query
                support_indices = selected_indices[:k_shot]
                query_indices = selected_indices[k_shot:k_shot + query_size]

                # Add to support set
                support_X.append(X[support_indices])
                support_y.append(np.full(k_shot, class_idx))  # Use task-specific class indices

                # Add to query set
                query_X.append(X[query_indices])
                query_y.append(np.full(query_size, class_idx))  # Use task-specific class indices

            # Combine and shuffle support set
            support_X = np.vstack(support_X)
            support_y = np.concatenate(support_y)
            support_indices = np.arange(len(support_y))
            np.random.shuffle(support_indices)
            support_X = support_X[support_indices]
            support_y = support_y[support_indices]

            # Combine and shuffle query set
            query_X = np.vstack(query_X)
            query_y = np.concatenate(query_y)
            query_indices = np.arange(len(query_y))
            np.random.shuffle(query_indices)
            query_X = query_X[query_indices]
            query_y = query_y[query_indices]

            # Create a task dictionary
            task = {
                'support_X': support_X,
                'support_y': support_y,
                'query_X': query_X,
                'query_y': query_y,
                'n_way': n_way,
                'k_shot': k_shot,
                'classes': task_classes
            }

            tasks.append(task)

        return tasks


#############################################################
# Transformer Architecture
#############################################################

class TransformerBlock(layers.Layer):
    def __init__(self, embed_dim, num_heads, ff_dim, rate=0.1):
        super(TransformerBlock, self).__init__()
        self.att = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.ffn = keras.Sequential([
            layers.Dense(ff_dim, activation="relu"),
            layers.Dense(embed_dim),
        ])
        self.layernorm1 = layers.LayerNormalization(epsilon=1e-6)
        self.layernorm2 = layers.LayerNormalization(epsilon=1e-6)
        self.dropout1 = layers.Dropout(rate)
        self.dropout2 = layers.Dropout(rate)

    def call(self, inputs, training):
        attn_output = self.att(inputs, inputs)
        attn_output = self.dropout1(attn_output, training=training)
        out1 = self.layernorm1(inputs + attn_output)
        ffn_output = self.ffn(out1)
        ffn_output = self.dropout2(ffn_output, training=training)
        return self.layernorm2(out1 + ffn_output)

    def get_config(self):
        config = super().get_config()
        config.update({
            "att": self.att,
            "ffn": self.ffn,
            "layernorm1": self.layernorm1,
            "layernorm2": self.layernorm2,
            "dropout1": self.dropout1,
            "dropout2": self.dropout2,
        })
        return config


class PositionalEncoding(layers.Layer):
    def __init__(self, position, d_model):
        super(PositionalEncoding, self).__init__()
        self.pos_encoding = self.positional_encoding(position, d_model)

    def get_angles(self, position, i, d_model):
        angles = 1 / tf.pow(10000, (2 * (i // 2)) / tf.cast(d_model, tf.float32))
        return position * angles

    def positional_encoding(self, position, d_model):
        angle_rads = self.get_angles(
            position=tf.range(position, dtype=tf.float32)[:, tf.newaxis],
            i=tf.range(d_model, dtype=tf.float32)[tf.newaxis, :],
            d_model=d_model
        )

        # Apply sin to even indices in the array; 2i
        sines = tf.math.sin(angle_rads[:, 0::2])

        # Apply cos to odd indices in the array; 2i+1
        cosines = tf.math.cos(angle_rads[:, 1::2])

        pos_encoding = tf.concat([sines, cosines], axis=-1)
        pos_encoding = pos_encoding[tf.newaxis, ...]

        return tf.cast(pos_encoding, tf.float32)

    def call(self, inputs):
        return inputs + self.pos_encoding[:, :tf.shape(inputs)[1], :]

    def get_config(self):
        config = super().get_config()
        config.update({
            "pos_encoding": self.pos_encoding,
        })
        return config


#############################################################
# MAML Implementation
#############################################################

class MAMLTransformer:
    def __init__(
        self,
        input_shape,
        n_way=5,
        k_shot=5,
        inner_lr=0.01,
        meta_lr=0.001,
        meta_batch_size=32,
        num_inner_updates=5,
        embed_dim=256,
        num_heads=8,
        ff_dim=512,
        num_transformer_blocks=4,
        mlp_units=[128, 64],
        dropout=0.1,
        name="maml_transformer"
    ):
        self.input_shape = input_shape
        self.n_way = n_way
        self.k_shot = k_shot
        self.inner_lr = inner_lr
        self.meta_lr = meta_lr
        self.meta_batch_size = meta_batch_size
        self.num_inner_updates = num_inner_updates
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.ff_dim = ff_dim
        self.num_transformer_blocks = num_transformer_blocks
        self.mlp_units = mlp_units
        self.dropout = dropout
        self.name = name

        # Build the model
        self.model = self._build_model()

        # Initialize optimizer
        self.meta_optimizer = keras.optimizers.Adam(learning_rate=self.meta_lr)

        # Store the original weights
        self.meta_weights = self.model.get_weights()

    def _build_model(self):
        """Build the transformer-based model for few-shot learning"""
        inputs = layers.Input(shape=self.input_shape)

        # Embedding layer to transform input features to embedding space
        x = layers.Dense(self.embed_dim)(inputs)

        # Reshape for transformer if input is not sequential
        # For sequential data, we would reshape to (batch_size, seq_length, embed_dim)
        # For tabular data, we treat each feature as a "token"
        x = tf.expand_dims(x, axis=1)  # Add sequence dimension of length 1

        # Apply positional encoding
        x = PositionalEncoding(position=50, d_model=self.embed_dim)(x)

        # Transformer blocks
        for _ in range(self.num_transformer_blocks):
            x = TransformerBlock(
                embed_dim=self.embed_dim,
                num_heads=self.num_heads,
                ff_dim=self.ff_dim,
                rate=self.dropout
            )(x)

        # Global average pooling
        x = layers.GlobalAveragePooling1D()(x)

        # MLP for classification
        for dim in self.mlp_units:
            x = layers.Dense(dim, activation="relu")(x)
            x = layers.Dropout(self.dropout)(x)

        # Output layer
        outputs = layers.Dense(self.n_way, activation="softmax")(x)

        return keras.Model(inputs, outputs, name=self.name)

    def train_on_batch(self, batch_of_tasks):
        """
        Train on a batch of tasks using MAML

        Args:
            batch_of_tasks: list of tasks, each containing support and query sets
        """
        meta_batch_size = min(len(batch_of_tasks), self.meta_batch_size)

        with tf.GradientTape() as meta_tape:
            total_meta_loss = tf.constant(0.0, dtype=tf.float32)

            for i in range(meta_batch_size):
                task = batch_of_tasks[i]
                support_X = task['support_X']
                support_y = task['support_y']
                query_X = task['query_X']
                query_y = task['query_y']

                # Convert to one-hot
                support_y_one_hot = tf.one_hot(support_y, depth=self.n_way)
                query_y_one_hot = tf.one_hot(query_y, depth=self.n_way)

                # Store the original weights
                original_weights = self.model.get_weights()

                # Inner loop - adapt to the task
                for _ in range(self.num_inner_updates):
                    with tf.GradientTape() as inner_tape:
                        support_logits = self.model(support_X, training=True)
                        support_loss = keras.losses.categorical_crossentropy(
                            support_y_one_hot, support_logits
                        )
                        support_loss = tf.reduce_mean(support_loss)

                    # Compute gradients and update model weights
                    gradients = inner_tape.gradient(support_loss, self.model.trainable_variables)

                    # Manual weight update for inner loop
                    updated_weights = []
                    for j, (weight, grad) in enumerate(zip(self.model.get_weights(), gradients)):
                        if grad is not None:
                            updated_weights.append(weight - self.inner_lr * grad)
                        else:
                            updated_weights.append(weight)

                    # Set the updated weights
                    self.model.set_weights(updated_weights)

                # Evaluate on query set with the adapted model
                query_logits = self.model(query_X, training=True)
                query_loss = keras.losses.categorical_crossentropy(
                    query_y_one_hot, query_logits
                )
                query_loss = tf.reduce_mean(query_loss)

                # Add to meta loss
                total_meta_loss += query_loss

                # Restore original weights
                self.model.set_weights(original_weights)

            # Average meta loss
            total_meta_loss /= meta_batch_size

        # Compute gradients of meta loss with respect to meta weights
        meta_gradients = meta_tape.gradient(total_meta_loss, self.model.trainable_variables)

        # Apply meta gradients
        self.meta_optimizer.apply_gradients(zip(meta_gradients, self.model.trainable_variables))

        # Update stored meta weights
        self.meta_weights = self.model.get_weights()

        return total_meta_loss.numpy()

    def evaluate(self, tasks, num_inner_updates=None):
        """
        Evaluate the meta-model on a list of tasks

        Args:
            tasks: list of tasks to evaluate on
            num_inner_updates: number of inner updates to perform (if None, use self.num_inner_updates)

        Returns:
            mean accuracy across tasks
        """
        if num_inner_updates is None:
            num_inner_updates = self.num_inner_updates

        accuracies = []
        losses = []

        # Restore meta weights
        self.model.set_weights(self.meta_weights)

        for task in tasks:
            support_X = task['support_X']
            support_y = task['support_y']
            query_X = task['query_X']
            query_y = task['query_y']

            # Convert to one-hot
            support_y_one_hot = tf.one_hot(support_y, depth=self.n_way)

            # Store the original weights
            original_weights = self.model.get_weights()

            # Inner loop adaptation
            for _ in range(num_inner_updates):
                with tf.GradientTape() as tape:
                    support_logits = self.model(support_X, training=True)
                    support_loss = keras.losses.categorical_crossentropy(
                        support_y_one_hot, support_logits
                    )
                    support_loss = tf.reduce_mean(support_loss)

                # Compute gradients and update model weights
                gradients = tape.gradient(support_loss, self.model.trainable_variables)

                # Manual weight update
                updated_weights = []
                for j, (weight, grad) in enumerate(zip(self.model.get_weights(), gradients)):
                    if grad is not None:
                        updated_weights.append(weight - self.inner_lr * grad)
                    else:
                        updated_weights.append(weight)

                # Set the updated weights
                self.model.set_weights(updated_weights)

            # Evaluate on query set
            query_logits = self.model(query_X, training=False)
            pred_y = tf.argmax(query_logits, axis=1).numpy()
            accuracy = np.mean(pred_y == query_y)

            # Calculate loss
            query_y_one_hot = tf.one_hot(query_y, depth=self.n_way)
            loss = keras.losses.categorical_crossentropy(query_y_one_hot, query_logits)
            loss = tf.reduce_mean(loss).numpy()

            accuracies.append(accuracy)
            losses.append(loss)

            # Restore original weights
            self.model.set_weights(original_weights)

        # Return mean accuracy and loss
        return np.mean(accuracies), np.mean(losses)

    def adapt_to_task(self, support_X, support_y, num_inner_updates=None):
        """
        Adapt the model to a new task using the support set

        Args:
            support_X: support set features
            support_y: support set labels
            num_inner_updates: number of inner updates to perform (if None, use self.num_inner_updates)

        Returns:
            Adapted model
        """
        if num_inner_updates is None:
            num_inner_updates = self.num_inner_updates

        # Convert class indices to one-hot
        support_y_one_hot = tf.one_hot(support_y, depth=self.n_way)

        # Reset to the meta weights
        self.model.set_weights(self.meta_weights)

        # Inner loop adaptation
        for _ in range(num_inner_updates):
            with tf.GradientTape() as tape:
                support_logits = self.model(support_X, training=True)
                support_loss = keras.losses.categorical_crossentropy(
                    support_y_one_hot, support_logits
                )
                support_loss = tf.reduce_mean(support_loss)

            # Compute gradients
            gradients = tape.gradient(support_loss, self.model.trainable_variables)

            # Manual weight update
            updated_weights = []
            for j, (weight, grad) in enumerate(zip(self.model.get_weights(), gradients)):
                if grad is not None:
                    updated_weights.append(weight - self.inner_lr * grad)
                else:
                    updated_weights.append(weight)

            # Set the updated weights
            self.model.set_weights(updated_weights)

        return self.model

    def save_meta_model(self, filepath):
        """Save the meta-model weights"""
        # Set to meta weights before saving
        self.model.set_weights(self.meta_weights)
        self.model.save_weights(filepath)

    def load_meta_model(self, filepath):
        """Load the meta-model weights"""
        self.model.load_weights(filepath)
        self.meta_weights = self.model.get_weights()


#############################################################
# Training and Evaluation
#############################################################

class MAMLTrainer:
    def __init__(
        self,
        maml_model,
        train_tasks,
        val_tasks,
        test_tasks=None,
        meta_epochs=10000,
        meta_batch_size=16,
        eval_interval=100,
        early_stopping_patience=10,
        log_dir='logs'
    ):
        self.maml_model = maml_model
        self.train_tasks = train_tasks
        self.val_tasks = val_tasks
        self.test_tasks = test_tasks
        self.meta_epochs = meta_epochs
        self.meta_batch_size = meta_batch_size
        self.eval_interval = eval_interval
        self.early_stopping_patience = early_stopping_patience
        self.log_dir = log_dir

        # Create log directory if it doesn't exist
        os.makedirs(log_dir, exist_ok=True)

        # Initialize training history
        self.history = {
            'train_loss': [],
            'val_accuracy': [],
            'val_loss': []
        }

        # Initialize early stopping variables
        self.best_val_accuracy = 0
        self.patience_counter = 0
        self.best_weights = None

    def train(self):
        """Train the MAML model"""
        print("Starting meta-training...")

        start_time = time.time()

        for epoch in range(self.meta_epochs):
            # Sample batch of tasks
            batch_indices = np.random.choice(
                len(self.train_tasks),
                min(self.meta_batch_size, len(self.train_tasks)),
                replace=False
            )
            batch_of_tasks = [self.train_tasks[i] for i in batch_indices]

            # Train on batch of tasks
            loss = self.maml_model.train_on_batch(batch_of_tasks)
            self.history['train_loss'].append(loss)

            # Evaluate periodically
            if (epoch + 1) % self.eval_interval == 0:
                val_accuracy, val_loss = self.maml_model.evaluate(self.val_tasks)
                self.history['val_accuracy'].append(val_accuracy)
                self.history['val_loss'].append(val_loss)

                elapsed_time = time.time() - start_time
                print(f"Epoch {epoch+1}/{self.meta_epochs} - "
                      f"Loss: {loss:.4f} - "
                      f"Val Accuracy: {val_accuracy:.4f} - "
                      f"Val Loss: {val_loss:.4f} - "
                      f"Time: {elapsed_time:.2f}s")

                # Check for early stopping
                if val_accuracy > self.best_val_accuracy:
                    self.best_val_accuracy = val_accuracy
                    self.patience_counter = 0
                    self.best_weights = self.maml_model.model.get_weights()
                    # Save best model
                    self.maml_model.save_meta_model(os.path.join(self.log_dir, 'best_model.h5'))
                else:
                    self.patience_counter += 1

                if self.patience_counter >= self.early_stopping_patience:
                    print(f"Early stopping at epoch {epoch+1}")
                    break

        # Restore best weights
        if self.best_weights is not None:
            self.maml_model.model.set_weights(self.best_weights)
            self.maml_model.meta_weights = self.best_weights

        print(f"Meta-training completed in {time.time() - start_time:.2f}s")

        # Final evaluation on test set if available
        if self.test_tasks is not None:
            test_accuracy, test_loss = self.maml_model.evaluate(self.test_tasks)
            print(f"Test Accuracy: {test_accuracy:.4f} - Test Loss: {test_loss:.4f}")

        return self.history

    def visualize_training(self):
        """Visualize the training history"""
        plt.figure(figsize=(15, 5))

        # Plot training loss
        plt.subplot(1, 2, 1)
        plt.plot(self.history['train_loss'], label='Training Loss')
        plt.plot(np.arange(0, len(self.history['train_loss']), self.eval_interval)[:-1],
                 self.history['val_loss'], 'r-', label='Validation Loss')
        plt.title('Meta-Learning Loss')
        plt.xlabel('Meta-Iterations')
        plt.ylabel('Loss')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)

        # Plot validation accuracy
        plt.subplot(1, 2, 2)
        plt.plot(np.arange(0, len(self.history['train_loss']), self.eval_interval)[:-1],
                 self.history['val_accuracy'], 'g-', label='Validation Accuracy')
        plt.title('Few-Shot Classification Accuracy')
        plt.xlabel('Meta-Iterations')
        plt.ylabel('Accuracy')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)

        plt.tight_layout()
        plt.savefig(os.path.join(self.log_dir, 'training_history.png'), dpi=300)
        plt.show()

    def plot_adaptation_curve(self, task, updates_range=[0, 1, 2, 3, 5, 10]):
        """Plot the adaptation curve for a specific task"""
        support_X = task['support_X']
        support_y = task['support_y']
        query_X = task['query_X']
        query_y = task['query_y']

        accuracies = []

        for updates in updates_range:
            # Adapt model to task with different number of gradient updates
            adapted_model = self.maml_model.adapt_to_task(support_X, support_y, num_inner_updates=updates)

            # Evaluate on query set
            query_logits = adapted_model(query_X, training=False)
            pred_y = tf.argmax(query_logits, axis=1).numpy()
            accuracy = np.mean(pred_y == query_y)
            accuracies.append(accuracy)

        plt.figure(figsize=(10, 6))
        plt.plot(updates_range, accuracies, 'o-', linewidth=2)
        plt.title('Adaptation Performance vs Gradient Steps')
        plt.xlabel('Number of Gradient Updates')
        plt.ylabel('Query Set Accuracy')
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.savefig(os.path.join(self.log_dir, 'adaptation_curve.png'), dpi=300)
        plt.show()

        return accuracies


#############################################################
# Network Security Scenario Analysis
#############################################################

class NetworkSecurityAnalyzer:
    def __init__(self, maml_model, data_processor, log_dir='logs'):
        self.maml_model = maml_model
        self.data_processor = data_processor
        self.log_dir = log_dir
        os.makedirs(log_dir, exist_ok=True)

    def analyze_novel_attack(self, X, y, attack_indices, normal_indices, n_shot=5):
        """
        Analyze the model's ability to detect a novel attack type

        Args:
            X: Feature matrix
            y: Binary labels
            attack_indices: Indices of examples of the novel attack type
            normal_indices: Indices of normal traffic
            n_shot: Number of examples to use for adaptation
        """
        # Ensure we have enough examples
        if len(attack_indices) < n_shot * 2:
            print(f"Warning: Not enough attack examples. Using {len(attack_indices) // 2} shots instead.")
            n_shot = len(attack_indices) // 2

        # Split attack indices into support and query sets
        attack_support_indices = attack_indices[:n_shot]
        attack_query_indices = attack_indices[n_shot:2*n_shot]

        # Select normal examples for support and query sets
        normal_support_indices = np.random.choice(normal_indices, n_shot, replace=False)
        normal_query_indices = np.random.choice(
            [i for i in normal_indices if i not in normal_support_indices],
            n_shot,
            replace=False
        )

        # Create binary classification task (normal vs attack)
        support_indices = np.concatenate([normal_support_indices, attack_support_indices])
        query_indices = np.concatenate([normal_query_indices, attack_query_indices])

        support_X = X[support_indices]
        support_y = np.concatenate([np.zeros(n_shot), np.ones(n_shot)])
        query_X = X[query_indices]
        query_y = np.concatenate([np.zeros(n_shot), np.ones(n_shot)])

        # Shuffle support set
        support_shuffle = np.arange(len(support_y))
        np.random.shuffle(support_shuffle)
        support_X = support_X[support_shuffle]
        support_y = support_y[support_shuffle]

        # Create task dictionary
        task = {
            'support_X': support_X,
            'support_y': support_y,
            'query_X': query_X,
            'query_y': query_y,
            'n_way': 2,  # Binary classification
            'k_shot': n_shot
        }

        # Evaluate adaptation performance
        accuracies = []
        precisions = []
        recalls = []
        f1_scores = []
        confusion_matrices = []

        update_steps = [0, 1, 3, 5, 10]

        for steps in update_steps:
            # Adapt the model to the task
            adapted_model = self.maml_model.adapt_to_task(support_X, support_y, num_inner_updates=steps)

            # Make predictions on query set
            query_logits = adapted_model(query_X, training=False)
            pred_y = tf.argmax(query_logits, axis=1).numpy()

            # Calculate metrics
            accuracy = np.mean(pred_y == query_y)
            cm = confusion_matrix(query_y, pred_y)

            # Calculate precision, recall, and F1 for the attack class (label 1)
            tn, fp, fn, tp = cm.ravel()
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

            accuracies.append(accuracy)
            precisions.append(precision)
            recalls.append(recall)
            f1_scores.append(f1)
            confusion_matrices.append(cm)

        # Plot adaptation results
        self._plot_adaptation_metrics(update_steps, accuracies, precisions, recalls, f1_scores)

        # Plot confusion matrix for best step
        best_idx = np.argmax(f1_scores)
        self._plot_confusion_matrix(confusion_matrices[best_idx], ['Normal', 'Attack'])

        return {
            'accuracies': accuracies,
            'precisions': precisions,
            'recalls': recalls,
            'f1_scores': f1_scores,
            'best_steps': update_steps[best_idx],
            'best_f1': f1_scores[best_idx]
        }

    def _plot_adaptation_metrics(self, steps, accuracies, precisions, recalls, f1_scores):
        """Plot adaptation metrics vs gradient steps"""
        plt.figure(figsize=(12, 8))

        plt.plot(steps, accuracies, 'o-', label='Accuracy', linewidth=2)
        plt.plot(steps, precisions, 's-', label='Precision', linewidth=2)
        plt.plot(steps, recalls, '^-', label='Recall', linewidth=2)
        plt.plot(steps, f1_scores, 'D-', label='F1 Score', linewidth=2)

        plt.title('Attack Detection Performance vs Adaptation Steps')
        plt.xlabel('Number of Gradient Updates')
        plt.ylabel('Metric Value')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.savefig(os.path.join(self.log_dir, 'adaptation_metrics.png'), dpi=300)
        plt.show()

    def _plot_confusion_matrix(self, cm, classes):
        """Plot confusion matrix"""
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.savefig(os.path.join(self.log_dir, 'confusion_matrix.png'), dpi=300)
        plt.show()

    def analyze_attack_types(self, tasks, attack_names=None):
        """
        Analyze performance across different attack types

        Args:
            tasks: List of few-shot tasks for different attack types
            attack_names: List of attack type names corresponding to tasks
        """
        if attack_names is None:
            attack_names = [f"Attack Type {i+1}" for i in range(len(tasks))]

        n_attacks = len(tasks)
        accuracies = []
        f1_scores = []

        for i, task in enumerate(tasks):
            # Adapt the model to the task
            adapted_model = self.maml_model.adapt_to_task(
                task['support_X'], task['support_y'], num_inner_updates=5
            )

            # Make predictions on query set
            query_X = task['query_X']
            query_y = task['query_y']
            query_logits = adapted_model(query_X, training=False)
            pred_y = tf.argmax(query_logits, axis=1).numpy()

            # Calculate metrics
            accuracy = np.mean(pred_y == query_y)
            cm = confusion_matrix(query_y, pred_y)

            # For binary classification (assuming class 1 is the attack)
            if task['n_way'] == 2:
                tn, fp, fn, tp = cm.ravel()
                precision = tp / (tp + fp) if (tp + fp) > 0 else 0
                recall = tp / (tp + fn) if (tp + fn) > 0 else 0
                f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            else:
                # For multiclass, use macro-averaged F1
                report = classification_report(query_y, pred_y, output_dict=True)
                f1 = report['macro avg']['f1-score']

            accuracies.append(accuracy)
            f1_scores.append(f1)

        # Plot results
        self._plot_attack_performance(attack_names, accuracies, f1_scores)

        return {
            'attack_names': attack_names,
            'accuracies': accuracies,
            'f1_scores': f1_scores
        }

    def _plot_attack_performance(self, attack_names, accuracies, f1_scores):
        """Plot performance across different attack types"""
        plt.figure(figsize=(12, 6))

        x = np.arange(len(attack_names))
        width = 0.35

        plt.bar(x - width/2, accuracies, width, label='Accuracy')
        plt.bar(x + width/2, f1_scores, width, label='F1 Score')

        plt.xlabel('Attack Type')
        plt.ylabel('Score')
        plt.title('Performance Across Different Attack Types')
        plt.xticks(x, attack_names, rotation=45, ha='right')
        plt.legend()
        plt.tight_layout()
        plt.grid(True, axis='y', linestyle='--', alpha=0.7)
        plt.savefig(os.path.join(self.log_dir, 'attack_performance.png'), dpi=300)
        plt.show()

    def visualize_attention_weights(self, X, y, attack_type_idx=None):
        """
        Visualize attention weights to explain model decisions

        Args:
            X: Feature matrix
            y: Labels
            attack_type_idx: Optional index of attack type to visualize
        """
        # Get a sample
        if attack_type_idx is not None:
            sample_idx = np.where(y == attack_type_idx)[0][0]
        else:
            sample_idx = np.random.choice(len(y))

        sample_X = X[sample_idx:sample_idx+1]

        # Create a transformer model with accessible attention weights
        input_shape = X.shape[1:]
        embed_dim = self.maml_model.embed_dim
        num_heads = self.maml_model.num_heads
        ff_dim = self.maml_model.ff_dim

        # A simplified transformer model for visualization
        inputs = layers.Input(shape=input_shape)
        x = layers.Dense(embed_dim)(inputs)
        x = tf.expand_dims(x, axis=1)  # Add sequence dimension

        # Add positional encoding
        x = PositionalEncoding(position=50, d_model=embed_dim)(x)

        # Use the first transformer block for visualization
        attention_layer = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        attn_output = attention_layer(x, x)

        # Create the model
        vis_model = keras.Model(inputs=inputs, outputs=attn_output)

        # Copy weights from trained model (first layer only)
        vis_model.layers[1].set_weights(self.maml_model.model.layers[1].get_weights())

        # Run the model to get attention outputs
        attention_output = vis_model(sample_X)

        # Feature names (for visualization)
        if hasattr(self.data_processor, 'X_columns'):
            feature_names = self.data_processor.X_columns
        else:
            feature_names = [f"Feature_{i}" for i in range(input_shape[0])]

        # Plot attention heatmap
        plt.figure(figsize=(12, 10))
        sns.heatmap(attention_output[0].numpy(), cmap='viridis')
        plt.title(f"Attention Heatmap for Sample (Class {y[sample_idx]})")
        plt.xlabel("Embedded Features")
        plt.ylabel("Sequence Position")
        plt.tight_layout()
        plt.savefig(os.path.join(self.log_dir, 'attention_heatmap.png'), dpi=300)
        plt.show()

        # Return the most important features based on attention
        feature_importance = attention_output[0].numpy().mean(axis=1).flatten()
        top_k = 10  # Top k important features
        top_indices = np.argsort(feature_importance)[-top_k:]

        plt.figure(figsize=(12, 6))
        plt.barh(np.array(feature_names)[top_indices], feature_importance[top_indices])
        plt.xlabel("Average Attention")
        plt.ylabel("Feature")
        plt.title("Top Features by Attention Weight")
        plt.tight_layout()
        plt.savefig(os.path.join(self.log_dir, 'feature_importance.png'), dpi=300)
        plt.show()


#############################################################
# End-to-End Pipeline Demo
#############################################################

def run_demo(data_path=None, dataset_name="synthetic"):
    """Run a complete demo of the MAML transformer for network intrusion detection"""
    print("=" * 80)
    print("MAML Transformer for Network Intrusion Detection - Demo")
    print("=" * 80)

    # Step 1: Data preparation
    print("\n1. Loading and preprocessing data...")
    data_processor = NetworkDataProcessor(data_path)

    if dataset_name == "synthetic":
        df = data_processor._generate_synthetic_data(n_samples=10000)
    else:
        df = data_processor.load_data(dataset_name)

    X, y_binary, y_multiclass = data_processor.preprocess_data(df)

    print(f"Total samples: {len(X)}")
    print(f"Features: {X.shape[1]}")
    if y_multiclass is not None:
        print(f"Number of attack classes: {len(np.unique(y_multiclass))}")
    print(f"Attack samples: {np.sum(y_binary)}")
    print(f"Normal samples: {len(y_binary) - np.sum(y_binary)}")

    # Step 2: Create tasks for meta-learning
    print("\n2. Creating few-shot learning tasks...")
    if y_multiclass is not None:
        all_tasks = data_processor.create_tasks(X, y_multiclass, num_tasks=200, k_shot=5, query_size=15)

        # Split into train, validation and test tasks
        num_train = int(len(all_tasks) * 0.7)
        num_val = int(len(all_tasks) * 0.15)

        train_tasks = all_tasks[:num_train]
        val_tasks = all_tasks[num_train:num_train+num_val]
        test_tasks = all_tasks[num_train+num_val:]

        print(f"Number of training tasks: {len(train_tasks)}")
        print(f"Number of validation tasks: {len(val_tasks)}")
        print(f"Number of test tasks: {len(test_tasks)}")
    else:
        print("Multiclass labels not available. Cannot create few-shot tasks.")
        return

    # Step 3: Initialize MAML model
    print("\n3. Initializing MAML Transformer model...")
    input_shape = X.shape[1:]  # Feature dimensions
    n_way = min(5, len(np.unique(y_multiclass)))  # Number of classes per task

    maml_model = MAMLTransformer(
        input_shape=input_shape,
        n_way=n_way,
        k_shot=5,
        inner_lr=0.01,
        meta_lr=0.001,
        meta_batch_size=16,
        num_inner_updates=5,
        embed_dim=128,
        num_heads=4,
        ff_dim=256,
        num_transformer_blocks=3,
        mlp_units=[64, 32],
        dropout=0.1
    )

    print(f"Model initialized with {n_way}-way classification")
    print(f"Input shape: {input_shape}")

    # Step 4: Meta-training
    print("\n4. Starting meta-training...")
    trainer = MAMLTrainer(
        maml_model=maml_model,
        train_tasks=train_tasks,
        val_tasks=val_tasks,
        test_tasks=test_tasks,
        meta_epochs=1000,  # Reduced for demo
        meta_batch_size=16,
        eval_interval=50,
        early_stopping_patience=5,
        log_dir='logs/maml_transformer'
    )

    history = trainer.train()

    # Step 5: Visualize training process
    print("\n5. Visualizing training history...")
    trainer.visualize_training()

    # Step 6: Adaptation analysis
    print("\n6. Analyzing adaptation to novel attacks...")
    # Select a random test task for analysis
    random_task_idx = np.random.randint(len(test_tasks))
    random_task = test_tasks[random_task_idx]

    print(f"Analyzing adaptation curve for task {random_task_idx}")
    adaptation_curve = trainer.plot_adaptation_curve(random_task)

    # Step 7: Network security analysis
    print("\n7. Performing network security analysis...")
    security_analyzer = NetworkSecurityAnalyzer(
        maml_model=maml_model,
        data_processor=data_processor,
        log_dir='logs/maml_transformer'
    )

    # Find indices for each attack type
    if y_multiclass is not None:
        attack_types = np.unique(y_multiclass)
        normal_indices = np.where(y_binary == 0)[0]

        # Skip normal class (usually labeled as 0)
        for attack_idx in attack_types:
            if attack_idx == 0:  # Skip normal class
                continue

            attack_name = f"Attack Type {attack_idx}"
            print(f"\nAnalyzing novel attack detection: {attack_name}")

            # Get indices for this attack type
            attack_indices = np.where(y_multiclass == attack_idx)[0]

            if len(attack_indices) < 10:
                print(f"Not enough samples for attack type {attack_idx}. Skipping.")
                continue

            # Analyze novel attack detection
            results = security_analyzer.analyze_novel_attack(
                X, y_binary, attack_indices, normal_indices, n_shot=5
            )

            print(f"Best adaptation steps: {results['best_steps']}")
            print(f"Best F1 score: {results['best_f1']:.4f}")

    # Step 8: Cross-attack analysis
    print("\n8. Analyzing performance across attack types...")
    # Create tasks for different attack types
    attack_tasks = []
    attack_names = []

    if y_multiclass is not None:
        for attack_idx in attack_types:
            if attack_idx == 0:  # Skip normal class
                continue

            # Create a binary classification task (normal vs this attack)
            attack_indices = np.where(y_multiclass == attack_idx)[0]

            if len(attack_indices) < 10:
                continue

            # Select 5 examples for support and 15 for query
            support_attack = attack_indices[:5]
            query_attack = attack_indices[5:20]

            # Select normal examples
            support_normal = normal_indices[:5]
            query_normal = normal_indices[5:20]

            # Create support and query sets
            support_indices = np.concatenate([support_normal, support_attack])
            query_indices = np.concatenate([query_normal, query_attack])

            # Create binary labels
            support_y = np.concatenate([np.zeros(5), np.ones(5)])
            query_y = np.concatenate([np.zeros(15), np.ones(15)])

            # Shuffle support set
            support_shuffle = np.arange(len(support_y))
            np.random.shuffle(support_shuffle)
            support_X = X[support_indices][support_shuffle]
            support_y = support_y[support_shuffle]

            # Create task
            task = {
                'support_X': support_X,
                'support_y': support_y,
                'query_X': X[query_indices],
                'query_y': query_y,
                'n_way': 2,
                'k_shot': 5
            }

            attack_tasks.append(task)
            attack_names.append(f"Attack {attack_idx}")

    if attack_tasks:
        performance = security_analyzer.analyze_attack_types(attack_tasks, attack_names)

        print("\nPerformance across attack types:")
        for name, acc, f1 in zip(performance['attack_names'], performance['accuracies'], performance['f1_scores']):
            print(f"{name}: Accuracy={acc:.4f}, F1={f1:.4f}")

    # Step 9: Attention visualization
    print("\n9. Visualizing attention weights for explainability...")
    security_analyzer.visualize_attention_weights(X, y_multiclass if y_multiclass is not None else y_binary)

    print("\nDemo completed successfully!")
    return maml_model, trainer, security_analyzer


# Main function to run the entire pipeline
if __name__ == "__main__":
    # Set smaller figures for Jupyter notebooks if needed
    plt.rcParams['figure.figsize'] = (10, 6)

    # Run the complete demo
    maml_model, trainer, analyzer = run_demo()