In [None]:
# Import libraries
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import joblib
import json
import warnings

warnings.filterwarnings('ignore')
plt.style.use('seaborn-v0_8-whitegrid')

# Try to import SHAP
try:
    import shap
    shap.initjs()
    print(f"‚úÖ SHAP version: {shap.__version__}")
except ImportError:
    print("‚ùå SHAP not installed. Install with: pip install shap")
    shap = None

print("Libraries loaded successfully!")

## 1. Load Model and Data

In [None]:
# Define paths
DATA_DIR = Path('../data')
PROCESSED_DIR = DATA_DIR / 'processed'
MODELS_DIR = Path('../models')
RESULTS_DIR = DATA_DIR / 'results'
RESULTS_DIR.mkdir(parents=True, exist_ok=True)

print(f"Data directory: {DATA_DIR}")
print(f"Models directory: {MODELS_DIR}")

In [None]:
def create_sample_data_and_model(n_samples=5000):
    """
    Create sample data and train a simple model for demonstration.
    """
    from sklearn.ensemble import RandomForestClassifier
    from sklearn.preprocessing import LabelEncoder
    
    np.random.seed(42)
    
    # Feature names
    feature_names = [
        'Flow Duration', 'Total Fwd Packets', 'Total Bwd Packets',
        'Fwd Packet Length Mean', 'Bwd Packet Length Mean',
        'Flow Bytes/s', 'Flow Packets/s', 'Flow IAT Mean',
        'Fwd IAT Mean', 'Bwd IAT Mean', 'SYN Flag Count',
        'ACK Flag Count', 'FIN Flag Count', 'RST Flag Count',
        'Packet Length Mean', 'Packet Length Std', 'Packet Length Var',
        'Destination Port', 'Init Win Bytes Fwd', 'Init Win Bytes Bwd'
    ]
    
    # Generate features
    X = np.random.randn(n_samples, len(feature_names))
    X = pd.DataFrame(X, columns=feature_names)
    
    # Generate labels with correlation to features
    score = (X['SYN Flag Count'] * 0.3 + 
             X['Flow Packets/s'] * 0.2 + 
             X['RST Flag Count'] * 0.2 +
             np.random.randn(n_samples) * 0.3)
    
    labels = pd.cut(score, bins=[-np.inf, -0.5, 0.5, 1.0, np.inf],
                    labels=['BENIGN', 'PortScan', 'DoS', 'DDoS'])
    
    # Encode labels
    le = LabelEncoder()
    y = le.fit_transform(labels)
    
    # Train model
    model = RandomForestClassifier(n_estimators=100, max_depth=10, random_state=42, n_jobs=-1)
    model.fit(X, y)
    
    return X, y, model, le

# Try to load existing model and data, or create sample
try:
    X_test = pd.read_csv(PROCESSED_DIR / 'X_test.csv')
    y_test = pd.read_csv(PROCESSED_DIR / 'y_test.csv')['label']
    model = joblib.load(MODELS_DIR / 'xgboost_model.pkl')
    le = joblib.load(PROCESSED_DIR / 'label_encoder.pkl')
    y_test_encoded = le.transform(y_test)
    print("‚úÖ Loaded existing model and data")
except:
    print("‚ö†Ô∏è Creating sample data and model for demonstration...")
    X_test, y_test_encoded, model, le = create_sample_data_and_model()
    y_test = pd.Series(le.inverse_transform(y_test_encoded))

classes = le.classes_
print(f"\nData shape: {X_test.shape}")
print(f"Classes: {classes}")

## 2. SHAP Explainer Setup

In [None]:
if shap:
    # Sample data for SHAP (use subset for speed)
    n_explain = min(1000, len(X_test))
    X_explain = X_test.iloc[:n_explain]
    
    print(f"Computing SHAP values for {n_explain} samples...")
    print("This may take a few minutes...")
    
    # Create SHAP explainer
    # Use TreeExplainer for tree-based models (faster)
    try:
        explainer = shap.TreeExplainer(model)
        print("Using TreeExplainer")
    except:
        # Fallback to KernelExplainer
        background = shap.sample(X_test, 100)
        explainer = shap.KernelExplainer(model.predict_proba, background)
        print("Using KernelExplainer")
    
    # Calculate SHAP values
    shap_values = explainer.shap_values(X_explain)
    
    print(f"\n‚úÖ SHAP values computed!")
    print(f"Shape: {np.array(shap_values).shape}")
