In [7]:
# Step 4: Hybrid Explanation Generator - FIXED VERSION
# Combines XAI (what) + Causal (why/how) for actionable NIDS alert explanations
# FIXES:
# 1. Feature filtering to exclude missing/NA values (-1.0)
# 2. Improved severity assessment logic for imbalanced dataset
# 3. Better handling of Irrelevant predictions

import os
import warnings
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from pathlib import Path
import joblib
from typing import Dict, List, Tuple, Optional
import json

warnings.filterwarnings('ignore')

print("="*70)
print("STEP 4: HYBRID EXPLANATION GENERATOR (FIXED)")
print("="*70)

# ==================== CONFIGURATION ====================

LSTM_MODEL_PATH = '../step1_lstm_xai/best_lstm.pt'
SCALER_PATH = '../step1_lstm_xai/scaler.joblib'
CAUSAL_GRAPH_PATH = '../step2_causal_discovery/causal_graph.gpickle'
DATA_PATH = '../step2_causal_discovery/causal_discovery_data.csv'

# Feature names - MUST match the 42 features used in Step 1 LSTM training
FEATURE_NAMES = [
    'SignatureID', 'SignatureMatchesPerDay', 'AlertCount', 'Proto',
    'ExtPort', 'IntPort', 'Similarity', 'SCAS', 'AppProtoSimilarity',
    'DnsRrnameSimilarity', 'DnsRrtypeSimilarity', 'EmailFromSimilarity',
    'EmailStatusSimilarity', 'EmailToSimilarity', 'ExtIPSimilarity',
    'ExtPortSimilarity', 'HttpContentTypeSimilarity', 'HttpHostnameSimilarity',
    'HttpMethodSimilarity', 'HttpProtocolSimilarity', 'HttpRequestBodySimilarity',
    'HttpResponseBodySimilarity', 'HttpStatusSimilarity', 'HttpUrlSimilarity',
    'HttpUserAgentSimilarity', 'IntIPSimilarity', 'IntPortSimilarity',
    'ProtoSimilarity', 'SignatureIDSimilarity', 'SmtpHeloSimilarity',
    'SmtpMailFromSimilarity', 'SmtpRcptToSimilarity', 'SshClientProtoSimilarity',
    'SshClientSoftwareSimilarity', 'SshServerProtoSimilarity',
    'SshServerSoftwareSimilarity', 'TlsFingerprintSimilarity',
    'TlsIssuerDnSimilarity', 'TlsJa3hashSimilarity', 'TlsSniSimilarity',
    'TlsSubjectSimilarity', 'TlsVersionSimilarity'
]

# Subset used in causal discovery (Step 2) - SOC analyst-determined important features
CAUSAL_FEATURES = [
    'SignatureMatchesPerDay', 'Similarity', 'SCAS', 'SignatureID',
    'SignatureIDSimilarity', 'Proto', 'AlertCount', 'IntPort', 
    'ExtPort', 'ProtoSimilarity'
]

# Missing value indicator
MISSING_VALUE_INDICATOR = -1.0

# ==================== LOAD MODELS AND DATA ====================
print("\nLoading models and data...")

# Load LSTM model architecture
class LSTMClassifier(nn.Module):
    def __init__(self, input_size, hidden_size=128, num_layers=2, dropout=0.2):
        super().__init__()
        self.lstm = nn.LSTM(
            input_size=input_size,
            hidden_size=hidden_size,
            num_layers=num_layers,
            batch_first=True,
            dropout=dropout if num_layers > 1 else 0.0
        )
        self.fc = nn.Sequential(
            nn.Linear(hidden_size, 64),
            nn.ReLU(),
            nn.Dropout(dropout),
            nn.Linear(64, 2)
        )
    
    def forward(self, x):
        out, _ = self.lstm(x)
        out = out[:, -1, :]
        return self.fc(out)

# Load trained model
device = torch.device('cpu')
input_size = len(FEATURE_NAMES)
model = LSTMClassifier(input_size=input_size).to(device)

if Path(LSTM_MODEL_PATH).exists():
    model.load_state_dict(torch.load(LSTM_MODEL_PATH, map_location=device))
    model.eval()
    print(f"✓ Loaded LSTM model from {LSTM_MODEL_PATH}")
else:
    print(f"⚠ Warning: {LSTM_MODEL_PATH} not found. Using untrained model.")

# Load scaler
if Path(SCALER_PATH).exists():
    scaler = joblib.load(SCALER_PATH)
    print(f"✓ Loaded scaler from {SCALER_PATH}")
else:
    print(f"⚠ Warning: {SCALER_PATH} not found. Scaling may be incorrect.")
    scaler = None

# Load causal graph
if Path(CAUSAL_GRAPH_PATH).exists():
    causal_graph = nx.read_gpickle(CAUSAL_GRAPH_PATH)
    print(f"✓ Loaded causal graph from {CAUSAL_GRAPH_PATH}")
    print(f"  Nodes: {causal_graph.number_of_nodes()}, Edges: {causal_graph.number_of_edges()}")
    print(f"  Note: Causal graph limited to 10 SOC analyst-identified features")
else:
    print(f"⚠ Warning: {CAUSAL_GRAPH_PATH} not found. Creating empty graph.")
    causal_graph = nx.DiGraph()

