In [2]:
# Step 4: Hybrid Explanation Generator
# Combines XAI (what) + Causal (why/how) for actionable NIDS alert explanations

import os
import warnings
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from pathlib import Path
import joblib
from typing import Dict, List, Tuple, Optional
import json

warnings.filterwarnings('ignore')

print("="*70)
print("STEP 4: HYBRID EXPLANATION GENERATOR")
print("="*70)

# ==================== CONFIGURATION ====================
LSTM_MODEL_PATH = '../step1_lstm_xai/best_lstm.pt'
SCALER_PATH = '../step1_lstm_xai/scaler.joblib'
CAUSAL_GRAPH_PATH = '../step2_causal_discovery/causal_graph.gpickle'
DATA_PATH = '../step2_causal_discovery/causal_discovery_data.csv'

# Feature names - MUST match the 42 features used in Step 1 LSTM training
# These are ALL features from the original dataset (excluding dropped columns)
FEATURE_NAMES = [
    'SignatureID', 'SignatureMatchesPerDay', 'AlertCount', 'Proto',
    'ExtPort', 'IntPort', 'Similarity', 'SCAS', 'AppProtoSimilarity',
    'DnsRrnameSimilarity', 'DnsRrtypeSimilarity', 'EmailFromSimilarity',
    'EmailStatusSimilarity', 'EmailToSimilarity', 'ExtIPSimilarity',
    'ExtPortSimilarity', 'HttpContentTypeSimilarity', 'HttpHostnameSimilarity',
    'HttpMethodSimilarity', 'HttpProtocolSimilarity', 'HttpRequestBodySimilarity',
    'HttpResponseBodySimilarity', 'HttpStatusSimilarity', 'HttpUrlSimilarity',
    'HttpUserAgentSimilarity', 'IntIPSimilarity', 'IntPortSimilarity',
    'ProtoSimilarity', 'SignatureIDSimilarity', 'SmtpHeloSimilarity',
    'SmtpMailFromSimilarity', 'SmtpRcptToSimilarity', 'SshClientProtoSimilarity',
    'SshClientSoftwareSimilarity', 'SshServerProtoSimilarity',
    'SshServerSoftwareSimilarity', 'TlsFingerprintSimilarity',
    'TlsIssuerDnSimilarity', 'TlsJa3hashSimilarity', 'TlsSniSimilarity',
    'TlsSubjectSimilarity', 'TlsVersionSimilarity'
]

# Subset used in causal discovery (Step 2) - for causal analysis only
CAUSAL_FEATURES = [
    'SignatureMatchesPerDay', 'Similarity', 'SCAS', 'SignatureID',
    'SignatureIDSimilarity', 'Proto', 'AlertCount', 'IntPort', 
    'ExtPort', 'ProtoSimilarity'
]

# ==================== LOAD MODELS AND DATA ====================
print("\nLoading models and data...")

# Load LSTM model architecture (must match Step 1)
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size=128, num_layers=2, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0
        )
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 2)
        )
    
    def forward(self, x):
        out, _ = self.lstm(x)
        out = out[:, -1, :]
        return self.fc(out)

# Load trained model
device = torch.device('cpu')
input_size = len(FEATURE_NAMES)
model = LSTMClassifier(input_size=input_size).to(device)

if Path(LSTM_MODEL_PATH).exists():
    model.load_state_dict(torch.load(LSTM_MODEL_PATH, map_location=device))
    model.eval()
    print(f"✓ Loaded LSTM model from {LSTM_MODEL_PATH}")
else:
    print(f"⚠ Warning: {LSTM_MODEL_PATH} not found. Using untrained model.")

# Load scaler
if Path(SCALER_PATH).exists():
    scaler = joblib.load(SCALER_PATH)
    print(f"✓ Loaded scaler from {SCALER_PATH}")
else:
    print(f"⚠ Warning: {SCALER_PATH} not found. Scaling may be incorrect.")
    scaler = None

