In [10]:
pip install eli5==0.11.0 scikit-learn==0.24.2

Collecting eli5==0.11.0
  Downloading eli5-0.11.0-py2.py3-none-any.whl.metadata (17 kB)
Collecting scikit-learn==0.24.2
  Downloading scikit-learn-0.24.2.tar.gz (7.5 MB)
     ---------------------------------------- 0.0/7.5 MB ? eta -:--:--
     -- ------------------------------------- 0.5/7.5 MB 2.4 MB/s eta 0:00:03
     ------ --------------------------------- 1.3/7.5 MB 3.5 MB/s eta 0:00:02
     --------- ------------------------------ 1.8/7.5 MB 3.0 MB/s eta 0:00:02
     ----------- ---------------------------- 2.1/7.5 MB 2.4 MB/s eta 0:00:03
     ----------- ---------------------------- 2.1/7.5 MB 2.4 MB/s eta 0:00:03
     ----------- ---------------------------- 2.1/7.5 MB 2.4 MB/s eta 0:00:03
     ------------ --------------------------- 2.4/7.5 MB 1.7 MB/s eta 0:00:04
     ------------ --------------------------- 2.4/7.5 MB 1.7 MB/s eta 0:00:04
     ------------- -------------------------- 2.6/7.5 MB 1.4 MB/s eta 0:00:04
     --------------- ------------------------ 2.9/7.5 MB 

  error: subprocess-exited-with-error
  
  Preparing metadata (pyproject.toml) did not run successfully.
  exit code: 1
  
  [72 lines of output]
  Partial import of sklearn during the build process.
  
    `numpy.distutils` is deprecated since NumPy 1.23.0, as a result
    of the deprecation of `distutils` itself. It will be removed for
    Python >= 3.12. For older Python versions it will remain present.
    It is recommended to use `setuptools < 60.0` for those Python versions.
    For more details, see:
      https://numpy.org/devdocs/reference/distutils_status_migration.html
  
  
  !!
  
          ********************************************************************************
          Please consider removing the following classifiers in favor of a SPDX license expression:
  
          License :: OSI Approved
  
          See https://packaging.python.org/en/latest/guides/writing-pyproject-toml/#license for details.
          *****************************************************

In [12]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.inspection import permutation_importance
from sklearn.metrics import confusion_matrix, classification_report
import shap
import lime
import lime.lime_tabular
import joblib
import pickle
import os
from IPython.display import display

