In [1]:
# ============================================================================
# V2X SIMULATION ANALYSIS & ATTACK DETECTION NOTEBOOK
# ============================================================================
# Author: AI Assistant
# Date: 2024
# Description: Machine learning pipeline for V2X attack detection using
#              simulation data from highway, mixed, and urban scenarios
# ============================================================================

# %% [markdown]
# # V2X Simulation Analysis & Attack Detection
# 
# ## Overview
# This notebook analyzes V2X simulation data to detect malicious vehicles using machine learning.
# It processes data from multiple scenarios (highway, mixed, urban) and densities (50, 100, 150 vehicles/km).
# 
# ## Key Features:
# 1. **Network Statistics**: PDR, neighbor analysis, message statistics
# 2. **Feature Engineering**: 5-second window features (message counts, RSSI stats, neighbor features)
# 3. **ML Pipeline**: Random Forest with cross-validation and hyperparameter tuning
# 4. **Trust Scoring**: Exponential smoothing model based on neighbor consistency
# 5. **Visualization**: ROC curves, trust convergence plots, attack detection curves
# 6. **Export**: LaTeX tables, model saving, comprehensive results
# 
# ## Setup Instructions:
# 1. Update `ROOT` variable below to point to your dataset
# 2. Run all cells sequentially
# 3. Check `/mnt/data/` for output files and figures

# %% [markdown]
# ## 1. SETUP & CONFIGURATION

# %%
import os
import sys
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from datetime import datetime, timedelta
from pathlib import Path
import json
import pickle
from typing import Dict, List, Tuple, Optional, Any
import gc

# Machine Learning
from sklearn.model_selection import train_test_split, cross_val_score, StratifiedKFold, GridSearchCV
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import (accuracy_score, precision_score, recall_score, f1_score,
                             roc_auc_score, roc_curve, confusion_matrix, classification_report,
                             ConfusionMatrixDisplay)
from sklearn.pipeline import Pipeline

# Statistical Analysis
from scipy import stats
from scipy.signal import welch
import scipy.spatial.distance as dist

# Visualization
import matplotlib
matplotlib.use('Agg')  # For headless environments
plt.style.use('seaborn-v0_8-darkgrid')

# Progress bars
from tqdm.auto import tqdm

# Suppress warnings
warnings.filterwarnings('ignore')

# Set random seed for reproducibility
RANDOM_SEED = 42
np.random.seed(RANDOM_SEED)

# %% [markdown]
# ### Configuration Variables

# %%
# ============================================================================
# CONFIGURATION
# ============================================================================

# Set this to your dataset root directory
ROOT = "/home/jeanhuit/Documents/Workspace/simulation/results/"
# Alternative: Use symlink from /mnt/data/results to your actual data

# Output directories
FIG_DIR = Path("/home/jeanhuit/Documents/Workspace/simulation/figures")
OUTPUT_DIR = Path("/home/jeanhuit/Documents/Workspace/simulation/output")
MODEL_PATH = Path("/home/jeanhuit/Documents/Workspace/simulation/model")

# Create output directories if they don't exist
os.makedirs(FIG_DIR, exist_ok=True)
os.makedirs(OUTPUT_DIR, exist_ok=True)

# Analysis parameters
WINDOW_SIZE = 5  # seconds
SLIDING_STEP = 1  # second for sliding window

# ML parameters
TEST_SIZE = 0.3
CV_FOLDS = 5
RF_N_ESTIMATORS = 100

# Trust score parameters
TRUST_ALPHA = 0.3  # Exponential smoothing factor
CONSISTENCY_THRESHOLD = 0.7

# Scenarios and densities to process
SCENARIOS = ["highway", "mixed", "urban"]
DENSITIES = [50, 100, 150]
RUNS = ["run-1"]  # Can be extended to multiple runs

print(f"Configuration loaded:")
print(f"  Root directory: {ROOT}")
print(f"  Output directory: {OUTPUT_DIR}")
print(f"  Figure directory: {FIG_DIR}")
print(f"  Window size: {WINDOW_SIZE}s")
print(f"  Scenarios: {SCENARIOS}")
print(f"  Densities: {DENSITIES}")

# %% [markdown]
# ### Utility Functions

# %%
# ============================================================================
# UTILITY FUNCTIONS
# ============================================================================

def check_file_exists(filepath: str) -> bool:
    """Check if a file exists and print status."""
    exists = os.path.exists(filepath)
    if not exists:
        print(f"  Warning: File not found - {filepath}")
    return exists

def load_csv_safe(filepath: str, **kwargs) -> pd.DataFrame:
    """Load CSV file with error handling."""
    try:
        if os.path.exists(filepath):
            df = pd.read_csv(filepath, **kwargs)
            print(f"  Loaded: {os.path.basename(filepath)} ({len(df)} rows)")
            return df
        else:
            print(f"  Warning: File not found - {filepath}")
            return pd.DataFrame()
    except Exception as e:
        print(f"  Error loading {filepath}: {e}")
        return pd.DataFrame()

def save_figure(fig, filename: str, dpi: int = 300):
    """Save figure to FIG_DIR."""
    filepath = os.path.join(FIG_DIR, filename)
    fig.savefig(filepath, dpi=dpi, bbox_inches='tight')
    plt.close(fig)
    print(f"  Figure saved: {filepath}")

def save_dataframe(df: pd.DataFrame, filename: str):
    """Save dataframe to OUTPUT_DIR."""
    filepath = os.path.join(OUTPUT_DIR, filename)
    df.to_csv(filepath, index=False)
    print(f"  Data saved: {filepath} ({len(df)} rows)")

def export_latex_table(df: pd.DataFrame, filename: str, caption: str = "", label: str = ""):
    """Export dataframe as LaTeX table."""
    filepath = os.path.join(OUTPUT_DIR, filename)
    latex_str = df.to_latex(index=False, caption=caption, label=label)
    with open(filepath, 'w') as f:
        f.write(latex_str)
    print(f"  LaTeX table saved: {filepath}")

def compute_pdr(sent: int, received: int) -> float:
    """Compute Packet Delivery Ratio."""
    if sent == 0:
        return 0.0
    return received / sent

# %% [markdown]
# ## 2. DATA LOADING & NETWORK STATISTICS

# %%
# ============================================================================
# DATA LOADER CLASS
# ============================================================================

