In [None]:
#############################################################
# Network Security Scenario Analysis
#############################################################

class NetworkSecurityAnalyzer:
    def __init__(self, maml_model, data_processor, log_dir='logs'):
        self.maml_model = maml_model
        self.data_processor = data_processor
        self.log_dir = log_dir
        os.makedirs(log_dir, exist_ok=True)

    def analyze_novel_attack(self, X, y, attack_indices, normal_indices, n_shot=5):
        """
        Analyze the model's ability to detect a novel attack type

        Args:
            X: Feature matrix
            y: Binary labels
            attack_indices: Indices of examples of the novel attack type
            normal_indices: Indices of normal traffic
            n_shot: Number of examples to use for adaptation
        """
        # Ensure we have enough examples
        if len(attack_indices) < n_shot * 2:
            print(f"Warning: Not enough attack examples. Using {len(attack_indices) // 2} shots instead.")
            n_shot = len(attack_indices) // 2

        # Split attack indices into support and query sets
        attack_support_indices = attack_indices[:n_shot]
        attack_query_indices = attack_indices[n_shot:2*n_shot]

        # Select normal examples for support and query sets
        normal_support_indices = np.random.choice(normal_indices, n_shot, replace=False)
        normal_query_indices = np.random.choice(
            [i for i in normal_indices if i not in normal_support_indices],
            n_shot,
            replace=False
        )

        # Create binary classification task (normal vs attack)
        support_indices = np.concatenate([normal_support_indices, attack_support_indices])
        query_indices = np.concatenate([normal_query_indices, attack_query_indices])

        support_X = X[support_indices]
        support_y = np.concatenate([np.zeros(n_shot), np.ones(n_shot)])
        query_X = X[query_indices]
        query_y = np.concatenate([np.zeros(n_shot), np.ones(n_shot)])

        # Shuffle support set
        support_shuffle = np.arange(len(support_y))
        np.random.shuffle(support_shuffle)
        support_X = support_X[support_shuffle]
        support_y = support_y[support_shuffle]

        # Create task dictionary
        task = {
            'support_X': support_X,
            'support_y': support_y,
            'query_X': query_X,
            'query_y': query_y,
            'n_way': 2,  # Binary classification
            'k_shot': n_shot
        }

        # Evaluate adaptation performance
        accuracies = []
        precisions = []
        recalls = []
        f1_scores = []
        confusion_matrices = []

        update_steps = [0, 1, 3, 5, 10]

        for steps in update_steps:
            # Adapt the model to the task
            adapted_model = self.maml_model.adapt_to_task(support_X, support_y, num_inner_updates=steps)

            # Make predictions on query set
            query_logits = adapted_model(query_X, training=False)
            pred_y = tf.argmax(query_logits, axis=1).numpy()

            # Calculate metrics
            accuracy = np.mean(pred_y == query_y)
            cm = confusion_matrix(query_y, pred_y)

            # Calculate precision, recall, and F1 for the attack class (label 1)
            tn, fp, fn, tp = cm.ravel()
            precision = tp / (tp + fp) if (tp + fp) > 0 else 0
            recall = tp / (tp + fn) if (tp + fn) > 0 else 0
            f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0

            accuracies.append(accuracy)
            precisions.append(precision)
            recalls.append(recall)
            f1_scores.append(f1)
            confusion_matrices.append(cm)

        # Plot adaptation results
        self._plot_adaptation_metrics(update_steps, accuracies, precisions, recalls, f1_scores)

        # Plot confusion matrix for best step
        best_idx = np.argmax(f1_scores)
        self._plot_confusion_matrix(confusion_matrices[best_idx], ['Normal', 'Attack'])

        return {
            'accuracies': accuracies,
            'precisions': precisions,
            'recalls': recalls,
            'f1_scores': f1_scores,
            'best_steps': update_steps[best_idx],
            'best_f1': f1_scores[best_idx]
        }

    def _plot_adaptation_metrics(self, steps, accuracies, precisions, recalls, f1_scores):
        """Plot adaptation metrics vs gradient steps"""
        plt.figure(figsize=(12, 8))

        plt.plot(steps, accuracies, 'o-', label='Accuracy', linewidth=2)
        plt.plot(steps, precisions, 's-', label='Precision', linewidth=2)
        plt.plot(steps, recalls, '^-', label='Recall', linewidth=2)
        plt.plot(steps, f1_scores, 'D-', label='F1 Score', linewidth=2)

        plt.title('Attack Detection Performance vs Adaptation Steps')
        plt.xlabel('Number of Gradient Updates')
        plt.ylabel('Metric Value')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)
        plt.savefig(os.path.join(self.log_dir, 'adaptation_metrics.png'), dpi=300)
        plt.show()

    def _plot_confusion_matrix(self, cm, classes):
        """Plot confusion matrix"""
        plt.figure(figsize=(8, 6))
        sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=classes, yticklabels=classes)
        plt.title('Confusion Matrix')
        plt.ylabel('True Label')
        plt.xlabel('Predicted Label')
        plt.tight_layout()
        plt.savefig(os.path.join(self.log_dir, 'confusion_matrix.png'), dpi=300)
        plt.show()

    def analyze_attack_types(self, tasks, attack_names=None):
        """
        Analyze performance across different attack types

        Args:
            tasks: List of few-shot tasks for different attack types
            attack_names: List of attack type names corresponding to tasks
        """
        if attack_names is None:
            attack_names = [f"Attack Type {i+1}" for i in range(len(tasks))]

        n_attacks = len(tasks)
        accuracies = []
        f1_scores = []

        for i, task in enumerate(tasks):
            # Adapt the model to the task
            adapted_model = self.maml_model.adapt_to_task(
                task['support_X'], task['support_y'], num_inner_updates=5
            )

            # Make predictions on query set
            query_X = task['query_X']
            query_y = task['query_y']
            query_logits = adapted_model(query_X, training=False)
            pred_y = tf.argmax(query_logits, axis=1).numpy()

            # Calculate metrics
            accuracy = np.mean(pred_y == query_y)
            cm = confusion_matrix(query_y, pred_y)

            # For binary classification (assuming class 1 is the attack)
            if task['n_way'] == 2:
                tn, fp, fn, tp = cm.ravel()
                precision = tp / (tp + fp) if (tp + fp) > 0 else 0
                recall = tp / (tp + fn) if (tp + fn) > 0 else 0
                f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
            else:
                # For multiclass, use macro-averaged F1
                report = classification_report(query_y, pred_y, output_dict=True)
                f1 = report['macro avg']['f1-score']

            accuracies.append(accuracy)
            f1_scores.append(f1)

        # Plot results
        self._plot_attack_performance(attack_names, accuracies, f1_scores)

        return {
            'attack_names': attack_names,
            'accuracies': accuracies,
            'f1_scores': f1_scores
        }

    def _plot_attack_performance(self, attack_names, accuracies, f1_scores):
        """Plot performance across different attack types"""
        plt.figure(figsize=(12, 6))

        x = np.arange(len(attack_names))
        width = 0.35

        plt.bar(x - width/2, accuracies, width, label='Accuracy')
        plt.bar(x + width/2, f1_scores, width, label='F1 Score')

        plt.xlabel('Attack Type')
        plt.ylabel('Score')
        plt.title('Performance Across Different Attack Types')
        plt.xticks(x, attack_names, rotation=45, ha='right')
        plt.legend()
        plt.tight_layout()
        plt.grid(True, axis='y', linestyle='--', alpha=0.7)
        plt.savefig(os.path.join(self.log_dir, 'attack_performance.png'), dpi=300)
        plt.show()

    def visualize_attention_weights(self, X, y, attack_type_idx=None):
        """
        Visualize attention weights to explain model decisions

        Args:
            X: Feature matrix
            y: Labels
            attack_type_idx: Optional index of attack type to visualize
        """
        # Get a sample
        if attack_type_idx is not None:
            sample_idx = np.where(y == attack_type_idx)[0][0]
        else:
            sample_idx = np.random.choice(len(y))

        sample_X = X[sample_idx:sample_idx+1]

        # Create a transformer model with accessible attention weights
        input_shape = X.shape[1:]
        embed_dim = self.maml_model.embed_dim
        num_heads = self.maml_model.num_heads
        ff_dim = self.maml_model.ff_dim

        # A simplified transformer model for visualization
        inputs = layers.Input(shape=input_shape)
        x = layers.Dense(embed_dim)(inputs)
        x = tf.expand_dims(x, axis=1)  # Add sequence dimension

        # Add positional encoding
        x = PositionalEncoding(position=50, d_model=embed_dim)(x)

        # Use the first transformer block for visualization
        attention_layer = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        attn_output = attention_layer(x, x)

        # Create the model
        vis_model = keras.Model(inputs=inputs, outputs=attn_output)

        # Copy weights from trained model (first layer only)
        vis_model.layers[1].set_weights(self.maml_model.model.layers[1].get_weights())

        # Run the model to get attention outputs
        attention_output = vis_model(sample_X)

        # Feature names (for visualization)
        if hasattr(self.data_processor, 'X_columns'):
            feature_names = self.data_processor.X_columns
        else:
            feature_names = [f"Feature_{i}" for i in range(input_shape[0])]

        # Plot attention heatmap
        plt.figure(figsize=(12, 10))
        sns.heatmap(attention_output[0].numpy(), cmap='viridis')
        plt.title(f"Attention Heatmap for Sample (Class {y[sample_idx]})")
        plt.xlabel("Embedded Features")
        plt.ylabel("Sequence Position")
        plt.tight_layout()
        plt.savefig(os.path.join(self.log_dir, 'attention_heatmap.png'), dpi=300)
        plt.show()

        # Return the most important features based on attention
        feature_importance = attention_output[0].numpy().mean(axis=1).flatten()
        top_k = 10  # Top k important features
        top_indices = np.argsort(feature_importance)[-top_k:]

        plt.figure(figsize=(12, 6))
        plt.barh(np.array(feature_names)[top_indices], feature_importance[top_indices])
        plt.xlabel("Average Attention")
        plt.ylabel("Feature")
        plt.title("Top Features by Attention Weight")
        plt.tight_layout()
        plt.savefig(os.path.join(self.log_dir, 'feature_importance.png'), dpi=300)
        plt.show()
