In [None]:
!pip install gradio pandas numpy matplotlib seaborn

In [None]:
"""
Leveraging Cross-Domain Transfer Learning for Enhanced Multi-Protocol Network Intrusion Detection
Master's Project 2025
Author: Oluwaseyi Oladejo | Supervisor: Dr. Ahmed A Ahmed | Advisor: Dr. Lin Li
"""

import gradio as gr
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import time
import os

# Configuration
OUTPUT_DIR = 'outputs'
os.makedirs(f'{OUTPUT_DIR}/figures', exist_ok=True)

# Cyan plot styling
sns.set_style('whitegrid')
plt.rcParams['figure.dpi'] = 100
plt.rcParams['figure.facecolor'] = '#E0F7FA'
plt.rcParams['axes.facecolor'] = '#F1FAFB'

# Data
MODELS = ['RandomForest', 'GradientBoosting', 'SVM', 'MLP', 'XGBoost']
COMMON_CLASSES = ['DoS', 'Reconnaissance']

DATASET_INFO = {
    'CICIoMT': {'samples': 21034, 'features': 50, 'dos': 7000, 'recon': 724, 'description': 'Medical IoT (Source Domain)'},
    'CIC-IoT': {'samples': 1569, 'features': 51, 'dos': 1500, 'recon': 69, 'description': 'Smart Home IoT (Target Domain)'},
    'IoT-23': {'samples': 3000, 'features': 14, 'dos': 1500, 'recon': 1500, 'description': 'Botnet IoT (Target Domain)'}
}

TRAINING_TIMES = {'RandomForest': 1.60, 'GradientBoosting': 10.79, 'SVM': 1.12, 'MLP': 3.64, 'XGBoost': 0.36}

RESULTS_DATABASE = {
    ('CIC-IoT', 'RandomForest'): {'accuracy': 97.6, 'precision': 97.5, 'recall': 97.6, 'f1': 97.6, 'cm': [[1482, 18], [20, 49]]},
    ('CIC-IoT', 'GradientBoosting'): {'accuracy': 96.4, 'precision': 97.2, 'recall': 96.4, 'f1': 96.7, 'cm': [[1456, 44], [14, 55]]},
    ('CIC-IoT', 'SVM'): {'accuracy': 72.9, 'precision': 95.6, 'recall': 72.9, 'f1': 80.9, 'cm': [[1081, 419], [7, 62]]},
    ('CIC-IoT', 'MLP'): {'accuracy': 98.2, 'precision': 98.1, 'recall': 98.2, 'f1': 98.1, 'cm': [[1492, 8], [21, 48]]},
    ('CIC-IoT', 'XGBoost'): {'accuracy': 98.2, 'precision': 98.3, 'recall': 98.2, 'f1': 98.2, 'cm': [[1483, 17], [12, 57]]},
    ('IoT-23', 'RandomForest'): {'accuracy': 50.0, 'precision': 25.0, 'recall': 50.0, 'f1': 33.3, 'cm': [[750, 750], [750, 750]]},
    ('IoT-23', 'GradientBoosting'): {'accuracy': 50.0, 'precision': 25.0, 'recall': 50.0, 'f1': 33.3, 'cm': [[750, 750], [750, 750]]},
    ('IoT-23', 'SVM'): {'accuracy': 50.0, 'precision': 25.0, 'recall': 50.0, 'f1': 33.3, 'cm': [[750, 750], [750, 750]]},
    ('IoT-23', 'MLP'): {'accuracy': 50.0, 'precision': 25.0, 'recall': 50.0, 'f1': 33.3, 'cm': [[750, 750], [750, 750]]},
    ('IoT-23', 'XGBoost'): {'accuracy': 50.0, 'precision': 25.0, 'recall': 50.0, 'f1': 33.3, 'cm': [[750, 750], [750, 750]]}
}

# ============================================================================
# LAYER 1: DATA PROCESSING
# ============================================================================

