In [None]:
# Step 2: Causal Discovery for NIDS Alert Classification
# Implements PC and Hill-Climb algorithms to learn causal relationships
# Optimized for M1 Mac

import os
import warnings
import numpy as np
import pandas as pd
import networkx as nx
import matplotlib.pyplot as plt
from pathlib import Path
from sklearn.preprocessing import KBinsDiscretizer
from scipy import stats

warnings.filterwarnings('ignore')

print("="*70)
print("STEP 2: CAUSAL DISCOVERY FOR NIDS ALERTS")
print("="*70)

# ==================== CONFIGURATION ====================
DATA_PATH = 'dataset-labeled-anon-ip.csv'
SAMPLE_SIZE = 5000  # Per class (10K total)
SEED = 42
np.random.seed(SEED)

# Feature selection based on:
# 1. SOC expert features (from paper)
# 2. Top XAI features (from your Step 1 SHAP analysis)
SOC_FEATURES = [
    'SignatureMatchesPerDay',
    'Similarity', 
    'SCAS',
    'SignatureID',
    'SignatureIDSimilarity'
]

# Add top features from network traffic analysis
ADDITIONAL_FEATURES = [
    'Proto',
    'AlertCount',
    'IntPort',
    'ExtPort',
    'ProtoSimilarity'
]

SELECTED_FEATURES = SOC_FEATURES + ADDITIONAL_FEATURES
TARGET = 'Label'

print(f"\nSelected {len(SELECTED_FEATURES)} features for causal discovery:")
for i, f in enumerate(SELECTED_FEATURES, 1):
    print(f"  {i}. {f}")

# ==================== DATA LOADING & PREPROCESSING ====================
print("\n" + "="*70)
print("LOADING AND PREPROCESSING DATA")
print("="*70)

# Check if file exists
if not Path(DATA_PATH).exists():
    print(f"ERROR: {DATA_PATH} not found in current directory")
    print(f"Current directory: {os.getcwd()}")
    print("\nPlease ensure the dataset file is in the same directory as this script.")
    raise FileNotFoundError(f"{DATA_PATH} not found")

# Get file info
file_size = Path(DATA_PATH).stat().st_size / (1024**2)
print(f"Found {DATA_PATH} ({file_size:.2f} MB)")

# Load dataset with robust error handling
print(f"Loading data...")
try:
    # Method 1: Python engine with line limit (most reliable for macOS)
    df = pd.read_csv(DATA_PATH, 
                     engine='python',
                     encoding='utf-8',
                     on_bad_lines='skip')  # Skip problematic lines
    print(f"✓ Loaded dataset: {df.shape}")
    
except Exception as e1:
    print(f"Method 1 failed: {e1}")
    try:
        # Method 2: Read first N lines that we need
        print("Trying to read specific number of rows...")
        # We only need ~10K samples anyway
        df = pd.read_csv(DATA_PATH, 
                         nrows=50000,  # Read first 50K rows
                         engine='python')
        print(f"✓ Loaded subset: {df.shape}")
        
    except Exception as e2:
        print(f"Method 2 failed: {e2}")
        print("\nTrying to diagnose the file...")
        
        # Check if file is readable
        try:
            with open(DATA_PATH, 'r', encoding='utf-8') as f:
                first_line = f.readline()
                print(f"File is readable. Header: {first_line[:100]}")
        except Exception as e3:
            print(f"Cannot read file: {e3}")
            raise
        
        raise Exception("All loading methods failed. File may be corrupted.")

# Select relevant columns
required_cols = SELECTED_FEATURES + [TARGET]
df = df[required_cols].copy()

# Handle missing values and -1 placeholders
df = df.replace({-1: np.nan})
df = df.fillna(df.median(numeric_only=True))

# Balance classes
print(f"\nBalancing classes (sampling {SAMPLE_SIZE} per class)...")
df_balanced = pd.concat([
    df[df[TARGET] == 0].sample(n=SAMPLE_SIZE, random_state=SEED),
    df[df[TARGET] == 1].sample(n=SAMPLE_SIZE, random_state=SEED)
]).reset_index(drop=True)