# Load sample data for testing
if Path(DATA_PATH).exists():
    df = pd.read_csv(DATA_PATH)
    print(f"✓ Loaded data from {DATA_PATH}: {df.shape}")
    
    if len(df.columns) < len(FEATURE_NAMES):
        print(f"  ⚠ Causal discovery data has only {len(df.columns)} columns")
        print(f"  ⚠ LSTM needs {len(FEATURE_NAMES)} features")
        print(f"  Loading original dataset instead...")
        
        original_data_path = '../dataset-labeled-anon-ip.csv'
        if Path(original_data_path).exists():
            df_full = pd.read_csv(original_data_path)
            print(f"  ✓ Loaded full dataset: {df_full.shape}")
            
            # Drop non-feature columns
            drop_cols = ['SignatureText', 'Timestamp', 'ExtIP', 'IntIP']
            for col in drop_cols:
                if col in df_full.columns:
                    df_full = df_full.drop(columns=[col])
            
            df = df_full
            print(f"  ✓ Using full dataset with {len(df.columns)} features")
        else:
            print(f"  ✗ Could not find {original_data_path}")
            df = None
else:
    print(f"⚠ Warning: {DATA_PATH} not found.")
    df = None

# ==================== XAI COMPONENT ====================
print("\n" + "="*70)
print("XAI COMPONENT: Feature Importance (with Missing Value Filtering)")
print("="*70)

def compute_deeplift_attribution(model, alert_tensor, baseline=None):
    """
    Compute DeepLIFT attributions for a single alert
    
    Args:
        model: LSTM model
        alert_tensor: torch.Tensor of shape (1, 1, num_features) - MUST BE SCALED
        baseline: Baseline for comparison (default: zeros)
    
    Returns:
        numpy array of feature attributions
    """
    try:
        from captum.attr import DeepLift
        
        if baseline is None:
            baseline = torch.zeros_like(alert_tensor)
        
        dl = DeepLift(model)
        
        # Get model prediction
        with torch.no_grad():
            output = model(alert_tensor)
            pred_class = output.argmax(dim=1).item()
        
        # Compute attributions for predicted class
        attributions = dl.attribute(alert_tensor, baselines=baseline, target=pred_class)
        
        return attributions.squeeze().detach().cpu().numpy()
    
    except ImportError:
        print("⚠ Captum not installed. Using gradient-based approximation.")
        return compute_gradient_attribution(model, alert_tensor)
    except Exception as e:
        print(f"⚠ DeepLIFT failed: {e}. Using gradient-based approximation.")
        return compute_gradient_attribution(model, alert_tensor)

def compute_gradient_attribution(model, alert_tensor):
    """
    Fallback: Simple gradient-based attribution
    """
    alert_tensor.requires_grad = True
    output = model(alert_tensor)
    pred_class = output.argmax(dim=1).item()
    
    # Compute gradient
    model.zero_grad()
    output[0, pred_class].backward()
    
    gradients = alert_tensor.grad.squeeze().detach().cpu().numpy()
    values = alert_tensor.squeeze().detach().cpu().numpy()
    
    # Attribution = gradient * input
    attributions = gradients * values
    
    return attributions

def is_missing_value(value, threshold=MISSING_VALUE_INDICATOR):
    """
    Check if a feature value represents missing/NA data
    In this dataset, -1.0 indicates 'not applicable' or 'missing'
    """
    return abs(value - threshold) < 1e-6

def generate_xai_explanation(model, alert_features, feature_names, top_k=5, scaler=None):
    """
    Generate XAI explanation for an alert
    
    FIXED: Now filters out missing values (-1.0) from top features
    
    Args:
        model: LSTM model
        alert_features: numpy array of feature values (UNSCALED from dataset)
        feature_names: list of feature names
        top_k: number of top features to return
        scaler: MinMaxScaler used during training (REQUIRED for correct results!)
    
    Returns:
        Dictionary with XAI results
    """
    # Scale the features before passing to model
    if scaler is not None:
        alert_features_scaled = scaler.transform(alert_features.reshape(1, -1)).flatten()
        print(f"  [DEBUG] Scaled features range: [{alert_features_scaled.min():.4f}, {alert_features_scaled.max():.4f}]")
    else:
        print("  ⚠ WARNING: No scaler provided. Results may be incorrect!")
        alert_features_scaled = alert_features
    
    # Prepare input tensor (use SCALED features)
    alert_tensor = torch.tensor(alert_features_scaled, dtype=torch.float32).unsqueeze(0).unsqueeze(0).to(device)
    
    # Get prediction
    model.eval()
    with torch.no_grad():
        output = model(alert_tensor)
        probs = torch.softmax(output, dim=1)[0]
        pred_class = output.argmax(dim=1).item()
        confidence = probs[pred_class].item()
    
    print(f"  [DEBUG] Prediction: {pred_class}, Confidence: {confidence:.4f}")
    
    # Compute attributions
    attributions = compute_deeplift_attribution(model, alert_tensor)
    
    print(f"  [DEBUG] Attribution range: [{attributions.min():.4f}, {attributions.max():.4f}]")
    print(f"  [DEBUG] Non-zero attributions: {np.count_nonzero(attributions)}/{len(attributions)}")
    
    # Combine features with their attributions
    feature_importance = []
    for i, (name, attr, value) in enumerate(zip(feature_names, attributions, alert_features)):
        feature_importance.append({
            'feature': name,
            'importance': float(attr),
            'value': float(value),  # Use ORIGINAL unscaled value for interpretability
            'abs_importance': float(abs(attr)),
            'is_missing': is_missing_value(value)
        })
    
    # Sort by absolute importance
    feature_importance.sort(key=lambda x: x['abs_importance'], reverse=True)
    
    # FIX 1: Filter out missing values from top features
    feature_importance_present = [
        feat for feat in feature_importance 
        if not feat['is_missing']
    ]
    
    missing_count = len([f for f in feature_importance if f['is_missing']])
    present_count = len(feature_importance_present)
    
    print(f"  [DEBUG] Features: {present_count} present, {missing_count} missing (filtered)")
    
    # If not enough present features, fall back to all features
    if len(feature_importance_present) < top_k:
        print(f"  ⚠ Warning: Only {len(feature_importance_present)} non-missing features available")
        top_features = feature_importance[:top_k]  # Use all including missing
    else:
        top_features = feature_importance_present[:top_k]  # Use only present features
    
    return {
        'prediction': 'Important' if pred_class == 1 else 'Irrelevant',
        'confidence': confidence,
        'pred_class': pred_class,
        'top_features': top_features,
        'all_features': feature_importance,
        'num_missing_features': missing_count,
        'num_present_features': present_count
    }

