In [None]:
# V2X Simulation ML Analysis Notebook
# Complete Machine Learning Pipeline for Attack Detection in V2X Networks

"""
This notebook performs comprehensive ML analysis on V2X simulation data:
- Multi-scenario analysis (highway, urban, mixed)
- Feature engineering with time-windowing
- Supervised ML for attack detection
- Trust score computation
- Complete evaluation and visualization

Author: Generated for V2X Security Research
Date: 2025
"""

# ============================================================================
# 1. SETUP & CONFIGURATION
# ============================================================================

import os
import sys
import warnings
import pickle
from pathlib import Path
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy import stats
from scipy.fft import fft, fftfreq
from scipy.stats import ttest_ind

from sklearn.model_selection import train_test_split, cross_val_score, GridSearchCV, StratifiedKFold
from sklearn.preprocessing import StandardScaler, RobustScaler, LabelEncoder
from sklearn.ensemble import RandomForestClassifier
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report, roc_curve, auc, roc_auc_score
)
import joblib

from tqdm.notebook import tqdm

warnings.filterwarnings('ignore')

# Configure plotting
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
plt.rcParams['figure.figsize'] = (12, 6)
plt.rcParams['font.size'] = 10

# ============================================================================
# CONFIGURATION
# ============================================================================

# Root directory - CHANGE THIS TO YOUR DATA LOCATION
ROOT_DIR = Path("/home/jeanhuit/Documents/Workspace/simulation/results/")

# Output directories
FIG_DIR = Path("/home/jeanhuit/Documents/Workspace/simulation/figures")
OUTPUT_DIR = Path("/home/jeanhuit/Documents/Workspace/simulation/output")
MODEL_PATH = Path("/home/jeanhuit/Documents/Workspace/simulation/model")

# Create output directories
FIG_DIR.mkdir(parents=True, exist_ok=True)
OUTPUT_DIR.mkdir(parents=True, exist_ok=True)

# Scenarios and densities
SCENARIOS = ['highway', 'mixed', 'urban']
DENSITIES = ['density-50', 'density-100','density-150']
RUNS = ['run-1']  # Extend with ['run-1', 'run-2', ...] for multi-run

# Feature engineering parameters
WINDOW_SIZE = 5.0  # seconds
TRUST_ALPHA = 0.7  # Exponential smoothing parameter

# ML parameters
TEST_SIZE = 0.3
RANDOM_STATE = 42
CV_FOLDS = 5

print("=" * 80)
print("V2X SIMULATION ML ANALYSIS")
print("=" * 80)
print(f"Root Directory: {ROOT_DIR}")
print(f"Scenarios: {SCENARIOS}")
print(f"Densities: {DENSITIES}")
print(f"Window Size: {WINDOW_SIZE}s")
print("=" * 80)

# ============================================================================
# 2. DATA LOADING & NETWORK STATISTICS
# ============================================================================

def load_scenario_data(scenario, density, run='run-1'):
    """
    Load all CSV files for a specific scenario/density/run.
    
    Returns:
        dict: Dictionary containing all loaded dataframes
    """
    base_path = ROOT_DIR / scenario / str(density) / run
    
    data = {}
    files = ['bsm_log.csv', 'rssi_log.csv', 'neighbor_log.csv', 
             'sybil_log.csv', 'replay_log.csv', 'jammer_log.csv']
    
    for file in files:
        file_path = base_path / file
        if file_path.exists():
            try:
                df = pd.read_csv(file_path)
                data[file.replace('.csv', '')] = df
            except Exception as e:
                print(f"Error loading {file_path}: {e}")
                data[file.replace('.csv', '')] = pd.DataFrame()
        else:
            print(f"Warning: {file_path} not found")
            data[file.replace('.csv', '')] = pd.DataFrame()
    
    return data

def compute_network_statistics(data_dict):
    """
    Compute basic network statistics from loaded data.
    
    Returns:
        dict: Network statistics
    """
    stats = {}
    
    bsm = data_dict.get('bsm_log', pd.DataFrame())
    if not bsm.empty:
        stats['total_sent'] = len(bsm[bsm['type'] == 'sent']) if 'type' in bsm.columns else 0
        stats['total_received'] = len(bsm[bsm['type'] == 'received']) if 'type' in bsm.columns else 0
        stats['pdr'] = stats['total_received'] / stats['total_sent'] if stats['total_sent'] > 0 else 0
    else:
        stats['total_sent'] = 0
        stats['total_received'] = 0
        stats['pdr'] = 0
    
    neighbor = data_dict.get('neighbor_log', pd.DataFrame())
    if not neighbor.empty and 'vehicle_id' in neighbor.columns:
        stats['avg_neighbors'] = neighbor.groupby('vehicle_id').size().mean()
        stats['unique_vehicles'] = neighbor['vehicle_id'].nunique()
    else:
        stats['avg_neighbors'] = 0
        stats['unique_vehicles'] = 0
    
    return stats