In [13]:
class HNIDS_Explainer:
    """
    Explainable AI (XAI) toolkit for HNIDS (Hybrid Network-based Intrusion Detection System).
    Provides various methods to interpret and explain model decisions.
    """
    
    def __init__(self, hnids_model=None, X_train=None, y_train=None, X_test=None, y_test=None):
        """
        Initialize the HNIDS Explainer with model and data.
        
        Args:
            hnids_model: Trained HNIDS model
            X_train: Training features
            y_train: Training labels
            X_test: Test features
            y_test: Test labels
        """
        self.hnids_model = hnids_model
        self.X_train = X_train
        self.y_train = y_train
        self.X_test = X_test
        self.y_test = y_test
        self.explainers = {}
        
    def load_model(self, model_path):
        """
        Load a previously saved HNIDS model.
        
        Args:
            model_path: Path to the saved model file
        """
        try:
            with open(model_path, 'rb') as f:
                self.hnids_model = pickle.load(f)
            print(f"Model loaded successfully from {model_path}")
        except Exception as e:
            print(f"Error loading model: {e}")
    
    def load_data(self, X_train=None, y_train=None, X_test=None, y_test=None):
        """
        Load data for explanation.
        
        Args:
            X_train: Training features
            y_train: Training labels
            X_test: Test features
            y_test: Test labels
        """
        if X_train is not None:
            self.X_train = X_train
        if y_train is not None:
            self.y_train = y_train
        if X_test is not None:
            self.X_test = X_test
        if y_test is not None:
            self.y_test = y_test
    
    def get_feature_importance(self, n_top_features=20, plot=True):
        """
        Get feature importance using permutation importance method.
        
        Args:
            n_top_features: Number of top features to display
            plot: Whether to plot the feature importance
            
        Returns:
            DataFrame with feature importance scores
        """
        if self.hnids_model is None or self.X_test is None or self.y_test is None:
            print("Model or test data not available. Please load model and data first.")
            return None
        
        try:
            # Get the core Random Forest model from HNIDS
            rf_model = self.hnids_model.irf_model
            
            # Get selected features if available
            if hasattr(self.hnids_model, 'best_features') and self.hnids_model.best_features is not None:
                feature_indices = np.where(self.hnids_model.best_features == 1)[0]
                if hasattr(self.hnids_model, 'feature_names') and self.hnids_model.feature_names is not None:
                    valid_indices = [idx for idx in feature_indices if idx < len(self.hnids_model.feature_names)]
                    selected_features = [self.hnids_model.feature_names[idx] for idx in valid_indices]
                    X_test_selected = self.X_test[selected_features]
                else:
                    # Use all features if feature names not available
                    X_test_selected = self.X_test
            else:
                # Use all features if no feature selection
                X_test_selected = self.X_test
            
            # Calculate permutation importance
            perm_importance = permutation_importance(rf_model, X_test_selected, self.y_test, 
                                                   n_repeats=10, random_state=42, n_jobs=-1)
            
            # Create DataFrame with importance scores
            feature_names = X_test_selected.columns
            importance_df = pd.DataFrame({
                'Feature': feature_names,
                'Importance': perm_importance.importances_mean,
                'Std_Dev': perm_importance.importances_std
            })
            
            # Sort by importance
            importance_df = importance_df.sort_values('Importance', ascending=False).reset_index(drop=True)
            
            # Plot feature importance
            if plot:
                plt.figure(figsize=(12, 8))
                
                # Take top N features
                plot_df = importance_df.head(n_top_features).copy()
                
                # Plot
                ax = sns.barplot(x='Importance', y='Feature', data=plot_df, 
                                palette='viridis', xerr=plot_df['Std_Dev'])
                
                plt.title('Feature Importance (Permutation Method)', fontsize=16)
                plt.xlabel('Importance Score', fontsize=14)
                plt.ylabel('Features', fontsize=14)
                plt.tight_layout()
                plt.show()
            
            return importance_df
            
        except Exception as e:
            print(f"Error calculating feature importance: {e}")
            return None
    
    def initialize_shap_explainer(self):
        """
        Initialize SHAP explainer for the model.
        """
        if self.hnids_model is None or self.X_train is None:
            print("Model or training data not available. Please load model and data first.")
            return
        
        try:
            # Get the core Random Forest model from HNIDS
            rf_model = self.hnids_model.irf_model
            
            # Get selected features if available
            if hasattr(self.hnids_model, 'best_features') and self.hnids_model.best_features is not None:
                feature_indices = np.where(self.hnids_model.best_features == 1)[0]
                if hasattr(self.hnids_model, 'feature_names') and self.hnids_model.feature_names is not None:
                    valid_indices = [idx for idx in feature_indices if idx < len(self.hnids_model.feature_names)]
                    selected_features = [self.hnids_model.feature_names[idx] for idx in valid_indices]
                    X_train_selected = self.X_train[selected_features]
                else:
                    # Use all features if feature names not available
                    X_train_selected = self.X_train
            else:
                # Use all features if no feature selection
                X_train_selected = self.X_train
            
            # Create a smaller sample for faster SHAP explanation
            X_sample = X_train_selected.sample(min(1000, len(X_train_selected)), random_state=42)
            
            # Initialize SHAP explainer
            self.explainers['shap'] = shap.TreeExplainer(rf_model)
            print("SHAP explainer initialized successfully.")
            
            return X_sample
            
        except Exception as e:
            print(f"Error initializing SHAP explainer: {e}")
    
    def global_shap_explanation(self, max_display=20):
        """
        Provide global SHAP explanation for the model.
        
        Args:
            max_display: Maximum number of features to display
        """
        if 'shap' not in self.explainers:
            X_sample = self.initialize_shap_explainer()
            if X_sample is None:
                return
        else:
            # Get selected features
            if hasattr(self.hnids_model, 'best_features') and self.hnids_model.best_features is not None:
                feature_indices = np.where(self.hnids_model.best_features == 1)[0]
                if hasattr(self.hnids_model, 'feature_names') and self.hnids_model.feature_names is not None:
                    valid_indices = [idx for idx in feature_indices if idx < len(self.hnids_model.feature_names)]
                    selected_features = [self.hnids_model.feature_names[idx] for idx in valid_indices]
                    X_sample = self.X_train[selected_features].sample(min(1000, len(self.X_train)), random_state=42)
                else:
                    X_sample = self.X_train.sample(min(1000, len(self.X_train)), random_state=42)
            else:
                X_sample = self.X_train.sample(min(1000, len(self.X_train)), random_state=42)
        
        try:
            # Calculate SHAP values
            shap_values = self.explainers['shap'].shap_values(X_sample)
            
            # Plot summary
            plt.figure(figsize=(12, 10))
            
            # For binary classification, we need the SHAP values for class 1 (intrusion)
            if isinstance(shap_values, list) and len(shap_values) > 1:
                shap_values_to_plot = shap_values[1]  # Class 1 (intrusion)
                shap.summary_plot(shap_values_to_plot, X_sample, max_display=max_display, 
                                show=False, plot_size=(12, 10))
            else:
                shap.summary_plot(shap_values, X_sample, max_display=max_display, 
                                show=False, plot_size=(12, 10))
            
            plt.title('SHAP Feature Importance', fontsize=16)
            plt.tight_layout()
            plt.show()
            
            # Plot bar summary
            plt.figure(figsize=(12, 8))
            
            if isinstance(shap_values, list) and len(shap_values) > 1:
                shap.summary_plot(shap_values_to_plot, X_sample, plot_type="bar", 
                                max_display=max_display, show=False, plot_size=(12, 8))
            else:
                shap.summary_plot(shap_values, X_sample, plot_type="bar", 
                                max_display=max_display, show=False, plot_size=(12, 8))
            
            plt.title('SHAP Mean Absolute Value (Impact on Model Output)', fontsize=16)
            plt.tight_layout()
            plt.show()
            
        except Exception as e:
            print(f"Error in SHAP explanation: {e}")
    
    def local_shap_explanation(self, instance_idx, force_plot=False):
        """
        Provide local SHAP explanation for a specific instance.
        
        Args:
            instance_idx: Index of the instance to explain
            force_plot: Whether to create a force plot
        """
        if 'shap' not in self.explainers:
            X_sample = self.initialize_shap_explainer()
            if X_sample is None:
                return
        
        try:
            # Get selected features if available
            if hasattr(self.hnids_model, 'best_features') and self.hnids_model.best_features is not None:
                feature_indices = np.where(self.hnids_model.best_features == 1)[0]
                if hasattr(self.hnids_model, 'feature_names') and self.hnids_model.feature_names is not None:
                    valid_indices = [idx for idx in feature_indices if idx < len(self.hnids_model.feature_names)]
                    selected_features = [self.hnids_model.feature_names[idx] for idx in valid_indices]
                    instance = self.X_test[selected_features].iloc[instance_idx:instance_idx+1]
                else:
                    instance = self.X_test.iloc[instance_idx:instance_idx+1]
            else:
                instance = self.X_test.iloc[instance_idx:instance_idx+1]
            
            # Get true label and prediction
            true_label = self.y_test.iloc[instance_idx]
            prediction = self.hnids_model.predict(pd.DataFrame(instance).reset_index(drop=True))[0]
            
            # Calculate SHAP values
            shap_values = self.explainers['shap'].shap_values(instance)
            
            # For binary classification
            if isinstance(shap_values, list) and len(shap_values) > 1:
                print(f"Instance {instance_idx}:")
                print(f"True Label: {true_label} (0=Normal, 1=Intrusion)")
                print(f"Prediction: {prediction} (0=Normal, 1=Intrusion)")
                
                # Generate waterfall plot for the predicted class
                plt.figure(figsize=(12, 8))
                shap.waterfall_plot(shap.Explanation(
                    values=shap_values[prediction][0], 
                    base_values=self.explainers['shap'].expected_value[prediction],
                    data=instance.values[0],
                    feature_names=instance.columns.tolist()
                ), max_display=20, show=False)
                plt.title(f'SHAP Waterfall Plot for Instance {instance_idx}', fontsize=16)
                plt.tight_layout()
                plt.show()
                
                # Generate force plot if requested
                if force_plot:
                    shap.force_plot(
                        self.explainers['shap'].expected_value[prediction],
                        shap_values[prediction][0],
                        instance.iloc[0],
                        matplotlib=True,
                        show=False
                    )
                    plt.title(f'SHAP Force Plot for Instance {instance_idx}', fontsize=16)
                    plt.tight_layout()
                    plt.show()
            else:
                print(f"Instance {instance_idx}:")
                print(f"True Label: {true_label} (0=Normal, 1=Intrusion)")
                print(f"Prediction: {prediction} (0=Normal, 1=Intrusion)")
                
                # Generate waterfall plot
                plt.figure(figsize=(12, 8))
                shap.waterfall_plot(shap.Explanation(
                    values=shap_values[0], 
                    base_values=self.explainers['shap'].expected_value,
                    data=instance.values[0],
                    feature_names=instance.columns.tolist()
                ), max_display=20, show=False)
                plt.title(f'SHAP Waterfall Plot for Instance {instance_idx}', fontsize=16)
                plt.tight_layout()
                plt.show()
                
                # Generate force plot if requested
                if force_plot:
                    shap.force_plot(
                        self.explainers['shap'].expected_value,
                        shap_values[0],
                        instance.iloc[0],
                        matplotlib=True,
                        show=False
                    )
                    plt.title(f'SHAP Force Plot for Instance {instance_idx}', fontsize=16)
                    plt.tight_layout()
                    plt.show()
                    
        except Exception as e:
            print(f"Error in local SHAP explanation: {e}")
    
    def initialize_lime_explainer(self):
        """
        Initialize LIME explainer for the model.
        """
        if self.hnids_model is None or self.X_train is None:
            print("Model or training data not available. Please load model and data first.")
            return
        
        try:
            # Get selected features if available
            if hasattr(self.hnids_model, 'best_features') and self.hnids_model.best_features is not None:
                feature_indices = np.where(self.hnids_model.best_features == 1)[0]
                if hasattr(self.hnids_model, 'feature_names') and self.hnids_model.feature_names is not None:
                    valid_indices = [idx for idx in feature_indices if idx < len(self.hnids_model.feature_names)]
                    selected_features = [self.hnids_model.feature_names[idx] for idx in valid_indices]
                    X_train_selected = self.X_train[selected_features]
                else:
                    X_train_selected = self.X_train
            else:
                X_train_selected = self.X_train
            
            # Initialize LIME explainer
            self.explainers['lime'] = lime.lime_tabular.LimeTabularExplainer(
                X_train_selected.values,
                feature_names=X_train_selected.columns.tolist(),
                class_names=['Normal', 'Intrusion'],
                discretize_continuous=True,
                mode='classification'
            )
            
            print("LIME explainer initialized successfully.")
            
        except Exception as e:
            print(f"Error initializing LIME explainer: {e}")
    
    def local_lime_explanation(self, instance_idx, num_features=10):
        """
        Provide local LIME explanation for a specific instance.
        
        Args:
            instance_idx: Index of the instance to explain
            num_features: Number of features to include in the explanation
        """
        if 'lime' not in self.explainers:
            self.initialize_lime_explainer()
            if 'lime' not in self.explainers:
                return
        
        try:
            # Get the Random Forest model
            rf_model = self.hnids_model.irf_model
            
            # Get selected features if available
            if hasattr(self.hnids_model, 'best_features') and self.hnids_model.best_features is not None:
                feature_indices = np.where(self.hnids_model.best_features == 1)[0]
                if hasattr(self.hnids_model, 'feature_names') and self.hnids_model.feature_names is not None:
                    valid_indices = [idx for idx in feature_indices if idx < len(self.hnids_model.feature_names)]
                    selected_features = [self.hnids_model.feature_names[idx] for idx in valid_indices]
                    instance = self.X_test[selected_features].iloc[instance_idx].values
                    feature_names = selected_features
                else:
                    instance = self.X_test.iloc[instance_idx].values
                    feature_names = self.X_test.columns.tolist()
            else:
                instance = self.X_test.iloc[instance_idx].values
                feature_names = self.X_test.columns.tolist()
            
            # Get true label and prediction
            true_label = self.y_test.iloc[instance_idx]
            prediction_fn = lambda x: rf_model.predict_proba(x)
            
            # Generate LIME explanation
            explanation = self.explainers['lime'].explain_instance(
                instance, prediction_fn, num_features=num_features
            )
            
            # Display explanation
            print(f"Instance {instance_idx}:")
            print(f"True Label: {true_label} (0=Normal, 1=Intrusion)")
            predicted_prob = rf_model.predict_proba([instance])[0]
            predicted_class = np.argmax(predicted_prob)
            print(f"Predicted: {predicted_class} (0=Normal, 1=Intrusion)")
            print(f"Prediction Probabilities: Normal={predicted_prob[0]:.4f}, Intrusion={predicted_prob[1]:.4f}")
            
            # Plot explanation
            plt.figure(figsize=(12, 8))
            explanation.as_pyplot_figure(label=1)  # Plot for intrusion class
            plt.title(f'LIME Explanation for Instance {instance_idx}', fontsize=16)
            plt.tight_layout()
            plt.show()
            
        except Exception as e:
            print(f"Error in LIME explanation: {e}")
    
    def partial_dependence_plot(self, feature_names, grid_resolution=20, max_plots=6):
        """
        Create partial dependence plots for selected features.
        
        Args:
            feature_names: List of feature names to plot
            grid_resolution: Resolution of the grid
            max_plots: Maximum number of plots to display
        """
        if self.hnids_model is None or self.X_train is None:
            print("Model or training data not available. Please load model and data first.")
            return
        
        try:
            # Get the Random Forest model
            rf_model = self.hnids_model.irf_model
            
            # Get selected features if available
            if hasattr(self.hnids_model, 'best_features') and self.hnids_model.best_features is not None:
                feature_indices = np.where(self.hnids_model.best_features == 1)[0]
                if hasattr(self.hnids_model, 'feature_names') and self.hnids_model.feature_names is not None:
                    valid_indices = [idx for idx in feature_indices if idx < len(self.hnids_model.feature_names)]
                    selected_features = [self.hnids_model.feature_names[idx] for idx in valid_indices]
                    X_data = self.X_train[selected_features]
                else:
                    X_data = self.X_train
            else:
                X_data = self.X_train
            
            # Validate feature names
            valid_feature_names = []
            for feature in feature_names:
                if feature in X_data.columns:
                    valid_feature_names.append(feature)
                else:
                    print(f"Warning: Feature '{feature}' not found in the data.")
            
            if not valid_feature_names:
                print("No valid features to plot.")
                return
            
            # Limit the number of plots
            valid_feature_names = valid_feature_names[:max_plots]
            
            # Calculate feature indices
            feature_indices = [list(X_data.columns).index(feat) for feat in valid_feature_names]
            
            # Create partial dependence plots
            from sklearn.inspection import partial_dependence
            
            # Set up the figure
            n_features = len(valid_feature_names)
            n_cols = min(3, n_features)
            n_rows = (n_features + n_cols - 1) // n_cols
            
            fig, axes = plt.subplots(n_rows, n_cols, figsize=(15, 5 * n_rows))
            if n_features == 1:
                axes = np.array([axes])  # Make it indexable for a single plot
            axes = axes.flatten()
            
            # Generate plots
            for i, (feature_idx, feature_name) in enumerate(zip(feature_indices, valid_feature_names)):
                # Calculate partial dependence
                pdp = partial_dependence(
                    rf_model, X_data, features=[feature_idx], 
                    kind='average', grid_resolution=grid_resolution
                )
                
                # Plot
                feature_values = pdp['values'][0]
                avg_effect = pdp['average'][0]
                
                ax = axes[i]
                ax.plot(feature_values, avg_effect, 'b-', linewidth=2)
                ax.set_title(f'Partial Dependence for {feature_name}', fontsize=12)
                ax.set_xlabel(feature_name, fontsize=10)
                ax.set_ylabel('Average Prediction (Probability)', fontsize=10)
                ax.grid(True, alpha=0.3)
            
            # Hide any unused axes
            for j in range(i+1, len(axes)):
                axes[j].set_visible(False)
                
            plt.tight_layout()
            plt.show()
            
        except Exception as e:
            print(f"Error creating partial dependence plots: {e}")
    
    def feature_interaction_plot(self, feature1, feature2):
        """
        Create a 2D partial dependence plot to show interaction between two features.
        
        Args:
            feature1: First feature name
            feature2: Second feature name
        """
        if self.hnids_model is None or self.X_train is None:
            print("Model or training data not available. Please load model and data first.")
            return
        
        try:
            # Get the Random Forest model
            rf_model = self.hnids_model.irf_model
            
            # Get selected features if available
            if hasattr(self.hnids_model, 'best_features') and self.hnids_model.best_features is not None:
                feature_indices = np.where(self.hnids_model.best_features == 1)[0]
                if hasattr(self.hnids_model, 'feature_names') and self.hnids_model.feature_names is not None:
                    valid_indices = [idx for idx in feature_indices if idx < len(self.hnids_model.feature_names)]
                    selected_features = [self.hnids_model.feature_names[idx] for idx in valid_indices]
                    X_data = self.X_train[selected_features]
                else:
                    X_data = self.X_train
            else:
                X_data = self.X_train
            
            # Validate features
            if feature1 not in X_data.columns:
                print(f"Feature '{feature1}' not found in the data.")
                return
            if feature2 not in X_data.columns:
                print(f"Feature '{feature2}' not found in the data.")
                return
            
            # Get feature indices
            feature1_idx = list(X_data.columns).index(feature1)
            feature2_idx = list(X_data.columns).index(feature2)
            
            # Calculate 2D partial dependence
            from sklearn.inspection import partial_dependence
            
            pdp = partial_dependence(
                rf_model, X_data, features=[(feature1_idx, feature2_idx)], 
                kind='average', grid_resolution=20
            )
            
            # Extract data for plotting
            XX, YY = np.meshgrid(pdp['values'][0][0], pdp['values'][0][1])
            Z = pdp['average'][0].T
            
            # Create the plot
            fig, ax = plt.subplots(figsize=(10, 8))
            
            # Contour plot
            contour = ax.contourf(XX, YY, Z, cmap='viridis', alpha=0.8)
            
            # Add colorbar
            cbar = plt.colorbar(contour, ax=ax)
            cbar.set_label('Average Prediction (Probability)', fontsize=12)
            
            # Set labels and title
            ax.set_xlabel(feature1, fontsize=12)
            ax.set_ylabel(feature2, fontsize=12)
            ax.set_title(f'Interaction Effect between {feature1} and {feature2}', fontsize=14)
            
            plt.tight_layout()
            plt.show()
            
        except Exception as e:
            print(f"Error creating feature interaction plot: {e}")
    
    def plot_confusion_matrix(self, normalized=True):
        """
        Plot confusion matrix for the model's predictions.
        
        Args:
            normalized: Whether to normalize the confusion matrix
        """
        if self.hnids_model is None or self.X_test is None or self.y_test is None:
            print("Model or test data not available. Please load model and data first.")
            return
        
        try:
            # Get predictions
            y_pred = self.hnids_model.predict(self.X_test)
            
            # Calculate confusion matrix
            cm = confusion_matrix(self.y_test, y_pred)
            
            # Normalize if requested
            if normalized:
                cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
                
            # Create the plot
            plt.figure(figsize=(10, 8))
            sns.heatmap(cm, annot=True, fmt='.2f' if normalized else 'd', 
                      cmap='Blues', cbar=False, square=True)
            
            # Set labels and title
            plt.xlabel('Predicted Label', fontsize=12)
            plt.ylabel('True Label', fontsize=12)
            title = 'Normalized Confusion Matrix' if normalized else 'Confusion Matrix'
            plt.title(title, fontsize=14)
            
            # Set tick labels
            plt.xticks([0.5, 1.5], ['Normal', 'Intrusion'])
            plt.yticks([0.5, 1.5], ['Normal', 'Intrusion'])
            
            plt.tight_layout()
            plt.show()
            
            # Print classification report
            print("Classification Report:")
            print(classification_report(self.y_test, y_pred, target_names=['Normal', 'Intrusion']))
            
        except Exception as e:
            print(f"Error plotting confusion matrix: {e}")
    
    def get_detection_threshold_metrics(self, thresholds=None):
        """
        Calculate model metrics at different detection thresholds.
        
        Args:
            thresholds: List of threshold values to evaluate
            
        Returns:
            DataFrame with metrics for each threshold
        """
        if self.hnids_model is None or self.X_test is None or self.y_test is None:
            print("Model or test data not available. Please load model and data first.")
            return None
        
        try:
            # Get predictions
            if not hasattr(self.hnids_model.irf_model, 'predict_proba'):
                print("Model does not support probability predictions.")
                return None
                
            y_proba = self.hnids_model.irf_model.predict_proba(self.X_test)[:, 1]
            
            # Define thresholds if not provided
            if thresholds is None:
                thresholds = np.arange(0.1, 1.0, 0.1)
                
            # Calculate metrics for each threshold
            results = []
            
            for threshold in thresholds:
                y_pred = (y_proba >= threshold).astype(int)
                
                # Calculate metrics
                tn, fp, fn, tp = confusion_matrix(self.y_test, y_pred).ravel()
                
                accuracy = (tp + tn) / (tp + tn + fp + fn)
                precision = tp / (tp + fp) if (tp + fp) > 0 else 0
                recall = tp / (tp + fn) if (tp + fn) > 0 else 0
                f1 = 2 * precision * recall / (precision + recall) if (precision + recall) > 0 else 0
                specificity = tn / (tn + fp) if (tn + fp) > 0 else 0
                false_alarm = fp / (fp + tn) if (fp + tn) > 0 else 0
                
                results.append({
                    'Threshold': threshold,
                    'Accuracy': accuracy,
                    'Precision': precision,
                    'Recall': recall,
                    'F1-Score': f1,
                    'Specificity': specificity,
                    'False Alarm Rate': false_alarm,
                    'TP': tp,
                    'FP': fp,
                    'TN': tn,
                    'FN': fn
                })
            
            # Create DataFrame
            results_df = pd.DataFrame(results)
            
            # Plot metrics
            plt.figure(figsize=(12, 8))
            
            plt.plot(results_df['Threshold'], results_df['Accuracy'], 'b-', linewidth=2, label='Accuracy')
            plt.plot(results_df['Threshold'], results_df['Precision'], 'g-', linewidth=2, label='Precision')
            plt.plot(results_df['Threshold'], results_df['Recall'], 'r-', linewidth=2, label='Recall')
            plt.plot(results_df['Threshold'], results_df['F1-Score'], 'c-', linewidth=2, label='F1-Score')
            plt.plot(results_df['Threshold'], results_df['False Alarm Rate'], 'm-', linewidth=2, label='False Alarm Rate')
            
            plt.grid(True, alpha=0.3)
            plt.xlabel('Threshold', fontsize=12)
            plt.ylabel('Metric Value', fontsize=12)
            plt.title('Model Metrics at Different Detection Thresholds', fontsize=14)
            plt.legend(fontsize=10)
            
            plt.tight_layout()
            plt.show()
            
            return results_df
            
        except Exception as e:
            print(f"Error calculating threshold metrics: {e}")
            return None
    
    def plot_roc_curve(self):
        """
        Plot ROC curve and calculate AUC.
        """
        if self.hnids_model is None or self.X_test is None or self.y_test is None:
            print("Model or test data not available. Please load model and data first.")
            return
        
        try:
            from sklearn.metrics import roc_curve, auc
            
            # Get predictions
            if not hasattr(self.hnids_model.irf_model, 'predict_proba'):
                print("Model does not support probability predictions.")
                return
                
            y_proba = self.hnids_model.irf_model.predict_proba(self.X_test)[:, 1]
            
            # Calculate ROC curve
            fpr, tpr, thresholds = roc_curve(self.y_test, y_proba)
            roc_auc = auc(fpr, tpr)
            
            # Plot ROC curve
            plt.figure(figsize=(10, 8))
            
            plt.plot(fpr, tpr, 'b-', linewidth=2, label=f'ROC curve (AUC = {roc_auc:.4f})')
            plt.plot([0, 1], [0, 1], 'k--', linewidth=1)
            
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel('False Positive Rate', fontsize=12)
            plt.ylabel('True Positive Rate', fontsize=12)
            plt.title('Receiver Operating Characteristic (ROC) Curve', fontsize=14)
            plt.legend(loc='lower right', fontsize=10)
            plt.grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.show()
            
            # Print AUC value
            print(f"Area Under the ROC Curve (AUC): {roc_auc:.4f}")
            
        except Exception as e:
            print(f"Error plotting ROC curve: {e}")
    
    def plot_precision_recall_curve(self):
        """
        Plot Precision-Recall curve.
        """
        if self.hnids_model is None or self.X_test is None or self.y_test is None:
            print("Model or test data not available. Please load model and data first.")
            return
        
        try:
            from sklearn.metrics import precision_recall_curve, average_precision_score
            
            # Get predictions
            if not hasattr(self.hnids_model.irf_model, 'predict_proba'):
                print("Model does not support probability predictions.")
                return
                
            y_proba = self.hnids_model.irf_model.predict_proba(self.X_test)[:, 1]
            
            # Calculate Precision-Recall curve
            precision, recall, thresholds = precision_recall_curve(self.y_test, y_proba)
            average_precision = average_precision_score(self.y_test, y_proba)
            
            # Plot Precision-Recall curve
            plt.figure(figsize=(10, 8))
            
            plt.plot(recall, precision, 'b-', linewidth=2, 
                   label=f'Precision-Recall curve (AP = {average_precision:.4f})')
            
            plt.xlim([0.0, 1.0])
            plt.ylim([0.0, 1.05])
            plt.xlabel('Recall', fontsize=12)
            plt.ylabel('Precision', fontsize=12)
            plt.title('Precision-Recall Curve', fontsize=14)
            plt.legend(loc='lower left', fontsize=10)
            plt.grid(True, alpha=0.3)
            
            plt.tight_layout()
            plt.show()
            
            # Print Average Precision
            print(f"Average Precision (AP): {average_precision:.4f}")
            
        except Exception as e:
            print(f"Error plotting Precision-Recall curve: {e}")
    
    def analyze_attack_types(self, attack_mapping):
        """
        Analyze model performance on different attack types.
        
        Args:
            attack_mapping: Dictionary mapping attack classes to attack types
        """
        if self.hnids_model is None or self.X_test is None or self.y_test is None:
            print("Model or test data not available. Please load model and data first.")
            return
        
        try:
            # Get predictions
            y_pred = self.hnids_model.predict(self.X_test)
            
            # Create DataFrame with true and predicted labels
            results_df = pd.DataFrame({
                'true_label': self.y_test,
                'predicted': y_pred
            })
            
            # Add attack type if mapping is provided
            if attack_mapping is not None and hasattr(self.y_test, 'index'):
                attack_types = []
                for idx in self.y_test.index:
                    attack_type = attack_mapping.get(idx, 'Unknown')
                    attack_types.append(attack_type)
                
                results_df['attack_type'] = attack_types
                
                # Analyze performance by attack type
                attack_type_results = {}
                
                for attack_type in results_df['attack_type'].unique():
                    attack_data = results_df[results_df['attack_type'] == attack_type]
                    
                    # Calculate metrics
                    accuracy = (attack_data['true_label'] == attack_data['predicted']).mean()
                    
                    attack_type_results[attack_type] = {
                        'count': len(attack_data),
                        'accuracy': accuracy
                    }
                
                # Create DataFrame
                attack_results_df = pd.DataFrame.from_dict(attack_type_results, orient='index')
                
                # Plot results
                plt.figure(figsize=(12, 6))
                
                ax = sns.barplot(x=attack_results_df.index, y='accuracy', data=attack_results_df)
                
                # Add count annotations
                for i, p in enumerate(ax.patches):
                    attack_type = attack_results_df.index[i]
                    count = attack_results_df.loc[attack_type, 'count']
                    ax.annotate(f'n={count}', (p.get_x() + p.get_width() / 2., p.get_height()),
                               ha='center', va='center', fontsize=10, color='black',
                               xytext=(0, 5), textcoords='offset points')
                
                plt.xticks(rotation=45, ha='right')
                plt.xlabel('Attack Type', fontsize=12)
                plt.ylabel('Accuracy', fontsize=12)
                plt.title('Model Accuracy by Attack Type', fontsize=14)
                plt.grid(True, alpha=0.3)
                
                plt.tight_layout()
                plt.show()
                
                return attack_results_df
            else:
                print("Attack mapping not provided or index not available in test data.")
                
        except Exception as e:
            print(f"Error analyzing attack types: {e}")
            
    def analyze_errors(self, n_samples=5):
        """
        Analyze model errors by showing examples of false positives and false negatives.
        
        Args:
            n_samples: Number of examples to show for each error type
        """
        if self.hnids_model is None or self.X_test is None or self.y_test is None:
            print("Model or test data not available. Please load model and data first.")
            return
        
        try:
            # Get predictions
            y_pred = self.hnids_model.predict(self.X_test)
            
            # Create DataFrame with results
            results_df = pd.DataFrame({
                'true_label': self.y_test,
                'predicted': y_pred
            })
            
            # Identify false positives and false negatives
            false_positives = results_df[(results_df['true_label'] == 0) & (results_df['predicted'] == 1)]
            false_negatives = results_df[(results_df['true_label'] == 1) & (results_df['predicted'] == 0)]
            
            # Print summary
            print(f"Total test samples: {len(results_df)}")
            print(f"False positives (Normal classified as Intrusion): {len(false_positives)}")
            print(f"False negatives (Intrusion classified as Normal): {len(false_negatives)}")
            
            # Show examples of false positives
            if len(false_positives) > 0:
                print("\n===== FALSE POSITIVES (Normal misclassified as Intrusion) =====")
                fp_indices = false_positives.index[:min(n_samples, len(false_positives))]
                
                for i, idx in enumerate(fp_indices):
                    print(f"\nFalse Positive Example {i+1} (Index: {idx}):")
                    instance = self.X_test.loc[idx]
                    
                    # Show top features
                    features_df = pd.DataFrame({'Feature': instance.index, 'Value': instance.values})
                    features_df = features_df.sort_values('Value', ascending=False)
                    print(features_df.head(10))
                    
                    # If LIME explainer is available, show explanation
                    if 'lime' in self.explainers:
                        print("\nLIME Explanation:")
                        rf_model = self.hnids_model.irf_model
                        prediction_fn = lambda x: rf_model.predict_proba(x)
                        
                        explanation = self.explainers['lime'].explain_instance(
                            instance.values, prediction_fn, num_features=10
                        )
                        
                        # Get features contributing to prediction
                        lime_exp = explanation.as_list(label=1)  # For intrusion class
                        lime_df = pd.DataFrame(lime_exp, columns=['Feature', 'Contribution'])
                        lime_df = lime_df.sort_values('Contribution', ascending=False)
                        print(lime_df)
            
            # Show examples of false negatives
            if len(false_negatives) > 0:
                print("\n===== FALSE NEGATIVES (Intrusion misclassified as Normal) =====")
                fn_indices = false_negatives.index[:min(n_samples, len(false_negatives))]
                
                for i, idx in enumerate(fn_indices):
                    print(f"\nFalse Negative Example {i+1} (Index: {idx}):")
                    instance = self.X_test.loc[idx]
                    
                    # Show top features
                    features_df = pd.DataFrame({'Feature': instance.index, 'Value': instance.values})
                    features_df = features_df.sort_values('Value', ascending=False)
                    print(features_df.head(10))
                    
                    # If LIME explainer is available, show explanation
                    if 'lime' in self.explainers:
                        print("\nLIME Explanation:")
                        rf_model = self.hnids_model.irf_model
                        prediction_fn = lambda x: rf_model.predict_proba(x)
                        
                        explanation = self.explainers['lime'].explain_instance(
                            instance.values, prediction_fn, num_features=10
                        )
                        
                        # Get features contributing to prediction
                        lime_exp = explanation.as_list(label=0)  # For normal class
                        lime_df = pd.DataFrame(lime_exp, columns=['Feature', 'Contribution'])
                        lime_df = lime_df.sort_values('Contribution', ascending=False)
                        print(lime_df)
                        
        except Exception as e:
            print(f"Error analyzing model errors: {e}")
            
    def save_explainer(self, filepath):
        """
        Save the explainer for future use.
        
        Args:
            filepath: Path to save the explainer
        """
        try:
            # Create directory if it doesn't exist
            os.makedirs(os.path.dirname(filepath), exist_ok=True)
            
            # Save explainers
            with open(filepath, 'wb') as f:
                pickle.dump(self.explainers, f)
                
            print(f"Explainer saved to {filepath}")
            
        except Exception as e:
            print(f"Error saving explainer: {e}")
            
    def load_explainer(self, filepath):
        """
        Load a previously saved explainer.
        
        Args:
            filepath: Path to the saved explainer
        """
        try:
            with open(filepath, 'rb') as f:
                self.explainers = pickle.load(f)
                
            print(f"Explainer loaded from {filepath}")
            
        except Exception as e:
            print(f"Error loading explainer: {e}")