print(f"Balanced dataset: {df_balanced.shape}")
print(f"Class distribution:\n{df_balanced[TARGET].value_counts()}")

# ==================== DISCRETIZATION ====================
print("\n" + "="*70)
print("DISCRETIZING CONTINUOUS FEATURES")
print("="*70)

# Causal discovery works better with discrete/categorical data
# Using quantile-based binning (5 bins)
df_discrete = df_balanced.copy()

discretizer = KBinsDiscretizer(n_bins=5, encode='ordinal', strategy='quantile')

continuous_features = [f for f in SELECTED_FEATURES 
                       if df_discrete[f].nunique() > 10]

print(f"Discretizing {len(continuous_features)} continuous features:")
for f in continuous_features:
    unique_before = df_discrete[f].nunique()
    df_discrete[f] = discretizer.fit_transform(df_discrete[[f]])
    unique_after = df_discrete[f].nunique()
    print(f"  {f}: {unique_before} → {unique_after} bins")

# Save preprocessed data
df_discrete.to_csv('causal_discovery_data.csv', index=False)
print("\nSaved preprocessed data to: causal_discovery_data.csv")

# ==================== DOMAIN KNOWLEDGE CONSTRAINTS ====================
print("\n" + "="*70)
print("APPLYING DOMAIN KNOWLEDGE CONSTRAINTS")
print("="*70)

# Define forbidden edges based on network security domain knowledge
# Format: (from, to) - "from cannot cause to"
FORBIDDEN_EDGES = [
    # Low-level protocol features cannot cause high-level signature features
    ('Proto', 'SignatureID'),
    ('Proto', 'SignatureMatchesPerDay'),
    ('ExtPort', 'SignatureID'),
    ('IntPort', 'SignatureID'),
    
    # Signature ID is determined by signature rules, not by similarity
    ('Similarity', 'SignatureID'),
    ('SCAS', 'SignatureID'),
    
    # Port numbers don't cause protocol
    ('IntPort', 'Proto'),
    ('ExtPort', 'Proto'),
    
    # Alert count is an effect, not a cause of individual features
    ('AlertCount', 'Proto'),
    ('AlertCount', 'IntPort'),
    ('AlertCount', 'ExtPort'),
]

print(f"Defined {len(FORBIDDEN_EDGES)} forbidden edges:")
for i, (src, dst) in enumerate(FORBIDDEN_EDGES[:5], 1):
    print(f"  {i}. {src} → {dst} (forbidden)")
print(f"  ... and {len(FORBIDDEN_EDGES) - 5} more")

# ==================== INSTALL CAUSAL-LEARN ====================
print("\n" + "="*70)
print("CHECKING CAUSAL-LEARN INSTALLATION")
print("="*70)

try:
    #import causal_learn
    print("causal-learn is installed")
except ImportError:
    print("Installing causal-learn...")
    os.system("pip install causal-learn")
    #import causal_learn

from causallearn.search.ConstraintBased.PC import pc
from causallearn.search.ScoreBased.GES import ges
from causallearn.utils.GraphUtils import GraphUtils
from causallearn.utils.PCUtils.BackgroundKnowledge import BackgroundKnowledge
from causallearn.graph.GraphNode import GraphNode

# ==================== PC ALGORITHM ====================
print("\n" + "="*70)
print("RUNNING PC ALGORITHM (Constraint-Based)")
print("="*70)

# Prepare data matrix
X = df_discrete[SELECTED_FEATURES].values
feature_names = SELECTED_FEATURES

# Set up background knowledge (forbidden edges)
bk = BackgroundKnowledge()

# Create GraphNode objects for each feature
graph_nodes = [GraphNode(f) for f in feature_names]

# Add forbidden edges using GraphNode objects
for src, dst in FORBIDDEN_EDGES:
    if src in feature_names and dst in feature_names:
        src_idx = feature_names.index(src)
        dst_idx = feature_names.index(dst)
        bk.add_forbidden_by_node(graph_nodes[src_idx], graph_nodes[dst_idx])