# Load all data
print("\n[1/10] Loading data from all scenarios...")
all_data = {}
network_stats = []

for scenario in tqdm(SCENARIOS, desc="Scenarios"):
    for density in DENSITIES:
        for run in RUNS:
            key = f"{scenario}_{density}_{run}"
            try:
                data = load_scenario_data(scenario, density, run)
                all_data[key] = data
                
                stats = compute_network_statistics(data)
                stats.update({
                    'scenario': scenario,
                    'density': density,
                    'run': run
                })
                network_stats.append(stats)
            except Exception as e:
                print(f"Error loading {key}: {e}")

# Create network statistics table
network_stats_df = pd.DataFrame(network_stats)
print("\n" + "=" * 80)
print("NETWORK STATISTICS SUMMARY")
print("=" * 80)
print(network_stats_df.to_string())
print("=" * 80)

# Save network statistics
network_stats_df.to_csv(OUTPUT_DIR / "network_statistics.csv", index=False)

# ============================================================================
# 3. FEATURE ENGINEERING
# ============================================================================

def create_time_windows(df, time_col='time', window_size=5.0):
    """
    Create time window indices for a dataframe.
    
    Args:
        df: DataFrame with time column
        time_col: Name of time column
        window_size: Window size in seconds
    
    Returns:
        Series: Window index for each row
    """
    if df.empty or time_col not in df.columns:
        return pd.Series([], dtype=int)
    
    return (df[time_col] // window_size).astype(int)

def extract_message_features(bsm_df, window_size=5.0):
    """
    Extract message-based features from BSM logs.
    
    Returns:
        DataFrame: Features per vehicle per window
    """
    if bsm_df.empty:
        return pd.DataFrame()
    
    # Ensure required columns exist
    required_cols = ['time', 'vehicle_id']
    if not all(col in bsm_df.columns for col in required_cols):
        return pd.DataFrame()
    
    bsm_df['window'] = create_time_windows(bsm_df, window_size=window_size)
    
    features = []
    
    for (vehicle_id, window), group in bsm_df.groupby(['vehicle_id', 'window']):
        feat = {
            'vehicle_id': vehicle_id,
            'window': window,
            'time_start': window * window_size,
            'msg_count_total': len(group),
        }
        
        # Sent/received counts
        if 'type' in group.columns:
            feat['msg_count_sent'] = len(group[group['type'] == 'sent'])
            feat['msg_count_received'] = len(group[group['type'] == 'received'])
        
        # Unique senders
        if 'sender_id' in group.columns:
            feat['unique_senders'] = group['sender_id'].nunique()
        
        # Speed statistics
        if 'speed' in group.columns:
            feat['speed_mean'] = group['speed'].mean()
            feat['speed_std'] = group['speed'].std()
            feat['speed_min'] = group['speed'].min()
            feat['speed_max'] = group['speed'].max()
        
        features.append(feat)
    
    return pd.DataFrame(features)

def extract_neighbor_features(neighbor_df, window_size=5.0):
    """
    Extract neighbor-based features.
    
    Returns:
        DataFrame: Neighbor features per vehicle per window
    """
    if neighbor_df.empty:
        return pd.DataFrame()
    
    required_cols = ['time', 'vehicle_id']
    if not all(col in neighbor_df.columns for col in required_cols):
        return pd.DataFrame()
    
    neighbor_df['window'] = create_time_windows(neighbor_df, window_size=window_size)
    
    features = []
    
    for (vehicle_id, window), group in neighbor_df.groupby(['vehicle_id', 'window']):
        feat = {
            'vehicle_id': vehicle_id,
            'window': window,
            'neighbor_count': len(group),
        }
        
        if 'neighbor_id' in group.columns:
            feat['unique_neighbors'] = group['neighbor_id'].nunique()
        
        if 'duration' in group.columns:
            feat['neighbor_duration_mean'] = group['duration'].mean()
            feat['neighbor_duration_std'] = group['duration'].std()
            feat['neighbor_duration_max'] = group['duration'].max()
        
        features.append(feat)
    
    return pd.DataFrame(features)

def extract_rssi_features(rssi_df, window_size=5.0):
    """
    Extract RSSI-based features.
    
    Returns:
        DataFrame: RSSI features per vehicle per window
    """
    if rssi_df.empty:
        return pd.DataFrame()
    
    required_cols = ['time', 'vehicle_id', 'rssi']
    if not all(col in rssi_df.columns for col in required_cols):
        return pd.DataFrame()
    
    rssi_df['window'] = create_time_windows(rssi_df, window_size=window_size)
    
    features = []
    
    for (vehicle_id, window), group in rssi_df.groupby(['vehicle_id', 'window']):
        rssi_values = group['rssi'].values
        
        feat = {
            'vehicle_id': vehicle_id,
            'window': window,
            'rssi_mean': rssi_values.mean(),
            'rssi_std': rssi_values.std(),
            'rssi_min': rssi_values.min(),
            'rssi_max': rssi_values.max(),
            'rssi_variance': rssi_values.var(),
        }
        
        # RSSI trend (linear regression slope)
        if len(rssi_values) > 1:
            x = np.arange(len(rssi_values))
            slope, _ = np.polyfit(x, rssi_values, 1)
            feat['rssi_trend'] = slope
        else:
            feat['rssi_trend'] = 0
        
        features.append(feat)
    
    return pd.DataFrame(features)

def extract_spectral_features(bsm_df, window_size=5.0):
    """
    Extract spectral features from message intervals using FFT.
    
    Returns:
        DataFrame: Spectral features per vehicle per window
    """
    if bsm_df.empty or 'time' not in bsm_df.columns:
        return pd.DataFrame()
    
    bsm_df['window'] = create_time_windows(bsm_df, window_size=window_size)
    
    features = []
    
    for (vehicle_id, window), group in bsm_df.groupby(['vehicle_id', 'window']):
        times = np.sort(group['time'].values)
        
        if len(times) > 2:
            intervals = np.diff(times)
            
            # FFT of intervals
            fft_vals = np.abs(fft(intervals))
            
            feat = {
                'vehicle_id': vehicle_id,
                'window': window,
                'spectral_energy': np.sum(fft_vals**2),
                'spectral_entropy': stats.entropy(fft_vals + 1e-10),
                'interval_mean': intervals.mean(),
                'interval_std': intervals.std(),
            }
            features.append(feat)
    
    return pd.DataFrame(features)

def combine_all_features(data_dict, window_size=5.0):
    """
    Extract and combine all features for a scenario.
    
    Returns:
        DataFrame: Combined feature matrix
    """
    # Extract features from each log
    msg_feat = extract_message_features(data_dict.get('bsm_log', pd.DataFrame()), window_size)
    neighbor_feat = extract_neighbor_features(data_dict.get('neighbor_log', pd.DataFrame()), window_size)
    rssi_feat = extract_rssi_features(data_dict.get('rssi_log', pd.DataFrame()), window_size)
    spectral_feat = extract_spectral_features(data_dict.get('bsm_log', pd.DataFrame()), window_size)
    
    # Merge all features
    features = msg_feat
    
    if not neighbor_feat.empty:
        features = features.merge(neighbor_feat, on=['vehicle_id', 'window'], how='outer')
    
    if not rssi_feat.empty:
        features = features.merge(rssi_feat, on=['vehicle_id', 'window'], how='outer')
    
    if not spectral_feat.empty:
        features = features.merge(spectral_feat, on=['vehicle_id', 'window'], how='outer')
    
    return features.fillna(0)

print("\n[2/10] Extracting features from all scenarios...")

all_features = {}

for key, data in tqdm(all_data.items(), desc="Feature Extraction"):
    try:
        features = combine_all_features(data, window_size=WINDOW_SIZE)
        if not features.empty:
            scenario, density, run = key.split('_')
            features['scenario'] = scenario
            features['density'] = int(density)
            features['run'] = run
            all_features[key] = features
    except Exception as e:
        print(f"Error extracting features for {key}: {e}")

print(f"✓ Extracted features for {len(all_features)} configurations")

# ============================================================================
# 4. LABEL GENERATION
# ============================================================================

def generate_labels(data_dict, window_size=5.0):
    """
    Generate labels from attack logs.
    
    Returns:
        DataFrame: Labels per vehicle per window
    """
    labels = []
    
    # Combine all attack logs
    attack_dfs = []
    attack_types = []
    
    for attack_type in ['sybil_log', 'replay_log', 'jammer_log']:
        df = data_dict.get(attack_type, pd.DataFrame())
        if not df.empty and 'time' in df.columns and 'vehicle_id' in df.columns:
            df['attack_type'] = attack_type.replace('_log', '')
            attack_dfs.append(df)
    
    if not attack_dfs:
        return pd.DataFrame()
    
    attacks = pd.concat(attack_dfs, ignore_index=True)
    attacks['window'] = create_time_windows(attacks, window_size=window_size)
    
    # Get all vehicle-window combinations from attacks
    for (vehicle_id, window), group in attacks.groupby(['vehicle_id', 'window']):
        labels.append({
            'vehicle_id': vehicle_id,
            'window': window,
            'is_malicious': 1,
            'attack_type': group['attack_type'].iloc[0]  # Primary attack type
        })
    
    return pd.DataFrame(labels)

print("\n[3/10] Generating labels...")

all_labels = {}

for key, data in tqdm(all_data.items(), desc="Label Generation"):
    try:
        labels = generate_labels(data, window_size=WINDOW_SIZE)
        all_labels[key] = labels
    except Exception as e:
        print(f"Error generating labels for {key}: {e}")

print(f"✓ Generated labels for {len(all_labels)} configurations")

# ============================================================================
# 5. EXPLORATORY DATA ANALYSIS
# ============================================================================

print("\n[4/10] Performing exploratory data analysis...")

# Combine all features and labels
combined_features = []
combined_labels = []

for key in all_features.keys():
    feat = all_features[key].copy()
    lab = all_labels.get(key, pd.DataFrame())
    
    # Merge features with labels
    if not lab.empty:
        merged = feat.merge(lab, on=['vehicle_id', 'window'], how='left')
    else:
        merged = feat.copy()
        merged['is_malicious'] = 0
        merged['attack_type'] = 'benign'
    
    merged['is_malicious'] = merged['is_malicious'].fillna(0).astype(int)
    merged['attack_type'] = merged['attack_type'].fillna('benign')
    
    combined_features.append(merged)

# Create master dataframe
master_df = pd.concat(combined_features, ignore_index=True)

print(f"\n✓ Master dataset shape: {master_df.shape}")
print(f"✓ Benign samples: {(master_df['is_malicious'] == 0).sum()}")
print(f"✓ Malicious samples: {(master_df['is_malicious'] == 1).sum()}")
print(f"\nAttack type distribution:")
print(master_df['attack_type'].value_counts())

# Save combined dataset
master_df.to_csv(OUTPUT_DIR / "master_features.csv", index=False)
print(f"\n✓ Saved master features to {OUTPUT_DIR / 'master_features.csv'}")

# ============================================================================
# 6. ML PIPELINE & TRAINING
# ============================================================================

print("\n[5/10] Training ML models...")

# Prepare features and labels
feature_cols = [col for col in master_df.columns if col not in 
                ['vehicle_id', 'window', 'time_start', 'scenario', 'density', 
                 'run', 'is_malicious', 'attack_type']]

X = master_df[feature_cols].values
y_binary = master_df['is_malicious'].values
y_multiclass = master_df['attack_type'].values

# Encode multiclass labels
le = LabelEncoder()
y_multiclass_encoded = le.fit_transform(y_multiclass)

print(f"Features shape: {X.shape}")
print(f"Feature columns: {feature_cols}")

# Train-test split
X_train, X_test, y_train_bin, y_test_bin, y_train_multi, y_test_multi = train_test_split(
    X, y_binary, y_multiclass_encoded, test_size=TEST_SIZE, random_state=RANDOM_STATE, stratify=y_binary
)

print(f"Training set: {X_train.shape[0]} samples")
print(f"Test set: {X_test.shape[0]} samples")

# Feature scaling
scaler = RobustScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)