else:
    print("SHAP not available - skipping SHAP analysis")
    shap_values = None

## 3. Global Feature Importance

In [None]:
if shap and shap_values is not None:
    # Summary plot - bar
    plt.figure(figsize=(12, 8))
    
    # Handle multi-class SHAP values
    if isinstance(shap_values, list):
        # Average across classes
        mean_shap = np.abs(np.array(shap_values)).mean(axis=0)
        shap.summary_plot(mean_shap, X_explain, plot_type='bar', show=False)
    else:
        shap.summary_plot(shap_values, X_explain, plot_type='bar', show=False)
    
    plt.title('Global Feature Importance (SHAP)', fontsize=14)
    plt.tight_layout()
    plt.savefig(str(RESULTS_DIR / 'shap_global_importance.png'), dpi=150, bbox_inches='tight')
    plt.show()
else:
    # Fallback to model feature importance
    print("\nUsing model's built-in feature importance:")
    if hasattr(model, 'feature_importances_'):
        importance = pd.DataFrame({
            'feature': X_test.columns,
            'importance': model.feature_importances_
        }).sort_values('importance', ascending=True)
        
        plt.figure(figsize=(10, 8))
        plt.barh(importance['feature'], importance['importance'], color='steelblue')
        plt.xlabel('Importance')
        plt.title('Feature Importance')
        plt.tight_layout()
        plt.savefig(str(RESULTS_DIR / 'feature_importance_fallback.png'), dpi=150)
        plt.show()

In [None]:
if shap and shap_values is not None:
    # Summary plot - beeswarm (shows feature value impact)
    plt.figure(figsize=(12, 10))
    
    if isinstance(shap_values, list):
        # Use first class for beeswarm
        shap.summary_plot(shap_values[0], X_explain, show=False)
    else:
        shap.summary_plot(shap_values, X_explain, show=False)
    
    plt.title('SHAP Values (Feature Impact on Predictions)', fontsize=14)
    plt.tight_layout()
    plt.savefig(str(RESULTS_DIR / 'shap_beeswarm.png'), dpi=150, bbox_inches='tight')
    plt.show()

## 4. Per-Class Feature Importance

In [None]:
if shap and shap_values is not None and isinstance(shap_values, list):
    # Feature importance per class
    n_classes = len(classes)
    n_cols = min(2, n_classes)
    n_rows = (n_classes + 1) // 2
    
    fig, axes = plt.subplots(n_rows, n_cols, figsize=(14, 5*n_rows))
    if n_classes == 1:
        axes = [axes]
    else:
        axes = axes.flatten()
    
    for idx, (class_name, class_shap) in enumerate(zip(classes, shap_values)):
        if idx < len(axes):
            # Get mean absolute SHAP values
            mean_abs_shap = np.abs(class_shap).mean(axis=0)
            feature_importance = pd.DataFrame({
                'feature': X_explain.columns,
                'importance': mean_abs_shap
            }).sort_values('importance', ascending=True).tail(15)
            
            axes[idx].barh(feature_importance['feature'], 
                          feature_importance['importance'],
                          color=plt.cm.Set2(idx / n_classes))
            axes[idx].set_xlabel('Mean |SHAP|')
            axes[idx].set_title(f'Top Features for: {class_name}')
    
    # Hide unused axes
    for idx in range(n_classes, len(axes)):
        axes[idx].set_visible(False)
    
    plt.tight_layout()
    plt.savefig(str(RESULTS_DIR / 'shap_per_class_importance.png'), dpi=150, bbox_inches='tight')
    plt.show()
else:
    print("Per-class SHAP analysis requires multi-class SHAP values")