# ==================== CAUSAL COMPONENT ====================
print("\n" + "="*70)
print("CAUSAL COMPONENT: Root Cause Analysis")
print("="*70)

def find_root_causes(graph, target_feature):
    """Find all root causes (ancestors with no incoming edges) of target"""
    if target_feature not in graph:
        return []
    
    ancestors = nx.ancestors(graph, target_feature)
    root_causes = [node for node in ancestors if graph.in_degree(node) == 0]
    
    return root_causes

def find_causal_path(graph, source, target):
    """Find shortest causal path from source to target"""
    if source not in graph or target not in graph:
        return None
    
    try:
        path = nx.shortest_path(graph, source, target)
        return path
    except nx.NetworkXNoPath:
        return None

def get_direct_causes(graph, feature):
    """Get direct causes (parents) of a feature"""
    if feature not in graph:
        return []
    return list(graph.predecessors(feature))

def analyze_causal_chain(graph, target_feature, alert_data, all_feature_names):
    """
    Analyze causal chains leading to target feature
    
    Args:
        graph: NetworkX causal graph (contains CAUSAL_FEATURES only)
        target_feature: Feature to analyze (e.g., 'SCAS')
        alert_data: Dictionary of ALL feature values (42 features)
        all_feature_names: List of all feature names (42 features)
    
    Returns:
        Dictionary with causal analysis
    """
    # Only analyze if feature is in the causal graph
    if target_feature not in graph:
        return {
            'target': target_feature,
            'in_graph': False,
            'root_causes': [],
            'causal_paths': [],
            'direct_causes': [],
            'reason': 'Feature not in SOC analyst-determined causal graph'
        }
    
    # Find root causes
    root_causes = find_root_causes(graph, target_feature)
    
    # Find causal paths from each root cause
    causal_paths = []
    for root in root_causes:
        path = find_causal_path(graph, root, target_feature)
        if path:
            # Add feature values to path
            path_with_values = []
            for feature in path:
                value = alert_data.get(feature, 'N/A')
                path_with_values.append({
                    'feature': feature,
                    'value': value
                })
            
            causal_paths.append({
                'root': root,
                'path': path,
                'path_with_values': path_with_values,
                'length': len(path)
            })
    
    # Get direct causes
    direct_causes = get_direct_causes(graph, target_feature)
    direct_causes_with_values = []
    for cause in direct_causes:
        direct_causes_with_values.append({
            'feature': cause,
            'value': alert_data.get(cause, 'N/A')
        })
    
    return {
        'target': target_feature,
        'in_graph': True,
        'root_causes': root_causes,
        'causal_paths': causal_paths,
        'direct_causes': direct_causes_with_values,
        'num_paths': len(causal_paths)
    }

# ==================== HYBRID EXPLAINER ====================
print("\n" + "="*70)
print("HYBRID EXPLAINER: Combining XAI + Causal")
print("="*70)