# Binary classification model
print("\n--- Training Binary Classifier (Attack Detection) ---")

rf_binary = RandomForestClassifier(
    n_estimators=100,
    max_depth=20,
    min_samples_split=10,
    min_samples_leaf=4,
    random_state=RANDOM_STATE,
    n_jobs=-1
)

rf_binary.fit(X_train_scaled, y_train_bin)

# Cross-validation
cv_scores = cross_val_score(rf_binary, X_train_scaled, y_train_bin, 
                            cv=CV_FOLDS, scoring='f1')
print(f"Cross-validation F1 scores: {cv_scores}")
print(f"Mean CV F1: {cv_scores.mean():.4f} (+/- {cv_scores.std():.4f})")

# Hyperparameter tuning (optional - can be time-consuming)
TUNE_HYPERPARAMETERS = False  # Set to True to enable

if TUNE_HYPERPARAMETERS:
    print("\n--- Hyperparameter Tuning ---")
    param_grid = {
        'n_estimators': [50, 100, 200],
        'max_depth': [10, 20, 30],
        'min_samples_split': [5, 10, 20],
        'min_samples_leaf': [2, 4, 8]
    }
    
    grid_search = GridSearchCV(
        RandomForestClassifier(random_state=RANDOM_STATE, n_jobs=-1),
        param_grid,
        cv=3,
        scoring='f1',
        n_jobs=-1,
        verbose=1
    )
    
    grid_search.fit(X_train_scaled, y_train_bin)
    print(f"Best parameters: {grid_search.best_params_}")
    print(f"Best F1 score: {grid_search.best_score_:.4f}")
    
    rf_binary = grid_search.best_estimator_