# Load causal graph
if Path(CAUSAL_GRAPH_PATH).exists():
    causal_graph = nx.read_gpickle(CAUSAL_GRAPH_PATH)
    print(f"✓ Loaded causal graph from {CAUSAL_GRAPH_PATH}")
    print(f"  Nodes: {causal_graph.number_of_nodes()}, Edges: {causal_graph.number_of_edges()}")
else:
    print(f"⚠ Warning: {CAUSAL_GRAPH_PATH} not found. Creating empty graph.")
    causal_graph = nx.DiGraph()

# Load sample data for testing
if Path(DATA_PATH).exists():
    df = pd.read_csv(DATA_PATH)
    print(f"✓ Loaded data from {DATA_PATH}: {df.shape}")
    
    # Check if we need to load the full original dataset instead
    if len(df.columns) < len(FEATURE_NAMES):
        print(f"  ⚠ Causal discovery data has only {len(df.columns)} columns")
        print(f"  ⚠ LSTM needs {len(FEATURE_NAMES)} features")
        print(f"  Loading original dataset instead...")
        
        # Try to load original dataset
        original_data_path = 'dataset-labeled-anon-ip.csv'
        if Path(original_data_path).exists():
            df_full = pd.read_csv(original_data_path)
            print(f"  ✓ Loaded full dataset: {df_full.shape}")
            
            # Drop non-feature columns
            drop_cols = ['SignatureText', 'Timestamp', 'ExtIP', 'IntIP']
            for col in drop_cols:
                if col in df_full.columns:
                    df_full = df_full.drop(columns=[col])
            
            # Use full dataset
            df = df_full
            print(f"  ✓ Using full dataset with {len(df.columns)} features")
        else:
            print(f"  ✗ Could not find {original_data_path}")
            print(f"  Will create synthetic examples instead")
            df = None
else:
    print(f"⚠ Warning: {DATA_PATH} not found.")
    df = None

# ==================== XAI COMPONENT ====================
print("\n" + "="*70)
print("XAI COMPONENT: Feature Importance")
print("="*70)

def compute_deeplift_attribution(model, alert_tensor, baseline=None):
    """
    Compute DeepLIFT attributions for a single alert
    
    Args:
        model: LSTM model
        alert_tensor: torch.Tensor of shape (1, 1, num_features)
        baseline: Baseline for comparison (default: zeros)
    
    Returns:
        numpy array of feature attributions
    """
    try:
        from captum.attr import DeepLift
        
        if baseline is None:
            baseline = torch.zeros_like(alert_tensor)
        
        dl = DeepLift(model)
        
        # Get model prediction
        with torch.no_grad():
            output = model(alert_tensor)
            pred_class = output.argmax(dim=1).item()
        
        # Compute attributions for predicted class
        attributions = dl.attribute(alert_tensor, baselines=baseline, target=pred_class)
        
        return attributions.squeeze().detach().cpu().numpy()
    
    except ImportError:
        print("⚠ Captum not installed. Using gradient-based approximation.")
        return compute_gradient_attribution(model, alert_tensor)

def compute_gradient_attribution(model, alert_tensor):
    """
    Fallback: Simple gradient-based attribution
    """
    alert_tensor.requires_grad = True
    output = model(alert_tensor)
    pred_class = output.argmax(dim=1).item()
    
    # Compute gradient
    model.zero_grad()
    output[0, pred_class].backward()
    
    gradients = alert_tensor.grad.squeeze().detach().cpu().numpy()
    values = alert_tensor.squeeze().detach().cpu().numpy()
    
    # Attribution = gradient * input
    attributions = gradients * values
    
    return attributions