## 5. Individual Prediction Explanations

In [None]:
def explain_prediction(idx, X, shap_values, model, classes, explainer):
    """
    Generate detailed explanation for a single prediction.
    """
    sample = X.iloc[idx:idx+1]
    prediction = model.predict(sample)[0]
    
    # Get probabilities if available
    if hasattr(model, 'predict_proba'):
        proba = model.predict_proba(sample)[0]
    else:
        proba = None
    
    print("=" * 60)
    print(f"PREDICTION EXPLANATION - Sample #{idx}")
    print("=" * 60)
    print(f"\nüéØ Predicted Class: {classes[prediction]}")
    
    if proba is not None:
        print(f"\nüìä Class Probabilities:")
        for i, (cls, prob) in enumerate(zip(classes, proba)):
            bar = '‚ñà' * int(prob * 30)
            print(f"   {cls:15s}: {prob:.3f} {bar}")
    
    return sample, prediction, proba

# Explain a few predictions
if shap and shap_values is not None:
    # Find interesting samples (one from each class)
    predictions = model.predict(X_explain)
    
    print("\n" + "=" * 60)
    print("SAMPLE PREDICTION EXPLANATIONS")
    print("=" * 60)
    
    for class_idx, class_name in enumerate(classes[:3]):  # Limit to 3 classes
        class_samples = np.where(predictions == class_idx)[0]
        if len(class_samples) > 0:
            sample_idx = class_samples[0]
            sample, pred, proba = explain_prediction(
                sample_idx, X_explain, shap_values, model, classes, explainer
            )
            print()

In [None]:
if shap and shap_values is not None:
    # Waterfall plot for individual prediction
    sample_idx = 0
    
    print(f"\nWaterfall plot for sample #{sample_idx}:")
    
    plt.figure(figsize=(12, 8))
    
    # Get SHAP values for this sample
    if isinstance(shap_values, list):
        pred_class = model.predict(X_explain.iloc[sample_idx:sample_idx+1])[0]
        sample_shap = shap_values[pred_class][sample_idx]
        base_value = explainer.expected_value[pred_class] if hasattr(explainer.expected_value, '__len__') else explainer.expected_value
    else:
        sample_shap = shap_values[sample_idx]
        base_value = explainer.expected_value
    
    # Create explanation object
    explanation = shap.Explanation(
        values=sample_shap,
        base_values=base_value,
        data=X_explain.iloc[sample_idx].values,
        feature_names=X_explain.columns.tolist()
    )
    
    shap.plots.waterfall(explanation, show=False)
    plt.title(f'SHAP Waterfall - Sample #{sample_idx}')
    plt.tight_layout()
    plt.savefig(str(RESULTS_DIR / 'shap_waterfall.png'), dpi=150, bbox_inches='tight')
    plt.show()

## 6. Feature Interaction Analysis

In [None]:
if shap and shap_values is not None:
    # Dependence plots for top features
    print("Feature Dependence Analysis:")
    
    # Get top features by importance
    if isinstance(shap_values, list):
        mean_shap = np.abs(np.array(shap_values)).mean(axis=(0, 1))
    else:
        mean_shap = np.abs(shap_values).mean(axis=0)
    
    top_features_idx = np.argsort(mean_shap)[-4:][::-1]  # Top 4
    top_features = X_explain.columns[top_features_idx]
    
    fig, axes = plt.subplots(2, 2, figsize=(14, 10))
    axes = axes.flatten()
    
    for idx, feature in enumerate(top_features):
        plt.sca(axes[idx])
        
        if isinstance(shap_values, list):
            # Use first class for dependence plot
            feature_idx = X_explain.columns.get_loc(feature)
            shap.dependence_plot(
                feature_idx, shap_values[0], X_explain,
                ax=axes[idx], show=False
            )
        else:
            shap.dependence_plot(
                feature, shap_values, X_explain,
                ax=axes[idx], show=False
            )
        
        axes[idx].set_title(f'Dependence: {feature}')
    
    plt.tight_layout()
    plt.savefig(str(RESULTS_DIR / 'shap_dependence.png'), dpi=150, bbox_inches='tight')
    plt.show()