# Save model
joblib.dump(rf_binary, MODEL_PATH)
joblib.dump(scaler, OUTPUT_DIR / "scaler.pkl")
print(f"\n✓ Saved model to {MODEL_PATH}")

# Multiclass classification
print("\n--- Training Multiclass Classifier (Attack Type Identification) ---")

rf_multiclass = RandomForestClassifier(
    n_estimators=100,
    max_depth=20,
    random_state=RANDOM_STATE,
    n_jobs=-1
)

rf_multiclass.fit(X_train_scaled, y_train_multi)

# ============================================================================
# 7. MODEL EVALUATION
# ============================================================================

print("\n[6/10] Evaluating models...")

# Binary classification predictions
y_pred_bin = rf_binary.predict(X_test_scaled)
y_pred_bin_proba = rf_binary.predict_proba(X_test_scaled)[:, 1]

# Multiclass predictions
y_pred_multi = rf_multiclass.predict(X_test_scaled)

# Binary classification metrics
print("\n" + "=" * 80)
print("BINARY CLASSIFICATION RESULTS (Attack Detection)")
print("=" * 80)

acc_bin = accuracy_score(y_test_bin, y_pred_bin)
prec_bin = precision_score(y_test_bin, y_pred_bin)
rec_bin = recall_score(y_test_bin, y_pred_bin)
f1_bin = f1_score(y_test_bin, y_pred_bin)