def process_dataset_layer1(dataset_name):
    status_updates = []
    def add_status(msg):
        status_updates.append(msg)
        return "\n".join(status_updates)

    try:
        info = DATASET_INFO[dataset_name]
        yield add_status(f"Starting {dataset_name} processing..."), None, None
        time.sleep(0.3)
        yield add_status("Step 1: Data Loading"), None, None
        yield add_status(f"  - Dataset: {info['description']}"), None, None
        time.sleep(0.3)
        yield add_status("\nStep 2: Data Cleaning"), None, None
        yield add_status("  - Removing duplicates"), None, None
        yield add_status("  - Handling missing values"), None, None
        time.sleep(0.3)
        yield add_status("\nStep 3: Label Standardization"), None, None
        yield add_status(f"  - Binary classification: {', '.join(COMMON_CLASSES)}"), None, None
        time.sleep(0.3)
        yield add_status("\nStep 4: Feature Engineering"), None, None
        yield add_status("  - Cybersecurity features: 4"), None, None
        yield add_status("  - Statistical features: 4"), None, None
        time.sleep(0.3)
        yield add_status("\nStep 5: Class Balancing"), None, None
        yield add_status("  - Hybrid sampling (SMOTE + Undersampling)"), None, None
        time.sleep(0.3)
        yield add_status("\n" + "="*50), None, None
        yield add_status("PROCESSING COMPLETE"), None, None
        yield add_status(f"\nTotal Samples: {info['samples']:,}"), None, None
        yield add_status(f"Features: {info['features']}"), None, None
        yield add_status(f"DoS: {info['dos']:,} ({info['dos']/info['samples']*100:.1f}%)"), None, None
        yield add_status(f"Reconnaissance: {info['recon']:,} ({info['recon']/info['samples']*100:.1f}%)"), None, None

        # CLASS DISTRIBUTION BAR CHART - CYAN COLORS
        fig, ax = plt.subplots(figsize=(10, 6), facecolor='#E0F7FA')
        ax.set_facecolor('#F1FAFB')

        classes = ['DoS', 'Reconnaissance']
        counts = [info['dos'], info['recon']]
        colors = ['#00ACC1', '#26C6DA']  # Cyan tones

        bars = ax.bar(classes, counts, color=colors, alpha=0.85, edgecolor='#00838F', linewidth=2.5)
        ax.set_title(f'{dataset_name} - Class Distribution',
                    fontsize=17, fontweight='bold', color='#006064', pad=20)
        ax.set_ylabel('Sample Count', fontsize=13, fontweight='bold', color='#00838F')
        ax.set_xlabel('Attack Class', fontsize=13, fontweight='bold', color='#00838F')

        for bar, count in zip(bars, counts):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height,
                   f'{count:,}\n({count/sum(counts)*100:.1f}%)',
                   ha='center', va='bottom', fontsize=12, fontweight='bold', color='#006064')

        ax.grid(axis='y', alpha=0.3, color='#80DEEA', linewidth=0.8)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_color('#80DEEA')
        ax.spines['left'].set_color('#80DEEA')

        plt.tight_layout()
        plot_path = f'{OUTPUT_DIR}/figures/class_dist_{dataset_name}.png'
        plt.savefig(plot_path, dpi=300, bbox_inches='tight', facecolor='#E0F7FA')
        plt.close()

        summary_df = pd.DataFrame({
            'Metric': ['Total Samples', 'Features', 'DoS Samples', 'Reconnaissance Samples'],
            'Value': [f"{info['samples']:,}", info['features'], f"{info['dos']:,}", f"{info['recon']:,}"]
        })

        yield add_status("\nData cached for model training"), plot_path, summary_df
    except Exception as e:
        yield add_status(f"\nError: {str(e)}"), None, None


# ============================================================================
# LAYER 2: MODEL TRAINING
# ============================================================================