## 7. Attack Type Explanations

In [None]:
def explain_attack_type(attack_type, X, shap_values, model, classes, top_n=10):
    """
    Generate explanation for a specific attack type.
    """
    attack_idx = np.where(classes == attack_type)[0][0]
    predictions = model.predict(X)
    attack_samples = predictions == attack_idx
    
    if not attack_samples.any():
        print(f"No samples predicted as {attack_type}")
        return
    
    print(f"\n{'='*60}")
    print(f"ATTACK TYPE EXPLANATION: {attack_type}")
    print(f"{'='*60}")
    print(f"\nüìä Samples predicted as {attack_type}: {attack_samples.sum()}")
    
    # Get SHAP values for attack samples
    if isinstance(shap_values, list):
        attack_shap = shap_values[attack_idx][attack_samples]
    else:
        attack_shap = shap_values[attack_samples]
    
    # Mean absolute SHAP values for this attack
    mean_shap = np.abs(attack_shap).mean(axis=0)
    
    # Top contributing features
    top_idx = np.argsort(mean_shap)[-top_n:][::-1]
    
    print(f"\nüîç Top {top_n} Contributing Features:")
    print("-" * 50)
    for i, idx in enumerate(top_idx, 1):
        feature = X.columns[idx]
        importance = mean_shap[idx]
        avg_value = X.iloc[attack_samples][feature].mean()
        print(f"   {i:2d}. {feature:30s} | Importance: {importance:.4f} | Avg Value: {avg_value:.2f}")
    
    return mean_shap, top_idx

# Explain each attack type
if shap and shap_values is not None:
    for attack in classes:
        if attack != 'BENIGN':
            try:
                explain_attack_type(attack, X_explain, shap_values, model, classes)
            except:
                print(f"Could not explain {attack}")

## 8. Generate SOC-Ready Explanations

In [None]:
def generate_soc_explanation(sample, shap_vals, features, prediction, proba, classes):
    """
    Generate human-readable explanation for SOC analysts.
    """
    explanation = {
        'verdict': classes[prediction],
        'confidence': float(proba[prediction]) if proba is not None else None,
        'risk_level': 'HIGH' if proba is not None and proba[prediction] > 0.8 else 'MEDIUM' if proba is not None and proba[prediction] > 0.5 else 'LOW',
        'top_indicators': [],
        'narrative': ''
    }
    
    # Get top contributing features
    top_idx = np.argsort(np.abs(shap_vals))[-5:][::-1]
    
    for idx in top_idx:
        feature_name = features[idx]
        feature_value = float(sample.iloc[0, idx])
        shap_contribution = float(shap_vals[idx])
        direction = 'increased' if shap_contribution > 0 else 'decreased'
        
        explanation['top_indicators'].append({
            'feature': feature_name,
            'value': feature_value,
            'contribution': shap_contribution,
            'direction': direction
        })
    
    # Generate narrative
    if explanation['verdict'] != 'BENIGN':
        indicators = explanation['top_indicators'][:3]
        narrative_parts = []
        for ind in indicators:
            narrative_parts.append(
                f"{ind['feature']} ({ind['value']:.2f}) {ind['direction']} suspicion"
            )
        explanation['narrative'] = (
            f"This traffic was classified as {explanation['verdict']} with "
            f"{explanation['confidence']*100:.1f}% confidence. "
            f"Key indicators: {'; '.join(narrative_parts)}."
        )
    else:
        explanation['narrative'] = "Traffic appears normal with no significant anomalies detected."
    
    return explanation