class HybridExplainer:
    """
    Combines XAI and Causal Analysis for comprehensive NIDS alert explanations
    """
    
    def __init__(self, model, causal_graph, feature_names, scaler=None):
        self.model = model
        self.graph = causal_graph
        self.feature_names = feature_names
        self.scaler = scaler
    
    def explain(self, alert_data, alert_id=None):
        """
        Generate hybrid explanation for a single alert
        
        Args:
            alert_data: numpy array (42 features) or dict of feature values
            alert_id: Optional alert identifier
        
        Returns:
            HybridExplanation object
        """
        # Convert dict to array if needed
        if isinstance(alert_data, dict):
            alert_features = np.array([alert_data.get(f, 0) for f in self.feature_names])
            alert_dict = alert_data
        else:
            alert_features = alert_data
            alert_dict = {f: v for f, v in zip(self.feature_names, alert_features)}
        
        # Step 1: Get LSTM prediction and XAI explanation (uses all 42 features)
        xai_results = generate_xai_explanation(
            self.model, 
            alert_features, 
            self.feature_names,
            top_k=5,
            scaler=self.scaler
        )
        
        # Step 2: Analyze causal chains for top XAI features
        # Only analyze features that exist in causal graph
        causal_analyses = []
        for feat_info in xai_results['top_features']:
            feature_name = feat_info['feature']
            
            # Only do causal analysis if feature is in causal graph
            if feature_name in self.graph:
                causal_analysis = analyze_causal_chain(
                    self.graph,
                    feature_name,
                    alert_dict,
                    self.feature_names
                )
                causal_analyses.append(causal_analysis)
            else:
                # Document why causal analysis is unavailable
                causal_analyses.append({
                    'target': feature_name,
                    'in_graph': False,
                    'root_causes': [],
                    'causal_paths': [],
                    'direct_causes': [],
                    'reason': 'Feature not in SOC analyst-determined causal graph'
                })
        
        # Step 3: Analyze causal chain to Label (outcome)
        label_causal = None
        if 'Label' in self.graph:
            label_causal = analyze_causal_chain(
                self.graph,
                'Label',
                alert_dict,
                self.feature_names
            )
        
        # Step 4: Generate recommendations
        recommendations = self._generate_recommendations(
            xai_results,
            causal_analyses,
            alert_dict
        )
        
        # Create explanation object
        explanation = HybridExplanation(
            alert_id=alert_id,
            alert_data=alert_dict,
            xai_results=xai_results,
            causal_analyses=causal_analyses,
            label_causal=label_causal,
            recommendations=recommendations
        )
        
        return explanation
    
    def _generate_recommendations(self, xai_results, causal_analyses, alert_data):
        """
        Generate actionable recommendations based on XAI and causal analysis
        
        FIX 2: Improved severity logic for imbalanced dataset
        """
        recommendations = {
            'severity': 'MEDIUM',
            'immediate_actions': [],
            'investigation_steps': [],
            'root_cause_mitigation': [],
            'severity_reasoning': ''
        }
        
        # FIX 2: Improved severity assessment for imbalanced dataset
        confidence = xai_results['confidence']
        prediction = xai_results['prediction']
        
        if prediction == 'Important':
            # For Important alerts (1.5% of dataset)
            # High confidence threshold because dataset is heavily skewed
            if confidence > 0.95:
                recommendations['severity'] = 'HIGH'
                recommendations['severity_reasoning'] = (
                    f"High-confidence Important alert ({confidence:.1%}). "
                    "Given dataset imbalance (1.5% Important), high confidence indicates strong evidence."
                )
            elif confidence > 0.80:
                recommendations['severity'] = 'MEDIUM'
                recommendations['severity_reasoning'] = (
                    f"Medium-confidence Important alert ({confidence:.1%}). "
                    "Requires further investigation to validate."
                )
            else:
                recommendations['severity'] = 'LOW'
                recommendations['severity_reasoning'] = (
                    f"Low-confidence Important alert ({confidence:.1%}). "
                    "Model uncertain - likely borderline case requiring manual review."
                )
        else:
            # For Irrelevant alerts (98.5% of dataset)
            if confidence > 0.95:
                recommendations['severity'] = 'MINIMAL'
                recommendations['severity_reasoning'] = (
                    f"High-confidence Irrelevant alert ({confidence:.1%}). "
                    "Can likely be safely dismissed."
                )
            elif confidence > 0.80:
                recommendations['severity'] = 'LOW'
                recommendations['severity_reasoning'] = (
                    f"Medium-confidence Irrelevant alert ({confidence:.1%}). "
                    "Quick review recommended to confirm."
                )
            else:
                recommendations['severity'] = 'MEDIUM'
                recommendations['severity_reasoning'] = (
                    f"Low-confidence Irrelevant alert ({confidence:.1%}). "
                    "Uncertain classification - may be Important. Investigate carefully."
                )
        
        # Analyze top features for specific actions
        top_features = xai_results['top_features']
        
        for feat in top_features:
            feature = feat['feature']
            value = feat['value']
            importance = feat['importance']
            
            # Skip if missing value
            if feat.get('is_missing', False):
                continue
            
            # SSH-related features
            if feature == 'IntPort' and value == 22:
                recommendations['immediate_actions'].append(
                    "SSH port (22) targeted - Enable SSH hardening (key-only auth, fail2ban)"
                )
            
            # DNS-related features
            if feature == 'IntPort' and value == 53:
                recommendations['immediate_actions'].append(
                    "DNS query detected - Check for DNS tunneling or exfiltration attempts"
                )
            
            # Outlier detection (SCAS = Similarity-based Contextual Anomaly Score)
            if feature == 'SCAS' and value == 1:
                recommendations['immediate_actions'].append(
                    "⚠ OUTLIER DETECTED (SCAS=1) - Novel attack pattern not seen before"
                )
                recommendations['investigation_steps'].append(
                    "Compare with historical alerts - this is a unique pattern requiring manual analysis"
                )
                # Outliers should always be investigated
                if recommendations['severity'] in ['MINIMAL', 'LOW']:
                    recommendations['severity'] = 'MEDIUM'
                    recommendations['severity_reasoning'] += " Upgraded due to outlier detection."
            
            # High signature matching frequency
            if feature == 'SignatureMatchesPerDay' and value > 50000:
                recommendations['immediate_actions'].append(
                    f"⚠ EXTREMELY HIGH signature match frequency ({int(value):,}/day) - Potential coordinated campaign"
                )
                recommendations['investigation_steps'].append(
                    f"Search for other hosts with SignatureMatchesPerDay > 50K in last 24h"
                )
                recommendations['root_cause_mitigation'].append(
                    f"Review signature rules - {int(value):,} matches/day may indicate rule needs tuning"
                )
            elif feature == 'SignatureMatchesPerDay' and value > 10000:
                recommendations['investigation_steps'].append(
                    f"High signature match frequency ({int(value):,}/day) - Monitor for escalation"
                )
            
            # High similarity to known attacks
            if feature == 'SignatureIDSimilarity' and value > 0.95:
                recommendations['investigation_steps'].append(
                    f"Pattern closely matches known attacks (similarity={value:.2f}) - "
                    f"Check threat intelligence for SignatureID={int(alert_data.get('SignatureID', 0))}"
                )
            
            # Overall similarity score
            if feature == 'Similarity' and value > 0.9:
                recommendations['investigation_steps'].append(
                    f"High contextual similarity ({value:.2f}) - Alert matches common attack patterns"
                )
            
            # Alert volume/flooding
            if feature == 'AlertCount' and value > 100:
                recommendations['immediate_actions'].append(
                    f"⚠ HIGH VOLUME ATTACK ({int(value)} alerts) - Consider rate limiting or blocking source"
                )
            elif feature == 'AlertCount' and value > 10:
                recommendations['investigation_steps'].append(
                    f"Multiple alerts ({int(value)}) from same source - Possible persistent attack"
                )
            
            # Protocol-specific recommendations
            if feature == 'Proto':
                if value == 6:  # TCP
                    recommendations['investigation_steps'].append(
                        "TCP traffic detected - Review connection state and payload"
                    )
                elif value == 17:  # UDP
                    recommendations['investigation_steps'].append(
                        "UDP traffic detected - Check for amplification or flooding attacks"
                    )
        
        # Add causal-based recommendations
        for causal in causal_analyses:
            if causal.get('in_graph', False) and causal.get('root_causes'):
                root_cause_str = ', '.join(causal['root_causes'])
                recommendations['root_cause_mitigation'].append(
                    f"Root causes of {causal['target']}: {root_cause_str}"
                )
                
                # Add specific mitigation for known root causes
                if 'SignatureMatchesPerDay' in causal['root_causes']:
                    recommendations['root_cause_mitigation'].append(
                        "Mitigate high signature matches: Review and tune signature rules, "
                        "implement rate limiting for repeat offenders"
                    )
                
                if 'Proto' in causal['root_causes']:
                    recommendations['root_cause_mitigation'].append(
                        "Protocol-based attack: Consider protocol-specific firewall rules"
                    )
        
        # Default action if nothing specific found
        if not recommendations['immediate_actions'] and prediction == 'Important':
            recommendations['immediate_actions'].append(
                "Review alert details and correlate with other security events in SIEM"
            )
        
        if not recommendations['investigation_steps'] and prediction == 'Important':
            recommendations['investigation_steps'].append(
                f"Investigate SignatureID {int(alert_data.get('SignatureID', 0))} in threat intelligence feeds"
            )
        
        # For Irrelevant alerts, add dismissal guidance
        if prediction == 'Irrelevant' and confidence > 0.9:
            recommendations['immediate_actions'].append(
                "Alert classified as Irrelevant with high confidence - Safe to dismiss after quick review"
            )
        
        return recommendations