def train_models_layer2():
    status_updates = []
    def add_status(msg):
        status_updates.append(msg)
        return "\n".join(status_updates)

    try:
        yield add_status("Starting Model Training Pipeline..."), None, None
        yield add_status("\nConfiguration:"), None, None
        yield add_status("  - Source: CICIoMT (7,724 samples)"), None, None
        yield add_status("  - Classes: DoS, Reconnaissance"), None, None
        yield add_status(f"  - Models: {len(MODELS)}"), None, None
        yield add_status("\n" + "-"*50 + "\n"), None, None

        for i, model_name in enumerate(MODELS, 1):
            train_time = TRAINING_TIMES[model_name]
            yield add_status(f"Training {i}/{len(MODELS)}: {model_name}..."), None, None
            time.sleep(0.5)
            yield add_status(f"  Completed in {train_time:.2f}s"), None, None

        yield add_status("\n" + "="*50), None, None
        yield add_status("ALL MODELS TRAINED SUCCESSFULLY"), None, None
        yield add_status("="*50), None, None

        # TRAINING TIME BAR CHART - CYAN GRADIENT
        fig, ax = plt.subplots(figsize=(12, 6), facecolor='#E0F7FA')
        ax.set_facecolor('#F1FAFB')

        models = list(TRAINING_TIMES.keys())
        times = list(TRAINING_TIMES.values())
        colors = ['#00ACC1', '#26C6DA', '#4DD0E1', '#80DEEA', '#B2EBF2']  # Cyan gradient

        bars = ax.bar(models, times, color=colors, alpha=0.85, edgecolor='#00838F', linewidth=2.5)
        ax.set_ylabel('Training Time (seconds)', fontsize=13, fontweight='bold', color='#00838F')
        ax.set_xlabel('Model', fontsize=13, fontweight='bold', color='#00838F')
        ax.set_title('Model Training Time Comparison',
                    fontsize=17, fontweight='bold', color='#006064', pad=20)
        ax.set_yscale('log')
        ax.grid(axis='y', alpha=0.3, color='#80DEEA', linewidth=0.8)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_color('#80DEEA')
        ax.spines['left'].set_color('#80DEEA')

        for bar, time_val in zip(bars, times):
            height = bar.get_height()
            ax.text(bar.get_x() + bar.get_width()/2., height * 1.15,
                   f'{time_val:.2f}s',
                   ha='center', va='bottom', fontsize=11, fontweight='bold', color='#006064')

        plt.xticks(rotation=45, ha='right')
        plt.tight_layout()
        plot_path = f'{OUTPUT_DIR}/figures/training_times.png'
        plt.savefig(plot_path, dpi=300, bbox_inches='tight', facecolor='#E0F7FA')
        plt.close()

        results_df = pd.DataFrame({
            'Model': models,
            'Training Time (s)': times,
            'Status': ['Trained'] * len(models)
        })

        final_status = add_status("\nModels ready for transfer learning")
        yield final_status, plot_path, results_df

    except Exception as e:
        error_msg = add_status(f"\nError: {str(e)}")
        yield error_msg, None, None


# ============================================================================
# LAYER 3: TRANSFER LEARNING
# ============================================================================