print("Running PC algorithm...")
print(f"  - Data shape: {X.shape}")
print(f"  - Independence test: fisherz")
print(f"  - Significance level: α = 0.05")
print(f"  - Background knowledge: {len(FORBIDDEN_EDGES)} forbidden edges")

# Run PC algorithm
cg_pc = pc(
    X,
    alpha=0.05,
    indep_test='fisherz',
    background_knowledge=bk,
    verbose=False,
    show_progress=False
)

print("\nPC Algorithm completed!")

# Extract graph
pc_graph = cg_pc.G
pc_edges = []
for i in range(len(feature_names)):
    for j in range(len(feature_names)):
        if pc_graph.graph[i, j] == 1:  # i -> j
            pc_edges.append((feature_names[i], feature_names[j]))
        elif pc_graph.graph[i, j] == -1 and pc_graph.graph[j, i] == 1:  # i <- j
            pc_edges.append((feature_names[j], feature_names[i]))

print(f"Discovered {len(pc_edges)} directed edges")

# ==================== HILL-CLIMB (GES) ALGORITHM ====================
print("\n" + "="*70)
print("RUNNING GES ALGORITHM (Score-Based, similar to Hill-Climb)")
print("="*70)

print("Running GES algorithm...")
print(f"  - Scoring method: BIC")
print(f"  - Data shape: {X.shape}")

# Run GES (Greedy Equivalence Search - score-based like Hill-Climb)
record_ges = ges(X, score_func='local_score_BIC', maxP=None)

print("\nGES Algorithm completed!")

# Extract graph
ges_graph = record_ges['G']
ges_edges = []
for i in range(len(feature_names)):
    for j in range(len(feature_names)):
        if ges_graph.graph[i, j] == 1 and ges_graph.graph[j, i] == -1:  # i -> j
            ges_edges.append((feature_names[i], feature_names[j]))

print(f"Discovered {len(ges_edges)} directed edges")

# ==================== COMPARISON & CONSENSUS ====================
print("\n" + "="*70)
print("COMPARING PC vs GES RESULTS")
print("="*70)

pc_set = set(pc_edges)
ges_set = set(ges_edges)

consensus_edges = pc_set.intersection(ges_set)
pc_only = pc_set - ges_set
ges_only = ges_set - pc_set

print(f"\nConsensus edges (both algorithms agree): {len(consensus_edges)}")
print(f"PC-only edges: {len(pc_only)}")
print(f"GES-only edges: {len(ges_only)}")

print("\nConsensus edges:")
for src, dst in sorted(consensus_edges):
    print(f"  {src} → {dst}")

# ==================== BUILD FINAL CAUSAL GRAPH ====================
print("\n" + "="*70)
print("BUILDING FINAL CAUSAL GRAPH")
print("="*70)

# Use consensus + high-confidence edges from both
final_edges = list(consensus_edges)

# Add edges that appear in at least one algorithm and make domain sense
candidate_edges = pc_set.union(ges_set) - consensus_edges

print(f"\nFinal causal graph contains {len(final_edges)} edges")

# Create NetworkX directed graph
G = nx.DiGraph()
G.add_nodes_from(SELECTED_FEATURES)
G.add_edges_from(final_edges)

# Add outcome node
G.add_node(TARGET)

# Check which features are connected to the outcome
print(f"\nAnalyzing causal relationships with outcome ({TARGET})...")

# Simple correlation analysis with outcome
correlations = []
for feature in SELECTED_FEATURES:
    corr, pval = stats.spearmanr(df_discrete[feature], df_discrete[TARGET])
    correlations.append((feature, corr, pval))

# Add edges to outcome for strong correlations
for feature, corr, pval in sorted(correlations, key=lambda x: abs(x[1]), reverse=True):
    if pval < 0.01 and abs(corr) > 0.1:
        G.add_edge(feature, TARGET)
        print(f"  {feature} → {TARGET} (ρ={corr:.3f}, p={pval:.2e})")

# ==================== GRAPH ANALYSIS ====================
print("\n" + "="*70)
print("CAUSAL GRAPH ANALYSIS")
print("="*70)