print(f"Accuracy:  {acc_bin:.4f}")
print(f"Precision: {prec_bin:.4f}")
print(f"Recall:    {rec_bin:.4f}")
print(f"F1-Score:  {f1_bin:.4f}")

print("\nClassification Report:")
print(classification_report(y_test_bin, y_pred_bin, 
                          target_names=['Benign', 'Malicious']))

# Confusion matrix
cm_bin = confusion_matrix(y_test_bin, y_pred_bin)
print("\nConfusion Matrix:")
print(cm_bin)

# ROC-AUC
roc_auc_bin = roc_auc_score(y_test_bin, y_pred_bin_proba)
print(f"\nROC-AUC Score: {roc_auc_bin:.4f}")

# Multiclass metrics
print("\n" + "=" * 80)
print("MULTICLASS CLASSIFICATION RESULTS (Attack Type)")
print("=" * 80)

acc_multi = accuracy_score(y_test_multi, y_pred_multi)
prec_multi = precision_score(y_test_multi, y_pred_multi, average='weighted')
rec_multi = recall_score(y_test_multi, y_pred_multi, average='weighted')
f1_multi = f1_score(y_test_multi, y_pred_multi, average='weighted')

print(f"Accuracy:  {acc_multi:.4f}")
print(f"Precision: {prec_multi:.4f}")
print(f"Recall:    {rec_multi:.4f}")
print(f"F1-Score:  {f1_multi:.4f}")

print("\nClassification Report:")
print(classification_report(y_test_multi, y_pred_multi, 
                          target_names=le.classes_))

# Save metrics
metrics_df = pd.DataFrame({
    'Model': ['Binary', 'Multiclass'],
    'Accuracy': [acc_bin, acc_multi],
    'Precision': [prec_bin, prec_multi],
    'Recall': [rec_bin, rec_multi],
    'F1-Score': [f1_bin, f1_multi]
})

metrics_df.to_csv(OUTPUT_DIR / "ml_metrics.csv", index=False)

# Feature importance
feature_importance = pd.DataFrame({
    'feature': feature_cols,
    'importance': rf_binary.feature_importances_
}).sort_values('importance', ascending=False)

feature_importance.to_csv(OUTPUT_DIR / "feature_importance.csv", index=False)

print("\n✓ Top 10 Important Features:")
print(feature_importance.head(10).to_string())

# ============================================================================
# 8. TRUST SCORE COMPUTATION
# ============================================================================

print("\n[7/10] Computing trust scores...")

def compute_trust_scores(features_df, alpha=0.7):
    """
    Compute trust scores using exponential smoothing.
    
    trust_t = α × consistency_t + (1-α) × trust_{t-1}
    
    Args:
        features_df: DataFrame with features
        alpha: Smoothing parameter
    
    Returns:
        DataFrame: Trust scores per vehicle over time
    """
    trust_scores = []
    
    for vehicle_id in features_df['vehicle_id'].unique():
        vehicle_data = features_df[features_df['vehicle_id'] == vehicle_id].sort_values('window')
        
        trust = 1.0  # Initial trust
        
        for _, row in vehicle_data.iterrows():
            # Compute consistency based on multiple factors
            consistency = 0.0
            n_factors = 0
            
            # Factor 1: Neighbor consistency (more neighbors = more consistent)
            if 'unique_neighbors' in row and row['unique_neighbors'] > 0:
                consistency += min(row['unique_neighbors'] / 10.0, 1.0)
                n_factors += 1
            
            # Factor 2: RSSI stability (low variance = more stable)
            if 'rssi_variance' in row and row['rssi_variance'] >= 0:
                rssi_stability = 1.0 / (1.0 + row['rssi_variance'] / 100.0)
                consistency += rssi_stability
                n_factors += 1
            
            # Factor 3: Message plausibility (speed variations)
            if 'speed_std' in row and row['speed_std'] >= 0:
                speed_plausibility = 1.0 / (1.0 + row['speed_std'] / 10.0)
                consistency += speed_plausibility
                n_factors += 1
            
            # Average consistency
            if n_factors > 0:
                consistency /= n_factors
            else:
                consistency = 0.5  # Neutral
            
            # Update trust with exponential smoothing
            trust = alpha * consistency + (1 - alpha) * trust
            
            trust_scores.append({
                'vehicle_id': vehicle_id,
                'window': row['window'],
                'time': row['window'] * WINDOW_SIZE,
                'trust_score': trust,
                'consistency': consistency
            })
    
    return pd.DataFrame(trust_scores)

# Compute trust scores for all scenarios
trust_scores_all = {}