class V2XDataLoader:
    """Load and manage V2X simulation data from multiple scenarios."""
    
    def __init__(self, root_dir: str):
        self.root_dir = Path(root_dir)
        self.data = {}
        self.stats = {}
        
    def load_scenario(self, scenario: str, density: int, run: str = "run-1"):
        """Load all log files for a specific scenario, density, and run."""
        scenario_path = self.root_dir / scenario / f"density-{density}" / run
        
        if not scenario_path.exists():
            print(f"Warning: Path not found - {scenario_path}")
            return None
        
        print(f"\nLoading {scenario} (density={density}, run={run})...")
        
        # Load all CSV files
        files = {
            'bsm': load_csv_safe(scenario_path / "bsm_log.csv"),
            'rssi': load_csv_safe(scenario_path / "rssi_log.csv"),
            'neighbor': load_csv_safe(scenario_path / "neighbor_log.csv"),
            'sybil': load_csv_safe(scenario_path / "sybil_log.csv"),
            'replay': load_csv_safe(scenario_path / "replay_log.csv"),
            'jammer': load_csv_safe(scenario_path / "jammer_log.csv")
        }
        
        # Store data
        key = f"{scenario}_{density}_{run}"
        self.data[key] = files
        
        # Compute basic statistics
        self._compute_basic_stats(key, files)
        
        return files
    
    def _compute_basic_stats(self, key: str, files: Dict):
        """Compute basic network statistics."""
        stats = {}
        
        # BSM statistics
        bsm_df = files.get('bsm', pd.DataFrame())
        if not bsm_df.empty:
            stats['bsm_sent'] = len(bsm_df)
            stats['bsm_unique_senders'] = bsm_df['sender_id'].nunique() if 'sender_id' in bsm_df.columns else 0
            
            # If there's a receiver_id column
            if 'receiver_id' in bsm_df.columns:
                stats['bsm_received'] = bsm_df['receiver_id'].notna().sum()
                stats['pdr'] = compute_pdr(stats['bsm_sent'], stats['bsm_received'])
            else:
                stats['bsm_received'] = 0
                stats['pdr'] = 0
        
        # RSSI statistics
        rssi_df = files.get('rssi', pd.DataFrame())
        if not rssi_df.empty and 'rssi' in rssi_df.columns:
            stats['rssi_mean'] = rssi_df['rssi'].mean()
            stats['rssi_std'] = rssi_df['rssi'].std()
            stats['rssi_min'] = rssi_df['rssi'].min()
            stats['rssi_max'] = rssi_df['rssi'].max()
        
        # Neighbor statistics
        neighbor_df = files.get('neighbor', pd.DataFrame())
        if not neighbor_df.empty:
            if 'neighbor_count' in neighbor_df.columns:
                stats['avg_neighbors'] = neighbor_df['neighbor_count'].mean()
                stats['max_neighbors'] = neighbor_df['neighbor_count'].max()
            stats['unique_vehicles'] = self._count_unique_vehicles(neighbor_df)
        
        # Attack statistics
        attack_stats = self._compute_attack_stats(files)
        stats.update(attack_stats)
        
        self.stats[key] = stats
        
    def _count_unique_vehicles(self, neighbor_df: pd.DataFrame) -> int:
        """Count unique vehicles from neighbor logs."""
        vehicle_ids = set()
        
        # Check different possible column names
        for col in ['vehicle_id', 'node_id', 'sender_id']:
            if col in neighbor_df.columns:
                vehicle_ids.update(neighbor_df[col].unique())
        
        # Also check neighbor lists if present
        if 'neighbors' in neighbor_df.columns:
            for neighbors in neighbor_df['neighbors'].dropna():
                if isinstance(neighbors, str):
                    try:
                        neighbor_list = eval(neighbors) if '[' in neighbors else neighbors.split(',')
                        vehicle_ids.update([str(n).strip() for n in neighbor_list])
                    except:
                        pass
        
        return len(vehicle_ids)
    
    def _compute_attack_stats(self, files: Dict) -> Dict:
        """Compute attack-related statistics."""
        stats = {}
        
        attack_dfs = {
            'sybil': files.get('sybil'),
            'replay': files.get('replay'),
            'jammer': files.get('jammer')
        }
        
        total_attack_rows = 0
        unique_attackers = set()
        
        for attack_type, df in attack_dfs.items():
            if df is not None and not df.empty:
                rows = len(df)
                stats[f'{attack_type}_events'] = rows
                total_attack_rows += rows
                
                # Count unique attackers
                for col in ['attacker_id', 'malicious_id', 'vehicle_id']:
                    if col in df.columns:
                        unique_attackers.update(df[col].unique())
        
        stats['total_attack_events'] = total_attack_rows
        stats['unique_attackers'] = len(unique_attackers)
        
        return stats
    
    def get_all_scenarios(self, scenarios: List[str], densities: List[int], runs: List[str]):
        """Load all specified scenarios, densities, and runs."""
        all_data = {}
        
        for scenario in scenarios:
            for density in densities:
                for run in runs:
                    key = f"{scenario}_{density}_{run}"
                    if key not in self.data:
                        self.load_scenario(scenario, density, run)
                    all_data[key] = self.data.get(key, {})
        
        return all_data
    
    def get_consolidated_stats(self) -> pd.DataFrame:
        """Create a consolidated DataFrame of all statistics."""
        rows = []
        
        for key, stats in self.stats.items():
            # Parse key
            parts = key.split('_')
            if len(parts) >= 3:
                scenario = parts[0]
                density = int(parts[1])
                run = '_'.join(parts[2:])
                
                row = {
                    'scenario': scenario,
                    'density': density,
                    'run': run
                }
                row.update(stats)
                rows.append(row)
        
        return pd.DataFrame(rows)

# %%
# Initialize data loader
print("Initializing data loader...")
data_loader = V2XDataLoader(ROOT)

# Load all data
print("\n" + "="*60)
print("LOADING ALL SCENARIOS AND DENSITIES")
print("="*60)

all_data = data_loader.get_all_scenarios(SCENARIOS, DENSITIES, RUNS)

# Display loaded data summary
print(f"\nLoaded {len(all_data)} scenario-density-run combinations:")
for key in all_data.keys():
    print(f"  - {key}")

# %%
# Compute and display consolidated statistics
print("\n" + "="*60)
print("NETWORK STATISTICS SUMMARY")
print("="*60)

consolidated_stats = data_loader.get_consolidated_stats()

if not consolidated_stats.empty:
    # Display summary
    print("\nConsolidated Statistics:")
    print(consolidated_stats.to_string(index=False))
    
    # Save statistics
    save_dataframe(consolidated_stats, "consolidated_network_stats.csv")
    
    # Export as LaTeX table
    export_latex_table(
        consolidated_stats,
        "network_stats_latex.tex",
        caption="Network Statistics for All Scenarios and Densities",
        label="tab:network_stats"
    )
    
    # Create summary by scenario
    scenario_summary = consolidated_stats.groupby('scenario').agg({
        'bsm_sent': 'mean',
        'bsm_received': 'mean',
        'pdr': 'mean',
        'avg_neighbors': 'mean',
        'total_attack_events': 'mean',
        'unique_attackers': 'mean'
    }).round(2)
    
    print("\nScenario-wise Summary:")
    print(scenario_summary)
    
    # Create summary by density
    density_summary = consolidated_stats.groupby('density').agg({
        'bsm_sent': 'mean',
        'bsm_received': 'mean',
        'pdr': 'mean',
        'avg_neighbors': 'mean',
        'total_attack_events': 'mean'
    }).round(2)
    
    print("\nDensity-wise Summary:")
    print(density_summary)
    
else:
    print("No statistics available. Check data loading.")

# %% [markdown]
# ## 3. FEATURE ENGINEERING

# %%
# ============================================================================
# FEATURE ENGINEERING CLASS
# ============================================================================