print(f"\nNodes: {G.number_of_nodes()}")
print(f"Edges: {G.number_of_edges()}")

# Find root causes (nodes with no incoming edges)
root_causes = [node for node in G.nodes() if G.in_degree(node) == 0]
print(f"\nRoot causes (no incoming edges): {len(root_causes)}")
for rc in root_causes:
    print(f"  - {rc}")

# Find direct causes of outcome
if TARGET in G:
    direct_causes = list(G.predecessors(TARGET))
    print(f"\nDirect causes of {TARGET}: {len(direct_causes)}")
    for dc in direct_causes:
        print(f"  - {dc}")

# Find features with highest causal influence (most outgoing edges)
influence = [(node, G.out_degree(node)) for node in SELECTED_FEATURES]
influence.sort(key=lambda x: x[1], reverse=True)
print(f"\nMost influential features (highest out-degree):")
for node, degree in influence[:5]:
    print(f"  - {node}: {degree} outgoing edges")

# ==================== VISUALIZATION ====================
print("\n" + "="*70)
print("GENERATING VISUALIZATIONS")
print("="*70)

fig, axes = plt.subplots(1, 2, figsize=(20, 8))

# PC Algorithm Graph
ax1 = axes[0]
G_pc = nx.DiGraph()
G_pc.add_nodes_from(SELECTED_FEATURES)
G_pc.add_edges_from(pc_edges)

pos1 = nx.spring_layout(G_pc, k=2, iterations=50, seed=SEED)
nx.draw_networkx_nodes(G_pc, pos1, node_color='lightblue', 
                       node_size=1500, alpha=0.9, ax=ax1)
nx.draw_networkx_labels(G_pc, pos1, font_size=8, font_weight='bold', ax=ax1)
nx.draw_networkx_edges(G_pc, pos1, edge_color='gray', 
                       arrows=True, arrowsize=20, ax=ax1)
ax1.set_title(f'PC Algorithm\n{len(pc_edges)} edges', 
              fontsize=14, fontweight='bold')
ax1.axis('off')

# GES Algorithm Graph
ax2 = axes[1]
G_ges = nx.DiGraph()
G_ges.add_nodes_from(SELECTED_FEATURES)
G_ges.add_edges_from(ges_edges)

pos2 = nx.spring_layout(G_ges, k=2, iterations=50, seed=SEED)
nx.draw_networkx_nodes(G_ges, pos2, node_color='lightcoral', 
                       node_size=1500, alpha=0.9, ax=ax2)
nx.draw_networkx_labels(G_ges, pos2, font_size=8, font_weight='bold', ax=ax2)
nx.draw_networkx_edges(G_ges, pos2, edge_color='gray', 
                       arrows=True, arrowsize=20, ax=ax2)
ax2.set_title(f'GES Algorithm\n{len(ges_edges)} edges', 
              fontsize=14, fontweight='bold')
ax2.axis('off')

plt.tight_layout()
plt.savefig('causal_graphs_comparison.png', dpi=300, bbox_inches='tight')
print("Saved: causal_graphs_comparison.png")

# Final consensus graph with outcome
fig, ax = plt.subplots(1, 1, figsize=(14, 10))

# Color nodes by type
node_colors = []
for node in G.nodes():
    if node == TARGET:
        node_colors.append('gold')
    elif node in SOC_FEATURES:
        node_colors.append('lightblue')
    else:
        node_colors.append('lightgreen')

pos = nx.spring_layout(G, k=3, iterations=100, seed=SEED)
nx.draw_networkx_nodes(G, pos, node_color=node_colors, 
                       node_size=2000, alpha=0.9, ax=ax)
nx.draw_networkx_labels(G, pos, font_size=9, font_weight='bold', ax=ax)
nx.draw_networkx_edges(G, pos, edge_color='gray', 
                       arrows=True, arrowsize=20, 
                       connectionstyle='arc3,rad=0.1', ax=ax)