def generate_xai_explanation(model, alert_features, feature_names, top_k=5):
    """
    Generate XAI explanation for an alert
    
    Args:
        model: LSTM model
        alert_features: numpy array of feature values
        feature_names: list of feature names
        top_k: number of top features to return
    
    Returns:
        Dictionary with XAI results
    """
    # Prepare input tensor
    alert_tensor = torch.tensor(alert_features, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
    
    # Get prediction
    model.eval()
    with torch.no_grad():
        output = model(alert_tensor)
        probs = torch.softmax(output, dim=1)[0]
        pred_class = output.argmax(dim=1).item()
        confidence = probs[pred_class].item()
    
    # Compute attributions
    attributions = compute_deeplift_attribution(model, alert_tensor)
    
    # Combine features with their attributions
    feature_importance = []
    for i, (name, attr, value) in enumerate(zip(feature_names, attributions, alert_features)):
        feature_importance.append({
            'feature': name,
            'importance': float(attr),
            'value': float(value),
            'abs_importance': float(abs(attr))
        })
    
    # Sort by absolute importance
    feature_importance.sort(key=lambda x: x['abs_importance'], reverse=True)
    
    return {
        'prediction': 'Important' if pred_class == 1 else 'Irrelevant',
        'confidence': confidence,
        'pred_class': pred_class,
        'top_features': feature_importance[:top_k],
        'all_features': feature_importance
    }

# ==================== CAUSAL COMPONENT ====================
print("\n" + "="*70)
print("CAUSAL COMPONENT: Root Cause Analysis")
print("="*70)

def find_root_causes(graph, target_feature):
    """Find all root causes (ancestors with no incoming edges) of target"""
    if target_feature not in graph:
        return []
    
    ancestors = nx.ancestors(graph, target_feature)
    root_causes = [node for node in ancestors if graph.in_degree(node) == 0]
    
    return root_causes

def find_causal_path(graph, source, target):
    """Find shortest causal path from source to target"""
    if source not in graph or target not in graph:
        return None
    
    try:
        path = nx.shortest_path(graph, source, target)
        return path
    except nx.NetworkXNoPath:
        return None

def get_direct_causes(graph, feature):
    """Get direct causes (parents) of a feature"""
    if feature not in graph:
        return []
    return list(graph.predecessors(feature))

def analyze_causal_chain(graph, target_feature, alert_data, all_feature_names):
    """
    Analyze causal chains leading to target feature
    
    Args:
        graph: NetworkX causal graph (contains CAUSAL_FEATURES only)
        target_feature: Feature to analyze (e.g., 'SCAS')
        alert_data: Dictionary of ALL feature values (42 features)
        all_feature_names: List of all feature names (42 features)
    
    Returns:
        Dictionary with causal analysis
    """
    # Only analyze if feature is in the causal graph
    if target_feature not in graph:
        return {
            'target': target_feature,
            'in_graph': False,
            'root_causes': [],
            'causal_paths': [],
            'direct_causes': []
        }
    
    # Find root causes
    root_causes = find_root_causes(graph, target_feature)
    
    # Find causal paths from each root cause
    causal_paths = []
    for root in root_causes:
        path = find_causal_path(graph, root, target_feature)
        if path:
            # Add feature values to path
            path_with_values = []
            for feature in path:
                value = alert_data.get(feature, 'N/A')
                path_with_values.append({
                    'feature': feature,
                    'value': value
                })
            
            causal_paths.append({
                'root': root,
                'path': path,
                'path_with_values': path_with_values,
                'length': len(path)
            })
    
    # Get direct causes
    direct_causes = get_direct_causes(graph, target_feature)
    direct_causes_with_values = []
    for cause in direct_causes:
        direct_causes_with_values.append({
            'feature': cause,
            'value': alert_data.get(cause, 'N/A')
        })
    
    return {
        'target': target_feature,
        'in_graph': True,
        'root_causes': root_causes,
        'causal_paths': causal_paths,
        'direct_causes': direct_causes_with_values,
        'num_paths': len(causal_paths)
    }

# ==================== HYBRID EXPLAINER ====================
print("\n" + "="*70)
print("HYBRID EXPLAINER: Combining XAI + Causal")
print("="*70)

class HybridExplainer:
    """
    Combines XAI and Causal Analysis for comprehensive NIDS alert explanations
    """
    
    def __init__(self, model, causal_graph, feature_names, scaler=None):
        self.model = model
        self.graph = causal_graph
        self.feature_names = feature_names
        self.scaler = scaler
    
    def explain(self, alert_data, alert_id=None):
        """
        Generate hybrid explanation for a single alert
        
        Args:
            alert_data: numpy array (42 features) or dict of feature values
            alert_id: Optional alert identifier
        
        Returns:
            HybridExplanation object
        """
        # Convert dict to array if needed
        if isinstance(alert_data, dict):
            # Ensure all 42 features are present
            alert_features = np.array([alert_data.get(f, 0) for f in self.feature_names])
            alert_dict = alert_data
        else:
            alert_features = alert_data
            # Create dict with all features
            alert_dict = {f: v for f, v in zip(self.feature_names, alert_features)}
        
        # Step 1: Get LSTM prediction and XAI explanation (uses all 42 features)
        xai_results = generate_xai_explanation(
            self.model, 
            alert_features, 
            self.feature_names,
            top_k=5
        )
        
        # Step 2: Analyze causal chains for top XAI features
        # Only analyze features that exist in causal graph
        causal_analyses = []
        for feat_info in xai_results['top_features']:
            feature_name = feat_info['feature']
            
            # Only do causal analysis if feature is in causal graph
            if feature_name in self.graph:
                causal_analysis = analyze_causal_chain(
                    self.graph,
                    feature_name,
                    alert_dict,
                    self.feature_names
                )
                causal_analyses.append(causal_analysis)
        
        # Step 3: Analyze causal chain to Label (outcome)
        label_causal = None
        if 'Label' in self.graph:
            label_causal = analyze_causal_chain(
                self.graph,
                'Label',
                alert_dict,
                self.feature_names
            )
        
        # Step 4: Generate recommendations
        recommendations = self._generate_recommendations(
            xai_results,
            causal_analyses,
            alert_dict
        )
        
        # Create explanation object
        explanation = HybridExplanation(
            alert_id=alert_id,
            alert_data=alert_dict,
            xai_results=xai_results,
            causal_analyses=causal_analyses,
            label_causal=label_causal,
            recommendations=recommendations
        )
        
        return explanation
    
    def _generate_recommendations(self, xai_results, causal_analyses, alert_data):
        """Generate actionable recommendations based on XAI and causal analysis"""
        recommendations = {
            'severity': 'MEDIUM',
            'immediate_actions': [],
            'investigation_steps': [],
            'root_cause_mitigation': []
        }
        
        # Determine severity
        confidence = xai_results['confidence']
        if confidence > 0.9:
            recommendations['severity'] = 'HIGH'
        elif confidence > 0.7:
            recommendations['severity'] = 'MEDIUM'
        else:
            recommendations['severity'] = 'LOW'
        
        # Analyze top features for specific actions
        top_features = xai_results['top_features']
        
        for feat in top_features:
            feature = feat['feature']
            value = feat['value']
            
            # SSH port targeted
            if feature == 'IntPort' and value == 22:
                recommendations['immediate_actions'].append(
                    "SSH port (22) targeted - Enable SSH hardening (key-only auth, fail2ban)"
                )
            
            # Outlier detection
            if feature == 'SCAS' and value == 1:
                recommendations['immediate_actions'].append(
                    "Outlier detected (SCAS=1) - Novel attack pattern, requires manual investigation"
                )
                recommendations['investigation_steps'].append(
                    "Compare with historical alerts - this pattern hasn't been seen before"
                )
            
            # High signature matching
            if feature == 'SignatureMatchesPerDay' and value > 50000:
                recommendations['immediate_actions'].append(
                    "Extremely high signature match frequency - Potential coordinated campaign"
                )
                recommendations['investigation_steps'].append(
                    f"Search for other hosts with SignatureMatchesPerDay > 50K in last 24h"
                )
                recommendations['root_cause_mitigation'].append(
                    f"Review signature rules - {int(value)} matches/day suggests tuning needed"
                )
            
            # High similarity to known attacks
            if feature == 'SignatureIDSimilarity' and value > 0.9:
                recommendations['investigation_steps'].append(
                    f"Pattern matches known attacks (similarity={value:.2f}) - Check threat intelligence"
                )
            
            # Alert flood
            if feature == 'AlertCount' and value > 100:
                recommendations['immediate_actions'].append(
                    f"High volume attack ({int(value)} alerts) - Consider rate limiting/blocking"
                )
        
        # Add causal-based recommendations
        for causal in causal_analyses:
            if causal['root_causes']:
                root_cause_str = ', '.join(causal['root_causes'])
                recommendations['root_cause_mitigation'].append(
                    f"Root causes of {causal['target']}: {root_cause_str}"
                )
        
        # Default action if nothing specific
        if not recommendations['immediate_actions']:
            recommendations['immediate_actions'].append(
                "Review alert details and correlate with other security events"
            )
        
        return recommendations

class HybridExplanation:
    """
    Container for hybrid explanation results
    """
    
    def __init__(self, alert_id, alert_data, xai_results, causal_analyses, 
                 label_causal, recommendations):
        self.alert_id = alert_id
        self.alert_data = alert_data
        self.xai = xai_results
        self.causal = causal_analyses
        self.label_causal = label_causal
        self.recommendations = recommendations
    
    def to_dict(self):
        """Convert to dictionary"""
        return {
            'alert_id': self.alert_id,
            'alert_data': self.alert_data,
            'xai_analysis': self.xai,
            'causal_analysis': self.causal,
            'label_causal': self.label_causal,
            'recommendations': self.recommendations
        }
    
    def to_json(self, filepath=None):
        """Export to JSON"""
        data = self.to_dict()
        if filepath:
            with open(filepath, 'w') as f:
                json.dump(data, f, indent=2)
            return filepath
        else:
            return json.dumps(data, indent=2)
    
    def to_text(self):
        """Generate natural language explanation"""
        lines = []
        lines.append("="*70)
        lines.append(f"HYBRID EXPLANATION - Alert #{self.alert_id or 'Unknown'}")
        lines.append("="*70)
        lines.append("")
        
        # Classification
        lines.append("🎯 CLASSIFICATION")
        lines.append(f"  Prediction: {self.xai['prediction']}")
        lines.append(f"  Confidence: {self.xai['confidence']:.1%}")
        lines.append(f"  Severity: {self.recommendations['severity']}")
        lines.append("")
        
        # XAI Analysis
        lines.append("📊 XAI ANALYSIS: What triggered this alert?")
        lines.append("")
        for i, feat in enumerate(self.xai['top_features'], 1):
            lines.append(f"  {i}. {feat['feature']}")
            lines.append(f"     Value: {feat['value']:.4f}")
            lines.append(f"     Importance: {feat['importance']:.4f}")
            lines.append("")
        
        # Causal Analysis
        lines.append("🔍 CAUSAL ANALYSIS: Why/How did this happen?")
        lines.append("")
        
        for causal in self.causal:
            if not causal['in_graph']:
                continue
            
            lines.append(f"  Feature: {causal['target']}")
            
            if causal['root_causes']:
                lines.append(f"  Root Causes: {', '.join(causal['root_causes'])}")
            
            if causal['causal_paths']:
                lines.append(f"  Causal Chains:")
                for path_info in causal['causal_paths'][:2]:  # Show top 2 paths
                    path_str = ' → '.join(path_info['path'])
                    lines.append(f"    • {path_str}")
            
            lines.append("")
        
        # Label causal analysis
        if self.label_causal and self.label_causal['in_graph']:
            lines.append("  Direct Causes of Alert Classification:")
            for cause in self.label_causal['direct_causes']:
                lines.append(f"    • {cause['feature']} = {cause['value']:.4f}")
            lines.append("")
        
        # Recommendations
        lines.append("✅ RECOMMENDED ACTIONS")
        lines.append("")
        
        if self.recommendations['immediate_actions']:
            lines.append("  Immediate Actions:")
            for action in self.recommendations['immediate_actions']:
                lines.append(f"    • {action}")
            lines.append("")
        
        if self.recommendations['investigation_steps']:
            lines.append("  Investigation Steps:")
            for step in self.recommendations['investigation_steps']:
                lines.append(f"    • {step}")
            lines.append("")
        
        if self.recommendations['root_cause_mitigation']:
            lines.append("  Root Cause Mitigation:")
            for mitigation in self.recommendations['root_cause_mitigation']:
                lines.append(f"    • {mitigation}")
            lines.append("")
        
        lines.append("="*70)
        
        return '\n'.join(lines)
    
    def visualize(self, save_path=None):
        """Create visual explanation"""
        fig = plt.figure(figsize=(16, 10))
        gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3)
        
        # Title
        fig.suptitle(f'Hybrid Explanation - Alert #{self.alert_id or "Unknown"}', 
                     fontsize=16, fontweight='bold')
        
        # 1. XAI Feature Importance (top subplot)
        ax1 = fig.add_subplot(gs[0, :])
        top_feats = self.xai['top_features']
        features = [f['feature'] for f in top_feats]
        importances = [f['importance'] for f in top_feats]
        colors = ['red' if imp > 0 else 'blue' for imp in importances]
        
        ax1.barh(features, importances, color=colors, alpha=0.7)
        ax1.set_xlabel('Importance Score')
        ax1.set_title('XAI Analysis: Top Feature Importance', fontweight='bold')
        ax1.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
        ax1.grid(axis='x', alpha=0.3)
        
        # 2. Prediction info (middle left)
        ax2 = fig.add_subplot(gs[1, 0])
        ax2.axis('off')
        
        pred_text = f"""
PREDICTION
{'='*30}
Class: {self.xai['prediction']}
Confidence: {self.xai['confidence']:.1%}
Severity: {self.recommendations['severity']}
"""
        ax2.text(0.1, 0.5, pred_text, fontsize=11, family='monospace',
                verticalalignment='center')
        
        # 3. Causal paths (middle right)
        ax3 = fig.add_subplot(gs[1, 1])
        ax3.axis('off')
        
        causal_text = "CAUSAL CHAINS\n" + "="*30 + "\n"
        for causal in self.causal[:3]:
            if causal['causal_paths']:
                for path_info in causal['causal_paths'][:1]:
                    path_str = ' → '.join(path_info['path'][:4])
                    if len(path_info['path']) > 4:
                        path_str += ' ...'
                    causal_text += f"• {path_str}\n"
        
        ax3.text(0.1, 0.5, causal_text, fontsize=10, family='monospace',
                verticalalignment='center')
        
        # 4. Recommendations (bottom)
        ax4 = fig.add_subplot(gs[2, :])
        ax4.axis('off')
        
        rec_text = "RECOMMENDED ACTIONS\n" + "="*50 + "\n"
        if self.recommendations['immediate_actions']:
            rec_text += "\nImmediate:\n"
            for action in self.recommendations['immediate_actions'][:3]:
                rec_text += f"  • {action[:60]}...\n" if len(action) > 60 else f"  • {action}\n"
        
        ax4.text(0.05, 0.5, rec_text, fontsize=10, family='monospace',
                verticalalignment='center')
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"Saved visualization to: {save_path}")
        
        return fig