class FeatureEngineer:
    """Extract features from V2X data in sliding time windows."""
    
    def __init__(self, window_size: int = 5, step: int = 1):
        self.window_size = window_size
        self.step = step
        
    def create_time_windows(self, df: pd.DataFrame, time_col: str = 'timestamp') -> List[Tuple]:
        """Create sliding time windows."""
        if df.empty:
            return []
        
        if time_col not in df.columns:
            print(f"Warning: Time column '{time_col}' not found in dataframe")
            return []
        
        # Ensure timestamp is numeric
        df = df.copy()
        df[time_col] = pd.to_numeric(df[time_col], errors='coerce')
        df = df.dropna(subset=[time_col])
        
        if df.empty:
            return []
        
        min_time = df[time_col].min()
        max_time = df[time_col].max()
        
        windows = []
        start = min_time
        while start <= max_time - self.window_size:
            end = start + self.window_size
            window_df = df[(df[time_col] >= start) & (df[time_col] < end)]
            if not window_df.empty:
                windows.append((start, end, window_df))
            start += self.step
        
        return windows
    
    def extract_bsm_features(self, bsm_df: pd.DataFrame, vehicle_id: str = None) -> Dict:
        """Extract BSM-based features for a vehicle or overall."""
        features = {}
        
        if bsm_df.empty:
            return self._get_empty_bsm_features()
        
        # Filter for specific vehicle if provided
        if vehicle_id is not None and 'sender_id' in bsm_df.columns:
            vehicle_bsm = bsm_df[bsm_df['sender_id'] == vehicle_id]
        else:
            vehicle_bsm = bsm_df
        
        # Message count features
        features['msg_count'] = len(vehicle_bsm)
        
        # Unique sender/receiver features
        if 'sender_id' in vehicle_bsm.columns:
            features['unique_senders'] = vehicle_bsm['sender_id'].nunique()
        
        if 'receiver_id' in vehicle_bsm.columns:
            features['unique_receivers'] = vehicle_bsm['receiver_id'].nunique()
        
        # Speed features (if available)
        speed_cols = [col for col in vehicle_bsm.columns if 'speed' in col.lower()]
        if speed_cols:
            speed_data = vehicle_bsm[speed_cols[0]]
            features.update({
                'speed_mean': speed_data.mean(),
                'speed_std': speed_data.std(),
                'speed_min': speed_data.min(),
                'speed_max': speed_data.max(),
                'speed_range': speed_data.max() - speed_data.min() if len(speed_data) > 1 else 0
            })
        
        # Temporal features
        if 'timestamp' in vehicle_bsm.columns:
            timestamps = vehicle_bsm['timestamp']
            if len(timestamps) > 1:
                intervals = np.diff(sorted(timestamps))
                features.update({
                    'msg_interval_mean': intervals.mean(),
                    'msg_interval_std': intervals.std(),
                    'msg_interval_min': intervals.min(),
                    'msg_interval_max': intervals.max()
                })
                
                # Spectral features (FFT of message intervals)
                if len(intervals) >= 8:  # Minimum for meaningful FFT
                    try:
                        fft_values = np.abs(np.fft.fft(intervals - intervals.mean()))
                        features['spectral_energy'] = np.sum(fft_values ** 2)
                        features['spectral_entropy'] = self._compute_spectral_entropy(fft_values)
                    except:
                        features['spectral_energy'] = 0
                        features['spectral_entropy'] = 0
        
        return features
    
    def extract_rssi_features(self, rssi_df: pd.DataFrame, vehicle_id: str = None) -> Dict:
        """Extract RSSI-based features."""
        features = {}
        
        if rssi_df.empty:
            return self._get_empty_rssi_features()
        
        # Filter for specific vehicle if provided
        if vehicle_id is not None:
            id_cols = [col for col in rssi_df.columns if 'id' in col or 'vehicle' in col]
            if id_cols:
                vehicle_rssi = rssi_df[rssi_df[id_cols[0]] == vehicle_id]
            else:
                vehicle_rssi = rssi_df
        else:
            vehicle_rssi = rssi_df
        
        if 'rssi' not in vehicle_rssi.columns or vehicle_rssi.empty:
            return self._get_empty_rssi_features()
        
        rssi_values = vehicle_rssi['rssi'].dropna()
        
        if len(rssi_values) == 0:
            return self._get_empty_rssi_features()
        
        # Basic statistics
        features.update({
            'rssi_mean': rssi_values.mean(),
            'rssi_std': rssi_values.std(),
            'rssi_min': rssi_values.min(),
            'rssi_max': rssi_values.max(),
            'rssi_range': rssi_values.max() - rssi_values.min(),
            'rssi_variance': rssi_values.var()
        })
        
        # Advanced statistics
        features.update({
            'rssi_skewness': stats.skew(rssi_values) if len(rssi_values) > 2 else 0,
            'rssi_kurtosis': stats.kurtosis(rssi_values) if len(rssi_values) > 3 else 0,
            'rssi_q1': np.percentile(rssi_values, 25),
            'rssi_q3': np.percentile(rssi_values, 75),
            'rss_iqr': np.percentile(rssi_values, 75) - np.percentile(rssi_values, 25)
        })
        
        # Trend analysis (if enough data points)
        if len(rssi_values) >= 3 and 'timestamp' in vehicle_rssi.columns:
            time_sorted = vehicle_rssi.sort_values('timestamp')
            rssi_sorted = time_sorted['rssi'].values
            
            # Linear trend
            try:
                x = np.arange(len(rssi_sorted))
                slope, _, _, _, _ = stats.linregress(x, rssi_sorted)
                features['rssi_trend_slope'] = slope
                
                # Moving statistics
                if len(rssi_sorted) >= 5:
                    window = min(5, len(rssi_sorted))
                    moving_avg = np.convolve(rssi_sorted, np.ones(window)/window, mode='valid')
                    features['rssi_moving_avg_change'] = moving_avg[-1] - moving_avg[0] if len(moving_avg) > 1 else 0
            except:
                features['rssi_trend_slope'] = 0
                features['rssi_moving_avg_change'] = 0
        
        return features
    
    def extract_neighbor_features(self, neighbor_df: pd.DataFrame, vehicle_id: str = None) -> Dict:
        """Extract neighbor-based features."""
        features = {}
        
        if neighbor_df.empty:
            return self._get_empty_neighbor_features()
        
        # Filter for specific vehicle if provided
        if vehicle_id is not None:
            id_cols = [col for col in neighbor_df.columns if 'id' in col or 'vehicle' in col]
            if id_cols:
                vehicle_neighbors = neighbor_df[neighbor_df[id_cols[0]] == vehicle_id]
            else:
                vehicle_neighbors = neighbor_df
        else:
            vehicle_neighbors = neighbor_df
        
        if vehicle_neighbors.empty:
            return self._get_empty_neighbor_features()
        
        # Count features
        if 'neighbor_count' in vehicle_neighbors.columns:
            counts = vehicle_neighbors['neighbor_count'].dropna()
            if len(counts) > 0:
                features.update({
                    'neighbor_count_mean': counts.mean(),
                    'neighbor_count_std': counts.std(),
                    'neighbor_count_min': counts.min(),
                    'neighbor_count_max': counts.max(),
                    'neighbor_count_change': counts.iloc[-1] - counts.iloc[0] if len(counts) > 1 else 0
                })
        
        # Neighbor consistency (if neighbor lists available)
        if 'neighbors' in vehicle_neighbors.columns:
            neighbor_lists = []
            for neighbors in vehicle_neighbors['neighbors'].dropna():
                if isinstance(neighbors, str):
                    try:
                        if '[' in neighbors:
                            neighbor_list = eval(neighbors)
                        else:
                            neighbor_list = neighbors.split(',')
                        neighbor_lists.append(set(str(n).strip() for n in neighbor_list))
                    except:
                        continue
            
            if len(neighbor_lists) >= 2:
                # Jaccard similarity between consecutive neighbor lists
                jaccard_scores = []
                for i in range(len(neighbor_lists) - 1):
                    if neighbor_lists[i] and neighbor_lists[i+1]:
                        intersection = neighbor_lists[i].intersection(neighbor_lists[i+1])
                        union = neighbor_lists[i].union(neighbor_lists[i+1])
                        jaccard = len(intersection) / len(union) if union else 0
                        jaccard_scores.append(jaccard)
                
                if jaccard_scores:
                    features.update({
                        'neighbor_jaccard_mean': np.mean(jaccard_scores),
                        'neighbor_jaccard_std': np.std(jaccard_scores),
                        'neighbor_jaccard_min': np.min(jaccard_scores),
                        'neighbor_jaccard_max': np.max(jaccard_scores)
                    })
        
        # Neighbor duration (if timestamps available)
        if 'timestamp' in vehicle_neighbors.columns and 'neighbor_count' in vehicle_neighbors.columns:
            timestamps = vehicle_neighbors['timestamp'].values
            counts = vehicle_neighbors['neighbor_count'].values
            
            if len(timestamps) > 1:
                # Rate of neighbor change
                time_diff = np.diff(timestamps)
                count_diff = np.diff(counts)
                valid_mask = time_diff > 0
                
                if np.any(valid_mask):
                    change_rate = count_diff[valid_mask] / time_diff[valid_mask]
                    features['neighbor_change_rate_mean'] = np.mean(np.abs(change_rate))
                    features['neighbor_change_rate_std'] = np.std(change_rate)
        
        return features
    
    def extract_combined_features(self, data_dict: Dict, window_start: float, 
                                 window_end: float, vehicle_id: str = None) -> Dict:
        """Extract all features for a time window."""
        features = {
            'window_start': window_start,
            'window_end': window_end,
            'vehicle_id': vehicle_id if vehicle_id else 'global'
        }
        
        # Filter data for this window
        window_data = {}
        for data_type, df in data_dict.items():
            if df is not None and not df.empty and 'timestamp' in df.columns:
                window_df = df[(df['timestamp'] >= window_start) & (df['timestamp'] < window_end)]
                window_data[data_type] = window_df
            else:
                window_data[data_type] = pd.DataFrame()
        
        # Extract BSM features
        bsm_features = self.extract_bsm_features(window_data.get('bsm', pd.DataFrame()), vehicle_id)
        features.update({f'bsm_{k}': v for k, v in bsm_features.items()})
        
        # Extract RSSI features
        rssi_features = self.extract_rssi_features(window_data.get('rssi', pd.DataFrame()), vehicle_id)
        features.update({f'rssi_{k}': v for k, v in rssi_features.items()})
        
        # Extract neighbor features
        neighbor_features = self.extract_neighbor_features(window_data.get('neighbor', pd.DataFrame()), vehicle_id)
        features.update({f'neighbor_{k}': v for k, v in neighbor_features.items()})
        
        # Combined features
        if vehicle_id:
            # Vehicle-specific consistency score
            rssi_consistency = 1.0 / (1.0 + features.get('rssi_std', 1.0))
            neighbor_consistency = features.get('neighbor_jaccard_mean', 0.5)
            features['consistency_score'] = (rssi_consistency + neighbor_consistency) / 2
        
        return features
    
    def extract_features_all_windows(self, data_dict: Dict, vehicle_ids: List[str] = None) -> pd.DataFrame:
        """Extract features for all vehicles across all time windows."""
        all_features = []
        
        # Get global time range from BSM data
        bsm_df = data_dict.get('bsm', pd.DataFrame())
        if bsm_df.empty:
            print("Warning: No BSM data available for feature extraction")
            return pd.DataFrame()
        
        # Create time windows
        windows = self.create_time_windows(bsm_df)
        print(f"Created {len(windows)} time windows of {self.window_size}s each")
        
        # Extract features for each window
        with tqdm(total=len(windows), desc="Extracting window features") as pbar:
            for window_start, window_end, window_df in windows:
                # Global features (aggregated across all vehicles)
                global_features = self.extract_combined_features(
                    data_dict, window_start, window_end, None
                )
                all_features.append(global_features)
                
                # Per-vehicle features
                if vehicle_ids:
                    for vehicle_id in vehicle_ids:
                        vehicle_features = self.extract_combined_features(
                            data_dict, window_start, window_end, vehicle_id
                        )
                        all_features.append(vehicle_features)
                
                pbar.update(1)
        
        # Convert to DataFrame
        features_df = pd.DataFrame(all_features)
        
        # Fill NaN values
        features_df = features_df.fillna(0)
        
        return features_df
    
    def _get_empty_bsm_features(self) -> Dict:
        """Return empty BSM features dictionary."""
        return {
            'msg_count': 0,
            'unique_senders': 0,
            'unique_receivers': 0,
            'speed_mean': 0,
            'speed_std': 0,
            'speed_min': 0,
            'speed_max': 0,
            'speed_range': 0,
            'msg_interval_mean': 0,
            'msg_interval_std': 0,
            'msg_interval_min': 0,
            'msg_interval_max': 0,
            'spectral_energy': 0,
            'spectral_entropy': 0
        }
    
    def _get_empty_rssi_features(self) -> Dict:
        """Return empty RSSI features dictionary."""
        return {
            'rssi_mean': 0,
            'rssi_std': 0,
            'rssi_min': 0,
            'rssi_max': 0,
            'rssi_range': 0,
            'rssi_variance': 0,
            'rssi_skewness': 0,
            'rssi_kurtosis': 0,
            'rssi_q1': 0,
            'rssi_q3': 0,
            'rss_iqr': 0,
            'rssi_trend_slope': 0,
            'rssi_moving_avg_change': 0
        }
    
    def _get_empty_neighbor_features(self) -> Dict:
        """Return empty neighbor features dictionary."""
        return {
            'neighbor_count_mean': 0,
            'neighbor_count_std': 0,
            'neighbor_count_min': 0,
            'neighbor_count_max': 0,
            'neighbor_count_change': 0,
            'neighbor_jaccard_mean': 0,
            'neighbor_jaccard_std': 0,
            'neighbor_jaccard_min': 0,
            'neighbor_jaccard_max': 0,
            'neighbor_change_rate_mean': 0,
            'neighbor_change_rate_std': 0
        }
    
    def _compute_spectral_entropy(self, fft_values: np.ndarray) -> float:
        """Compute spectral entropy from FFT values."""
        if len(fft_values) == 0 or np.sum(fft_values) == 0:
            return 0
        
        # Normalize to probability distribution
        prob = fft_values / np.sum(fft_values)
        
        # Remove zeros for log calculation
        prob = prob[prob > 0]
        
        # Compute entropy
        entropy = -np.sum(prob * np.log2(prob))
        
        # Normalize by maximum entropy
        max_entropy = np.log2(len(prob))
        
        return entropy / max_entropy if max_entropy > 0 else 0