for key, features in all_features.items():
    if not features.empty:
        trust = compute_trust_scores(features, alpha=TRUST_ALPHA)
        trust['scenario'] = features['scenario'].iloc[0]
        trust['density'] = features['density'].iloc[0]
        trust_scores_all[key] = trust

# Combine all trust scores
combined_trust = pd.concat(trust_scores_all.values(), ignore_index=True)
combined_trust.to_csv(OUTPUT_DIR / "trust_scores.csv", index=False)

print(f"✓ Computed trust scores for {len(trust_scores_all)} configurations")
print(f"✓ Mean trust score: {combined_trust['trust_score'].mean():.4f}")
print(f"✓ Trust score std: {combined_trust['trust_score'].std():.4f}")

# ============================================================================
# 9. VISUALIZATION
# ============================================================================

print("\n[8/10] Creating visualizations...")

# 1. ROC Curve
plt.figure(figsize=(10, 8))
fpr, tpr, _ = roc_curve(y_test_bin, y_pred_bin_proba)
plt.plot(fpr, tpr, label=f'Binary Classifier (AUC = {roc_auc_bin:.3f})', linewidth=2)
plt.plot([0, 1], [0, 1], 'k--', label='Random Classifier')
plt.xlabel('False Positive Rate', fontsize=12)
plt.ylabel('True Positive Rate', fontsize=12)
plt.title('ROC Curve - Attack Detection', fontsize=14, fontweight='bold')
plt.legend(fontsize=11)
plt.grid(alpha=0.3)
plt.tight_layout()
plt.savefig(FIG_DIR / "roc_curve.png", dpi=300, bbox_inches='tight')
plt.close()