# Legend
from matplotlib.patches import Patch
legend_elements = [
    Patch(facecolor='gold', label='Outcome'),
    Patch(facecolor='lightblue', label='SOC Expert Features'),
    Patch(facecolor='lightgreen', label='Additional Features')
]
ax.legend(handles=legend_elements, loc='upper right', fontsize=10)
ax.set_title(f'Final Causal Graph\n{G.number_of_edges()} edges', 
             fontsize=16, fontweight='bold')
ax.axis('off')

plt.tight_layout()
plt.savefig('final_causal_graph.png', dpi=300, bbox_inches='tight')
print("Saved: final_causal_graph.png")

# ==================== SAVE RESULTS ====================
print("\n" + "="*70)
print("SAVING RESULTS")
print("="*70)

# Save edge lists
pd.DataFrame(pc_edges, columns=['Source', 'Target']).to_csv(
    'pc_edges.csv', index=False
)
pd.DataFrame(ges_edges, columns=['Source', 'Target']).to_csv(
    'ges_edges.csv', index=False
)
pd.DataFrame(final_edges, columns=['Source', 'Target']).to_csv(
    'final_causal_edges.csv', index=False
)

# Save graph as adjacency matrix
adj_matrix = nx.to_pandas_adjacency(G, dtype=int)
adj_matrix.to_csv('causal_adjacency_matrix.csv')

# Save NetworkX graph object
nx.write_gpickle(G, 'causal_graph.gpickle')

print("\nSaved files:")
print("  - causal_discovery_data.csv (preprocessed data)")
print("  - pc_edges.csv (PC algorithm results)")
print("  - ges_edges.csv (GES algorithm results)")
print("  - final_causal_edges.csv (consensus edges)")
print("  - causal_adjacency_matrix.csv (adjacency matrix)")
print("  - causal_graph.gpickle (NetworkX graph object)")
print("  - causal_graphs_comparison.png (visualization)")
print("  - final_causal_graph.png (final graph visualization)")

# ==================== CAUSAL QUERY INTERFACE ====================
print("\n" + "="*70)
print("CAUSAL QUERY INTERFACE (for Step 4 integration)")
print("="*70)

def find_root_causes(graph, target_feature):
    """Find all ancestors (root causes) of a target feature"""
    if target_feature not in graph:
        return []
    ancestors = nx.ancestors(graph, target_feature)
    return list(ancestors)

def find_causal_path(graph, source, target):
    """Find causal path from source to target"""
    if source not in graph or target not in graph:
        return None
    try:
        path = nx.shortest_path(graph, source, target)
        return path
    except nx.NetworkXNoPath:
        return None

def get_direct_causes(graph, feature):
    """Get direct causes (parents) of a feature"""
    if feature not in graph:
        return []
    return list(graph.predecessors(feature))

def get_direct_effects(graph, feature):
    """Get direct effects (children) of a feature"""
    if feature not in graph:
        return []
    return list(graph.successors(feature))

# Example queries
print("\nExample Causal Queries:")

# Query 1: Root causes of SCAS
if 'SCAS' in G:
    root_causes_scas = find_root_causes(G, 'SCAS')
    print(f"\n1. Root causes of SCAS: {root_causes_scas}")

# Query 2: Direct causes of Label (important alerts)
if TARGET in G:
    direct_causes_label = get_direct_causes(G, TARGET)
    print(f"\n2. Direct causes of {TARGET}: {direct_causes_label}")

# Query 3: Causal path from SignatureMatchesPerDay to Label
if 'SignatureMatchesPerDay' in G and TARGET in G:
    path = find_causal_path(G, 'SignatureMatchesPerDay', TARGET)
    if path:
        print(f"\n3. Causal path SignatureMatchesPerDay → {TARGET}:")
        print(f"   {' → '.join(path)}")

# Save query functions for Step 4
print("\n" + "="*70)
print("STEP 2 COMPLETE!")
print("="*70)
print("\nNext steps:")
print("  1. Review causal graphs (causal_graphs_comparison.png)")
print("  2. Validate edges with domain experts")
print("  3. Proceed to Step 4: Hybrid Explanation Generation")
print("\nTo use in Step 4:")
print("  import networkx as nx")
print("  G = nx.read_gpickle('causal_graph.gpickle')")