class HybridExplanation:
    """
    Container for hybrid explanation results
    """
    
    def __init__(self, alert_id, alert_data, xai_results, causal_analyses, 
                 label_causal, recommendations):
        self.alert_id = alert_id
        self.alert_data = alert_data
        self.xai = xai_results
        self.causal = causal_analyses
        self.label_causal = label_causal
        self.recommendations = recommendations
    
    def to_dict(self):
        """Convert to dictionary with JSON-safe types"""
        
        def make_json_safe(obj):
            """Convert numpy types to native Python types"""
            if isinstance(obj, dict):
                return {k: make_json_safe(v) for k, v in obj.items()}
            elif isinstance(obj, list):
                return [make_json_safe(item) for item in obj]
            elif isinstance(obj, (np.integer, np.int64, np.int32)):
                return int(obj)
            elif isinstance(obj, (np.floating, np.float64, np.float32)):
                return float(obj)
            elif isinstance(obj, (np.bool_, bool)):
                return bool(obj)
            elif isinstance(obj, np.ndarray):
                return obj.tolist()
            else:
                return obj
        
        data = {
            'alert_id': self.alert_id,
            'alert_data': self.alert_data,
            'xai_analysis': self.xai,
            'causal_analysis': self.causal,
            'label_causal': self.label_causal,
            'recommendations': self.recommendations
        }
        
        # Make all data JSON-safe
        return make_json_safe(data)
    
    def to_json(self, filepath=None):
        """Export to JSON"""
        data = self.to_dict()
        if filepath:
            with open(filepath, 'w') as f:
                json.dump(data, f, indent=2)
            return filepath
        else:
            return json.dumps(data, indent=2)
    
    def to_text(self):
        """Generate natural language explanation"""
        lines = []
        lines.append("="*70)
        lines.append(f"HYBRID EXPLANATION - Alert #{self.alert_id or 'Unknown'}")
        lines.append("="*70)
        lines.append("")
        
        # Classification
        lines.append("🎯 CLASSIFICATION")
        lines.append(f"  Prediction: {self.xai['prediction']}")
        lines.append(f"  Confidence: {self.xai['confidence']:.1%}")
        lines.append(f"  Severity: {self.recommendations['severity']}")
        if self.recommendations.get('severity_reasoning'):
            lines.append(f"  Reasoning: {self.recommendations['severity_reasoning']}")
        lines.append(f"  Present Features: {self.xai['num_present_features']}/42")
        lines.append(f"  Missing Features: {self.xai['num_missing_features']}/42")
        lines.append("")
        
        # XAI Analysis
        lines.append("📊 XAI ANALYSIS: What triggered this alert?")
        lines.append("  (Showing only features with values present in this alert)")
        lines.append("")
        for i, feat in enumerate(self.xai['top_features'], 1):
            missing_marker = " [MISSING]" if feat.get('is_missing', False) else ""
            lines.append(f"  {i}. {feat['feature']}{missing_marker}")
            lines.append(f"     Value: {feat['value']:.4f}")
            lines.append(f"     Importance: {feat['importance']:.4f}")
            lines.append("")
        
        # Causal Analysis
        lines.append("🔍 CAUSAL ANALYSIS: Why/How did this happen?")
        lines.append("  (Limited to 10 SOC analyst-identified features)")
        lines.append("")
        
        has_causal_info = False
        for causal in self.causal:
            if not causal.get('in_graph', False):
                lines.append(f"  Feature: {causal['target']}")
                lines.append(f"  ⚠ {causal.get('reason', 'Not in causal graph')}")
                lines.append("")
                continue
            
            has_causal_info = True
            lines.append(f"  Feature: {causal['target']}")
            
            if causal.get('root_causes'):
                lines.append(f"  Root Causes: {', '.join(causal['root_causes'])}")
            
            if causal.get('causal_paths'):
                lines.append(f"  Causal Chains:")
                for path_info in causal['causal_paths'][:2]:  # Show top 2 paths
                    path_str = ' → '.join(path_info['path'])
                    lines.append(f"    • {path_str}")
            
            lines.append("")
        
        if not has_causal_info:
            lines.append("  ⚠ No top features found in causal graph")
            lines.append("  This indicates top XAI features are protocol-specific")
            lines.append("  (e.g., TLS, HTTP) and were not among SOC analyst-selected features")
            lines.append("")
        
        # Label causal analysis
        if self.label_causal and self.label_causal.get('in_graph'):
            lines.append("  Direct Causes of Alert Classification:")
            for cause in self.label_causal.get('direct_causes', []):
                value = cause['value']
                if isinstance(value, (int, float)) and value != -1.0:
                    lines.append(f"    • {cause['feature']} = {value:.4f}")
            lines.append("")
        
        # Recommendations
        lines.append("✅ RECOMMENDED ACTIONS")
        lines.append("")
        
        if self.recommendations['immediate_actions']:
            lines.append("  Immediate Actions:")
            for action in self.recommendations['immediate_actions']:
                lines.append(f"    • {action}")
            lines.append("")
        
        if self.recommendations['investigation_steps']:
            lines.append("  Investigation Steps:")
            for step in self.recommendations['investigation_steps']:
                lines.append(f"    • {step}")
            lines.append("")
        
        if self.recommendations['root_cause_mitigation']:
            lines.append("  Root Cause Mitigation:")
            for mitigation in self.recommendations['root_cause_mitigation']:
                lines.append(f"    • {mitigation}")
            lines.append("")
        
        lines.append("="*70)
        
        return '\n'.join(lines)
    
    def visualize(self, save_path=None):
        """Create visual explanation"""
        fig = plt.figure(figsize=(16, 10))
        gs = fig.add_gridspec(3, 2, hspace=0.3, wspace=0.3)
        
        # Title
        pred_emoji = "⚠️" if self.xai['prediction'] == 'Important' else "✓"
        fig.suptitle(f'{pred_emoji} Hybrid Explanation - Alert #{self.alert_id or "Unknown"}', 
                     fontsize=16, fontweight='bold')
        
        # 1. XAI Feature Importance (top subplot)
        ax1 = fig.add_subplot(gs[0, :])
        top_feats = self.xai['top_features']
        features = [f['feature'] for f in top_feats]
        importances = [f['importance'] for f in top_feats]
        
        # Color code: red for positive, blue for negative, gray for missing
        colors = []
        for feat in top_feats:
            if feat.get('is_missing', False):
                colors.append('gray')
            elif feat['importance'] > 0:
                colors.append('red')
            else:
                colors.append('blue')
        
        ax1.barh(features, importances, color=colors, alpha=0.7)
        ax1.set_xlabel('Importance Score')
        ax1.set_title('XAI Analysis: Top Feature Importance (Gray = Missing Values)', fontweight='bold')
        ax1.axvline(x=0, color='black', linestyle='-', linewidth=0.5)
        ax1.grid(axis='x', alpha=0.3)
        
        # 2. Prediction info (middle left)
        ax2 = fig.add_subplot(gs[1, 0])
        ax2.axis('off')
        
        severity_color = {
            'HIGH': '🔴', 'MEDIUM': '🟡', 'LOW': '🟢', 'MINIMAL': '⚪'
        }.get(self.recommendations['severity'], '⚪')
        
        pred_text = f"""
PREDICTION
{'='*30}
Class: {self.xai['prediction']}
Confidence: {self.xai['confidence']:.1%}
Severity: {severity_color} {self.recommendations['severity']}

Present Features: {self.xai['num_present_features']}/42
Missing Features: {self.xai['num_missing_features']}/42
"""
        ax2.text(0.1, 0.5, pred_text, fontsize=10, family='monospace',
                verticalalignment='center')
        
        # 3. Causal paths (middle right)
        ax3 = fig.add_subplot(gs[1, 1])
        ax3.axis('off')
        
        causal_text = "CAUSAL CHAINS\n" + "="*30 + "\n"
        has_paths = False
        for causal in self.causal[:3]:
            if causal.get('in_graph') and causal.get('causal_paths'):
                has_paths = True
                for path_info in causal['causal_paths'][:1]:
                    path_str = ' → '.join(path_info['path'][:4])
                    if len(path_info['path']) > 4:
                        path_str += ' ...'
                    causal_text += f"• {path_str}\n"
        
        if not has_paths:
            causal_text += "\n⚠ Top features not in\nSOC analyst-selected\ncausal graph\n"
        
        ax3.text(0.1, 0.5, causal_text, fontsize=10, family='monospace',
                verticalalignment='center')
        
        # 4. Recommendations (bottom)
        ax4 = fig.add_subplot(gs[2, :])
        ax4.axis('off')
        
        rec_text = "RECOMMENDED ACTIONS\n" + "="*50 + "\n"
        if self.recommendations['immediate_actions']:
            rec_text += "\nImmediate:\n"
            for action in self.recommendations['immediate_actions'][:3]:
                shortened = action[:70] + "..." if len(action) > 70 else action
                rec_text += f"  • {shortened}\n"
        
        if self.recommendations['investigation_steps']:
            rec_text += "\nInvestigation:\n"
            for step in self.recommendations['investigation_steps'][:2]:
                shortened = step[:70] + "..." if len(step) > 70 else step
                rec_text += f"  • {shortened}\n"
        
        ax4.text(0.05, 0.5, rec_text, fontsize=9, family='monospace',
                verticalalignment='center')
        
        if save_path:
            plt.savefig(save_path, dpi=300, bbox_inches='tight')
            print(f"✓ Saved visualization to: {save_path}")
        
        return fig