# Example usage
def main():
    """Demonstrate XAI for HNIDS model."""
    # Load your trained HNIDS model
    model_path = 'hnids_model.pkl'
    
    # Try to load previously saved model
    try:
        with open(model_path, 'rb') as f:
            hnids_model = pickle.load(f)
        
        print("Model loaded successfully.")
    except FileNotFoundError:
        print("Model file not found. Please train a model first or specify the correct path.")
        return
    except Exception as e:
        print(f"Error loading model: {e}")
        return
    
    # Load test data (assuming you have this data saved)
    try:
        X_test = pd.read_csv('X_test.csv')
        y_test = pd.read_csv('y_test.csv', squeeze=True)
        
        print("Test data loaded successfully.")
    except FileNotFoundError:
        print("Test data files not found. Please prepare your data first.")
        return
    except Exception as e:
        print(f"Error loading test data: {e}")
        return
    
    # Initialize explainer
    explainer = HNIDS_Explainer(hnids_model, X_test=X_test, y_test=y_test)
    
    # Get feature importance
    print("\n===== Feature Importance =====")
    importance_df = explainer.get_feature_importance(n_top_features=15)
    if importance_df is not None:
        print(importance_df.head(15))
    
    # Initialize SHAP explainer (can take some time for large models)
    print("\n===== Initializing SHAP Explainer =====")
    explainer.initialize_shap_explainer()
    
    # Global SHAP explanation
    print("\n===== Global SHAP Explanation =====")
    explainer.global_shap_explanation(max_display=15)
    
    # Local SHAP explanation for a specific instance
    print("\n===== Local SHAP Explanation =====")
    explainer.local_shap_explanation(instance_idx=0, force_plot=True)
    
    # Initialize LIME explainer
    print("\n===== Initializing LIME Explainer =====")
    explainer.initialize_lime_explainer()
    
    # Local LIME explanation
    print("\n===== Local LIME Explanation =====")
    explainer.local_lime_explanation(instance_idx=0, num_features=10)
    
    # Partial dependence plots for top features
    if importance_df is not None:
        top_features = importance_df['Feature'].head(6).tolist()
        print(f"\n===== Partial Dependence Plots for Top Features =====")
        explainer.partial_dependence_plot(top_features)
    
    # Feature interaction plot for top 2 features
    if importance_df is not None and len(importance_df) >= 2:
        feature1 = importance_df['Feature'].iloc[0]
        feature2 = importance_df['Feature'].iloc[1]
        print(f"\n===== Feature Interaction Plot: {feature1} vs {feature2} =====")
        explainer.feature_interaction_plot(feature1, feature2)
    
    # Confusion matrix
    print("\n===== Confusion Matrix =====")
    explainer.plot_confusion_matrix(normalized=True)
    
    # ROC curve
    print("\n===== ROC Curve =====")
    explainer.plot_roc_curve()
    
    # Precision-Recall curve
    print("\n===== Precision-Recall Curve =====")
    explainer.plot_precision_recall_curve()
    
    # Detection threshold metrics
    print("\n===== Detection Threshold Metrics =====")
    thresholds_df = explainer.get_detection_threshold_metrics()
    if thresholds_df is not None:
        print(thresholds_df)
    
    # Error analysis
    print("\n===== Error Analysis =====")
    explainer.analyze_errors(n_samples=3)
    
    # Save explainer
    explainer.save_explainer('hnids_explainer.pkl')
    
    print("\nXAI analysis completed.")

if __name__ == "__main__":
    main()

Model file not found. Please train a model first or specify the correct path.