# %%
# Initialize feature engineer
print("\nInitializing feature engineer...")
feature_engineer = FeatureEngineer(window_size=WINDOW_SIZE, step=SLIDING_STEP)

# Extract features for each scenario
print("\n" + "="*60)
print("FEATURE EXTRACTION")
print("="*60)

all_features = {}
vehicle_id_cache = {}

for key, data in all_data.items():
    print(f"\nProcessing {key}...")
    
    # Get vehicle IDs from BSM data
    bsm_df = data.get('bsm', pd.DataFrame())
    if not bsm_df.empty and 'sender_id' in bsm_df.columns:
        vehicle_ids = bsm_df['sender_id'].unique()[:20]  # Limit to first 20 for performance
        vehicle_id_cache[key] = vehicle_ids
    else:
        vehicle_ids = None
    
    # Extract features
    features_df = feature_engineer.extract_features_all_windows(data, vehicle_ids)
    
    if not features_df.empty:
        all_features[key] = features_df
        print(f"  Extracted {len(features_df)} feature rows")
        
        # Save features for this scenario
        save_dataframe(features_df, f"features_{key}.csv")
    else:
        print(f"  No features extracted for {key}")

# Combine all features
if all_features:
    combined_features = pd.concat(all_features.values(), ignore_index=True)
    print(f"\nTotal feature rows: {len(combined_features)}")
    print(f"Total feature columns: {len(combined_features.columns)}")
    
    # Display feature summary
    print("\nFeature columns:")
    for i, col in enumerate(combined_features.columns):
        print(f"  {i+1:3d}. {col}")
    
    # Save combined features
    save_dataframe(combined_features, "all_features_combined.csv")
    
else:
    print("Warning: No features were extracted")
    combined_features = pd.DataFrame()

# %% [markdown]
# ## 4. LABEL GENERATION

# %%
# ============================================================================
# LABEL GENERATOR CLASS
# ============================================================================

