<a href="https://colab.research.google.com/github/Khalidaman9555/IDS-AI/blob/main/Multi_Model_Realtime_IDS_Dashboard_EdgeIIoT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ... (Outer Colab setup code: ngrok, pip installs etc. remain the same) ...

# 3. Define the Streamlit application content and write it to app.py
import os
app_content = """
import streamlit as st

# MUST BE THE VERY FIRST STREAMLIT COMMAND
st.set_page_config(layout="wide", page_title="Enhanced IDS Dashboard")

import pandas as pd
import numpy as np
import time
from datetime import datetime
import altair as alt
from collections import defaultdict, deque
import matplotlib.pyplot as plt
from sklearn.preprocessing import MinMaxScaler # For LSTM, original scaler
from sklearn.metrics import confusion_matrix, classification_report
import joblib # For loading scikit-learn models
import os # Added for os.path.exists

# Conditional import for TensorFlow
try:
    from tensorflow.keras.models import load_model
    TENSORFLOW_AVAILABLE = True
except ImportError:
    TENSORFLOW_AVAILABLE = False
    # This st.sidebar.warning will only be called when the UI tries to render the sidebar later
    # However, for robustness, messages from initial loading could be collected and displayed
    # after set_page_config and basic UI elements are set up.
    # For now, this should be okay as set_page_config is first.

# --- Configuration ---
# Paths to your trained models and scaler
BASE_PATH = "/content/drive/MyDrive/Colab Notebooks/" # Adjust if your Drive is mounted differently
RESULTS_PATH = os.path.join(BASE_PATH, "results")
DATA_PATH = os.path.join(BASE_PATH, "datasets/ML-EdgeIIoT-dataset.csv")

# Scikit-learn models and common scaler
SKLEARN_MODEL_FILES = {
    "RandomForest": os.path.join(RESULTS_PATH, "randomforest_model.joblib"),
    "DecisionTree": os.path.join(RESULTS_PATH, "decisiontree_model.joblib"),
    "LogisticRegression": os.path.join(RESULTS_PATH, "logisticregression_model.joblib"),
    "XGBoost": os.path.join(RESULTS_PATH, "xgboost_model.joblib"),
    "SVM": os.path.join(RESULTS_PATH, "svm_model.joblib")
}
COMMON_SCALER_PATH = os.path.join(RESULTS_PATH, "common_scaler.joblib")

# LSTM Model
LSTM_MODEL_PATH = os.path.join(RESULTS_PATH, "lstm_model.h5")

# Features: CRITICAL - This list MUST match the features your models were trained on.
FEATURE_COLUMNS = [
    'arp.hw.size', 'http.content_length', 'http.response', 'http.tls_port',
    'tcp.ack_raw', 'tcp.checksum', 'tcp.connection.fin', 'tcp.connection.rst',
    'tcp.connection.syn', 'tcp.connection.synack', 'tcp.dstport',
    'tcp.flags.ack', 'tcp.len', 'udp.stream', 'udp.time_delta', 'dns.qry.qu',
    'dns.qry.type', 'dns.retransmission', 'dns.retransmit_request',
    'dns.retransmit_request_in', 'mqtt.conflag.cleansess', 'mqtt.hdrflags',
    'mqtt.len', 'mqtt.msg_decoded_as', 'mbtcp.len', 'mbtcp.trans_id',
    'mbtcp.unit_id'
]
LSTM_SEQUENCE_LENGTH = 1

# --- Session State Initialization ---
if 'data' not in st.session_state:
    st.session_state.data = {
        'events': deque(maxlen=500),
        'predictions_log': deque(maxlen=500),
        'attack_stats': defaultdict(int),
        'model_metrics': {'tp': 0, 'fp': 0, 'tn': 0, 'fn': 0},
        'common_scaler_sklearn': None,
        'lstm_scaler': MinMaxScaler(),
        'data_iterator': None,
        'selected_model_name': None,
        'loaded_models': {}
    }
if 'page_loaded' not in st.session_state:
    st.session_state.page_loaded = False

# --- Model and Scaler Loading ---
@st.cache_resource # Use cache_resource for model objects
def load_sklearn_models_and_scaler():
    models = {}
    # Temporary lists to hold messages, display them after set_page_config
    load_messages = []
    for name, path in SKLEARN_MODEL_FILES.items():
        if os.path.exists(path):
            try:
                models[name] = joblib.load(path)
            except Exception as e:
                load_messages.append(f"Error: Failed to load {name} from {path}: {e}")
        else:
            load_messages.append(f"Warning: Scikit-learn model not found: {path}")

    scaler = None
    if os.path.exists(COMMON_SCALER_PATH):
        try:
            scaler = joblib.load(COMMON_SCALER_PATH)
            load_messages.append("Success: Common Scaler (for scikit-learn) loaded.")
        except Exception as e:
            load_messages.append(f"Error: Failed to load Common Scaler from {COMMON_SCALER_PATH}: {e}")
    else:
        load_messages.append(f"Warning: Common Scaler not found: {COMMON_SCALER_PATH}")
    return models, scaler, load_messages

@st.cache_resource
def load_lstm_model_keras():
    load_messages = []
    if not TENSORFLOW_AVAILABLE:
        load_messages.append("Warning: TensorFlow/Keras not found. LSTM model will be unavailable.")
        return None, load_messages

    model = None
    if os.path.exists(LSTM_MODEL_PATH):
        try:
            model = load_model(LSTM_MODEL_PATH)
            load_messages.append("Success: LSTM Model loaded.")
        except Exception as e:
            load_messages.append(f"Error: LSTM Model loading failed from {LSTM_MODEL_PATH}: {e}")
    else:
        load_messages.append(f"Warning: LSTM Model not found: {LSTM_MODEL_PATH}")
    return model, load_messages

# Initial loading messages will be displayed in the sidebar later
_initial_load_messages = []

if not st.session_state.data['loaded_models']: # Load only once
    sklearn_models_loaded, common_scaler_loaded, skl_msgs = load_sklearn_models_and_scaler()
    _initial_load_messages.extend(skl_msgs)
    if common_scaler_loaded:
        st.session_state.data['common_scaler_sklearn'] = common_scaler_loaded
    for name, model in sklearn_models_loaded.items():
        st.session_state.data['loaded_models'][name] = model

    if TENSORFLOW_AVAILABLE:
        lstm_model_loaded, lstm_msgs = load_lstm_model_keras()
        _initial_load_messages.extend(lstm_msgs)
        if lstm_model_loaded:
            st.session_state.data['loaded_models']['LSTM'] = lstm_model_loaded

# --- Data Pipeline ---
class EdgeIIoTDataIterator:
    def __init__(self, file_path, feature_columns):
        self.feature_columns = feature_columns
        self.df_processed = pd.DataFrame()
        self.current_idx = 0
        self.load_messages = [] # Messages specific to data iterator

        if not os.path.exists(file_path):
            self.load_messages.append(f"Error: Dataset CSV not found at {file_path}.")
            return

        try:
            self.df = pd.read_csv(file_path, low_memory=False)
            self.load_messages.append(f"Info: Raw CSV loaded: {len(self.df)} rows, {len(self.df.columns)} columns.")

            for col in self.feature_columns:
                if col not in self.df.columns:
                    self.load_messages.append(f"Warning: Feature '{col}' not in CSV! Filling with 0.")
                    self.df[col] = 0
                else:
                    self.df[col] = pd.to_numeric(self.df[col], errors='coerce').fillna(0)

            if 'Attack_label' in self.df.columns:
                self.df['is_attack'] = self.df['Attack_label'].apply(lambda x: 0 if str(x) == '0' else 1)
            else:
                self.load_messages.append("Warning: 'Attack_label' not in CSV. Assuming normal.")
                self.df['is_attack'] = 0

            self.df_processed = self.df[self.feature_columns + ['is_attack']].copy()
            self.load_messages.append(f"Success: Dataset processed. Features: {len(self.feature_columns)}.")
        except Exception as e:
            self.load_messages.append(f"Error: Error loading/processing dataset: {str(e)}")

    def get_load_messages(self):
        return self.load_messages

    def __iter__(self):
        return self

    def __next__(self):
        if self.df_processed.empty or self.current_idx >= len(self.df_processed):
            self.current_idx = 0
            if self.df_processed.empty:
                # This message will be handled by the main loop checking iterator status
                raise StopIteration

        row = self.df_processed.iloc[self.current_idx]
        self.current_idx += 1
        sample_features = {col: float(row[col]) for col in self.feature_columns}
        return {
            'timestamp': datetime.now().strftime('%H:%M:%S.%f')[:-3],
            'is_attack': int(row['is_attack']),
            **sample_features
        }

_data_iterator_messages = []
if st.session_state.data.get('data_iterator') is None and DATA_PATH:
    iterator_instance = EdgeIIoTDataIterator(DATA_PATH, FEATURE_COLUMNS)
    st.session_state.data['data_iterator'] = iterator_instance
    _data_iterator_messages = iterator_instance.get_load_messages()

# --- Preprocessing and Prediction --- (Definition remains the same)
def preprocess_and_predict(sample_features_dict, model_name, feature_columns):
    features_2d = np.array([[sample_features_dict[col] for col in feature_columns]])
    current_model = st.session_state.data['loaded_models'].get(model_name)

    if current_model is None: return 0, 0.0, f"Model '{model_name}' not loaded"

    try:
        if model_name == "LSTM":
            scaler_lstm = st.session_state.data['lstm_scaler']
            if not hasattr(scaler_lstm, 'scale_'):
                 scaler_lstm.fit(features_2d)
            model_input = scaler_lstm.transform(features_2d)
            model_input = model_input.reshape(1, LSTM_SEQUENCE_LENGTH, len(feature_columns))
            status_msg = "LSTM Preprocessed"
            prediction_prob = current_model.predict(model_input, verbose=0)[0][0]
            predicted_class = int(prediction_prob > 0.5)
        else: # Scikit-learn
            scaler_sklearn = st.session_state.data.get('common_scaler_sklearn')
            if scaler_sklearn is None: return 0, 0.0, "Common scaler not loaded"
            model_input = scaler_sklearn.transform(features_2d)
            status_msg = "Scikit-learn Preprocessed"
            predicted_class = current_model.predict(model_input)[0]
            prediction_prob = current_model.predict_proba(model_input)[0][1] if hasattr(current_model, "predict_proba") else (1.0 if predicted_class == 1 else 0.0)

        return int(predicted_class), float(prediction_prob), status_msg
    except Exception as e:
        return 0, 0.0, f"Prediction error ({model_name}): {str(e)}"

# --- Dashboard UI ---
# Title is now after set_page_config
st.title("🛡️ Edge-IIoT Intrusion Detection System")

# Display initial loading messages in the sidebar
st.sidebar.header("🔄 Loading Status")
for msg in _initial_load_messages:
    if "Error:" in msg: st.sidebar.error(msg)
    elif "Warning:" in msg: st.sidebar.warning(msg)
    elif "Success:" in msg: st.sidebar.success(msg)
    else: st.sidebar.info(msg)

for msg in _data_iterator_messages: # Display data iterator messages
    if "Error:" in msg: st.sidebar.error(msg)
    elif "Warning:" in msg: st.sidebar.warning(msg)
    elif "Success:" in msg: st.sidebar.success(msg)
    else: st.sidebar.info(msg)


st.sidebar.header("⚙️ Controls & Settings")
available_model_names = [name for name, model in st.session_state.data['loaded_models'].items() if model is not None]
if not available_model_names:
    st.sidebar.error("No models loaded! Dashboard cannot operate.")
    st.stop()

active_model_name = st.sidebar.selectbox("🧠 Select Model:", available_model_names, index=0)
st.sidebar.info(f"Active Model: **{active_model_name}**")
st.session_state.data['selected_model_name'] = active_model_name

update_interval_ms = st.sidebar.slider("⏱️ Update Interval (ms)", 100, 2000, 750)
history_length = st.sidebar.slider("📊 History Length (samples)", 20, 200, 50)

if st.session_state.data['events'].maxlen != history_length:
    st.session_state.data['events'] = deque(st.session_state.data['events'], maxlen=history_length)
    st.session_state.data['predictions_log'] = deque(st.session_state.data['predictions_log'], maxlen=history_length)

col1, col2 = st.columns([2, 1])
with col1:
    st.subheader("🚦 Real-time Detections")
    event_plot_placeholder = st.empty()
    st.subheader("📈 Prediction Confidence")
    confidence_plot_placeholder = st.empty()
with col2:
    st.subheader("📊 Session Statistics")
    stats_placeholder = st.empty()
    st.subheader("⚠️ Recent Alerts (Predicted Attacks)")
    alerts_placeholder = st.empty()
    selected_model_obj = st.session_state.data['loaded_models'].get(active_model_name)
    if active_model_name not in ["LSTM", "SVM"] and selected_model_obj and (hasattr(selected_model_obj, 'feature_importances_') or hasattr(selected_model_obj, 'coef_')):
        st.subheader(f"Feature Importance ({active_model_name})")
        fi_placeholder = st.empty()

data_iterator = st.session_state.data.get('data_iterator')
if data_iterator is None or (hasattr(data_iterator, 'df_processed') and data_iterator.df_processed.empty):
    # This message is now more prominent if data loading failed
    st.error("Data not available or failed to load. Dashboard cannot run. Check sidebar for loading status messages.")
    st.stop()

if not st.session_state.page_loaded: st.session_state.page_loaded = True

# Main display loop (content remains largely the same as before)
while True:
    try: sample = next(data_iterator)
    except StopIteration: st.warning("Data source depleted."); break
    except Exception as e: st.error(f"Error fetching data: {e}"); time.sleep(2); continue

    sample_features = {k: sample[k] for k in FEATURE_COLUMNS}
    # Use the globally selected active_model_name
    predicted_class, prediction_prob, pred_status = preprocess_and_predict(sample_features, st.session_state.data['selected_model_name'], FEATURE_COLUMNS)


    if "error" in pred_status.lower() or "not loaded" in pred_status.lower():
        # Display persistent error in sidebar if prediction keeps failing for the selected model
        # This check helps to avoid flooding the main UI if a model is consistently failing.
        # More sophisticated error handling might be needed for continuous errors.
        pass # Error already shown via initial load messages or will be caught if model is None

    st.session_state.data['events'].append(sample)
    st.session_state.data['predictions_log'].append({
        'timestamp': sample['timestamp'], 'actual': sample['is_attack'],
        'predicted': predicted_class, 'probability': prediction_prob, 'model': st.session_state.data['selected_model_name']
    })

    actual, pred = sample['is_attack'], predicted_class
    metrics = st.session_state.data['model_metrics']
    if actual == 1 and pred == 1: metrics['tp'] += 1
    elif actual == 0 and pred == 1: metrics['fp'] += 1
    elif actual == 0 and pred == 0: metrics['tn'] += 1
    elif actual == 1 and pred == 0: metrics['fn'] += 1

    events_df = pd.DataFrame(list(st.session_state.data['events']))
    predictions_log_df = pd.DataFrame(list(st.session_state.data['predictions_log']))

    if not predictions_log_df.empty:
        chart_data = predictions_log_df.copy()
        try:
            chart_data['time_obj'] = pd.to_datetime(chart_data['timestamp'], format='%H:%M:%S.%f', errors='coerce')
            chart_data.dropna(subset=['time_obj'], inplace=True)
        except Exception: pass

        if 'time_obj' in chart_data.columns and not chart_data.empty: # Check if chart_data is not empty after potential dropna
            with event_plot_placeholder.container():
                base_chart = alt.Chart(chart_data).encode(x='time_obj:T')
                line_prob = base_chart.mark_line(opacity=0.8).encode(
                    y=alt.Y('probability:Q', title='Attack Probability', scale=alt.Scale(domain=[0, 1])),
                    tooltip=['timestamp:N', 'probability:Q', 'actual:N', 'predicted:N']
                ).properties(title="Attack Prediction Probability")
                actual_attacks = base_chart.transform_filter(alt.datum.actual == 1).mark_circle(size=100, color='red', opacity=0.7).encode(y='probability:Q')
                predicted_markers = base_chart.transform_filter(alt.datum.predicted == 1).mark_point(shape='diamond', size=80, color='orange', opacity=0.7, filled=True).encode(y='probability:Q')
                st.altair_chart(line_prob + actual_attacks + predicted_markers, use_container_width=True)

            with confidence_plot_placeholder.container():
                conf_chart = alt.Chart(chart_data).mark_bar().encode(
                    x='time_obj:T',
                    y=alt.Y('probability:Q', title='Confidence', scale=alt.Scale(domain=(0,1))),
                    color=alt.condition(alt.datum.probability > 0.5, alt.value('#e45756'), alt.value('#54a24b'))
                ).properties(title="Prediction Confidence", height=200)
                st.altair_chart(conf_chart, use_container_width=True)
        else:
            with event_plot_placeholder.container():
                if predictions_log_df.empty: st.caption("Waiting for data to plot...")
                else: st.warning("Timestamp conversion failed or no valid data for plotting main chart.")


    with stats_placeholder.container():
        tp, fp, tn, fn = metrics['tp'], metrics['fp'], metrics['tn'], metrics['fn']
        st.markdown(f"**Confusion Matrix (Session Totals):**")
        cm_df = pd.DataFrame([[tn, fp], [fn, tp]], columns=['Pred Normal', 'Pred Attack'], index=['Actual Normal', 'Actual Attack'])
        st.table(cm_df)
        total = tp + tn + fp + fn
        accuracy = (tp + tn) / total if total > 0 else 0
        precision = tp / (tp + fp) if (tp + fp) > 0 else 0
        recall = tp / (tp + fn) if (tp + fn) > 0 else 0
        f1 = 2*(precision*recall)/(precision+recall) if (precision+recall) > 0 else 0
        st.metric("Accuracy", f"{accuracy:.2%}")
        m_col1, m_col2, m_col3 = st.columns(3)
        m_col1.metric("Precision", f"{precision:.2%}")
        m_col2.metric("Recall", f"{recall:.2%}")
        m_col3.metric("F1-Score", f"{f1:.2%}")

    with alerts_placeholder.container():
        if not predictions_log_df.empty:
            recent_alerts = predictions_log_df[predictions_log_df['predicted'] == 1].tail(5)
            st.dataframe(recent_alerts[['timestamp', 'probability', 'actual']], hide_index=True, use_container_width=True)
        else: st.info("No attacks detected recently.")

    current_active_model_name = st.session_state.data['selected_model_name'] # Use the most up-to-date selected model
    selected_model_obj_fi = st.session_state.data['loaded_models'].get(current_active_model_name) # Re-fetch for feature importance part

    if current_active_model_name not in ["LSTM", "SVM"] and selected_model_obj_fi and \
       (hasattr(selected_model_obj_fi, 'feature_importances_') or hasattr(selected_model_obj_fi, 'coef_')) and \
       ('fi_placeholder' in locals() or 'fi_placeholder' in globals()):
        importances = None
        if hasattr(selected_model_obj_fi, 'feature_importances_'): importances = selected_model_obj_fi.feature_importances_
        elif hasattr(selected_model_obj_fi, 'coef_'): importances = selected_model_obj_fi.coef_[0]

        if importances is not None and 'fi_placeholder' in locals():
            with fi_placeholder.container():
                fi_df = pd.DataFrame({'feature': FEATURE_COLUMNS,
                                      'importance': np.abs(importances) if importances is not None else [0]*len(FEATURE_COLUMNS)}
                                    ).sort_values('importance', ascending=False).head(10)
                if not fi_df.empty:
                    fi_chart = alt.Chart(fi_df).mark_bar().encode(
                        x='importance:Q',
                        y=alt.Y('feature:N', sort='-x')
                    ).properties(title=f"Top 10 Features ({current_active_model_name})", height=250)
                    st.altair_chart(fi_chart, use_container_width=True)
                else:
                    st.caption("Could not generate feature importances.")
        elif 'fi_placeholder' in locals():
             with fi_placeholder.container(): st.caption("Feature importance not available for this model type.")


    time.sleep(update_interval_ms / 1000)

"""