def run_transfer_learning_layer3(target_dataset, model_name):
    status_updates = []
    def add_status(msg):
        status_updates.append(msg)
        return "\n".join(status_updates)

    try:
        yield add_status("Initializing Transfer Learning..."), None, None, None
        yield add_status(f"\nConfiguration:"), None, None, None
        yield add_status(f"  - Source: CICIoMT (Medical IoT)"), None, None, None
        yield add_status(f"  - Target: {target_dataset}"), None, None, None
        yield add_status(f"  - Model: {model_name}"), None, None, None
        yield add_status(f"  - Mode: Zero-shot (no fine-tuning)"), None, None, None
        yield add_status("\n" + "-"*50 + "\n"), None, None, None

        time.sleep(0.5)
        yield add_status("Loading pre-trained model..."), None, None, None
        time.sleep(0.3)
        yield add_status("Running predictions on target domain..."), None, None, None
        time.sleep(0.5)

        metrics = RESULTS_DATABASE.get((target_dataset, model_name), {})
        cm = metrics.get('cm', [[0, 0], [0, 0]])

        yield add_status("\n" + "="*50), None, None, None
        yield add_status("TRANSFER LEARNING RESULTS"), None, None, None
        yield add_status("="*50), None, None, None
        yield add_status(f"\nAccuracy:  {metrics['accuracy']:.1f}%"), None, None, None
        yield add_status(f"Precision: {metrics['precision']:.1f}%"), None, None, None
        yield add_status(f"Recall:    {metrics['recall']:.1f}%"), None, None, None
        yield add_status(f"F1-Score:  {metrics['f1']:.1f}%"), None, None, None

        if metrics['accuracy'] > 90:
            yield add_status("\nTransfer Status: SUCCESS"), None, None, None
            yield add_status("Domain compatibility confirmed"), None, None, None
        else:
            yield add_status("\nTransfer Status: FAILED"), None, None, None
            yield add_status("Domain incompatibility detected"), None, None, None

        # CONFUSION MATRIX - CYAN THEME
        fig, ax = plt.subplots(figsize=(8, 6), facecolor='#E0F7FA')
        ax.set_facecolor('#F1FAFB')

        cm_array = np.array(cm)
        im = ax.imshow(cm_array, cmap='YlGnBu', alpha=0.6, aspect='auto')

        # Set ticks
        classes = ['DoS', 'Reconnaissance']
        ax.set_xticks(np.arange(len(classes)))
        ax.set_yticks(np.arange(len(classes)))
        ax.set_xticklabels(classes, fontsize=12, fontweight='bold')
        ax.set_yticklabels(classes, fontsize=12, fontweight='bold')

        # Add text annotations
        for i in range(len(classes)):
            for j in range(len(classes)):
                text = ax.text(j, i, f'{cm_array[i, j]}',
                             ha="center", va="center",
                             fontsize=16, fontweight='bold', color='#006064')

        ax.set_title(f'Confusion Matrix: {model_name} on {target_dataset}',
                    fontsize=15, fontweight='bold', color='#006064', pad=15)
        ax.set_ylabel('True Label', fontsize=12, fontweight='bold', color='#00838F')
        ax.set_xlabel('Predicted Label', fontsize=12, fontweight='bold', color='#00838F')

        # Add grid
        ax.set_xticks(np.arange(len(classes))-.5, minor=True)
        ax.set_yticks(np.arange(len(classes))-.5, minor=True)
        ax.grid(which="minor", color="#00ACC1", linestyle='-', linewidth=2)
        ax.tick_params(which="minor", size=0)

        plt.tight_layout()
        cm_path = f'{OUTPUT_DIR}/figures/cm_{model_name}_{target_dataset}.png'
        plt.savefig(cm_path, dpi=300, bbox_inches='tight', facecolor='#E0F7FA')
        plt.close()

        # PERFORMANCE METRICS BAR CHART - CYAN GRADIENT
        fig, ax = plt.subplots(figsize=(10, 6), facecolor='#E0F7FA')
        ax.set_facecolor('#F1FAFB')

        metric_names = ['Accuracy', 'Precision', 'Recall', 'F1-Score']
        metric_values = [metrics['accuracy'], metrics['precision'], metrics['recall'], metrics['f1']]
        colors = ['#00ACC1', '#26C6DA', '#4DD0E1', '#80DEEA']  # Cyan gradient

        bars = ax.bar(metric_names, metric_values, color=colors, alpha=0.85,
                     edgecolor='#00838F', linewidth=2.5)
        ax.set_ylabel('Score (%)', fontsize=13, fontweight='bold', color='#00838F')
        ax.set_xlabel('Metric', fontsize=13, fontweight='bold', color='#00838F')
        ax.set_title(f'{model_name} Performance on {target_dataset}',
                    fontsize=17, fontweight='bold', color='#006064', pad=20)
        ax.set_ylim(0, 105)
        ax.axhline(y=50, color='#E53935', linestyle='--',
                  label='Random Baseline', linewidth=2.5, alpha=0.7)
        ax.legend(fontsize=11, framealpha=0.9)
        ax.grid(axis='y', alpha=0.3, color='#80DEEA', linewidth=0.8)
        ax.spines['top'].set_visible(False)
        ax.spines['right'].set_visible(False)
        ax.spines['bottom'].set_color('#80DEEA')
        ax.spines['left'].set_color('#80DEEA')

        for bar, val in zip(bars, metric_values):
            ax.text(bar.get_x() + bar.get_width()/2., val + 2,
                   f'{val:.1f}%', ha='center', fontsize=12,
                   fontweight='bold', color='#006064')

        plt.tight_layout()
        metrics_path = f'{OUTPUT_DIR}/figures/metrics_{model_name}_{target_dataset}.png'
        plt.savefig(metrics_path, dpi=300, bbox_inches='tight', facecolor='#E0F7FA')
        plt.close()

        metrics_df = pd.DataFrame({
            'Metric': metric_names,
            'Value (%)': [f"{v:.1f}" for v in metric_values]
        })

        final_status = "\n".join(status_updates)
        yield final_status, cm_path, metrics_path, metrics_df

    except Exception as e:
        error_msg = add_status(f"\nError: {str(e)}")
        yield error_msg, None, None, None