class LabelGenerator:
    """Generate labels for ML training from attack logs."""
    
    def __init__(self, window_size: int = 5):
        self.window_size = window_size
    
    def generate_labels(self, data_dict: Dict, windows: List[Tuple]) -> pd.DataFrame:
        """Generate labels for time windows based on attack logs."""
        labels = []
        
        # Get attack dataframes
        attack_dfs = {
            'sybil': data_dict.get('sybil', pd.DataFrame()),
            'replay': data_dict.get('replay', pd.DataFrame()),
            'jammer': data_dict.get('jammer', pd.DataFrame())
        }
        
        # Process each window
        for window_start, window_end, window_df in windows:
            label_row = {
                'window_start': window_start,
                'window_end': window_end
            }
            
            # Check each attack type
            attacks_present = []
            
            for attack_type, attack_df in attack_dfs.items():
                if attack_df.empty:
                    continue
                
                # Check if attack occurred in this window
                if 'timestamp' in attack_df.columns:
                    window_attacks = attack_df[
                        (attack_df['timestamp'] >= window_start) & 
                        (attack_df['timestamp'] < window_end)
                    ]
                    
                    if not window_attacks.empty:
                        attacks_present.append(attack_type)
                        
                        # Get attacker IDs if available
                        for id_col in ['attacker_id', 'malicious_id', 'vehicle_id']:
                            if id_col in window_attacks.columns:
                                attackers = window_attacks[id_col].unique()
                                label_row[f'{attack_type}_attackers'] = ','.join(map(str, attackers))
                                break
            
            # Create binary label (malicious vs benign)
            label_row['is_malicious'] = 1 if attacks_present else 0
            
            # Create multi-class label
            if attacks_present:
                if len(attacks_present) == 1:
                    label_row['attack_type'] = attacks_present[0]
                else:
                    label_row['attack_type'] = 'mixed'
            else:
                label_row['attack_type'] = 'benign'
            
            # Count total attacks in window
            label_row['attack_count'] = len(attacks_present)
            
            labels.append(label_row)
        
        return pd.DataFrame(labels)
    
    def generate_vehicle_labels(self, data_dict: Dict, windows: List[Tuple], 
                               vehicle_ids: List[str]) -> pd.DataFrame:
        """Generate per-vehicle labels for time windows."""
        all_labels = []
        
        # Get attack dataframes
        attack_dfs = {
            'sybil': data_dict.get('sybil', pd.DataFrame()),
            'replay': data_dict.get('replay', pd.DataFrame()),
            'jammer': data_dict.get('jammer', pd.DataFrame())
        }
        
        # Process each vehicle in each window
        for vehicle_id in tqdm(vehicle_ids, desc="Generating vehicle labels"):
            for window_start, window_end, window_df in windows:
                label_row = {
                    'window_start': window_start,
                    'window_end': window_end,
                    'vehicle_id': vehicle_id
                }
                
                # Check if this vehicle was attacking in this window
                is_attacker = False
                attack_types = []
                
                for attack_type, attack_df in attack_dfs.items():
                    if attack_df.empty:
                        continue
                    
                    # Check if this vehicle appears as attacker
                    for id_col in ['attacker_id', 'malicious_id', 'vehicle_id']:
                        if id_col in attack_df.columns:
                            # Filter for this vehicle and time window
                            vehicle_attacks = attack_df[
                                (attack_df[id_col] == vehicle_id) &
                                (attack_df['timestamp'] >= window_start) & 
                                (attack_df['timestamp'] < window_end)
                            ]
                            
                            if not vehicle_attacks.empty:
                                is_attacker = True
                                attack_types.append(attack_type)
                                break
                
                # Create labels
                label_row['is_malicious'] = 1 if is_attacker else 0
                label_row['attack_type'] = ','.join(attack_types) if attack_types else 'benign'
                label_row['attack_count'] = len(attack_types)
                
                all_labels.append(label_row)
        
        return pd.DataFrame(all_labels)
    
    def merge_features_labels(self, features_df: pd.DataFrame, labels_df: pd.DataFrame) -> pd.DataFrame:
        """Merge features and labels based on window times and vehicle IDs."""
        if features_df.empty or labels_df.empty:
            return pd.DataFrame()
        
        # Determine merge columns
        merge_cols = ['window_start', 'window_end']
        if 'vehicle_id' in features_df.columns and 'vehicle_id' in labels_df.columns:
            merge_cols.append('vehicle_id')
        
        # Merge
        merged_df = pd.merge(
            features_df, 
            labels_df, 
            on=merge_cols,
            how='left'  # Keep all features, fill missing labels
        )
        
        # Fill missing labels (assume benign)
        if 'is_malicious' in merged_df.columns:
            merged_df['is_malicious'] = merged_df['is_malicious'].fillna(0)
        
        if 'attack_type' in merged_df.columns:
            merged_df['attack_type'] = merged_df['attack_type'].fillna('benign')
        
        if 'attack_count' in merged_df.columns:
            merged_df['attack_count'] = merged_df['attack_count'].fillna(0)
        
        return merged_df

# %%
# Generate labels
print("\n" + "="*60)
print("LABEL GENERATION")
print("="*60)

label_generator = LabelGenerator(window_size=WINDOW_SIZE)
all_labels = {}

for key, data in all_data.items():
    print(f"\nGenerating labels for {key}...")
    
    # Get BSM data for time windows
    bsm_df = data.get('bsm', pd.DataFrame())
    if bsm_df.empty:
        print(f"  No BSM data for {key}")
        continue
    
    # Create time windows
    windows = feature_engineer.create_time_windows(bsm_df)
    if not windows:
        print(f"  No time windows for {key}")
        continue
    
    # Generate global labels
    global_labels = label_generator.generate_labels(data, windows)
    
    # Generate vehicle labels if vehicle IDs available
    vehicle_ids = vehicle_id_cache.get(key)
    if vehicle_ids is not None:
        vehicle_labels = label_generator.generate_vehicle_labels(data, windows, vehicle_ids)
        
        # Combine labels
        labels_df = pd.concat([global_labels, vehicle_labels], ignore_index=True)
    else:
        labels_df = global_labels
    
    if not labels_df.empty:
        all_labels[key] = labels_df
        print(f"  Generated {len(labels_df)} label rows")
        print(f"  Malicious windows: {labels_df['is_malicious'].sum()} ({labels_df['is_malicious'].mean():.1%})")
        
        # Save labels
        save_dataframe(labels_df, f"labels_{key}.csv")
    else:
        print(f"  No labels generated for {key}")

# Combine all labels
if all_labels:
    combined_labels = pd.concat(all_labels.values(), ignore_index=True)
    print(f"\nTotal label rows: {len(combined_labels)}")
    
    # Display label distribution
    print("\nLabel Distribution:")
    print(f"  Benign windows: {len(combined_labels[combined_labels['is_malicious'] == 0])}")
    print(f"  Malicious windows: {len(combined_labels[combined_labels['is_malicious'] == 1])}")
    
    if 'attack_type' in combined_labels.columns:
        print("\nAttack Type Distribution:")
        print(combined_labels['attack_type'].value_counts())
    
    # Save combined labels
    save_dataframe(combined_labels, "all_labels_combined.csv")
    
else:
    print("Warning: No labels were generated")
    combined_labels = pd.DataFrame()

# %% [markdown]
# ## 5. EXPLORATORY DATA ANALYSIS

# %%
# ============================================================================
# EXPLORATORY DATA ANALYSIS
# ============================================================================

print("\n" + "="*60)
print("EXPLORATORY DATA ANALYSIS")
print("="*60)

# Merge features and labels for analysis
if not combined_features.empty and not combined_labels.empty:
    print("\nMerging features and labels...")
    merged_data = label_generator.merge_features_labels(combined_features, combined_labels)
    
    print(f"Merged dataset shape: {merged_data.shape}")
    print(f"Columns: {len(merged_data.columns)}")
    
    # Save merged data
    save_dataframe(merged_data, "merged_features_labels.csv")
    
    # Basic statistics
    print("\nDataset Statistics:")
    print(f"Total samples: {len(merged_data)}")
    
    if 'is_malicious' in merged_data.columns:
        malicious_count = merged_data['is_malicious'].sum()
        malicious_percent = malicious_count / len(merged_data) * 100
        print(f"Malicious samples: {malicious_count} ({malicious_percent:.2f}%)")
    
    # Feature statistics
    print("\nFeature Statistics (first 20 numeric features):")
    numeric_cols = merged_data.select_dtypes(include=[np.number]).columns
    for col in numeric_cols[:20]:
        if col not in ['window_start', 'window_end', 'is_malicious', 'attack_count']:
            print(f"  {col:30s}: mean={merged_data[col].mean():8.4f}, "
                  f"std={merged_data[col].std():8.4f}, "
                  f"range=[{merged_data[col].min():8.4f}, {merged_data[col].max():8.4f}]")
    
    # Correlation analysis
    print("\nComputing feature correlations with labels...")
    if 'is_malicious' in merged_data.columns:
        # Select numeric features only
        feature_cols = [col for col in numeric_cols 
                       if col not in ['window_start', 'window_end', 'is_malicious', 'attack_count']
                       and not col.startswith('attack')]
        
        if feature_cols:
            correlations = []
            for col in feature_cols:
                corr = merged_data[col].corr(merged_data['is_malicious'])
                correlations.append((col, corr))
            
            # Sort by absolute correlation
            correlations.sort(key=lambda x: abs(x[1]), reverse=True)
            
            print("\nTop 10 features correlated with malicious label:")
            for col, corr in correlations[:10]:
                print(f"  {col:30s}: {corr:8.4f}")
    
    # Create correlation heatmap for top features
    if 'is_malicious' in merged_data.columns and len(feature_cols) >= 5:
        print("\nCreating correlation heatmap...")
        top_features = [col for col, _ in correlations[:15]] + ['is_malicious']
        corr_matrix = merged_data[top_features].corr()
        
        fig, ax = plt.subplots(figsize=(12, 10))
        sns.heatmap(corr_matrix, annot=True, fmt='.2f', cmap='coolwarm', 
                   center=0, square=True, ax=ax)
        ax.set_title('Feature Correlation Matrix (Top 15 Features)', fontsize=14)
        save_figure(fig, "feature_correlation_heatmap.png")
        
        # Plot feature distributions by class
        print("\nCreating feature distribution plots...")
        if len(feature_cols) >= 4:
            fig, axes = plt.subplots(2, 2, figsize=(15, 10))
            axes = axes.flatten()
            
            for i, col in enumerate(feature_cols[:4]):
                # Separate by class
                benign_data = merged_data[merged_data['is_malicious'] == 0][col].dropna()
                malicious_data = merged_data[merged_data['is_malicious'] == 1][col].dropna()
                
                # Plot histograms
                axes[i].hist(benign_data, alpha=0.5, label='Benign', bins=30, density=True)
                axes[i].hist(malicious_data, alpha=0.5, label='Malicious', bins=30, density=True)
                axes[i].set_xlabel(col)
                axes[i].set_ylabel('Density')
                axes[i].set_title(f'Distribution of {col}')
                axes[i].legend()
                axes[i].grid(True, alpha=0.3)
            
            plt.tight_layout()
            save_figure(fig, "feature_distributions_by_class.png")
    
