In [None]:
def run_demo(data_path=None, dataset_name="synthetic"):
    """Run a complete demo of the MAML transformer for network intrusion detection"""
    print("=" * 80)
    print("MAML Transformer for Network Intrusion Detection - Demo")
    print("=" * 80)

    # Step 1: Data preparation
    print("\n1. Loading and preprocessing data...")
    data_processor = NetworkDataProcessor(data_path)

    if dataset_name == "synthetic":
        df = data_processor._generate_synthetic_data(n_samples=10000)
    else:
        df = data_processor.load_data(dataset_name)

    X, y_binary, y_multiclass = data_processor.preprocess_data(df)

    print(f"Total samples: {len(X)}")
    print(f"Features: {X.shape[1]}")
    if y_multiclass is not None:
        print(f"Number of attack classes: {len(np.unique(y_multiclass))}")
    print(f"Attack samples: {np.sum(y_binary)}")
    print(f"Normal samples: {len(y_binary) - np.sum(y_binary)}")

    # Step 2: Create tasks for meta-learning
    print("\n2. Creating few-shot learning tasks...")
    if y_multiclass is not None:
        all_tasks = data_processor.create_tasks(X, y_multiclass, num_tasks=200, k_shot=5, query_size=15)

        # Split into train, validation and test tasks
        num_train = int(len(all_tasks) * 0.7)
        num_val = int(len(all_tasks) * 0.15)

        train_tasks = all_tasks[:num_train]
        val_tasks = all_tasks[num_train:num_train+num_val]
        test_tasks = all_tasks[num_train+num_val:]

        print(f"Number of training tasks: {len(train_tasks)}")
        print(f"Number of validation tasks: {len(val_tasks)}")
        print(f"Number of test tasks: {len(test_tasks)}")
    else:
        print("Multiclass labels not available. Cannot create few-shot tasks.")
        return

    # Step 3: Initialize MAML model
    print("\n3. Initializing MAML Transformer model...")
    input_shape = X.shape[1:]  # Feature dimensions
    n_way = min(5, len(np.unique(y_multiclass)))  # Number of classes per task

    maml_model = MAMLTransformer(
        input_shape=input_shape,
        n_way=n_way,
        k_shot=5,
        inner_lr=0.01,
        meta_lr=0.001,
        meta_batch_size=16,
        num_inner_updates=5,
        embed_dim=128,
        num_heads=4,
        ff_dim=256,
        num_transformer_blocks=3,
        mlp_units=[64, 32],
        dropout=0.1
    )

    print(f"Model initialized with {n_way}-way classification")
    print(f"Input shape: {input_shape}")

    # Step 4: Meta-training
    print("\n4. Starting meta-training...")
    trainer = MAMLTrainer(
        maml_model=maml_model,
        train_tasks=train_tasks,
        val_tasks=val_tasks,
        test_tasks=test_tasks,
        meta_epochs=1000,  # Reduced for demo
        meta_batch_size=16,
        eval_interval=50,
        early_stopping_patience=5,
        log_dir='logs/maml_transformer'
    )

    history = trainer.train()

    # Step 5: Visualize training process
    print("\n5. Visualizing training history...")
    trainer.visualize_training()

    # Step 6: Adaptation analysis
    print("\n6. Analyzing adaptation to novel attacks...")
    # Select a random test task for analysis
    random_task_idx = np.random.randint(len(test_tasks))
    random_task = test_tasks[random_task_idx]

    print(f"Analyzing adaptation curve for task {random_task_idx}")
    adaptation_curve = trainer.plot_adaptation_curve(random_task)

    # Step 7: Network security analysis
    print("\n7. Performing network security analysis...")
    security_analyzer = NetworkSecurityAnalyzer(
        maml_model=maml_model,
        data_processor=data_processor,
        log_dir='logs/maml_transformer'
    )

    # Find indices for each attack type
    if y_multiclass is not None:
        attack_types = np.unique(y_multiclass)
        normal_indices = np.where(y_binary == 0)[0]

        # Skip normal class (usually labeled as 0)
        for attack_idx in attack_types:
            if attack_idx == 0:  # Skip normal class
                continue

            attack_name = f"Attack Type {attack_idx}"
            print(f"\nAnalyzing novel attack detection: {attack_name}")

            # Get indices for this attack type
            attack_indices = np.where(y_multiclass == attack_idx)[0]

            if len(attack_indices) < 10:
                print(f"Not enough samples for attack type {attack_idx}. Skipping.")
                continue

            # Analyze novel attack detection
            results = security_analyzer.analyze_novel_attack(
                X, y_binary, attack_indices, normal_indices, n_shot=5
            )

            print(f"Best adaptation steps: {results['best_steps']}")
            print(f"Best F1 score: {results['best_f1']:.4f}")

    # Step 8: Cross-attack analysis
    print("\n8. Analyzing performance across attack types...")
    # Create tasks for different attack types
    attack_tasks = []
    attack_names = []

    if y_multiclass is not None:
        for attack_idx in attack_types:
            if attack_idx == 0:  # Skip normal class
                continue

            # Create a binary classification task (normal vs this attack)
            attack_indices = np.where(y_multiclass == attack_idx)[0]

            if len(attack_indices) < 10:
                continue

            # Select 5 examples for support and 15 for query
            support_attack = attack_indices[:5]
            query_attack = attack_indices[5:20]

            # Select normal examples
            support_normal = normal_indices[:5]
            query_normal = normal_indices[5:20]

            # Create support and query sets
            support_indices = np.concatenate([support_normal, support_attack])
            query_indices = np.concatenate([query_normal, query_attack])

            # Create binary labels
            support_y = np.concatenate([np.zeros(5), np.ones(5)])
            query_y = np.concatenate([np.zeros(15), np.ones(15)])

            # Shuffle support set
            support_shuffle = np.arange(len(support_y))
            np.random.shuffle(support_shuffle)
            support_X = X[support_indices][support_shuffle]
            support_y = support_y[support_shuffle]

            # Create task
            task = {
                'support_X': support_X,
                'support_y': support_y,
                'query_X': X[query_indices],
                'query_y': query_y,
                'n_way': 2,
                'k_shot': 5
            }

            attack_tasks.append(task)
            attack_names.append(f"Attack {attack_idx}")

    if attack_tasks:
        performance = security_analyzer.analyze_attack_types(attack_tasks, attack_names)

        print("\nPerformance across attack types:")
        for name, acc, f1 in zip(performance['attack_names'], performance['accuracies'], performance['f1_scores']):
            print(f"{name}: Accuracy={acc:.4f}, F1={f1:.4f}")

    # Step 9: Attention visualization
    print("\n9. Visualizing attention weights for explainability...")
    security_analyzer.visualize_attention_weights(X, y_multiclass if y_multiclass is not None else y_binary)

    print("\nDemo completed successfully!")
    return maml_model, trainer, security_analyzer


# Main function to run the entire pipeline
if __name__ == "__main__":
    # Set smaller figures for Jupyter notebooks if needed
    plt.rcParams['figure.figsize'] = (10, 6)

    # Run the complete demo
    maml_model, trainer, analyzer = run_demo()