# 2. Confusion Matrix - Binary
plt.figure(figsize=(8, 6))
sns.heatmap(cm_bin, annot=True, fmt='d', cmap='Blues', 
            xticklabels=['Benign', 'Malicious'],
            yticklabels=['Benign', 'Malicious'],
            cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix - Binary Classification', fontsize=14, fontweight='bold')
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.tight_layout()
plt.savefig(FIG_DIR / "confusion_matrix_binary.png", dpi=300, bbox_inches='tight')
plt.close()

# 3. Confusion Matrix - Multiclass
cm_multi = confusion_matrix(y_test_multi, y_pred_multi)
plt.figure(figsize=(10, 8))
sns.heatmap(cm_multi, annot=True, fmt='d', cmap='YlOrRd',
            xticklabels=le.classes_,
            yticklabels=le.classes_,
            cbar_kws={'label': 'Count'})
plt.title('Confusion Matrix - Attack Type Classification', fontsize=14, fontweight='bold')
plt.ylabel('True Label', fontsize=12)
plt.xlabel('Predicted Label', fontsize=12)
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.savefig(FIG_DIR / "confusion_matrix_multiclass.png", dpi=300, bbox_inches='tight')
plt.close()

# 4. Feature Importance
plt.figure(figsize=(12, 8))
top_n = 15
top_features = feature_importance.head(top_n)
plt.barh(range(len(top_features)), top_features['importance'])
plt.yticks(range(len(top_features)), top_features['feature'])
plt.xlabel('Importance', fontsize=12)
plt.title(f'Top {top_n} Most Important Features', fontsize=14, fontweight='bold')
plt.gca().invert_yaxis()
plt.grid(axis='x', alpha=0.3)
plt.tight_layout()
plt.savefig(FIG_DIR / "feature_importance.png", dpi=300, bbox_inches='tight')
plt.close()

# 5. Trust Score Over Time (sample vehicles)
plt.figure(figsize=(14, 6))
sample_vehicles = combined_trust['vehicle_id'].unique()[:5]
for vehicle_id in sample_vehicles:
    vehicle_trust = combined_trust[combined_trust['vehicle_id'] == vehicle_id]
    plt.plot(vehicle_trust['time'], vehicle_trust['trust_score'], 
             marker='o', markersize=3, label=f'Vehicle {vehicle_id}', alpha=0.7)

plt.xlabel('Time (s)', fontsize=12)
plt.ylabel('Trust Score', fontsize=12)
plt.title('Trust Score Evolution Over Time (Sample Vehicles)', fontsize=14, fontweight='bold')
plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
plt.grid(alpha=0.3)
plt.ylim([0, 1.05])
plt.tight_layout()
plt.savefig(FIG_DIR / "trust_scores_over_time.png", dpi=300, bbox_inches='tight')
plt.close()

# 6. PDR Comparison Across Scenarios
plt.figure(figsize=(12, 6))
pdr_data = network_stats_df.pivot_table(values='pdr', index='density', columns='scenario')
pdr_data.plot(kind='bar', ax=plt.gca(), width=0.7)
plt.xlabel('Vehicle Density (vehicles/km)', fontsize=12)
plt.ylabel('Packet Delivery Ratio', fontsize=12)
plt.title('PDR Comparison Across Scenarios', fontsize=14, fontweight='bold')
plt.legend(title='Scenario', fontsize=11)
plt.xticks(rotation=0)
plt.grid(axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig(FIG_DIR / "pdr_comparison.png", dpi=300, bbox_inches='tight')
plt.close()

# 7. Attack Detection Performance by Scenario
scenario_performance = []
for scenario in SCENARIOS:
    scenario_mask = master_df['scenario'] == scenario
    if scenario_mask.sum() > 0:
        X_scenario = master_df[scenario_mask][feature_cols].values
        y_scenario = master_df[scenario_mask]['is_malicious'].values
        
        if len(y_scenario) > 0 and y_scenario.sum() > 0:
            X_scenario_scaled = scaler.transform(X_scenario)
            y_pred_scenario = rf_binary.predict(X_scenario_scaled)
            
            scenario_performance.append({
                'scenario': scenario,
                'accuracy': accuracy_score(y_scenario, y_pred_scenario),
                'precision': precision_score(y_scenario, y_pred_scenario, zero_division=0),
                'recall': recall_score(y_scenario, y_pred_scenario, zero_division=0),
                'f1_score': f1_score(y_scenario, y_pred_scenario, zero_division=0)
            })

if scenario_performance:
    perf_df = pd.DataFrame(scenario_performance)
    perf_df_melted = perf_df.melt(id_vars='scenario', var_name='metric', value_name='score')
    
    plt.figure(figsize=(12, 6))
    sns.barplot(data=perf_df_melted, x='scenario', y='score', hue='metric')
    plt.xlabel('Scenario', fontsize=12)
    plt.ylabel('Score', fontsize=12)
    plt.title('ML Performance by Scenario', fontsize=14, fontweight='bold')
    plt.legend(title='Metric', fontsize=10)
    plt.ylim([0, 1.05])
    plt.grid(axis='y', alpha=0.3)
    plt.tight_layout()
    plt.savefig(FIG_DIR / "performance_by_scenario.png", dpi=300, bbox_inches='tight')
    plt.close()
    
    perf_df.to_csv(OUTPUT_DIR / "performance_by_scenario.csv", index=False)

print(f"✓ Saved all visualizations to {FIG_DIR}")

# ============================================================================
# 10. RESULTS EXPORT
# ============================================================================

print("\n[9/10] Exporting results...")

# Generate LaTeX table for metrics
def generate_latex_table(df, caption, label):
    """Generate LaTeX table from DataFrame."""
    latex = df.to_latex(index=False, float_format="%.4f", 
                        caption=caption, label=label)
    return latex

# Binary classification metrics LaTeX
latex_binary = generate_latex_table(
    pd.DataFrame({
        'Metric': ['Accuracy', 'Precision', 'Recall', 'F1-Score', 'ROC-AUC'],
        'Score': [acc_bin, prec_bin, rec_bin, f1_bin, roc_auc_bin]
    }),
    caption="Binary Classification Performance (Attack Detection)",
    label="tab:binary_metrics"
)

with open(OUTPUT_DIR / "binary_metrics.tex", 'w') as f:
    f.write(latex_binary)

# Multiclass classification metrics LaTeX
latex_multi = generate_latex_table(
    pd.DataFrame({
        'Metric': ['Accuracy', 'Precision', 'Recall', 'F1-Score'],
        'Score': [acc_multi, prec_multi, rec_multi, f1_multi]
    }),
    caption="Multiclass Classification Performance (Attack Type Identification)",
    label="tab:multiclass_metrics"
)

with open(OUTPUT_DIR / "multiclass_metrics.tex", 'w') as f:
    f.write(latex_multi)

# Network statistics LaTeX
latex_network = generate_latex_table(
    network_stats_df,
    caption="Network Statistics Summary",
    label="tab:network_stats"
)

with open(OUTPUT_DIR / "network_stats.tex", 'w') as f:
    f.write(latex_network)

print("✓ Exported LaTeX tables")

# Save comprehensive summary
summary = {
    'timestamp': pd.Timestamp.now(),
    'total_samples': len(master_df),
    'benign_samples': (master_df['is_malicious'] == 0).sum(),
    'malicious_samples': (master_df['is_malicious'] == 1).sum(),
    'num_features': len(feature_cols),
    'binary_accuracy': acc_bin,
    'binary_precision': prec_bin,
    'binary_recall': rec_bin,
    'binary_f1': f1_bin,
    'binary_roc_auc': roc_auc_bin,
    'multiclass_accuracy': acc_multi,
    'multiclass_f1': f1_multi,
    'cv_f1_mean': cv_scores.mean(),
    'cv_f1_std': cv_scores.std()
}

summary_df = pd.DataFrame([summary])
summary_df.to_csv(OUTPUT_DIR / "summary.csv", index=False)

with open(OUTPUT_DIR / "summary.txt", 'w') as f:
    f.write("=" * 80 + "\n")
    f.write("V2X ML ANALYSIS SUMMARY\n")
    f.write("=" * 80 + "\n\n")
    for key, value in summary.items():
        f.write(f"{key}: {value}\n")

print("✓ Saved comprehensive summary")

# ============================================================================
# 11. STATISTICAL SIGNIFICANCE TESTING
# ============================================================================

print("\n[10/10] Performing statistical significance tests...")

# Compare performance across densities
density_performance = []

for density in DENSITIES:
    density_mask = master_df['density'] == density
    if density_mask.sum() > 0:
        X_density = master_df[density_mask][feature_cols].values
        y_density = master_df[density_mask]['is_malicious'].values
        
        if len(y_density) > 0 and y_density.sum() > 0:
            X_density_scaled = scaler.transform(X_density)
            y_pred_density = rf_binary.predict(X_density_scaled)
            
            density_performance.append({
                'density': density,
                'f1_score': f1_score(y_density, y_pred_density, zero_division=0)
            })

if len(density_performance) >= 2:
    print("\nF1-Score by Density:")
    for item in density_performance:
        print(f"  Density {item['density']}: {item['f1_score']:.4f}")
    
    # Note: For proper statistical testing, would need multiple runs
    print("\n(Note: Statistical significance testing requires multiple runs with different seeds)")

print("\n" + "=" * 80)
print("ANALYSIS COMPLETE!")
print("=" * 80)
print(f"\nOutputs saved to:")
print(f"  - Figures: {FIG_DIR}")
print(f"  - Data/Metrics: {OUTPUT_DIR}")
print(f"  - Model: {MODEL_PATH}")
print("\n" + "=" * 80)

# ============================================================================
# 12. SUMMARY & NEXT STEPS
# ============================================================================

print("\n" + "=" * 80)
print("KEY FINDINGS")
print("=" * 80)
print(f"""
1. Dataset Statistics:
   - Total samples: {len(master_df):,}
   - Benign: {(master_df['is_malicious'] == 0).sum():,} ({(master_df['is_malicious'] == 0).sum()/len(master_df)*100:.1f}%)
   - Malicious: {(master_df['is_malicious'] == 1).sum():,} ({(master_df['is_malicious'] == 1).sum()/len(master_df)*100:.1f}%)

2. Binary Classification (Attack Detection):
   - Accuracy: {acc_bin:.4f}
   - Precision: {prec_bin:.4f}
   - Recall: {rec_bin:.4f}
   - F1-Score: {f1_bin:.4f}
   - ROC-AUC: {roc_auc_bin:.4f}

3. Multiclass Classification (Attack Type):
   - Accuracy: {acc_multi:.4f}
   - F1-Score: {f1_multi:.4f}

4. Cross-Validation:
   - Mean F1: {cv_scores.mean():.4f} (±{cv_scores.std():.4f})

5. Top 5 Most Important Features:
""")

for idx, row in feature_importance.head(5).iterrows():
    print(f"   - {row['feature']}: {row['importance']:.4f}")

print(f"""
6. Trust Scores:
   - Mean: {combined_trust['trust_score'].mean():.4f}
   - Std: {combined_trust['trust_score'].std():.4f}

7. Network Performance (Average PDR):
   - Highway: {network_stats_df[network_stats_df['scenario']=='highway']['pdr'].mean():.4f}
   - Mixed: {network_stats_df[network_stats_df['scenario']=='mixed']['pdr'].mean():.4f}
   - Urban: {network_stats_df[network_stats_df['scenario']=='urban']['pdr'].mean():.4f}
""")

print("\n" + "=" * 80)
print("NEXT STEPS & EXTENSIONS")
print("=" * 80)
print("""
1. Multi-Run Analysis:
   - Extend RUNS list to include ['run-1', 'run-2', 'run-3', ...]
   - Compute mean and std of metrics across runs
   - Perform statistical significance testing

2. Hyperparameter Optimization:
   - Set TUNE_HYPERPARAMETERS = True for GridSearchCV
   - Experiment with different algorithms (XGBoost, SVM, Neural Networks)

3. Advanced Features:
   - Add temporal dependencies (LSTM/RNN for sequence modeling)
   - Include spatial features (vehicle positions, trajectories)
   - Experiment with ensemble methods

4. Real-time Deployment:
   - Create streaming pipeline for online learning
   - Implement incremental model updates
   - Deploy with model monitoring

5. Explainability:
   - Add SHAP values for feature importance
   - Implement LIME for local explanations
   - Create decision path visualizations

6. Class Imbalance:
   - Experiment with SMOTE/ADASYN
   - Adjust class weights
   - Try cost-sensitive learning

7. Attack-Specific Models:
   - Train separate models for each attack type
   - Implement hierarchical classification
   - Build attack-specific trust mechanisms

8. Performance Optimization:
   - Profile code for bottlenecks
   - Parallelize feature extraction
   - Use Dask for large-scale data

To use this notebook:
1. Update ROOT_DIR to your data location
2. Run all cells (Cell → Run All)
3. Check outputs in {OUTPUT_DIR}
4. Visualizations in {FIG_DIR}
5. Trained model at {MODEL_PATH}
""")

print("=" * 80 + "\n")