else:
    print("Warning: Cannot perform EDA - features or labels are empty")
    merged_data = pd.DataFrame()

# %% [markdown]
# ## 6. ML PIPELINE & TRAINING

# %%
# ============================================================================
# MACHINE LEARNING PIPELINE
# ============================================================================

print("\n" + "="*60)
print("MACHINE LEARNING PIPELINE")
print("="*60)

if merged_data.empty or 'is_malicious' not in merged_data.columns:
    print("Error: No labeled data available for ML training")
    print("Skipping ML pipeline...")
else:
    # Prepare data for ML
    print("\nPreparing data for ML...")
    
    # Select features (exclude metadata and labels)
    exclude_cols = ['window_start', 'window_end', 'vehicle_id', 
                   'is_malicious', 'attack_type', 'attack_count',
                   'sybil_attackers', 'replay_attackers', 'jammer_attackers']
    
    feature_cols = [col for col in merged_data.columns 
                   if col not in exclude_cols 
                   and merged_data[col].dtype in [np.int64, np.float64]]
    
    print(f"Selected {len(feature_cols)} features for ML")
    
    # Create feature matrix and labels
    X = merged_data[feature_cols].fillna(0).values
    y = merged_data['is_malicious'].values
    
    print(f"Feature matrix shape: {X.shape}")
    print(f"Label distribution: {np.bincount(y.astype(int))}")
    
    # Check class balance
    class_counts = np.bincount(y.astype(int))
    if len(class_counts) > 1:
        print(f"Class balance: {class_counts[0]} benign vs {class_counts[1]} malicious")
        print(f"Imbalance ratio: {class_counts[0]/class_counts[1]:.2f}:1")
    
    # Train-test split
    print("\nSplitting data into train/test sets...")
    X_train, X_test, y_train, y_test = train_test_split(
        X, y, test_size=TEST_SIZE, random_state=RANDOM_SEED, stratify=y
    )
    
    print(f"Training set: {X_train.shape[0]} samples")
    print(f"Testing set: {X_test.shape[0]} samples")
    
    # Create and train Random Forest model
    print("\n" + "-"*40)
    print("RANDOM FOREST CLASSIFIER")
    print("-"*40)
    
    # Define pipeline
    pipeline = Pipeline([
        ('scaler', StandardScaler()),
        ('classifier', RandomForestClassifier(
            n_estimators=RF_N_ESTIMATORS,
            random_state=RANDOM_SEED,
            n_jobs=-1,
            class_weight='balanced'  # Handle class imbalance
        ))
    ])
    
    # Cross-validation
    print("\nPerforming cross-validation...")
    cv_scores = cross_val_score(pipeline, X_train, y_train, 
                               cv=CV_FOLDS, scoring='f1_weighted', n_jobs=-1)
    
    print(f"Cross-validation F1 scores: {cv_scores}")
    print(f"Mean CV F1: {cv_scores.mean():.4f} (+/- {cv_scores.std() * 2:.4f})")
    
    # Hyperparameter tuning
    print("\nPerforming hyperparameter tuning...")
    param_grid = {
        'classifier__n_estimators': [50, 100, 200],
        'classifier__max_depth': [None, 10, 20, 30],
        'classifier__min_samples_split': [2, 5, 10],
        'classifier__min_samples_leaf': [1, 2, 4]
    }
    
    grid_search = GridSearchCV(
        pipeline, param_grid, 
        cv=CV_FOLDS, 
        scoring='f1_weighted',
        n_jobs=-1,
        verbose=1
    )
    
    grid_search.fit(X_train, y_train)
    
    print(f"\nBest parameters: {grid_search.best_params_}")
    print(f"Best CV score: {grid_search.best_score_:.4f}")
    
    # Train final model with best parameters
    print("\nTraining final model...")
    best_model = grid_search.best_estimator_
    best_model.fit(X_train, y_train)
    
    # Save model
    print(f"\nSaving model to {MODEL_PATH}...")
    with open(MODEL_PATH, 'wb') as f:
        pickle.dump(best_model, f)
    
    print("Model saved successfully!")
    
    # Feature importance
    print("\nComputing feature importance...")
    if hasattr(best_model.named_steps['classifier'], 'feature_importances_'):
        importances = best_model.named_steps['classifier'].feature_importances_
        indices = np.argsort(importances)[::-1]
        
        print("\nTop 20 most important features:")
        for i in range(min(20, len(feature_cols))):
            idx = indices[i]
            print(f"  {i+1:2d}. {feature_cols[idx]:30s}: {importances[idx]:.6f}")
        
        # Plot feature importance
        fig, ax = plt.subplots(figsize=(12, 8))
        top_n = min(20, len(feature_cols))
        ax.barh(range(top_n), importances[indices[:top_n]][::-1])
        ax.set_yticks(range(top_n))
        ax.set_yticklabels([feature_cols[i] for i in indices[:top_n]][::-1])
        ax.set_xlabel('Feature Importance')
        ax.set_title('Top 20 Feature Importances')
        plt.tight_layout()
        save_figure(fig, "feature_importance.png")
    
    # Make predictions
    print("\nMaking predictions on test set...")
    y_pred = best_model.predict(X_test)
    y_pred_proba = best_model.predict_proba(X_test)[:, 1] if hasattr(best_model, 'predict_proba') else None

# %% [markdown]
# ## 7. MODEL EVALUATION

# %%
# ============================================================================
# MODEL EVALUATION
# ============================================================================

print("\n" + "="*60)
print("MODEL EVALUATION")
print("="*60)