# ==================== DEMO EXAMPLES ====================
print("\n" + "="*70)
print("GENERATING DEMO EXPLANATIONS")
print("="*70)

# Create explainer instance
explainer = HybridExplainer(
    model=model,
    causal_graph=causal_graph,
    feature_names=FEATURE_NAMES,
    scaler=scaler
)

# Example 1: Use data from file if available
if df is not None and len(df) > 0:
    print("\nGenerating explanations for sample alerts...")
    
    # Ensure we have all required features
    available_features = [f for f in FEATURE_NAMES if f in df.columns]
    print(f"Available features: {len(available_features)}/{len(FEATURE_NAMES)}")
    
    if len(available_features) < len(FEATURE_NAMES):
        print("⚠ Warning: Not all features available. Creating synthetic example...")
        df = None
    else:
        # Select a few interesting samples
        sample_indices = [0, 100, 500]
        
        for idx in sample_indices:
            if idx >= len(df):
                continue
            
            print(f"\n{'='*70}")
            print(f"EXAMPLE ALERT #{idx}")
            print(f"{'='*70}")
            
            # Get alert data (all 42 features)
            alert_row = df.iloc[idx]
            alert_features = alert_row[FEATURE_NAMES].values
            
            # Generate explanation
            explanation = explainer.explain(alert_features, alert_id=idx)
            
            # Print text explanation
            print(explanation.to_text())
            
            # Save JSON
            json_path = f'hybrid_explanation_{idx}.json'
            explanation.to_json(json_path)
            print(f"\nSaved JSON to: {json_path}")
            
            # Save visualization
            viz_path = f'hybrid_explanation_{idx}.png'
            explanation.visualize(save_path=viz_path)
            plt.close()