# ... (Rest of the Colab script: with open('app.py'...) and running Streamlit server code remains the same)
# Write the app content to app.py
with open('app.py', 'w') as f:
    f.write(app_content)

# 4. Run Streamlit
import threading
import time
import os # Ensure os is imported here as well

def run_streamlit_server():
    # Kill existing streamlit processes to avoid port conflicts
    os.system('pkill -f streamlit')
    time.sleep(2) # Give it a moment to kill
    # Command to run Streamlit, with specific flags for Colab compatibility
    os.system(f'streamlit run app.py --server.port 8501 --server.headless true --server.enableCORS false --server.enableXsrfProtection false')

# Start Streamlit in a background thread
thread = threading.Thread(target=run_streamlit_server, daemon=True)
thread.start()
print("Streamlit server starting in background...")
time.sleep(5)  # Wait for Streamlit to initialize

# 5. Provide access link using Colab's proxy
from google.colab.output import eval_js
print(f"\n📊 Access your dashboard at: {eval_js('google.colab.kernel.proxyPort(8501)')}")
print("If the link above doesn't work or shows an error, please check the Colab cell output for any Streamlit errors.")
print("Ensure your Google Drive is mounted and all file paths are correct.")

Streamlit server starting in background...

📊 Access your dashboard at: https://8501-m-s-19dsq8pbxxhz3-b.europe-west4-1.prod.colab.dev
If the link above doesn't work or shows an error, please check the Colab cell output for any Streamlit errors.
Ensure your Google Drive is mounted and all file paths are correct.