if 'y_test' in locals() and 'y_pred' in locals():
    # Compute metrics
    print("\nComputing evaluation metrics...")
    
    accuracy = accuracy_score(y_test, y_pred)
    precision = precision_score(y_test, y_pred, average='weighted', zero_division=0)
    recall = recall_score(y_test, y_pred, average='weighted', zero_division=0)
    f1 = f1_score(y_test, y_pred, average='weighted', zero_division=0)
    
    print(f"Accuracy:  {accuracy:.4f}")
    print(f"Precision: {precision:.4f}")
    print(f"Recall:    {recall:.4f}")
    print(f"F1-Score:  {f1:.4f}")
    
    # Detailed classification report
    print("\nClassification Report:")
    print(classification_report(y_test, y_pred, target_names=['Benign', 'Malicious'], zero_division=0))
    
    # Confusion matrix
    print("\nConfusion Matrix:")
    cm = confusion_matrix(y_test, y_pred)
    print(cm)
    
    # Plot confusion matrix
    fig, ax = plt.subplots(figsize=(8, 6))
    disp = ConfusionMatrixDisplay(confusion_matrix=cm, display_labels=['Benign', 'Malicious'])
    disp.plot(cmap='Blues', ax=ax)
    ax.set_title('Confusion Matrix')
    save_figure(fig, "confusion_matrix.png")
    
    # ROC Curve and AUC (if probability estimates available)
    if y_pred_proba is not None:
        print("\nComputing ROC/AUC...")
        fpr, tpr, thresholds = roc_curve(y_test, y_pred_proba)
        auc_score = roc_auc_score(y_test, y_pred_proba)
        
        print(f"AUC Score: {auc_score:.4f}")
        
        # Plot ROC curve
        fig, ax = plt.subplots(figsize=(8, 6))
        ax.plot(fpr, tpr, 'b-', label=f'ROC curve (AUC = {auc_score:.3f})')
        ax.plot([0, 1], [0, 1], 'r--', label='Random classifier')
        ax.set_xlabel('False Positive Rate')
        ax.set_ylabel('True Positive Rate')
        ax.set_title('Receiver Operating Characteristic (ROC) Curve')
        ax.legend(loc='lower right')
        ax.grid(True, alpha=0.3)
        save_figure(fig, "roc_curve.png")
    
    # Create metrics dataframe
    metrics_df = pd.DataFrame({
        'Metric': ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'AUC'],
        'Value': [accuracy, precision, recall, f1, auc_score if 'auc_score' in locals() else np.nan]
    })
    
    # Save metrics
    save_dataframe(metrics_df, "ml_metrics.csv")
    
    # Export metrics as LaTeX table
    export_latex_table(
        metrics_df.round(4),
        "ml_metrics_latex.tex",
        caption="Machine Learning Performance Metrics",
        label="tab:ml_metrics"
    )
    
    # Per-class metrics
    if len(np.unique(y_test)) > 1:
        per_class_metrics = []
        classes = ['Benign', 'Malicious']
        
        for i, class_name in enumerate(classes):
            if i < len(np.unique(y_test)):
                # Create binary labels for this class
                y_test_binary = (y_test == i).astype(int)
                y_pred_binary = (y_pred == i).astype(int)
                
                if len(np.unique(y_test_binary)) > 1:
                    class_accuracy = accuracy_score(y_test_binary, y_pred_binary)
                    class_precision = precision_score(y_test_binary, y_pred_binary, zero_division=0)
                    class_recall = recall_score(y_test_binary, y_pred_binary, zero_division=0)
                    class_f1 = f1_score(y_test_binary, y_pred_binary, zero_division=0)
                    
                    per_class_metrics.append({
                        'Class': class_name,
                        'Accuracy': class_accuracy,
                        'Precision': class_precision,
                        'Recall': class_recall,
                        'F1-Score': class_f1
                    })
        
        if per_class_metrics:
            per_class_df = pd.DataFrame(per_class_metrics)
            print("\nPer-class Metrics:")
            print(per_class_df.to_string(index=False))
            
            save_dataframe(per_class_df, "per_class_metrics.csv")
            
            # Export as LaTeX
            export_latex_table(
                per_class_df.round(4),
                "per_class_metrics_latex.tex",
                caption="Per-class Performance Metrics",
                label="tab:per_class_metrics"
            )
    
    # Scenario-wise evaluation (if scenario information is available)
    print("\n" + "-"*40)
    print("SCENARIO-WISE EVALUATION")
    print("-"*40)
    
    # Try to extract scenario from data
    if 'scenario' in merged_data.columns:
        # We need to map test indices back to original data
        test_indices = merged_data.index[merged_data.index.isin(
            merged_data.sample(frac=TEST_SIZE, random_state=RANDOM_SEED).index
        )]
        
        scenario_results = []
        for scenario in merged_data['scenario'].unique():
            scenario_mask = merged_data['scenario'] == scenario
            scenario_test_mask = scenario_mask & merged_data.index.isin(test_indices)
            
            if scenario_test_mask.any():
                scenario_y_true = merged_data.loc[scenario_test_mask, 'is_malicious'].values
                scenario_y_pred = y_pred[scenario_test_mask[scenario_test_mask].index]
                
                scenario_accuracy = accuracy_score(scenario_y_true, scenario_y_pred)
                scenario_f1 = f1_score(scenario_y_true, scenario_y_pred, average='weighted', zero_division=0)
                
                scenario_results.append({
                    'Scenario': scenario,
                    'Samples': len(scenario_y_true),
                    'Accuracy': scenario_accuracy,
                    'F1-Score': scenario_f1
                })
        
        if scenario_results:
            scenario_df = pd.DataFrame(scenario_results)
            print("\nScenario-wise Performance:")
            print(scenario_df.to_string(index=False))
            
            save_dataframe(scenario_df, "scenario_metrics.csv")
            
            # Plot scenario comparison
            fig, ax = plt.subplots(figsize=(10, 6))
            x = range(len(scenario_df))
            width = 0.35
            
            ax.bar([i - width/2 for i in x], scenario_df['Accuracy'], width, label='Accuracy')
            ax.bar([i + width/2 for i in x], scenario_df['F1-Score'], width, label='F1-Score')
            
            ax.set_xlabel('Scenario')
            ax.set_ylabel('Score')
            ax.set_title('Model Performance by Scenario')
            ax.set_xticks(x)
            ax.set_xticklabels(scenario_df['Scenario'])
            ax.legend()
            ax.grid(True, alpha=0.3)
            
            plt.tight_layout()
            save_figure(fig, "scenario_performance.png")
    
else:
    print("Warning: No model predictions available for evaluation")

# %% [markdown]
# ## 8. TRUST SCORE COMPUTATION

# %%
# ============================================================================
# TRUST SCORE COMPUTATION
# ============================================================================

print("\n" + "="*60)
print("TRUST SCORE COMPUTATION")
print("="*60)

class TrustScoreCalculator:
    """Calculate trust scores using exponential smoothing model."""
    
    def __init__(self, alpha: float = 0.3, threshold: float = 0.7):
        self.alpha = alpha  # Smoothing factor
        self.threshold = threshold  # Consistency threshold
        self.trust_scores = {}
    
    def compute_consistency_score(self, features: Dict) -> float:
        """Compute consistency score from features."""
        consistency_metrics = []
        
        # RSSI consistency (inverse of variance)
        if 'rssi_std' in features:
            rssi_consistency = 1.0 / (1.0 + features['rssi_std'])
            consistency_metrics.append(rssi_consistency)
        
        # Neighbor consistency
        if 'neighbor_jaccard_mean' in features:
            neighbor_consistency = features['neighbor_jaccard_mean']
            consistency_metrics.append(neighbor_consistency)
        
        # Message interval consistency (inverse of std)
        if 'msg_interval_std' in features and features['msg_interval_std'] > 0:
            msg_consistency = 1.0 / (1.0 + features['msg_interval_std'])
            consistency_metrics.append(msg_consistency)
        
        # Speed consistency
        if 'speed_std' in features and features['speed_std'] > 0:
            speed_consistency = 1.0 / (1.0 + features['speed_std'])
            consistency_metrics.append(speed_consistency)
        
        if consistency_metrics:
            return np.mean(consistency_metrics)
        else:
            return 0.5  # Default neutral score
    
    def update_trust_score(self, vehicle_id: str, consistency: float, 
                          previous_trust: float = None) -> float:
        """Update trust score using exponential smoothing."""
        if previous_trust is None:
            # Initialize trust score
            trust = consistency
        else:
            # Exponential smoothing: trust_t = α * consistency_t + (1-α) * trust_{t-1}
            trust = self.alpha * consistency + (1 - self.alpha) * previous_trust
        
        # Store trust score
        self.trust_scores[vehicle_id] = trust
        
        return trust
    
    def compute_trust_scores_over_time(self, features_df: pd.DataFrame) -> pd.DataFrame:
        """Compute trust scores for all vehicles over time."""
        if features_df.empty or 'vehicle_id' not in features_df.columns:
            print("Warning: No vehicle-specific features for trust computation")
            return pd.DataFrame()
        
        # Sort by vehicle and time
        features_sorted = features_df.sort_values(['vehicle_id', 'window_start'])
        
        trust_records = []
        current_trust = {}
        
        # Group by vehicle
        for vehicle_id, group in tqdm(features_sorted.groupby('vehicle_id'), 
                                     desc="Computing trust scores"):
            vehicle_trust = []
            
            for _, row in group.iterrows():
                # Convert row to dictionary
                row_dict = row.to_dict()
                
                # Compute consistency score
                consistency = self.compute_consistency_score(row_dict)
                
                # Get previous trust score
                prev_trust = current_trust.get(vehicle_id, 0.5)  # Start with neutral trust
                
                # Update trust score
                trust = self.update_trust_score(vehicle_id, consistency, prev_trust)
                current_trust[vehicle_id] = trust
                
                # Record trust
                trust_record = {
                    'vehicle_id': vehicle_id,
                    'window_start': row['window_start'],
                    'window_end': row['window_end'],
                    'consistency_score': consistency,
                    'trust_score': trust,
                    'is_trusted': trust >= self.threshold
                }
                
                # Add additional features if available
                for col in ['rssi_std', 'neighbor_jaccard_mean', 'msg_interval_std', 'speed_std']:
                    if col in row_dict:
                        trust_record[col] = row_dict[col]
                
                trust_records.append(trust_record)
                vehicle_trust.append(trust)
            
            # Store final trust score for this vehicle
            if vehicle_trust:
                self.trust_scores[vehicle_id] = vehicle_trust[-1]
        
        trust_df = pd.DataFrame(trust_records)
        
        return trust_df
    
    def plot_trust_convergence(self, trust_df: pd.DataFrame, n_vehicles: int = 10):
        """Plot trust score convergence over time for selected vehicles."""
        if trust_df.empty:
            print("Warning: No trust data to plot")
            return
        
        # Select random vehicles to plot
        vehicle_ids = trust_df['vehicle_id'].unique()
        if len(vehicle_ids) > n_vehicles:
            plot_vehicles = np.random.choice(vehicle_ids, n_vehicles, replace=False)
        else:
            plot_vehicles = vehicle_ids
        
        fig, axes = plt.subplots(2, 1, figsize=(14, 10))
        
        # Plot 1: Individual trust convergence
        ax1 = axes[0]
        for vehicle_id in plot_vehicles:
            vehicle_data = trust_df[trust_df['vehicle_id'] == vehicle_id].sort_values('window_start')
            if not vehicle_data.empty:
                ax1.plot(vehicle_data['window_start'], vehicle_data['trust_score'], 
                        marker='o', markersize=3, label=vehicle_id, alpha=0.7)
        
        ax1.axhline(y=self.threshold, color='r', linestyle='--', label=f'Threshold ({self.threshold})')
        ax1.set_xlabel('Time (window start)')
        ax1.set_ylabel('Trust Score')
        ax1.set_title('Trust Score Convergence Over Time')
        ax1.legend(loc='upper right', fontsize='small')
        ax1.grid(True, alpha=0.3)
        
        # Plot 2: Distribution of final trust scores
        ax2 = axes[1]
        final_trust = trust_df.groupby('vehicle_id')['trust_score'].last()
        
        ax2.hist(final_trust, bins=30, alpha=0.7, edgecolor='black')
        ax2.axvline(x=self.threshold, color='r', linestyle='--', label=f'Threshold ({self.threshold})')
        ax2.set_xlabel('Final Trust Score')
        ax2.set_ylabel('Number of Vehicles')
        ax2.set_title('Distribution of Final Trust Scores')
        ax2.legend()
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        save_figure(fig, "trust_convergence.png")
        
        # Additional plot: Trust vs Consistency
        fig2, ax = plt.subplots(figsize=(10, 6))
        scatter = ax.scatter(trust_df['consistency_score'], trust_df['trust_score'], 
                           c=trust_df['trust_score'], cmap='viridis', alpha=0.6)
        ax.set_xlabel('Consistency Score')
        ax.set_ylabel('Trust Score')
        ax.set_title('Trust Score vs Consistency Score')
        ax.grid(True, alpha=0.3)
        
        # Add colorbar
        plt.colorbar(scatter, ax=ax, label='Trust Score')
        
        # Add diagonal line (trust = consistency for no memory)
        x = np.linspace(0, 1, 100)
        ax.plot(x, x, 'r--', alpha=0.5, label='Trust = Consistency (no memory)')
        ax.legend()
        
        plt.tight_layout()
        save_figure(fig2, "trust_vs_consistency.png")