if df is None:
    print("\n⚠ No data available for demo. Creating synthetic example...")
    
    # Create synthetic alert with ALL 42 features
    synthetic_alert = {f: 0.0 for f in FEATURE_NAMES}
    
    # Set specific values for key features (SSH brute force scenario)
    synthetic_alert.update({
        'SignatureMatchesPerDay': 5000,
        'Similarity': 0.88,
        'SCAS': 1.0,
        'SignatureID': 2001219,
        'SignatureIDSimilarity': 0.92,
        'Proto': 6,
        'AlertCount': 250,
        'IntPort': 22,
        'ExtPort': 54321,
        'ProtoSimilarity': 1.0
    })
    
    explanation = explainer.explain(synthetic_alert, alert_id="DEMO-001")
    print(explanation.to_text())
    explanation.to_json('hybrid_explanation_demo.json')
    explanation.visualize(save_path='hybrid_explanation_demo.png')
    plt.close()

print("\n" + "="*70)
print("STEP 4 COMPLETE!")
print("="*70)
print("\nGenerated files:")
print("  - hybrid_explanation_*.json (structured data)")
print("  - hybrid_explanation_*.png (visualizations)")
print("\nNext: Step 5 - Evaluation (user study + quantitative metrics)")

STEP 4: HYBRID EXPLANATION GENERATOR

Loading models and data...
✓ Loaded LSTM model from ../step1_lstm_xai/best_lstm.pt
✓ Loaded scaler from ../step1_lstm_xai/scaler.joblib
✓ Loaded causal graph from ../step2_causal_discovery/causal_graph.gpickle
  Nodes: 11, Edges: 20
✓ Loaded data from ../step2_causal_discovery/causal_discovery_data.csv: (10000, 11)
  ⚠ Causal discovery data has only 11 columns
  ⚠ LSTM needs 42 features
  Loading original dataset instead...
  ✗ Could not find dataset-labeled-anon-ip.csv
  Will create synthetic examples instead

XAI COMPONENT: Feature Importance

CAUSAL COMPONENT: Root Cause Analysis

HYBRID EXPLAINER: Combining XAI + Causal

GENERATING DEMO EXPLANATIONS

⚠ No data available for demo. Creating synthetic example...
HYBRID EXPLANATION - Alert #DEMO-001

🎯 CLASSIFICATION
  Prediction: Irrelevant
  Confidence: 100.0%
  Severity: HIGH

📊 XAI ANALYSIS: What triggered this alert?

  1. SignatureID
     Value: 2001219.0000
     Importance: 0.0000

  2. Sign