# ==================== DEMO EXAMPLES ====================
print("\n" + "="*70)
print("GENERATING DEMO EXPLANATIONS")
print("="*70)

# Create explainer instance
explainer = HybridExplainer(
    model=model,
    causal_graph=causal_graph,
    feature_names=FEATURE_NAMES,
    scaler=scaler
)

# Example 1: Use data from file if available
if df is not None and len(df) > 0:
    print("\nGenerating explanations for diverse sample alerts...")
    
    # Check label distribution
    if 'Label' in df.columns:
        label_counts = df['Label'].value_counts()
        print(f"\nDataset distribution:")
        for label, count in label_counts.items():
            pct = count / len(df) * 100
            print(f"  {label}: {count:,} ({pct:.1f}%)")
    
    # Ensure we have all required features
    available_features = [f for f in FEATURE_NAMES if f in df.columns]
    print(f"\nAvailable features: {len(available_features)}/{len(FEATURE_NAMES)}")
    
    if len(available_features) < len(FEATURE_NAMES):
        print("⚠ Warning: Not all features available. Creating synthetic examples...")
        df = None
    else:
        # Select diverse samples: Important alerts AND Irrelevant alerts
        important_samples = []
        irrelevant_samples = []
        
        if 'Label' in df.columns:
            # Check if labels are numeric (0, 1) or string ('Important', 'Irrelevant')
            sample_label = df['Label'].iloc[0]
            
            if isinstance(sample_label, (int, np.integer)):
                # Numeric labels: 1 = Important, 0 = Irrelevant
                print("  Detected numeric labels (0=Irrelevant, 1=Important)")
                important_df = df[df['Label'] == 1]
                irrelevant_df = df[df['Label'] == 0]
            else:
                # String labels
                print("  Detected string labels")
                important_df = df[df['Label'] == 'Important']
                irrelevant_df = df[df['Label'] == 'Irrelevant']
            
            # Get some Important alerts
            if len(important_df) > 0:
                important_samples = important_df.sample(n=min(3, len(important_df)), random_state=42).index.tolist()
            
            # Get some Irrelevant alerts
            if len(irrelevant_df) > 0:
                irrelevant_samples = irrelevant_df.sample(n=min(2, len(irrelevant_df)), random_state=42).index.tolist()
            
            sample_indices = important_samples + irrelevant_samples
            print(f"\nSelected {len(important_samples)} Important + {len(irrelevant_samples)} Irrelevant alerts")
        else:
            sample_indices = [0, 100, 500, 1000, 5000]
        
        for idx in sample_indices:
            if idx >= len(df):
                continue
            
            print(f"\n{'='*70}")
            true_label = df.iloc[idx].get('Label', 'Unknown')
            # Convert numeric label to string for display
            if isinstance(true_label, (int, np.integer)):
                true_label_str = 'Important' if true_label == 1 else 'Irrelevant'
            else:
                true_label_str = str(true_label)
            print(f"EXAMPLE ALERT #{idx} (True Label: {true_label_str})")
            print(f"{'='*70}")
            
            # Get alert data (all 42 features)
            alert_row = df.iloc[idx]
            alert_features = alert_row[FEATURE_NAMES].values
            
            # Generate explanation
            explanation = explainer.explain(alert_features, alert_id=idx)
            
            # Print text explanation
            print(explanation.to_text())
            
            # Save JSON
            json_path = f'hybrid_explanation_fixed_{idx}.json'
            explanation.to_json(json_path)
            print(f"\n✓ Saved JSON to: {json_path}")
            
            # Save visualization
            viz_path = f'hybrid_explanation_fixed_{idx}.png'
            explanation.visualize(save_path=viz_path)
            plt.close()