# Generate explanations for sample alerts
if shap and shap_values is not None:
    print("\n" + "=" * 70)
    print("SOC-READY ALERT EXPLANATIONS")
    print("=" * 70)
    
    predictions = model.predict(X_explain)
    
    # Find attack samples
    attack_indices = np.where(predictions != np.where(classes == 'BENIGN')[0][0] if 'BENIGN' in classes else -1)[0][:3]
    
    for idx in attack_indices:
        sample = X_explain.iloc[idx:idx+1]
        pred = predictions[idx]
        proba = model.predict_proba(sample)[0] if hasattr(model, 'predict_proba') else None
        
        if isinstance(shap_values, list):
            sample_shap = shap_values[pred][idx]
        else:
            sample_shap = shap_values[idx]
        
        explanation = generate_soc_explanation(
            sample, sample_shap, X_explain.columns.tolist(),
            pred, proba, classes
        )
        
        print(f"\nüö® ALERT #{idx}")
        print(f"   Verdict: {explanation['verdict']}")
        print(f"   Risk Level: {explanation['risk_level']}")
        print(f"   Confidence: {explanation['confidence']*100:.1f}%" if explanation['confidence'] else "   Confidence: N/A")
        print(f"\n   üìù Narrative:")
        print(f"   {explanation['narrative']}")
        print(f"\n   üîç Key Indicators:")
        for ind in explanation['top_indicators']:
            print(f"      - {ind['feature']}: {ind['value']:.3f} ({ind['direction']} risk by {abs(ind['contribution']):.3f})")

## 9. Save Explainability Results

In [None]:
# Save feature importance rankings
if shap and shap_values is not None:
    if isinstance(shap_values, list):
        mean_shap = np.abs(np.array(shap_values)).mean(axis=(0, 1))
    else:
        mean_shap = np.abs(shap_values).mean(axis=0)
    
    importance_df = pd.DataFrame({
        'feature': X_explain.columns,
        'mean_abs_shap': mean_shap
    }).sort_values('mean_abs_shap', ascending=False)
    
    importance_df.to_csv(RESULTS_DIR / 'shap_feature_importance.csv', index=False)
    print(f"\n‚úÖ Feature importance saved to: {RESULTS_DIR / 'shap_feature_importance.csv'}")
    
    # Save explainability report
    explainability_report = {
        'analysis_date': pd.Timestamp.now().isoformat(),
        'n_samples_analyzed': len(X_explain),
        'n_features': len(X_explain.columns),
        'classes': list(classes),
        'top_10_features': importance_df.head(10).to_dict(orient='records')
    }
    
    with open(RESULTS_DIR / 'explainability_report.json', 'w') as f:
        json.dump(explainability_report, f, indent=2)
    
    print(f"‚úÖ Explainability report saved to: {RESULTS_DIR / 'explainability_report.json'}")

In [None]:
# Summary
print("\n" + "=" * 60)
print("EXPLAINABILITY ANALYSIS SUMMARY")
print("=" * 60)

if shap and shap_values is not None:
    print(f"\nüìä Samples analyzed: {len(X_explain)}")
    print(f"üìä Features analyzed: {len(X_explain.columns)}")
    print(f"\nüèÜ Top 5 Most Important Features (SHAP):")
    for i, row in importance_df.head(5).iterrows():
        print(f"   {row['feature']}: {row['mean_abs_shap']:.4f}")
    
    print(f"\nüìÅ Results saved to: {RESULTS_DIR}")
    print("   - shap_global_importance.png")
    print("   - shap_beeswarm.png")
    print("   - shap_per_class_importance.png")
    print("   - shap_waterfall.png")
    print("   - shap_dependence.png")
    print("   - shap_feature_importance.csv")
    print("   - explainability_report.json")
else:
    print("\n‚ö†Ô∏è SHAP analysis was not performed.")
    print("Install SHAP with: pip install shap")

## Key Insights for SOC Analysts

### How to Use SHAP Explanations:

1. **Global Importance**: Shows which features are most influential across all predictions
2. **Beeswarm Plot**: Shows how feature values (high/low) affect predictions
3. **Waterfall Plot**: Explains individual alerts step-by-step
4. **Dependence Plots**: Reveal feature interactions

### Actionable Recommendations:

- Focus monitoring on top features identified by SHAP
- Set thresholds based on feature values that increase attack probability
- Use individual explanations to validate/triage alerts
- Correlate SHAP insights with known attack signatures