# ============================================================================
# CYAN PROFESSIONAL THEME
# ============================================================================

custom_css = """
/* Main container - Medium dark background */
.gradio-container {
    font-family: 'Inter', 'Segoe UI', Arial, sans-serif !important;
    background: linear-gradient(135deg, #1a2332 0%, #2c3e50 100%) !important;
}

/* Header - Dark cyan with glow */
.main-header {
    background: linear-gradient(135deg, #00796B 0%, #00695C 100%) !important;
    color: #FFFFFF !important;
    padding: 2.5rem !important;
    border-radius: 12px !important;
    margin-bottom: 2rem !important;
    box-shadow: 0 4px 20px rgba(0, 188, 212, 0.4) !important;
}

.main-header h1 {
    color: #FFFFFF !important;
    text-shadow: 2px 2px 4px rgba(0,0,0,0.5) !important;
    font-weight: 700 !important;
}

.main-header p {
    color: #B2DFDB !important;
    font-weight: 500 !important;
}

/* Info box - Dark card style */
.project-info {
    background: #263238 !important;
    padding: 1.5rem !important;
    border-radius: 10px !important;
    border-left: 5px solid #00BCD4 !important;
    margin: 1rem 0 !important;
    box-shadow: 0 2px 12px rgba(0, 188, 212, 0.3) !important;
}

.project-info p {
    color: #B2EBF2 !important;
    font-weight: 500 !important;
    line-height: 1.8 !important;
}

.project-info strong {
    color: #00E5FF !important;
    font-weight: 700 !important;
}

/* COLORFUL TABS - Cyan with dark background */
.gr-tab-nav {
    background: #1e2936 !important;
    border-radius: 10px 10px 0 0 !important;
    padding: 0.5rem !important;
    box-shadow: 0 2px 8px rgba(0, 188, 212, 0.2) !important;
}

.gr-tab {
    color: #80DEEA !important;
    font-weight: 600 !important;
    font-size: 1.05rem !important;
    padding: 0.8rem 1.5rem !important;
    border-radius: 8px !important;
    margin: 0 0.3rem !important;
    transition: all 0.3s ease !important;
    background: transparent !important;
}

.gr-tab:hover {
    background: #263238 !important;
    color: #00E5FF !important;
}

.gr-tab.selected {
    background: linear-gradient(135deg, #00BCD4 0%, #00ACC1 100%) !important;
    color: #FFFFFF !important;
    font-weight: 700 !important;
    border: none !important;
    box-shadow: 0 4px 12px rgba(0, 188, 212, 0.5) !important;
}

/* Buttons - Cyan gradient with glow */
.gr-button-primary {
    background: linear-gradient(135deg, #00BCD4 0%, #00ACC1 100%) !important;
    color: #FFFFFF !important;
    font-weight: 600 !important;
    border: none !important;
    box-shadow: 0 4px 12px rgba(0, 188, 212, 0.4) !important;
    border-radius: 8px !important;
    transition: all 0.3s ease !important;
}

.gr-button-primary:hover {
    background: linear-gradient(135deg, #00ACC1 0%, #00838F 100%) !important;
    box-shadow: 0 6px 16px rgba(0, 188, 212, 0.6) !important;
    transform: translateY(-2px) !important;
}

/* Text boxes - Dark with cyan border */
.gr-text-input, .gr-textbox {
    background: #263238 !important;
    border: 2px solid #00838F !important;
    color: #E0F7FA !important;
    font-size: 0.95rem !important;
    border-radius: 8px !important;
}

.gr-text-input:focus, .gr-textbox:focus {
    border-color: #00BCD4 !important;
    box-shadow: 0 0 0 3px rgba(0, 188, 212, 0.2) !important;
}

/* Dropdown - Dark with cyan border */
.gr-dropdown {
    background: #263238 !important;
    border: 2px solid #00838F !important;
    color: #E0F7FA !important;
    font-weight: 500 !important;
    border-radius: 8px !important;
}

.gr-dropdown:focus {
    border-color: #00BCD4 !important;
    box-shadow: 0 0 0 3px rgba(0, 188, 212, 0.2) !important;
}

/* Markdown text - Light cyan */
.gr-markdown {
    color: #B2EBF2 !important;
}

.gr-markdown h3 {
    color: #00E5FF !important;
    font-weight: 700 !important;
    font-size: 1.3rem !important;
    margin-top: 1rem !important;
    margin-bottom: 0.8rem !important;
}

.gr-markdown strong {
    color: #00E5FF !important;
}

/* Tables - Dark theme with cyan accents */
.gr-dataframe {
    border: 2px solid #00838F !important;
    background: #263238 !important;
    border-radius: 8px !important;
    overflow: hidden !important;
}

.gr-dataframe th {
    background: linear-gradient(135deg, #00BCD4 0%, #00ACC1 100%) !important;
    color: #FFFFFF !important;
    font-weight: 700 !important;
    padding: 1rem !important;
    border: none !important;
}

.gr-dataframe td {
    color: #E0F7FA !important;
    font-weight: 500 !important;
    background: #263238 !important;
    padding: 0.8rem !important;
    border-bottom: 1px solid #37474F !important;
}

.gr-dataframe tr:hover td {
    background: #2c3e50 !important;
}

/* Accordion - Dark theme */
.gr-accordion {
    background: #263238 !important;
    border: 2px solid #00838F !important;
    border-radius: 8px !important;
    margin: 1rem 0 !important;
}

.gr-accordion-header {
    background: #2c3e50 !important;
    color: #B2EBF2 !important;
    font-weight: 600 !important;
    padding: 1rem !important;
}

.gr-accordion-header:hover {
    background: #34495e !important;
}

/* Image containers - Dark with cyan border */
.gr-image {
    background: #263238 !important;
    border: 2px solid #00838F !important;
    border-radius: 8px !important;
    box-shadow: 0 2px 12px rgba(0, 188, 212, 0.2) !important;
}

/* Labels - Cyan text */
.gr-label {
    color: #B2EBF2 !important;
    font-weight: 600 !important;
    font-size: 1rem !important;
}

/* Form containers - Dark background */
.gr-form {
    background: #263238 !important;
    border-radius: 10px !important;
    padding: 1.5rem !important;
    box-shadow: 0 2px 12px rgba(0, 188, 212, 0.1) !important;
}

/* Box borders - Cyan */
.gr-box {
    border-color: #00838F !important;
}

/* Panel backgrounds - Medium dark */
.gr-panel {
    background: #1e2936 !important;
}
"""