if df is None:
    print("\n⚠ No data available for demo. Creating synthetic examples...")
    
    # Create synthetic alerts for diverse scenarios
    
    # Scenario 1: SSH brute force (Important)
    print("\n" + "="*70)
    print("SYNTHETIC EXAMPLE 1: SSH Brute Force Attack")
    print("="*70)
    
    ssh_alert = {f: 0.0 for f in FEATURE_NAMES}
    ssh_alert.update({
        'SignatureMatchesPerDay': 50000,
        'Similarity': 0.95,
        'SCAS': 1.0,  # Outlier
        'SignatureID': 2001219,
        'SignatureIDSimilarity': 0.98,
        'Proto': 6,  # TCP
        'AlertCount': 250,
        'IntPort': 22,  # SSH
        'ExtPort': 54321,
        'ProtoSimilarity': 1.0,
        'IntPortSimilarity': 0.95,
        # Mark TLS features as N/A
        'TlsFingerprintSimilarity': -1.0,
        'TlsIssuerDnSimilarity': -1.0,
        'TlsSubjectSimilarity': -1.0,
        'TlsVersionSimilarity': -1.0,
        'TlsSniSimilarity': -1.0,
        'TlsJa3hashSimilarity': -1.0
    })
    
    explanation = explainer.explain(ssh_alert, alert_id="SSH-001")
    print(explanation.to_text())
    explanation.to_json('hybrid_explanation_fixed_ssh_attack.json')
    explanation.visualize(save_path='hybrid_explanation_fixed_ssh_attack.png')
    plt.close()
    
    # Scenario 2: Benign DNS query (Irrelevant)
    print("\n" + "="*70)
    print("SYNTHETIC EXAMPLE 2: Benign DNS Query")
    print("="*70)
    
    dns_alert = {f: 0.0 for f in FEATURE_NAMES}
    dns_alert.update({
        'SignatureMatchesPerDay': 50,
        'Similarity': 0.3,  # Low similarity
        'SCAS': 0.0,  # Not an outlier
        'SignatureID': 2027757,
        'SignatureIDSimilarity': 0.4,
        'Proto': 17,  # UDP
        'AlertCount': 1,
        'IntPort': 53,  # DNS
        'ExtPort': 49152,
        'ProtoSimilarity': 0.5,
        'DnsRrnameSimilarity': 0.8,
        'DnsRrtypeSimilarity': 0.9,
        # Mark irrelevant features as N/A
        'TlsFingerprintSimilarity': -1.0,
        'TlsIssuerDnSimilarity': -1.0,
        'TlsSubjectSimilarity': -1.0,
        'HttpContentTypeSimilarity': -1.0,
        'HttpStatusSimilarity': -1.0
    })
    
    explanation = explainer.explain(dns_alert, alert_id="DNS-001")
    print(explanation.to_text())
    explanation.to_json('hybrid_explanation_fixed_dns_benign.json')
    explanation.visualize(save_path='hybrid_explanation_fixed_dns_benign.png')
    plt.close()
    
    # Scenario 3: HTTP attack with high confidence (Important)
    print("\n" + "="*70)
    print("SYNTHETIC EXAMPLE 3: HTTP Attack")
    print("="*70)
    
    http_alert = {f: 0.0 for f in FEATURE_NAMES}
    http_alert.update({
        'SignatureMatchesPerDay': 15000,
        'Similarity': 0.88,
        'SCAS': 0.0,
        'SignatureID': 2842116,
        'SignatureIDSimilarity': 0.92,
        'Proto': 6,  # TCP
        'AlertCount': 50,
        'IntPort': 80,  # HTTP
        'ExtPort': 45678,
        'ProtoSimilarity': 0.95,
        'HttpContentTypeSimilarity': 0.85,
        'HttpStatusSimilarity': 0.90,
        'HttpMethodSimilarity': 0.88,
        'HttpUrlSimilarity': 0.92,
        'AppProtoSimilarity': 0.95,
        # Mark TLS as N/A (HTTP not HTTPS)
        'TlsFingerprintSimilarity': -1.0,
        'TlsIssuerDnSimilarity': -1.0,
        'TlsSubjectSimilarity': -1.0
    })
    
    explanation = explainer.explain(http_alert, alert_id="HTTP-001")
    print(explanation.to_text())
    explanation.to_json('hybrid_explanation_fixed_http_attack.json')
    explanation.visualize(save_path='hybrid_explanation_fixed_http_attack.png')
    plt.close()

