In [None]:
from flask import Flask, render_template, request, redirect, url_for, session
import numpy as np
import shap
import matplotlib
matplotlib.use('Agg')  # Non-GUI backend
import matplotlib.pyplot as plt
import joblib
import pandas as pd
import os

# Load trained model and components
model_data = joblib.load("intrusion_detection.pkl")
scaler = model_data['scaler']
pca = model_data['pca']
model = model_data['model']
label_encoders = model_data['label_encoders']

app = Flask(__name__)
app.secret_key = 'your_secret_key'

# Login credentials
users = {"admin": "password"}

@app.route('/')
def home():
    return redirect(url_for('login'))

@app.route('/login', methods=['GET', 'POST'])
def login():
    session.clear()
    if request.method == 'POST':
        username = request.form.get('username')
        password = request.form.get('password')
        if username in users and users[username] == password:
            session['user'] = username
            return redirect(url_for('predict'))
        return "Invalid Credentials!"
    return render_template('login.html')

@app.route('/predict', methods=['GET', 'POST'])
def predict():
    if 'user' not in session:
        return redirect(url_for('login'))

    if request.method == 'GET':
        return render_template('predict.html')
    
    try:
        # Expected features
        expected_features = [ 'duration', 'protocol_type', 'service', 'flag', 'src_bytes', 'dst_bytes', 
            'land', 'wrong_fragment', 'urgent', 'hot', 'num_failed_logins', 'logged_in',
            'num_compromised', 'root_shell', 'su_attempted', 'num_root', 'num_file_creations', 
            'num_shells', 'num_access_files', 'num_outbound_cmds', 'is_host_login', 'is_guest_login', 
            'count', 'srv_count', 'serror_rate', 'srv_serror_rate', 'rerror_rate', 'srv_rerror_rate', 
            'same_srv_rate', 'diff_srv_rate', 'srv_diff_host_rate', 'dst_host_count', 
            'dst_host_srv_count', 'dst_host_same_srv_rate', 'dst_host_diff_srv_rate', 
            'dst_host_same_src_port_rate', 'dst_host_srv_diff_host_rate', 'dst_host_serror_rate', 
            'dst_host_srv_serror_rate', 'dst_host_rerror_rate', 'dst_host_srv_rerror_rate']  # List all feature names here
        
        # Collect and process input data
        feature_values = []
        for feature in expected_features:
            value = request.form.get(feature, None)
            if value is None:
                return f"Error: Missing value for '{feature}'"
            if feature in label_encoders:
                try:
                    value = label_encoders[feature].transform([value])[0]
                except ValueError:
                    return f"Error: Unknown value '{value}' for '{feature}'"
            feature_values.append(float(value))
        
        features = np.array(feature_values).reshape(1, -1)
        features_scaled = scaler.transform(features)
        features_pca = pca.transform(features_scaled)
        
        # Make prediction
        prediction = model.predict(features_pca)[0]
        prediction_proba = model.predict_proba(features_pca).max() * 100
        
        # Precaution messages
        precaution = {
            "Normal": "No intrusion detected.",
            "DoS": "Mitigate DoS attacks with firewalls.",
            "Probe": "Monitor suspicious scanning attempts.",
            "R2L": "Ensure strong authentication mechanisms.",
            "U2R": "Update software and restrict user privileges."
        }
        attack_label = "Normal" if prediction == 0 else "Attack Detected"
        precaution_message = precaution.get(attack_label, "No specific precaution available.")

        # Generate SHAP values
        explainer = shap.TreeExplainer(model)
        shap_values = explainer.shap_values(features_pca)
        print("SHAP Values Shape:", np.shape(shap_values))
        print("Expected Value Shape:", np.shape(explainer.expected_value))

        # Ensure shap_values has the correct shape
        if isinstance(shap_values, list):
            shap_values_single = shap_values[prediction]  # Extract SHAP values for the predicted class
        elif isinstance(shap_values, np.ndarray):
            shap_values_single = shap_values[:, :, prediction]  # Select the correct class
        else:
            shap_values_single = np.array([shap_values])  # Convert scalar to array

        # Ensure explainer.expected_value is indexed correctly
        if isinstance(explainer.expected_value, np.ndarray) and explainer.expected_value.shape[0] == 23:
            base_value = explainer.expected_value[prediction]  # Select expected value for predicted class
        else:
            base_value = explainer.expected_value  # Use as-is if not multi-class

        # Reshape SHAP values if necessary
        shap_values_single = np.array(shap_values_single)

        if shap_values_single.ndim == 1:
            shap_values_single = shap_values_single.reshape(1, -1)

        # Create SHAP Explanation Object
        shap_explanation = shap.Explanation(
            values=shap_values_single[0],
            base_values=base_value,
            data=features_pca[0],
            feature_names=[f"PC{i+1}" for i in range(features_pca.shape[1])]
        )

        # Generate SHAP Waterfall Plot
        plt.figure()
        shap.plots.waterfall(shap_explanation)

        # Save the plot
        shap_plot_path = "static/shap_plot.png"
        plt.savefig(shap_plot_path, bbox_inches='tight', dpi=300)
        plt.close()

        # Generate SHAP Explanation Object
        shap_explanation = shap.Explanation(
            values=shap_values_single[0],
            base_values=base_value,
            data=features_pca[0],
            feature_names=[f"PC{i+1}" for i in range(features_pca.shape[1])]
        )

        # Generate SHAP Waterfall Plot
        plt.figure()
        shap.plots.waterfall(shap_explanation)

        # Save the plot
        shap_plot_path = "static/shap_plot.png"
        plt.savefig(shap_plot_path, bbox_inches='tight', dpi=300)
        plt.close()

        # Get actual attack labels from the trained model
        attack_labels = {i: label for i, label in enumerate(model.classes_)}  # Dynamically map indices to names

        # # Convert prediction index to actual attack name
        # attack_label = attack_labels.get(prediction, "Unknown Attack Type")  # Now shows label name instead of number
        # print("Attack Label Mapping:", attack_labels)
        # print(f"Predicted Attack Type: {attack_label}")  # Should now show "normal" or "neptune" instead of "11"
        # Convert prediction index to actual attack name using label encoder
        if 'attack_type' in label_encoders:  # Ensure the attack_type encoder exists
            attack_label = label_encoders['attack_type'].inverse_transform([prediction])[0]
        else:
            attack_label = attack_labels.get(prediction, "Unknown Attack Type")  # Fallback if no encoder

        print(f"Decoded Attack Label: {attack_label}")  # Should now show actual attack name like "neptune" or "normal"

        return render_template('result.html', 
                               attack_type=attack_label,  # Displays actual attack name!
                               confidence=round(prediction_proba, 2), 
                               precaution=precaution_message,  
                               shap_img=shap_plot_path)
           
    except Exception as e:
        return f"Error: {str(e)}"

@app.route('/logout')
def logout():
    session.pop('user', None)
    return redirect(url_for('login'))

if __name__ == '__main__':
    app.run(debug=True)