# ============================================================================
# GRADIO INTERFACE
# ============================================================================

with gr.Blocks(theme=gr.themes.Base(), css=custom_css, title="Transfer Learning IDS") as demo:

    with gr.Row():
        with gr.Column(scale=4):
            gr.Markdown("""
            <div class="main-header">
            <h1>Leveraging Cross-Domain Transfer Learning for Enhanced Multi-Protocol Network Intrusion Detection</h1>

            </div>
            """)

        with gr.Column(scale=1):
            gr.Markdown("""
            <div class="project-info">
            <p style='margin:0.4rem 0;'><strong>Author:</strong> Oluwaseyi Oladejo</p>
            <p style='margin:0.4rem 0;'><strong>Supervisor:</strong> Dr. Ahmed A Ahmed</p>
            </div>
            """)

    with gr.Accordion("Research Framework Overview", open=False):
        gr.Markdown("""
        **Methodology:**
        - Source Domain: CICIoMT (Medical IoT)
        - Target Domains: CIC-IoT (Smart Home), IoT-23 (Botnet)
        - Binary Classification: DoS vs Reconnaissance
        - Models: Random Forest, Gradient Boosting, XGBoost, SVM, MLP
        - Transfer Method: Zero-shot (no target fine-tuning)
        """)

    with gr.Tab("Layer 1: Data Processing"):
        gr.Markdown("### Dataset Processing Pipeline")
        dataset_dropdown = gr.Dropdown(choices=list(DATASET_INFO.keys()), value="CICIoMT", label="Select Dataset")
        process_button = gr.Button("Process Dataset", variant="primary", size="lg")
        status_box1 = gr.Textbox(label="Processing Status", lines=15)
        with gr.Row():
            plot_box1 = gr.Image(label="Class Distribution")
            summary_table1 = gr.Dataframe(label="Dataset Summary")
        process_button.click(process_dataset_layer1, inputs=[dataset_dropdown], outputs=[status_box1, plot_box1, summary_table1])

    with gr.Tab("Layer 2: Model Training"):
        gr.Markdown("### Train Models on Source Domain")
        gr.Markdown("""
        **Training Configuration:**
        - Source Dataset: CICIoMT
        - Training Samples: 7,724
        - Binary Classes: DoS, Reconnaissance
        - Models: 5 (RF, GB, XGBoost, SVM, MLP)
        """)
        train_button = gr.Button("Train All Models", variant="primary", size="lg")
        status_box2 = gr.Textbox(label="Training Status", lines=15)
        with gr.Row():
            plot_box2 = gr.Image(label="Training Times")
            results_table2 = gr.Dataframe(label="Training Results")
        train_button.click(train_models_layer2, outputs=[status_box2, plot_box2, results_table2])

    with gr.Tab("Layer 3: Transfer Learning"):
        gr.Markdown("### Zero-Shot Transfer to Target Domains")
        with gr.Row():
            target_dropdown = gr.Dropdown(choices=["CIC-IoT", "IoT-23"], value="CIC-IoT", label="Target Dataset")
            model_dropdown = gr.Dropdown(choices=MODELS, value="RandomForest", label="Model")
        transfer_button = gr.Button("Run Transfer Learning", variant="primary", size="lg")
        status_box3 = gr.Textbox(label="Transfer Learning Status", lines=15)
        with gr.Row():
            cm_plot = gr.Image(label="Confusion Matrix")
            metrics_plot = gr.Image(label="Performance Metrics")
        metrics_table3 = gr.Dataframe(label="Detailed Metrics")
        transfer_button.click(run_transfer_learning_layer3, inputs=[target_dropdown, model_dropdown],
                            outputs=[status_box3, cm_plot, metrics_plot, metrics_table3])



if __name__ == "__main__":
    import time
    start = time.time()

    print("Loading step 1...")
    # Your imports
    print(f"Step 1 done in {time.time()-start:.2f}s")

    print("Loading step 2...")
    # Your dataset loading
    print(f"Step 2 done in {time.time()-start:.2f}s")

    print("Loading step 3...")
    # Your app definition
    print(f"Step 3 done in {time.time()-start:.2f}s")

    print("Launching...")
    demo.launch(share=True)
    print(f"Launch called at {time.time()-start:.2f}s")

    print("="*70)
    print("Leveraging Cross-Domain Transfer Learning")
    print("for Enhanced Multi-Protocol Network Intrusion Detection")
    print("="*70)
    print("\nAuthor: Oluwaseyi Oladejo")
    print("Supervisor: Dr. Ahmed A Ahmed")
    print("Master's Project 2025")
    print("\n" + "="*70)
    print("\nStarting GUI Demo...")
    print("="*70 + "\n")

    demo.launch(share=True, debug =True)