print("\n" + "="*70)
print("STEP 4 COMPLETE! (FIXED VERSION)")
print("="*70)
print("\nKey Improvements:")
print("  ✓ Missing value filtering (-1.0 indicators excluded from top features)")
print("  ✓ Improved severity assessment for imbalanced dataset (1.5% vs 98.5%)")
print("  ✓ Better handling of Irrelevant predictions")
print("  ✓ Enhanced domain-specific recommendations")
print("  ✓ Severity reasoning explanations")
print("\nGenerated files:")
print("  - hybrid_explanation_fixed_*.json (structured data)")
print("  - hybrid_explanation_fixed_*.png (visualizations)")
print("\nNext: Step 5 - Evaluation (quantitative metrics + comparison)")

STEP 4: HYBRID EXPLANATION GENERATOR (FIXED)

Loading models and data...
✓ Loaded LSTM model from ../step1_lstm_xai/best_lstm.pt
✓ Loaded scaler from ../step1_lstm_xai/scaler.joblib
✓ Loaded causal graph from ../step2_causal_discovery/causal_graph.gpickle
  Nodes: 11, Edges: 20
  Note: Causal graph limited to 10 SOC analyst-identified features
✓ Loaded data from ../step2_causal_discovery/causal_discovery_data.csv: (10000, 11)
  ⚠ Causal discovery data has only 11 columns
  ⚠ LSTM needs 42 features
  Loading original dataset instead...
  ✓ Loaded full dataset: (1395324, 47)
  ✓ Using full dataset with 43 features

XAI COMPONENT: Feature Importance (with Missing Value Filtering)

CAUSAL COMPONENT: Root Cause Analysis

HYBRID EXPLAINER: Combining XAI + Causal

GENERATING DEMO EXPLANATIONS

Generating explanations for diverse sample alerts...

Dataset distribution:
  0: 1,374,372 (98.5%)
  1: 20,952 (1.5%)

Available features: 42/42
  Detected numeric labels (0=Irrelevant, 1=Important)

Se