# %%
# Compute trust scores
print("\nComputing trust scores...")

trust_calculator = TrustScoreCalculator(alpha=TRUST_ALPHA, threshold=CONSISTENCY_THRESHOLD)

# Use vehicle-specific features if available
if 'vehicle_id' in combined_features.columns and combined_features['vehicle_id'].nunique() > 1:
    trust_df = trust_calculator.compute_trust_scores_over_time(combined_features)
    
    if not trust_df.empty:
        print(f"Computed trust scores for {trust_df['vehicle_id'].nunique()} vehicles")
        print(f"Average trust score: {trust_df['trust_score'].mean():.4f}")
        print(f"Trusted vehicles: {trust_df['is_trusted'].sum()} / {trust_df['vehicle_id'].nunique()}")
        
        # Save trust scores
        save_dataframe(trust_df, "trust_scores.csv")
        
        # Plot trust convergence
        print("\nCreating trust convergence plots...")
        trust_calculator.plot_trust_convergence(trust_df, n_vehicles=15)
        
        # Merge trust scores with labels for analysis
        if 'is_malicious' in merged_data.columns:
            # We need to merge based on vehicle_id and window times
            merged_trust = pd.merge(
                trust_df,
                merged_data[['vehicle_id', 'window_start', 'window_end', 'is_malicious', 'attack_type']],
                on=['vehicle_id', 'window_start', 'window_end'],
                how='left'
            )
            
            # Analyze trust vs malicious behavior
            if 'is_malicious' in merged_trust.columns:
                print("\nTrust Analysis by Behavior:")
                benign_trust = merged_trust[merged_trust['is_malicious'] == 0]['trust_score']
                malicious_trust = merged_trust[merged_trust['is_malicious'] == 1]['trust_score']
                
                print(f"  Benign vehicles average trust: {benign_trust.mean():.4f}")
                print(f"  Malicious vehicles average trust: {malicious_trust.mean():.4f}")
                
                # Statistical test
                if len(benign_trust) > 1 and len(malicious_trust) > 1:
                    t_stat, p_value = stats.ttest_ind(benign_trust, malicious_trust, equal_var=False)
                    print(f"  T-test: t={t_stat:.4f}, p={p_value:.6f}")
                    
                    # Plot trust distribution by behavior
                    fig, ax = plt.subplots(figsize=(10, 6))
                    
                    ax.hist(benign_trust, bins=30, alpha=0.5, label='Benign', density=True)
                    ax.hist(malicious_trust, bins=30, alpha=0.5, label='Malicious', density=True)
                    
                    ax.set_xlabel('Trust Score')
                    ax.set_ylabel('Density')
                    ax.set_title('Trust Score Distribution by Behavior')
                    ax.legend()
                    ax.grid(True, alpha=0.3)
                    
                    plt.tight_layout()
                    save_figure(fig, "trust_distribution_by_behavior.png")
    else:
        print("Warning: No trust scores computed")
else:
    print("Warning: Insufficient vehicle-specific data for trust computation")

# %% [markdown]
# ## 9. VISUALIZATION

# %%
# ============================================================================
# COMPREHENSIVE VISUALIZATION
# ============================================================================

print("\n" + "="*60)
print("VISUALIZATION")
print("="*60)

def create_comprehensive_visualizations():
    """Create all visualizations and save to FIG_DIR."""
    
    print("\nCreating comprehensive visualizations...")
    
    # 1. Network Statistics Overview
    if not consolidated_stats.empty:
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        # PDR by scenario
        ax = axes[0, 0]
        scenario_pdr = consolidated_stats.groupby('scenario')['pdr'].mean()
        scenario_pdr.plot(kind='bar', ax=ax, color='skyblue')
        ax.set_title('Packet Delivery Ratio by Scenario')
        ax.set_ylabel('PDR')
        ax.set_x

Configuration loaded:
  Root directory: /home/jeanhuit/Documents/Workspace/simulation/results/
  Output directory: /home/jeanhuit/Documents/Workspace/simulation/output
  Figure directory: /home/jeanhuit/Documents/Workspace/simulation/figures
  Window size: 5s
  Scenarios: ['highway', 'mixed', 'urban']
  Densities: [50, 100, 150]
Initializing data loader...

LOADING ALL SCENARIOS AND DENSITIES

Loading highway (density=50, run=run-1)...
  Loaded: bsm_log.csv (14500 rows)
  Loaded: rssi_log.csv (49 rows)
  Loaded: neighbor_log.csv (7250 rows)
  Loaded: sybil_log.csv (80 rows)
  Loaded: replay_log.csv (5 rows)
  Loaded: jammer_log.csv (5000 rows)

Loading highway (density=100, run=run-1)...
  Loaded: bsm_log.csv (29000 rows)
  Loaded: rssi_log.csv (75 rows)
  Loaded: neighbor_log.csv (14500 rows)
  Loaded: sybil_log.csv (80 rows)
  Loaded: replay_log.csv (5 rows)
  Loaded: jammer_log.csv (5000 rows)

Loading highway (density=150, run=run-1)...
  Loaded: bsm_log.csv (43500 rows)
  Loaded: 

KeyError: "Column(s) ['avg_neighbors'] do not exist"