In [6]:
!pip install torch

Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-curand-cu12==10.3.5.147 (from torch)
  Downloading nvidia_curand_cu12-10.3.5

In [7]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [8]:
!pip install xlsxwriter

Collecting xlsxwriter
  Downloading XlsxWriter-3.2.3-py3-none-any.whl.metadata (2.7 kB)
Downloading XlsxWriter-3.2.3-py3-none-any.whl (169 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m169.4/169.4 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: xlsxwriter
Successfully installed xlsxwriter-3.2.3


Calculate Number of Neighbors within a specific radius

Purpose:
Count nearby vehicles within a 10m radius using spatial indexing.
Key Components:

KD-Tree Optimization: Enables O(n log n) neighbor searches instead of O(n²).

Timestamp Grouping: Processes vehicles existing at the same moment.

Self-Exclusion: Removes the vehicle itself from neighbor counts.

Error Handling: Validates required columns before processing.



In [32]:
# --- Calculate Number of Neighbors within a specific radius (Optimized with KD-Tree) ---
def calculate_neighbors_count(df, distance_threshold=10):
    """
    Calculate the number of neighbors within a specified distance threshold
    for each vehicle at each timestamp based on their positions using a KD-tree.

    Args:
        df (pd.DataFrame): DataFrame containing vehicle data with 'Time (s)', 'Vehicle ID',
                           'Position X (m)', and 'Position Y (m)' columns.
        distance_threshold (float): The maximum distance in meters to consider a vehicle a neighbor.

    Returns:
        pd.Series: A Series containing the number of neighbors for each row,
                   aligned with the input DataFrame's index.
    """
    neighbor_counts = pd.Series(0, index=df.index) # Initialize Series for neighbor counts

    # Ensure necessary columns exist
    required_cols = ['Time (s)', 'Vehicle ID', 'Position X (m)', 'Position Y (m)']
    if not all(col in df.columns for col in required_cols):
        missing = [col for col in required_cols if col not in df.columns]
        print(f"Error: Required columns for neighbor calculation are missing: {missing}. Skipping neighbor count calculation.")
        return neighbor_counts # Return all zeros if columns are missing

    # Process data for each timestamp
    timestamps = df['Time (s)'].unique()
    # print(f"Processing {len(timestamps)} unique timestamps for neighbor count...") # Can be verbose

    # Group by timestamp for efficient processing
    grouped_by_time = df.groupby('Time (s)')

    for ts, vehicles_at_ts in grouped_by_time:
        if len(vehicles_at_ts) <= 1:
            # If 0 or 1 vehicle, no neighbors to calculate for this timestamp
            continue

        # Extract positions for KD-tree and get original dataframe indices
        # Ensure positions are float type for KDTree
        positions = vehicles_at_ts[['Position X (m)', 'Position Y (m)']].values.astype(float)
        original_indices = vehicles_at_ts.index.tolist()

        # Build KD-tree
        tree = KDTree(positions)

        # Query the KD-tree for neighbors within the distance threshold
        # query_ball_point returns a list of lists, where each inner list contains the indices
        # of neighbors for the corresponding point in the input positions array.
        # The indices are relative to the positions array (0 to len(vehicles_at_ts)-1).
        neighbor_indices_list = tree.query_ball_point(positions, distance_threshold)

        # Count neighbors for each vehicle at this timestamp and update the Series
        for i_local, original_df_index in enumerate(original_indices): # Iterate through the original indices
            # Get the indices of neighbors for the i_local-th vehicle in the vehicles_at_ts subset
            neighbors_relative_indices = neighbor_indices_list[i_local]

            # Exclude the vehicle itself from its neighbors list
            # The query_ball_point includes the point itself, so we need to remove index i_local
            valid_neighbors_relative_indices = [idx for idx in neighbors_relative_indices if idx != i_local]

            numNeighbors = len(valid_neighbors_relative_indices)

            # Update the neighbor_counts Series using the original index
            neighbor_counts[original_df_index] = numNeighbors

    # print("Neighbor count calculation complete for this file.") # Can be verbose
    return neighbor_counts

Helper Function to Convert Lat/Long to X/Y

Purpose:
Convert GPS coordinates to local Cartesian coordinates for distance calculations.
Key Components:

Equirectangular Projection: Approximates Earth as a cylinder for small areas.

Reference Point: Uses dataset's mean lat/long as (0,0) origin.

Earth Radius: 6,371 km (standard average value).

In [31]:
# --- Helper Function to Convert Lat/Long to X/Y ---
def latlong_to_xy(lat, long, ref_lat=None, ref_long=None):
    """
    Convert latitude and longitude to X/Y coordinates in meters
    using a reference point as origin (0,0). Uses equirectangular projection
    for simplicity, which is suitable for small areas.

    Args:
        lat (pd.Series): Series of latitude values.
        long (pd.Series): Series of longitude values.
        ref_lat (float, optional): Reference latitude for the origin. Defaults to the mean latitude.
        ref_long (float, optional): Reference longitude for the origin. Defaults to the mean longitude.

    Returns:
        tuple: (x_coords, y_coords) in meters.
    """
    if ref_lat is None:
        ref_lat = lat.mean()  # Use mean latitude as reference
    if ref_long is None:
        ref_long = long.mean()  # Use mean longitude as reference

    # Earth radius in meters (using a common average)
    R = 6371000.0

    # Convert degrees to radians
    lat_rad = np.radians(lat)
    long_rad = np.radians(long)
    ref_lat_rad = np.radians(ref_lat)
    ref_long_rad = np.radians(ref_long)

    # Calculate X (East-West position) using equirectangular projection
    x = R * np.cos(ref_lat_rad) * (long_rad - ref_long_rad)

    # Calculate Y (North-South position) using equirectangular projection
    y = R * (lat_rad - ref_lat_rad)

    return x, y


Main Function

Workflow:

Input: Raw vehicular CSV data

Processing:

1.Time normalization

2.Coordinate conversion

3.Neighbor analysis

4.Risk assessment

Output: Enhanced CSV with risk labels

Error Handling: Robust validation at each stage

In [None]:
import pandas as pd
import numpy as np
import glob
import os
import sys
import math
from scipy.spatial import KDTree # Import KDTree for efficient neighbor search
import numpy.linalg as la # For calculating Euclidean distance efficiently
import traceback # Import traceback for detailed error reporting

# --- Configuration ---
# Directory containing the multiple Kaggle dataset files (assuming CSV format)
KAGGLE_DATASETS_DIR = 'Drive Dir Location' # <<< ADJUST THIS PATH

# File pattern to search for within the directory
KAGGLE_FILE_PATTERN = 'vehicular_dataset_*.csv' # Matches files like vehicular_dataset_YYYY-MM-DD.csv

# Suffix to add to the original filename for the new processed file
OUTPUT_FILE_SUFFIX = '_with_risk_label.csv' # Example: vehicular_dataset_2025-01-06_with_risk_label.csv

# --- Column Mapping (Map Kaggle dataset columns to necessary columns) ---
# IMPORTANT: These keys MUST match the EXACT column names in the downloaded Kaggle CSV files.
# The values are the names used in this script.
COLUMN_MAPPING = {
    'VehicleID': 'Vehicle ID',
    'Timestamp': 'Time (s)',
    'Latitude': 'Latitude',
    'Longitude': 'Longitude',
    'Speed': 'Speed (m/s)',
}




# --- Main Execution for Risk Labeling and File Saving ---
if __name__ == "__main__":
    # --- Start of Main Try Block ---
    try:
        full_pattern = os.path.join(KAGGLE_DATASETS_DIR, KAGGLE_FILE_PATTERN)
        print(f"Attempting to find CSV files from: {os.path.abspath(KAGGLE_DATASETS_DIR)} matching pattern '{KAGGLE_FILE_PATTERN}'")

        # Find all CSV files matching the pattern
        csv_files = glob.glob(full_pattern)

        print(f"Files found by glob: {csv_files}") # Print files found

        if not csv_files:
            print(f"No CSV files found matching '{KAGGLE_FILE_PATTERN}' in '{KAGGLE_DATASETS_DIR}'.")
            print("Please ensure your dataset files are in this directory and match the pattern.")
            sys.exit(1)

        print(f"\nFound {len(csv_files)} CSV files. Processing and saving to new files...")

        for csv_file_path in csv_files:
            file_name = os.path.basename(csv_file_path)
            print(f"\n--- Processing file: {file_name} ---")

            # --- Start of File Processing Try Block ---
            try:
                # Load the dataset
                df = pd.read_csv(csv_file_path)
                print(f"Successfully loaded {len(df)} rows.")

                # --- Select and Rename necessary columns based on mapping ---
                cols_to_select = list(COLUMN_MAPPING.keys())

                missing_cols_in_kaggle = [col for col in cols_to_select if col not in df.columns]
                if missing_cols_in_kaggle:
                     print(f"Error: Required columns for processing NOT found in {file_name}: {missing_cols_in_kaggle}. Skipping this file.")
                     continue # Skip to the next file

                # Create a copy to work with and add new columns
                df_processed = df[cols_to_select].copy()
                df_processed.rename(columns=COLUMN_MAPPING, inplace=True)

                # --- Data Preprocessing ---
                # 1. Convert timestamp to numeric (needed for grouping by time in neighbor calculation)
                print("Processing timestamp...")
                try:
                    df_processed['Time (s)'] = pd.to_datetime(df_processed['Time (s)'])
                    if df_processed['Time (s)'].dt.tz is not None:
                        df_processed['Time (s)'] = df_processed['Time (s)'].dt.tz_localize(None)
                        # print("Removed timezone information from 'Time (s)'.") # Can be verbose
                    df_processed['Time (s)'] = df_processed['Time (s)'].astype(np.int64) // 10**9 # Convert nanoseconds to seconds
                    print("Converted 'Time (s)' to seconds since epoch (numeric).")
                    time_conversion_successful = True
                except Exception as e:
                    print(f"Error converting 'Time (s)' to numeric in {file_name}: {e}.")
                    traceback.print_exc()
                    print("Skipping neighbor calculation and risk labeling for this file due to time conversion failure.")
                    time_conversion_successful = False
                    # Add placeholder columns to prevent errors if we continue
                    df_processed['Position X (m)'] = 0.0
                    df_processed['Position Y (m)'] = 0.0
                    df_processed[f'Number of Neighbors (10m)'] = 0
                    df_processed['Risk Label'] = -1 # Use a value outside 0/1
                    # We will still attempt to save the file with placeholders if possible


                # 2. Convert latitude/longitude to X/Y coordinates (needed for neighbor calculation)
                # Only proceed if Latitude and Longitude columns are present AND time conversion was successful
                if time_conversion_successful and 'Latitude' in df_processed.columns and 'Longitude' in df_processed.columns:
                    print("Converting lat/long to X/Y coordinates...")
                    ref_lat = df_processed['Latitude'].mean()
                    ref_long = df_processed['Longitude'].mean()
                    x_coords, y_coords = latlong_to_xy(df_processed['Latitude'], df_processed['Longitude'], ref_lat, ref_long)
                    df_processed['Position X (m)'] = x_coords
                    df_processed['Position Y (m)'] = y_coords
                    print("Added 'Position X (m)' and 'Position Y (m)' columns.")
                    position_conversion_successful = True
                else:
                    print("Skipping lat/long to X/Y conversion or neighbor calculation: Required 'Latitude' or 'Longitude' columns not found or time conversion failed.")
                    position_conversion_successful = False
                    # Add placeholder columns if conversion is skipped
                    df_processed['Position X (m)'] = 0.0
                    df_processed['Position Y (m)'] = 0.0
                    df_processed[f'Number of Neighbors (10m)'] = 0
                    df_processed['Risk Label'] = -1 # Use a value outside 0/1


                # 3. Calculate number of neighbors within 10m (needed for risk labeling)
                # Only proceed if Position X, Position Y, Time (s), and Vehicle ID are present AND time/position conversion was successful
                neighbor_count_col_name = f'Number of Neighbors (10m)'
                if position_conversion_successful and all(col in df_processed.columns for col in ['Position X (m)', 'Position Y (m)', 'Time (s)', 'Vehicle ID']):
                    print(f"Calculating {neighbor_count_col_name}...")
                    neighbor_counts = calculate_neighbors_count(df_processed, distance_threshold=10) # Calculate neighbors within 10m
                    df_processed[neighbor_count_col_name] = neighbor_counts # Add the calculated series
                    print(f"Added '{neighbor_count_col_name}' column.")
                    neighbor_calculation_successful = True
                else:
                    print(f"Skipping {neighbor_count_col_name} calculation: Required position, time, or vehicle ID columns not found or conversions failed.")
                    neighbor_calculation_successful = False
                    if neighbor_count_col_name not in df_processed.columns:
                         df_processed[neighbor_count_col_name] = 0 # Add placeholder if calculation skipped
                    df_processed['Risk Label'] = -1 # Use a value outside 0/1


                # --- Apply Risk Labeling Rule ---
                print("Applying risk labeling rule...")
                # Only proceed if Speed and the calculated neighbor count column exist AND neighbor calculation was successful
                if neighbor_calculation_successful and 'Speed (m/s)' in df_processed.columns and neighbor_count_col_name in df_processed.columns:
                    # Risk Label = 1 if Speed > 35 AND Number of Neighbors (10m) > 0, else 0
                    df_processed['Risk Label'] = np.where(
                        (df_processed['Speed (m/s)'] > 35) & (df_processed[neighbor_count_col_name] > 0),
                        1, # Risky
                        0  # Safe
                    )
                    print("Risk labeling complete.")

                    # Print counts of each risk label for this file
                    risk_label_counts = df_processed['Risk Label'].value_counts()
                    safe_count = risk_label_counts.get(0, 0)
                    risky_count = risk_label_counts.get(1, 0)
                    print(f"Generated Risk Labels for {file_name}: {safe_count} safe (0) and {risky_count} risky (1) instances")

                else:
                    print("Skipping risk labeling: Required 'Speed (m/s)' or neighbor count column not found or neighbor calculation failed.")
                    # Add placeholder column if labeling is skipped
                    if 'Risk Label' not in df_processed.columns:
                         df_processed['Risk Label'] = -1 # Use a value outside 0/1


                # --- Add the new columns to the original DataFrame structure ---
                # This step is crucial to add only the new columns ('Number of Neighbors (10m)' and 'Risk Label')
                # to the original DataFrame structure, keeping all its original columns.
                # We use the index to align the data correctly.
                # Ensure the new columns exist in df_processed before trying to select them
                new_cols_to_merge = []
                if neighbor_count_col_name in df_processed.columns:
                    new_cols_to_merge.append(neighbor_count_col_name)
                if 'Risk Label' in df_processed.columns:
                    new_cols_to_merge.append('Risk Label')

                if new_cols_to_merge:
                    new_columns_df = df_processed[new_cols_to_merge].copy()

                    # Ensure the indices match before merging
                    if not df.index.equals(new_columns_df.index):
                         print(f"Warning: Index mismatch between original and processed dataframes for {file_name}. Attempting reindex.")
                         new_columns_df = new_columns_df.reindex(df.index) # Reindex to match original

                    # Add the new columns to the original DataFrame.
                    # Use .get() to avoid KeyError if column already exists
                    for col in new_cols_to_merge:
                        if col not in df.columns:
                             df[col] = new_columns_df[col]
                        else:
                             # If column exists, overwrite it (shouldn't happen in this logic unless rerun)
                             df[col] = new_columns_df[col]
                             print(f"Warning: Overwriting existing column '{col}' in {file_name}.")
                else:
                     print("No new columns were successfully generated to merge.")


                # --- Define the output file path ---
                # Get the directory and the base filename without extension
                dir_name = os.path.dirname(csv_file_path)
                base_name = os.path.basename(csv_file_path)
                name_without_ext, ext = os.path.splitext(base_name)

                # Construct the new file path
                output_file_name = f"{name_without_ext}{OUTPUT_FILE_SUFFIX}"
                output_file_path = os.path.join(dir_name, output_file_name)

                # --- Save the updated DataFrame to a NEW file ---
                print(f"Saving updated data to NEW file: {output_file_path}")
                # Use index=False to avoid writing the DataFrame index as a column
                df.to_csv(output_file_path, index=False)
                print(f"Successfully saved updated data for {file_name} to {output_file_name}.")

            # --- End of File Processing Try Block ---
            except Exception as e:
                print(f"\nAn unexpected error occurred while processing {file_name}: {e}")
                traceback.print_exc() # Print full traceback for better debugging
                print(f"Skipping processing and saving for {file_name}.")

    # --- Main Except Blocks (for initial file finding) ---
    except FileNotFoundError:
        print(f"Error: Dataset directory or files not found at {os.path.abspath(KAGGLE_DATASETS_DIR)}")
        print("Please check the KAGGLE_DATASETS_DIR and KAGGLE_FILE_PATTERN variables.")
    except Exception as e:
        print(f"\nAn unexpected error occurred during initial file finding: {e}")
        traceback.print_exc() # Print full traceback for better debugging
        print("\nPlease review the error message and traceback for details.")
    finally:
        print("\n--- Script Execution Finished ---")



Attempting to find CSV files from: /content/drive/My Drive/Lokm/simulation_results/Dataset matching pattern 'vehicular_dataset_*.csv'
Files found by glob: ['/content/drive/My Drive/Lokm/simulation_results/Dataset/vehicular_dataset_2025-01-07.csv']

Found 1 CSV files. Processing and saving to new files...

--- Processing file: vehicular_dataset_2025-01-07.csv ---
Successfully loaded 4320000 rows.
Processing timestamp...
Converted 'Time (s)' to seconds since epoch (numeric).
Converting lat/long to X/Y coordinates...
Added 'Position X (m)' and 'Position Y (m)' columns.
Calculating Number of Neighbors (10m)...
Added 'Number of Neighbors (10m)' column.
Applying risk labeling rule...
Risk labeling complete.
Generated Risk Labels for vehicular_dataset_2025-01-07.csv: 4319970 safe (0) and 30 risky (1) instances
Saving updated data to NEW file: /content/drive/My Drive/Lokm/simulation_results/Dataset/vehicular_dataset_2025-01-07_with_risk_label.csv
Successfully saved updated data for vehicular_d

Helper Function to calculate coordinates in meter

Purpose:
Convert GPS coordinates to local Cartesian (X/Y) coordinates in meters for distance calculations.
Key Components:

Equirectangular Projection: Simplifies Earth's curvature for small areas.

Reference Point: Uses mean lat/long as (0,0) if not provided.

Earth Radius: Average value (R = 6371000 m) for approximation.

Equations:

X = R * cos(ref_lat) * (long - ref_long) (East-West)

Y = R * (lat - ref_lat) (North-South)



In [9]:
# --- Helper Functions ---
def latlong_to_xy(lat, long, ref_lat=None, ref_long=None):
    """
    Convert latitude and longitude to X/Y coordinates in meters
    using a reference point as origin (0,0). Uses equirectangular projection
    for simplicity, which is suitable for small areas.

    Args:
        lat (pd.Series): Series of latitude values.
        long (pd.Series): Series of longitude values.
        ref_lat (float, optional): Reference latitude for the origin. Defaults to the mean latitude.
        ref_long (float, optional): Reference longitude for the origin. Defaults to the mean longitude.

    Returns:
        tuple: (x_coords, y_coords) in meters.
    """
    if ref_lat is None:
        ref_lat = lat.mean()  # Use mean latitude as reference
    if ref_long is None:
        ref_long = long.mean()  # Use mean longitude as reference

    # Earth radius in meters (using a common average)
    R = 6371000.0

    # Convert degrees to radians
    lat_rad = np.radians(lat)
    long_rad = np.radians(long)
    ref_lat_rad = np.radians(ref_lat)
    ref_long_rad = np.radians(ref_long)

    # Calculate X (East-West position) using equirectangular projection
    x = R * np.cos(ref_lat_rad) * (long_rad - ref_long_rad)

    # Calculate Y (North-South position) using equirectangular projection
    y = R * (lat_rad - ref_lat_rad)

    return x, y


Calculate Number of Neighbors & Average Distance Function

Purpose:
Calculate the number of neighbors and average distance within a threshold for ML features.
Key Components:

KD-Tree: Efficient spatial indexing for fast neighbor searches.

Grouping by Timestamp: Processes vehicles at the same time.

Edge Handling: Skips timestamps with 0-1 vehicles, excludes self-neighbors.

Output Metrics:

Number of Neighbors: Count within distance_threshold.

Average Distance: Mean distance to neighbors (or 2 * threshold if none).

In [10]:

# --- Calculate Number of Neighbors & Average Distance Function (Optimized with KD-Tree) ---
# This is a general neighbor calculation for ML features, distinct from the 10m one for Risk Labeling
def calculate_neighbor_metrics_for_ml(df, distance_threshold=100):
    """
    Calculate the number of neighbors and average distance to neighbors for each vehicle
    at each timestamp based on their positions using a KD-tree for efficiency.
    This is for generating ML features, using a potentially different threshold than the risk label.

    Args:
        df (pd.DataFrame): DataFrame containing vehicle data with 'Time (s)', 'Vehicle ID',
                           'Position X (m)', and 'Position Y (m)' columns.
        distance_threshold (float): The maximum distance in meters to consider a vehicle a neighbor
                                    for the purpose of ML features.

    Returns:
        pd.DataFrame: The input DataFrame with 'Number of Neighbors' and 'Average Distance to Neighbors (m)' columns added.
    """
    print(f"Starting neighbor metrics calculation for ML features with distance threshold {distance_threshold}m...")

    # Ensure necessary columns exist
    required_cols = ['Time (s)', 'Vehicle ID', 'Position X (m)', 'Position Y (m)']
    if not all(col in df.columns for col in required_cols):
        missing = [col for col in required_cols if col not in df.columns]
        print(f"Error: Required columns for ML neighbor calculation are missing: {missing}. Skipping calculation.")
        # Add placeholder columns to prevent downstream errors
        if 'Number of Neighbors' not in df.columns:
             df['Number of Neighbors'] = 0
        if 'Average Distance to Neighbors (m)' not in df.columns:
             df['Average Distance to Neighbors (m)'] = 0.0
        return df # Return dataframe with placeholders

    # Initialize new columns with default values
    df['Number of Neighbors'] = 0
    df['Average Distance to Neighbors (m)'] = 0.0

    # Use a large value for default average distance if no neighbors
    default_avg_distance = distance_threshold * 2

    # Process data for each timestamp
    timestamps = df['Time (s)'].unique()
    print(f"Processing {len(timestamps)} unique timestamps for ML neighbor metrics...")

    # Group by timestamp for efficient processing
    grouped_by_time = df.groupby('Time (s)')

    for i, (ts, vehicles_at_ts) in enumerate(grouped_by_time):
        if (i + 1) % 100 == 0: # Print progress every 100 timestamps
            print(f"  Processing timestamp {i+1}/{len(timestamps)} (Time: {ts}s)")

        if len(vehicles_at_ts) <= 1:
            # If 0 or 1 vehicle, no neighbors to calculate for this timestamp
            continue

        # Extract positions for KD-tree and get original dataframe indices
        # Ensure positions are float type for KDTree
        positions = vehicles_at_ts[['Position X (m)', 'Position Y (m)']].values.astype(float)
        original_indices = vehicles_at_ts.index.tolist()

        # Build KD-tree
        tree = KDTree(positions)

        # Query the KD-tree for neighbors within the distance threshold
        # query_ball_point returns a list of lists, where each inner list contains the indices
        # of neighbors for the corresponding point in the input positions array.
        # The indices are relative to the positions array (0 to len(vehicles_at_ts)-1).
        neighbor_indices_list = tree.query_ball_point(positions, distance_threshold)

        # Calculate neighbor metrics for each vehicle at this timestamp
        for i_local, original_df_index in enumerate(original_indices): # Iterate through the original indices
            # Get the indices of neighbors for the i_local-th vehicle in the vehicles_at_ts subset
            neighbors_relative_indices = neighbor_indices_list[i_local]

            # Exclude the vehicle itself from its neighbors list
            # The query_ball_point includes the point itself, so we need to remove index i_local
            valid_neighbors_relative_indices = [idx for idx in neighbors_relative_indices if idx != i_local]

            numNeighbors = len(valid_neighbors_relative_indices)
            avgDistanceToNeighbors = 0.0 # Initialize average distance

            if numNeighbors > 0:
                # Get the positions of the actual neighbors using relative indices
                neighbor_positions = positions[valid_neighbors_relative_indices]

                # Get the position of the current vehicle
                current_vehicle_pos = positions[i_local]

                # Calculate distances to these neighbors using numpy.linalg.norm
                distances_to_neighbors = la.norm(neighbor_positions - current_vehicle_pos, axis=1)
                avgDistanceToNeighbors = np.mean(distances_to_neighbors)
            else:
                avgDistanceToNeighbors = default_avg_distance # Assign default if no neighbors

            # Update the original dataframe using the original index
            df.at[original_df_index, 'Number of Neighbors'] = numNeighbors
            df.at[original_df_index, 'Average Distance to Neighbors (m)'] = avgDistanceToNeighbors


    print("ML neighbor metrics calculation complete.")
    return df


Dataset Classes and Function to create Test and Train Data and DataLoader

Purpose:
Convert preprocessed data into PyTorch-compatible datasets.
Key Components:

Tensor Conversion: Converts NumPy arrays to PyTorch tensors.

Feature-Label Split:

Features: Vehicle ID, position, speed, acceleration, neighbor metrics.

Labels: Risk Label (0 = safe, 1 = risky).



In [19]:

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# --- Dataset Class ---
class V2VDataset(Dataset):
    """Custom Dataset for V2V log data."""
    def __init__(self, features, labels):
        # Convert numpy arrays to PyTorch tensors
        self.features = torch.tensor(features, dtype=torch.float32)
        # Labels are now discrete class indices (0 or 1), use dtype as long for classification
        self.labels = torch.tensor(labels, dtype=torch.long)

    def __len__(self):
        return len(self.features)

    def __getitem__(self, idx):
        return self.features[idx], self.labels[idx]


Data Loading and Preprocessing

Purpose:
Load CSV files, engineer features, and prepare train/test splits.
Key Steps:

File Loading:

Combines multiple CSV files into a single DataFrame.

Validates required columns (e.g., Risk Label).

Timestamp Conversion:

Converts datetime to numeric (seconds since epoch).

Position Calculation:

Uses latlong_to_xy to derive Position X/Y (m).

Acceleration Calculation:

Computes acceleration from speed differences over time.

Neighbor Metrics:

Calls calculate_neighbor_metrics_for_ml for ML features.

Feature Scaling:

Standardizes features using StandardScaler.

Train/Test Split:

80/20 stratified split to preserve class balance.

In [12]:

# --- Data Loading and Preprocessing ---
def load_and_preprocess_data(data_dir, file_pattern, column_mapping, neighbor_distance_threshold=100):
    """
    Loads data from multiple CSV files, preprocesses it, and splits into train/test sets
    for binary classification. Performs feature engineering within this function.

    Args:
        data_dir (str): Directory containing the CSV files.
        file_pattern (str): Pattern to match CSV files (e.e.g., '*.csv').
        column_mapping (dict): Mapping from original CSV column names to desired names.
        neighbor_distance_threshold (float): Distance threshold for calculating neighbor metrics for ML features.

    Returns:
        tuple: (train_dataloader, test_dataloader, scaler, y_test, feature_names)
               train_dataloader: DataLoader for the training set.
               test_dataloader: DataLoader for the testing data.
               scaler: The fitted StandardScaler object.
               y_test: The true labels of the test set (for confusion matrix).
               feature_names: List of names of the features used for training.
               Returns (None, None, None, None, None) if data loading or processing fails.
    """
    full_pattern = os.path.join(data_dir, file_pattern)
    print(f"Attempting to find and load CSV files from: {os.path.abspath(data_dir)} matching pattern '{file_pattern}'")

    csv_files = glob.glob(full_pattern)

    if not csv_files:
        print(f"No CSV files found matching '{file_pattern}' in '{data_dir}'.")
        return None, None, None, None, None

    all_dataframes = []
    print(f"Found {len(csv_files)} CSV files. Loading and combining...")

    for csv_file_path in csv_files:
        file_name = os.path.basename(csv_file_path)
        print(f"Loading file: {file_name}")
        try:
            df_single = pd.read_csv(csv_file_path)
            all_dataframes.append(df_single)
            print(f"Loaded {len(df_single)} rows from {file_name}.")
        except Exception as e:
            print(f"Error loading file {csv_file_path}: {e}. Skipping.")
            traceback.print_exc()
            continue

    if not all_dataframes:
        print("No data loaded from any of the files. Exiting.")
        return None, None, None, None, None

    combined_df = pd.concat(all_dataframes, ignore_index=True)
    print(f"\nCombined data from all files. Total rows: {len(combined_df)}")
    print(f"Original columns in combined data: {combined_df.columns.tolist()}")


    # --- Select and Rename necessary columns based on mapping ---
    cols_to_select = list(column_mapping.keys())

    missing_cols_in_source = [col for col in cols_to_select if col not in combined_df.columns]
    if missing_cols_in_source:
         print(f"Error: The following columns specified in COLUMN_MAPPING are NOT found in the combined dataset: {missing_cols_in_source}")
         print("Please verify the column names in the CSV files and update COLUMN_MAPPING.")
         return None, None, None, None, None

    df_processed = combined_df[cols_to_select].copy()
    df_processed.rename(columns=column_mapping, inplace=True)
    print(f"Columns after selecting and renaming: {df_processed.columns.tolist()}")

    # --- Data Preprocessing ---
    # 1. Convert timestamp to numeric
    print("Processing timestamp...")
    try:
        df_processed['Time (s)'] = pd.to_datetime(df_processed['Time (s)'])
        if df_processed['Time (s)'].dt.tz is not None:
            df_processed['Time (s)'] = df_processed['Time (s)'].dt.tz_localize(None)
        df_processed['Time (s)'] = df_processed['Time (s)'].astype(np.int64) // 10**9
        print("Converted 'Time (s)' to seconds since epoch (numeric).")
        time_conversion_successful = True
    except Exception as e:
        print(f"Error converting 'Time (s)' to numeric: {e}.")
        traceback.print_exc()
        print("Skipping acceleration and neighbor metrics calculation due to time conversion failure.")
        time_conversion_successful = False
        # Add placeholder columns if conversion fails
        if 'Position X (m)' not in df_processed.columns: df_processed['Position X (m)'] = 0.0
        if 'Position Y (m)' not in df_processed.columns: df_processed['Position Y (m)'] = 0.0
        if 'Calculated Acceleration (m/s^2)' not in df_processed.columns: df_processed['Calculated Acceleration (m/s^2)'] = 0.0
        if 'Number of Neighbors' not in df_processed.columns: df_processed['Number of Neighbors'] = 0
        if 'Average Distance to Neighbors (m)' not in df_processed.columns: df_processed['Average Distance to Neighbors (m)'] = 0.0


    # 2. Convert latitude/longitude to X/Y coordinates
    if time_conversion_successful and 'Latitude' in df_processed.columns and 'Longitude' in df_processed.columns:
        print("Converting lat/long to X/Y coordinates...")
        ref_lat = df_processed['Latitude'].mean()
        ref_long = df_processed['Longitude'].mean()
        x_coords, y_coords = latlong_to_xy(df_processed['Latitude'], df_processed['Longitude'], ref_lat, ref_long)
        df_processed['Position X (m)'] = x_coords
        df_processed['Position Y (m)'] = y_coords # Corrected typo here
        print("Added 'Position X (m)' and 'Position Y (m)' columns.")
        position_conversion_successful = True
    else:
        print("Skipping lat/long to X/Y conversion: Required 'Latitude' or 'Longitude' columns not found or time conversion failed.")
        position_conversion_successful = False
        if 'Position X (m)' not in df_processed.columns: df_processed['Position X (m)'] = 0.0
        if 'Position Y (m)' not in df_processed.columns: df_processed['Position Y (m)'] = 0.0
        if 'Calculated Acceleration (m/s^2)' not in df_processed.columns: df_processed['Calculated Acceleration (m/s^2)'] = 0.0
        if 'Number of Neighbors' not in df_processed.columns: df_processed['Number of Neighbors'] = 0
        if 'Average Distance to Neighbors (m)' not in df_processed.columns: df_processed['Average Distance to Neighbors (m)'] = 0.0


    # 3. Calculate acceleration from speed
    if time_conversion_successful and 'Speed (m/s)' in df_processed.columns and 'Time (s)' in df_processed.columns and 'Vehicle ID' in df_processed.columns:
         print("Calculating acceleration...")
         df_sorted = df_processed.sort_values(by=['Vehicle ID', 'Time (s)']).copy()
         df_sorted['Delta_Time'] = df_sorted.groupby('Vehicle ID')['Time (s)'].diff()
         df_sorted['Delta_Speed'] = df_sorted.groupby('Vehicle ID')['Speed (m/s)'].diff()
         df_sorted['Calculated Acceleration (m/s^2)'] = df_sorted.apply(
             lambda row: row['Delta_Speed'] / row['Delta_Time'] if pd.notna(row['Delta_Time']) and row['Delta_Time'] > 1e-9 else 0, axis=1
         )
         df_sorted.replace([np.inf, -np.inf], np.nan, inplace=True)
         df_sorted['Calculated Acceleration (m/s^2)'].fillna(0, inplace=True)
         df_processed['Calculated Acceleration (m/s^2)'] = df_sorted['Calculated Acceleration (m/s^2)']
         print("Acceleration calculation complete.")
    else:
         print("Skipping acceleration calculation: Required columns not found or time conversion failed.")
         if 'Calculated Acceleration (m/s^2)' not in df_processed.columns: df_processed['Calculated Acceleration (m/s^2)'] = 0.0


    # 4. Calculate general neighbor metrics for ML features
    if position_conversion_successful and all(col in df_processed.columns for col in ['Position X (m)', 'Position Y (m)', 'Time (s)', 'Vehicle ID']):
        df_processed = calculate_neighbor_metrics_for_ml(df_processed, neighbor_distance_threshold)
        print("Added 'Number of Neighbors' and 'Average Distance to Neighbors (m)' columns for ML features.")
    else:
        print("Skipping ML neighbor metrics calculation: Required position, time, or vehicle ID columns not found or conversions failed.")
        if 'Number of Neighbors' not in df_processed.columns: df_processed['Number of Neighbors'] = 0
        if 'Average Distance to Neighbors (m)' not in df_processed.columns: df_processed['Average Distance to Neighbors (m)'] = 0.0


    # --- Select Final Features and Labels for ML ---
    # These are the columns that will be used as input to the ML model
    ml_feature_columns = [
        'Vehicle ID', 'Time (s)', 'Position X (m)', 'Position Y (m)',
        'Speed (m/s)', 'Calculated Acceleration (m/s^2)',
        'Number of Neighbors', 'Average Distance to Neighbors (m)'
    ]
    ml_label_column = 'Risk Label' # <<< Use the Risk Label column

    # Check if the final required columns for ML are present after all processing
    missing_ml_features = [col for col in ml_feature_columns if col not in df_processed.columns]
    if missing_ml_features:
        print(f"Error: Missing required ML feature columns after processing: {missing_ml_features}")
        return None, None, None, None, None
    if ml_label_column not in df_processed.columns:
        print(f"Error: Missing ML label column '{ml_label_column}' after processing.")
        return None, None, None, None, None

    features = df_processed[ml_feature_columns].values
    labels = df_processed[ml_label_column].values # Use the Risk Label as target

    # --- Filter Labels for Binary Classification (0 or 1) ---
    # The Risk Label should already be 0 or 1, but this is a safeguard
    valid_labels_mask = np.isin(labels, [0, 1])
    if not np.all(valid_labels_mask):
        print(f"Warning: Filtering out {len(labels) - np.sum(valid_labels_mask)} rows with labels outside of [0, 1] for binary classification.")
        features = features[valid_labels_mask]
        labels = labels[valid_labels_mask]
        print(f"Remaining data points after filtering: {len(labels)}")

    if len(labels) < 2 or len(np.unique(labels)) < 2:
        print(f"Insufficient data or only one class ({np.unique(labels)}) remaining after filtering ({len(labels)} data points). Cannot train/evaluate.")
        return None, None, None, None, None


    # --- Handle Missing Values (NaNs) ---
    nan_count = np.isnan(features).sum()
    if nan_count > 0:
        print(f"Warning: Found {nan_count} NaN values in features after processing. Filling NaNs with 0.")
        features = np.nan_to_num(features, nan=0.0)

    # --- Feature Scaling ---
    scaler = StandardScaler()
    features_scaled = scaler.fit_transform(features)

    # --- Split Data ---
    try:
        X_train, X_test, y_train, y_test = train_test_split(
            features_scaled, labels, test_size=0.2, random_state=42, stratify=labels # Stratify helps maintain label distribution
        )
    except ValueError as ve:
        print(f"Error splitting data: {ve}. This might happen if a class has too few samples after filtering.")
        return None, None, None, None, None


    # --- Create PyTorch Datasets and Dataloaders ---
    train_dataset = V2VDataset(X_train, y_train)
    test_dataset = V2VDataset(X_test, y_test)

    # Define batch size for training
    batch_size = 64 # You can adjust this

    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
    test_dataloader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False) # No need to shuffle test data

    print("Data loading and preprocessing complete.")

    # Return feature names for plotting feature importance
    return train_dataloader, test_dataloader, scaler, y_test, ml_feature_columns # Added feature names to return


Model Architecture for both Client and Server

Purpose:
First part of the split learning model (privacy-preserving).
Structure:

Input Layer: 8 features (e.g., position, speed).

Hidden Layer: 64 neurons with ReLU activation.

Output: Activations sent to the server.

Purpose:
Second part of the split learning model (centralized computation).
Structure:

Hidden Layer: 32 neurons with ReLU.

Output Layer: 2 logits (binary classification).

Loss: Cross-entropy loss (no softmax – handled by PyTorch).



In [15]:

# --- Model Architectures ---
# client_model.py (Simulated)

import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
class ClientModel(nn.Module):
    def __init__(self, input_dim=8, client_output_dim=64):
        """
        Client-side model (Input layer + 1 hidden layer).

        Args:
            input_dim (int): Dimension of the input feature vector (8 in your case).
            client_output_dim (int): Dimension of the output activations from the client.
        """
        super().__init__()
        # First Linear layer
        self.fc1 = nn.Linear(input_dim, client_output_dim)

    def forward(self, x):
        """
        Forward pass through the client model.

        Args:
            x (torch.Tensor): Input feature tensor.

        Returns:
            torch.Tensor: Output activations after ReLU.
        """
        # Apply Linear layer and then ReLU activation
        activations = F.relu(self.fc1(x))
        return activations

# server_model.py (Simulated)
class ServerModel(nn.Module):
    def __init__(self, server_input_dim=64, hidden_dim=32, output_dim=2):
        """
        Server-side model (Remaining hidden layers + output layer).
        output_dim=2 is correct for binary classification with CrossEntropyLoss.

        Args:
            server_input_dim (int): Dimension of the input activations from the client.
            hidden_dim (int): Dimension of the server's hidden layer.
            output_dim (int): Dimension of the final output (2 for binary classification).
        """
        super().__init__()
        # Second Linear layer
        self.fc2 = nn.Linear(server_input_dim, hidden_dim)
        # Output Linear layer (outputs logits for 2 classes)
        self.out = nn.Linear(hidden_dim, output_dim)

    def forward(self, x):
        """
        Forward pass through the server model.

        Args:
            x (torch.Tensor): Input activations from the client.

        Returns:
            torch.Tensor: Final output logits for each class.
        """
        # Apply second Linear layer and ReLU
        x = F.relu(self.fc2(x))
        # Apply output Linear layer (no activation here, CrossEntropyLoss expects logits)
        output = self.out(x)
        return output


Training Function

Purpose:
Simulate split learning training while protecting raw data.
Key Steps:

Forward Pass:

Client computes activations (client_model).

Server computes logits (server_model).

Loss Calculation:
Cross-entropy loss between logits and labels.

Backward Pass:

Server calculates gradients and sends them to the client.

Client updates its model without exposing raw data.

Optimization:
Separate Adam optimizers for client/server.

Evaluation:
Tracks training/test loss, accuracy, and predictions.

In [16]:

# --- Training Function ---
def train_split_learning(client_model, server_model, train_dataloader, test_dataloader, num_epochs=10):
    """
    Simulates the Split Learning training process for binary classification.

    Args:
        client_model (nn.Module): The client-side PyTorch model.
        server_model (nn.Module): The server-side PyTorch model.
        train_dataloader (DataLoader): DataLoader for the training data.
        test_dataloader (DataLoader): DataLoader for the testing data.
        num_epochs (int): Number of training epochs.

    Returns:
        tuple: (training_losses, test_losses, epochs_list, all_test_labels, all_test_predictions)
               training_losses: List of training losses per epoch.
               test_losses: List of test losses per epoch.
               epochs_list: List of epoch numbers.
               all_test_labels: The true labels of the test set.
               all_test_predictions: The model's predictions on the test set.
    """
    # Define optimizers for client and server models
    client_optimizer = optim.Adam(client_model.parameters(), lr=0.001)
    server_optimizer = optim.Adam(server_model.parameters(), lr=0.001)

    # Define the loss function (Cross-Entropy Loss for binary classification)
    criterion = nn.CrossEntropyLoss()

    # Move models to the appropriate device (CPU or GPU if available)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    client_model.to(device)
    server_model.to(device)

    print(f"Using device: {device}")

    # Lists to store loss values for plotting
    training_losses = []
    test_losses = []
    epochs_list = []


    for epoch in range(num_epochs):
        client_model.train() # Set client model to training mode
        server_model.train() # Set server model to training mode
        running_loss = 0.0

        for i, (inputs, labels) in enumerate(train_dataloader):
            # Move data to the device
            inputs, labels = inputs.to(device), labels.to(device)

            # --- Split Learning Step 1: Client Forward Pass ---
            # Client computes activations
            client_optimizer.zero_grad() # Zero client gradients before forward pass
            client_activations = client_model(inputs)

            # --- Simulate sending activations to server ---
            # Detach activations to cut the graph for the client's part
            detached_activations = client_activations.detach()

            # --- Simulate receiving activations on server and reattaching for server's graph ---
            # Clone the detached activations and set requires_grad=True for the server's backward pass
            server_input = detached_activations.clone().requires_grad_(True)


            # --- Split Learning Step 2: Server Forward Pass and Loss Calculation ---
            # Server computes output logits
            server_optimizer.zero_grad() # Zero server gradients before forward pass
            server_outputs = server_model(server_input)

            # Calculate loss (CrossEntropyLoss expects logits and target class indices)
            loss = criterion(server_outputs, labels)


            # --- Split Learning Step 3: Server Backward Pass ---
            # Server performs backward pass to calculate gradients for server model and the server_input (activations)
            loss.backward() # Compute gradients

            # --- Simulate sending gradients back to client ---
            # Server gets gradients for the cut layer (server_input)
            gradients_to_client = server_input.grad

            # --- Split Learning Step 4: Client Backward Pass ---
            # Client receives gradients and performs backward pass starting from its original activations
            # The gradients from the server are applied to the corresponding activations on the client side
            client_activations.backward(gradients_to_client)

            # --- Split Learning Step 5: Optimizer Steps ---
            # Update model weights using the calculated gradients
            client_optimizer.step()
            server_optimizer.step()

            running_loss += loss.item()

            # Print loss periodically
            if (i + 1) % 100 == 0: # Print every 100 batches
                print(f'Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_dataloader)}], Loss: {loss.item():.4f}')

        # Calculate average training loss for the epoch
        avg_train_loss = running_loss / len(train_dataloader)
        training_losses.append(avg_train_loss)
        epochs_list.append(epoch + 1) # Store epoch number


        # --- Evaluation after each epoch (for monitoring during training) ---
        client_model.eval() # Set client model to evaluation mode
        server_model.eval() # Set server model to evaluation mode
        test_loss = 0.0
        correct_predictions = 0
        total_predictions = 0
        # Do not collect all test labels/predictions here to save memory if test set is large
        # Collection for final confusion matrix will happen after the loop

        with torch.no_grad(): # Disable gradient calculation during evaluation
            for inputs, labels in test_dataloader:
                inputs, labels = inputs.to(device), labels.to(device)

                # Client forward pass
                client_activations = client_model(inputs)

                # Server forward pass (no need to detach/reattach in eval mode)
                server_outputs = server_model(client_activations)

                # Calculate loss
                test_loss += criterion(server_outputs, labels).item()

                # Calculate accuracy
                _, predicted = torch.max(server_outputs.data, 1)
                total_predictions += labels.size(0)
                correct_predictions += (predicted == labels).sum().item()


        # Calculate average test loss for the epoch
        avg_test_loss = test_loss / len(test_dataloader)
        test_losses.append(avg_test_loss)

        accuracy = 100 * correct_predictions / total_predictions

        print(f'Epoch [{epoch+1}/{num_epochs}], Training Loss: {avg_train_loss:.4f}, Test Loss (CrossEntropy): {avg_test_loss:.4f}, Test Accuracy: {accuracy:.2f}%')

    print("Split Learning training finished.")

    # --- Collect all test labels and predictions AFTER training loop for final evaluation ---
    client_model.eval()
    server_model.eval()
    all_test_labels = []
    all_test_predictions = []
    with torch.no_grad():
        for inputs, labels in test_dataloader:
            inputs, labels = inputs.to(device), labels.to(device)
            client_activations = client_model(inputs)
            server_outputs = server_model(client_activations)
            _, predicted = torch.max(server_outputs.data, 1)
            all_test_labels.extend(labels.cpu().numpy())
            all_test_predictions.extend(predicted.cpu().numpy())


    # Return collected data for plotting and evaluation
    return training_losses, test_losses, epochs_list, all_test_labels, all_test_predictions, client_model # Added client_model to return


Main Function to train the Model

Purpose:
Orchestrate end-to-end training.
Key Steps:

Configuration:

Paths, column mappings, hyperparameters.

Data Loading:
Calls load_and_preprocess_data.

Model Initialization:
Client/server models moved to GPU if available.

Training:
Executes train_split_learning.

Output:
Saves losses, predictions, and models for evaluation.



In [20]:
# --- Part 1: ML Training Setup and Execution ---

# --- All Imports ---
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler # Recommended for scaling features

# Import only modules needed for training setup in Part 1
# Plotting imports will be in Part 2
# from sklearn.metrics import confusion_matrix, ConfusionMatrixDisplay, accuracy_score
# import matplotlib.pyplot as plt
# import seaborn as sns
# from sklearn.metrics import roc_curve, auc, precision_recall_curve, f1_score, classification_report
# from sklearn.manifold import TSNE
# from matplotlib.gridspec import GridSpec
# import matplotlib.colors as mcolors

import re
import glob
import os
import json # To potentially save/load aggregated results
import sys
import math
from scipy.spatial import KDTree # Import KDTree for efficient neighbor search
import numpy.linalg as la # For calculating Euclidean distance efficiently
import traceback # Import traceback for detailed error reporting


# --- Configuration ---
# Directory containing the multiple Kaggle dataset files (assuming CSV format)
# These are the files that should now contain the 'Risk Label' column
KAGGLE_DATASETS_DIR = '/content/drive/My Drive/Lokm/simulation_results/Dataset' # <<< ADJUST THIS PATH

# File pattern to search for within the directory
# Assuming the files were processed by the labeling script to add 'Risk Label'
KAGGLE_FILE_PATTERN = 'vehicular_dataset_*.csv' # Matches files created by the updated labeling script

# Define a placeholder directory for plots, will be created if needed in Part 2
PLOT_SAVE_DIR = 'training_plots'

# --- Column Mapping (Map original CSV columns to names used in this script) ---
# IMPORTANT: These keys MUST match the EXACT column names in your updated CSV files.
# The values are the names used internally for processing and ML features.
COLUMN_MAPPING = {
    'VehicleID': 'Vehicle ID',
    'Timestamp': 'Time (s)',
    'Latitude': 'Latitude',
    'Longitude': 'Longitude',
    'Speed': 'Speed (m/s)',
    # The 'Risk Label' column should now be present in your updated CSV files
    'Risk Label': 'Risk Label', # <<< Ensure this matches the column name added by the update script
    # The neighbor count added by the update script (e.g., 'Number of Neighbors (10m)')
    # is NOT used directly as an ML feature here. We recalculate general neighbor metrics.
}





# --- Main Training Execution (Part 1) ---
if __name__ == "__main__":
    print("--- Starting V2V Split Learning Training (Part 1) ---")

    # Define the path to your dataset directory and file pattern
    data_directory = KAGGLE_DATASETS_DIR
    # Use the file pattern that matches the NEW files created by the update script
    file_pattern = KAGGLE_FILE_PATTERN # Should be 'vehicular_dataset_*_with_risk_label_v2.csv'


    # Load and preprocess the data (including feature engineering and splitting)
    # These variables will hold the data and results needed for Part 2
    train_dataloader, test_dataloader, scaler, y_test_original, feature_names = load_and_preprocess_data(
        data_directory,
        file_pattern,
        COLUMN_MAPPING,
        neighbor_distance_threshold=100 # Set the distance threshold for ML neighbor features
    )

    if train_dataloader is not None and test_dataloader is not None:
        # Define model dimensions
        # input_feature_dim is determined by the number of columns in ml_feature_columns
        # We will get the actual dimension from the data after loading
        input_feature_dim = len(feature_names) # Get input dimension from processed data
        client_output_dim = 64 # As per your architecture
        server_hidden_dim = 32 # As per your architecture
        output_label_dim = 2 # 0 for safe, 1 for risky - correct for binary classification

        # Instantiate the client and server models
        # Define device early to pass to models
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        client_model = ClientModel(input_dim=input_feature_dim, client_output_dim=client_output_dim).to(device)
        server_model = ServerModel(server_input_dim=client_output_dim, hidden_dim=server_hidden_dim, output_dim=output_label_dim).to(device)


        print("\nClient Model Architecture:")
        print(client_model)
        print("\nServer Model Architecture:")
        print(server_model)

        # Start the Split Learning training process
        # Store the results in variables that can be accessed in Part 2
        # Note: In a true separate cell scenario, you'd need to ensure these variables
        # are available in the global scope or passed explicitly.
        training_losses, test_losses, epochs_list, all_test_labels, all_test_predictions, trained_client_model = train_split_learning(
            client_model, server_model, train_dataloader, test_dataloader, num_epochs=10 # Set epochs to 10
        )

        print("\n--- Part 1: Training Completed ---")
        print("Training losses, test losses, epoch list, test labels, predictions, and trained client model are stored in variables.")
        print("Proceed to Part 2 for visualization.")


    else:
        print("\n--- Part 1: Data Loading or Preprocessing Failed ---")
        print("Skipping model training.")


--- Starting V2V Split Learning Training (Part 1) ---
Attempting to find and load CSV files from: /content/drive/My Drive/Lokm/simulation_results/Dataset matching pattern 'vehicular_dataset_*.csv'
Found 1 CSV files. Loading and combining...
Loading file: vehicular_dataset_2025-01-07.csv
Loaded 4320000 rows from vehicular_dataset_2025-01-07.csv.

Combined data from all files. Total rows: 4320000
Original columns in combined data: ['VehicleID', 'Timestamp', 'Day', 'Hour', 'Minute', 'Second', 'Latitude', 'Longitude', 'Speed', 'Direction', 'StartingPointLatitude', 'StartingPointLongitude', 'DestinationLatitude', 'DestinationLongitude', 'CPU_Available', 'Memory_Available', 'BatteryLevel', 'TaskType', 'TaskSize', 'TaskPriority', 'NetworkLatency', 'SignalStrength', 'TrafficDensity', 'WeatherCondition', 'RoadCondition', 'VehicleType', 'VehicleAge', 'EngineTemperature', 'FuelLevel', 'TirePressure', 'BrakeFluidLevel', 'CoolantLevel', 'OilLevel', 'WiperFluidLevel', 'HeadlightStatus', 'BrakeLightS

The behavior will change in pandas 3.0. This inplace method will never work because the intermediate object on which we are setting values always behaves as a copy.

For example, when doing 'df[col].method(value, inplace=True)', try using 'df.method({col: value}, inplace=True)' or df[col] = df[col].method(value) instead, to perform the operation inplace on the original object.


  df_sorted['Calculated Acceleration (m/s^2)'].fillna(0, inplace=True)


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Epoch [1/10], Step [41600/54000], Loss: 0.0000
Epoch [1/10], Step [41700/54000], Loss: 0.0000
Epoch [1/10], Step [41800/54000], Loss: 0.0000
Epoch [1/10], Step [41900/54000], Loss: 0.0000
Epoch [1/10], Step [42000/54000], Loss: 0.0000
Epoch [1/10], Step [42100/54000], Loss: 0.0000
Epoch [1/10], Step [42200/54000], Loss: 0.0000
Epoch [1/10], Step [42300/54000], Loss: 0.0000
Epoch [1/10], Step [42400/54000], Loss: 0.0000
Epoch [1/10], Step [42500/54000], Loss: 0.0000
Epoch [1/10], Step [42600/54000], Loss: 0.0000
Epoch [1/10], Step [42700/54000], Loss: 0.0000
Epoch [1/10], Step [42800/54000], Loss: 0.0000
Epoch [1/10], Step [42900/54000], Loss: 0.0000
Epoch [1/10], Step [43000/54000], Loss: 0.0000
Epoch [1/10], Step [43100/54000], Loss: 0.0000
Epoch [1/10], Step [43200/54000], Loss: 0.0000
Epoch [1/10], Step [43300/54000], Loss: 0.0000
Epoch [1/10], Step [43400/54000], Loss: 0.0000
Epoch [1/10], Step [43500/54000], Loss: 0.

Plot Learning curves for visualizatinon

Purpose:
Track model convergence and detect overfitting.
Key Features:

Dual Loss Lines: Training (blue) vs. Validation (red) loss.

Moving Averages: Smoothed trends using adaptive window size.

Generalization Gap: Gray shaded area between losses.

Best Model Marker: Gold star at minimum validation loss.

Final Loss Markers: Horizontal lines with exact values.

Dynamic Scaling: Auto-adjusts axes to show annotations.

In [21]:
# Function to plot detailed learning curves
def plot_learning_curves(training_losses, test_losses, epochs_list, save_path):
    """
    Creates and saves an enhanced visualization of training and validation loss curves.

    Args:
        training_losses: List of training losses per epoch
        test_losses: List of test losses per epoch
        epochs_list: List of epoch numbers
        save_path: Directory to save the plot
    """
    plt.figure(figsize=(12, 7))

    # Create a light background grid for better readability
    plt.grid(True, linestyle='--', alpha=0.7)

    # Plot training and test losses with different line styles and markers
    plt.plot(epochs_list, training_losses, 'b-o', linewidth=2, markersize=8,
             label='Training Loss', alpha=0.8)
    plt.plot(epochs_list, test_losses, 'r-^', linewidth=2, markersize=8,
             label='Validation Loss', alpha=0.8)

    # Add trend lines (moving averages) for clearer visualization of patterns
    if len(epochs_list) > 3:  # Only if we have enough epochs
        window_size = max(2, len(epochs_list) // 5)  # Adaptive window size
        train_ma = np.convolve(training_losses, np.ones(window_size)/window_size, mode='valid')
        test_ma = np.convolve(test_losses, np.ones(window_size)/window_size, mode='valid')
        # Adjust x-axis for moving average plots
        plt.plot(epochs_list[window_size-1:], train_ma, 'b--', linewidth=1.5, alpha=0.5)
        plt.plot(epochs_list[window_size-1:], test_ma, 'r--', linewidth=1.5, alpha=0.5)


    # Highlight the epoch with minimum validation loss
    if test_losses: # Ensure test_losses is not empty
        min_loss_epoch = epochs_list[np.argmin(test_losses)]
        min_loss_value = min(test_losses)
        plt.scatter(min_loss_epoch, min_loss_value, s=200, c='gold',
                    edgecolor='k', marker='*', label=f'Best Model (Epoch {min_loss_epoch})')

        # Add annotations for the best model
        plt.annotate(f'Min Loss: {min_loss_value:.4f}',
                     xy=(min_loss_epoch, min_loss_value),
                     xytext=(min_loss_epoch, min_loss_value*1.2),
                     arrowprops=dict(facecolor='black', shrink=0.05, width=1.5, headwidth=8),
                     fontsize=10, ha='center')

    # Customize the plot
    plt.xlabel('Epochs', fontsize=14, fontweight='bold')
    plt.ylabel('Cross-Entropy Loss', fontsize=14, fontweight='bold') # Corrected line
    plt.title('Training and Validation Loss Over Time', fontsize=16, fontweight='bold')
    plt.legend(fontsize=12, loc='upper right')

    # Add a shaded region showing the gap between training and validation loss (indicates overfitting)
    plt.fill_between(epochs_list, training_losses, test_losses,
                     color='gray', alpha=0.2, label='Generalization Gap')

    # Add horizontal lines showing the final achieved losses (if lists are not empty)
    if training_losses:
        plt.axhline(y=training_losses[-1], color='b', linestyle='-.', alpha=0.3)
        # Text annotations for final loss values
        plt.text(epochs_list[-1] + 0.5, training_losses[-1],
                 f'Final Training: {training_losses[-1]:.4f}',
                 verticalalignment='center', color='blue')
    if test_losses:
        plt.axhline(y=test_losses[-1], color='r', linestyle='-.', alpha=0.3)
        plt.text(epochs_list[-1] + 0.5, test_losses[-1],
                 f'Final Validation: {test_losses[-1]:.4f}',
                 verticalalignment='center', color='red')


    # Adjust limits to show annotations clearly (only if lists are not empty)
    if epochs_list and training_losses and test_losses:
        plt.xlim(0, max(epochs_list) * 1.15)
        plt.ylim(0, max(max(training_losses), max(test_losses)) * 1.3)
    elif epochs_list: # If only epochs_list has data
         plt.xlim(0, max(epochs_list) * 1.15)


    plt.tight_layout()
    os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
    plt.savefig(os.path.join(save_path, 'learning_curves.png'), dpi=300, bbox_inches='tight')
    plt.close()



  Creates and saves an enhanced confusion matrix visualization with both absolute and normalized values.

Purpose:
Evaluate classification performance beyond accuracy.
Key Features:

Dual Visualization: Absolute counts + normalized percentages.

Class Metrics: Precision/Recall/F1 displayed below.

Error Handling: Skips plot for single-class data.

Color Coding: Blue (counts) vs. Red-Yellow-Green (normalized).

In [23]:

def plot_enhanced_confusion_matrix(y_true, y_pred, save_path):
    """
    Creates and saves an enhanced confusion matrix visualization with both absolute
    and normalized values.

    Args:
        y_true: True labels from test set
        y_pred: Predicted labels from model
        save_path: Directory to save the plot
    """
    # Ensure there are enough data points and at least two classes in y_true
    if len(y_true) < 2 or len(np.unique(y_true)) < 2:
        print("Cannot plot confusion matrix: Insufficient data or only one class in test set.")
        plt.close() # Close the figure to prevent it from being displayed empty
        return

    # Calculate confusion matrix
    cm = confusion_matrix(y_true, y_pred)
    # Handle potential division by zero if a row in cm sums to 0
    cm_sum_axis1 = cm.sum(axis=1)[:, np.newaxis]
    cm_normalized = np.zeros_like(cm, dtype=float)
    # Only divide if the sum is not zero
    non_zero_rows = cm_sum_axis1.flatten() != 0
    cm_normalized[non_zero_rows, :] = cm.astype('float')[non_zero_rows, :] / cm_sum_axis1[non_zero_rows, :]


    # Create subplots for both absolute and normalized confusion matrices
    fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(16, 7))

    # Plot absolute confusion matrix
    sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', cbar=True, ax=ax1)
    ax1.set_xlabel('Predicted Label', fontsize=12, fontweight='bold')
    ax1.set_ylabel('True Label', fontsize=12, fontweight='bold')
    ax1.set_title('Confusion Matrix (Absolute Counts)', fontsize=14, fontweight='bold')
    ax1.set_xticklabels(['Safe (0)', 'Risky (1)'])
    ax1.set_yticklabels(['Safe (0)', 'Risky (1)'])

    # Plot normalized confusion matrix
    sns.heatmap(cm_normalized, annot=True, fmt='.2%', cmap='RdYlGn', cbar=True, ax=ax2)
    ax2.set_xlabel('Predicted Label', fontsize=12, fontweight='bold')
    ax2.set_ylabel('True Label', fontsize=12, fontweight='bold')
    ax2.set_title('Confusion Matrix (Normalized)', fontsize=14, fontweight='bold')
    ax2.set_xticklabels(['Safe (0)', 'Risky (1)'])
    ax2.set_yticklabels(['Safe (0)', 'Risky (1)'])

    # Calculate and display additional metrics
    accuracy = accuracy_score(y_true, y_pred)
    # Ensure both classes are present in true labels for classification_report to work fully
    if len(np.unique(y_true)) >= 2:
        report = classification_report(y_true, y_pred, output_dict=True)
        # Ensure class 1 (Risky) exists in true labels before trying to access its metrics directly
        if '1' in report:
            precision = report['1']['precision']
            recall = report['1']['recall']
            f1 = report['1']['f1-score']
            metrics_text = f"Accuracy: {accuracy:.4f}\nPrecision: {precision:.4f}\n" \
                           f"Recall: {recall:.4f}\nF1 Score: {f1:.4f}"
        else:
            # Handle case where there are no risky instances in the test set
            metrics_text = f"Accuracy: {accuracy:.4f}\nPrecision: N/A\nRecall: N/A\nF1 Score: N/A\n(No Risky instances in test set)"
    else:
         metrics_text = f"Accuracy: {accuracy:.4f}\n(Insufficient data for full metrics)"


    props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
    plt.figtext(0.5, 0.01, metrics_text, fontsize=12, ha='center',
                bbox=props)

    plt.suptitle('Split Learning Model Performance - Confusion Matrices',
                 fontsize=16, fontweight='bold')
    plt.tight_layout(rect=[0, 0.08, 1, 0.95])  # Adjust for the text box
    os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
    plt.savefig(os.path.join(save_path, 'confusion_matrix.png'), dpi=300, bbox_inches='tight')
    plt.close()



Caculate false positive and true positive to plot confusion matrix

In [24]:


def false_positive_rate(y_true, y_pred):
    """Calculate the false positive rate."""
    cm = confusion_matrix(y_true, y_pred)
    # Ensure class 0 exists in true labels before calculating
    if 0 in np.unique(y_true):
        false_positive = cm[0, 1]
        true_negative = cm[0, 0]
        return false_positive / (false_positive + true_negative) if (false_positive + true_negative) > 0 else 0
    else:
        return 0.0 # Cannot calculate FPR if no true negatives


def true_positive_rate(y_true, y_pred):
    """Calculate the true positive rate (recall)."""
    cm = confusion_matrix(y_true, y_pred)
    # Ensure class 1 exists in true labels before calculating
    if 1 in np.unique(y_true):
        true_positive = cm[1, 1]
        false_negative = cm[1, 0]
        return true_positive / (true_positive + false_negative) if (true_positive + false_negative) > 0 else 0
    else:
        return 0.0 # Cannot calculate TPR if no true positives



Creates and saves a ROC curve visualization based on the model predictions.
    This implementation handles binary class labels directly.

Purpose:
Assess trade-off between TPR (detection) and FPR (false alarms).
Key Features:

Dummy Probabilities: Converts class labels to binary probabilities.

AUC Value: Area under curve quantifies overall performance.

Current Model Marker: Red dot shows actual operating point.

Baseline: Diagonal "random guess" line for reference.

In [29]:

def plot_roc_curve(y_true, y_pred_class, save_path):
    """
    Creates and saves a ROC curve visualization based on the model predictions.
    This implementation handles binary class labels directly.

    Args:
        y_true: True binary labels (0 or 1)
        y_pred_class: Predicted binary labels (0 or 1)
        save_path: Directory to save the plot
    """
    # Ensure both classes are present in y_true for roc_curve to work
    if len(np.unique(y_true)) < 2:
        print("Cannot plot ROC curve: Test set contains only one class.")
        plt.close()
        return

    plt.figure(figsize=(10, 8))

    # Convert class predictions to estimated probabilities for ROC curve
    # Note: This is a simplification when only class predictions are available.
    # For actual implementation, using model logits or predicted probabilities is preferred.
    # Creating a dummy probability array based on class predictions.
    y_pred_prob = np.zeros((len(y_pred_class), 2))
    for i, pred in enumerate(y_pred_class):
        if pred in [0, 1]: # Ensure predicted label is 0 or 1
             y_pred_prob[i, pred] = 1.0
        # If prediction is outside [0, 1], this row will have [0, 0] probabilities, which might affect AUC.
        # Ideally, predictions should be strictly 0 or 1.


    # Calculate ROC curve and ROC area for class 1 (Risky)
    fpr, tpr, _ = roc_curve(y_true, y_pred_prob[:, 1])
    roc_auc = auc(fpr, tpr)

    # Plot ROC curve
    plt.plot(fpr, tpr, color='darkorange', lw=2,
             label=f'ROC curve (area = {roc_auc:.3f})')

    # Plot diagonal line (random classifier)
    plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--',
             label='Random Classifier (AUC = 0.5)')

    # Calculate point for current threshold (using the provided helper functions)
    current_fpr = false_positive_rate(y_true, y_pred_class)
    current_tpr = true_positive_rate(y_true, y_pred_class)

    # Mark the current operating point
    plt.scatter(current_fpr, current_tpr, color='red', s=100, zorder=10,
                label=f'Current Model (FPR={current_fpr:.3f}, TPR={current_tpr:.3f})')

    # Customize the plot
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('False Positive Rate', fontsize=12, fontweight='bold')
    plt.ylabel('True Positive Rate', fontsize=12, fontweight='bold')
    plt.title('Receiver Operating Characteristic (ROC) Curve',
              fontsize=14, fontweight='bold')
    plt.legend(loc="lower right", fontsize=10)

    # Add grid for readability
    plt.grid(True, linestyle='--', alpha=0.6)

    plt.tight_layout()
    os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
    plt.savefig(os.path.join(save_path, 'roc_curve.png'), dpi=300, bbox_inches='tight')
    plt.close()



Functino to plot Precision Recall Curve

Purpose:
Evaluate performance under class imbalance.
Key Features:

F1 Contours: Dashed lines show F1-score thresholds.

Baseline: Horizontal line for positive class prevalence.

Current Model Marker: Red dot with exact metrics.

Adaptive Annotations: Avoids label overlap on contours.

In [28]:


def plot_precision_recall_curve(y_true, y_pred_class, save_path):
    """
    Creates and saves a Precision-Recall curve visualization.

    Args:
        y_true: True binary labels (0 or 1)
        y_pred_class: Predicted binary labels (0 or 1)
        save_path: Directory to save the plot
    """
    # Ensure both classes are present in y_true for precision_recall_curve to work
    if len(np.unique(y_true)) < 2:
        print("Cannot plot Precision-Recall curve: Test set contains only one class.")
        plt.close()
        return

    plt.figure(figsize=(10, 8))

    # Convert class predictions to estimated probabilities
    # Note: This is a simplification when only class predictions are available.
    # For actual implementation, using model logits or predicted probabilities is preferred.
    # Creating a dummy probability array based on class predictions.
    y_pred_prob = np.zeros((len(y_pred_class), 2))
    for i, pred in enumerate(y_pred_class):
         if pred in [0, 1]: # Ensure predicted label is 0 or 1
              y_pred_prob[i, pred] = 1.0
         # If prediction is outside [0, 1], this row will have [0, 0] probabilities.


    # Calculate precision-recall curve
    precision, recall, _ = precision_recall_curve(y_true, y_pred_prob[:, 1])
    pr_auc = auc(recall, precision)

    # Plot precision-recall curve
    plt.plot(recall, precision, color='blue', lw=2,
             label=f'Precision-Recall curve (area = {pr_auc:.3f})')

    # Calculate current precision and recall
    cm = confusion_matrix(y_true, y_pred_class)
    # Ensure class 1 exists in true labels for calculation
    if 1 in np.unique(y_true):
        current_precision = cm[1, 1] / (cm[1, 1] + cm[0, 1]) if (cm[1, 1] + cm[0, 1]) > 0 else 0
        current_recall = cm[1, 1] / (cm[1, 1] + cm[1, 0]) if (cm[1, 1] + cm[1, 0]) > 0 else 0

        # Mark the current operating point
        plt.scatter(current_recall, current_precision, color='red', s=100, zorder=10,
                    label=f'Current Model (Precision={current_precision:.3f}, Recall={current_recall:.3f})')

    # Plot baseline
    y_baseline = np.sum(y_true) / len(y_true)  # Proportion of positive class
    plt.axhline(y=y_baseline, color='grey', linestyle='--',
                 label=f'Baseline (No Skill): {y_baseline:.3f}')

    # Calculate F1 contours for reference
    f1_scores = np.linspace(0.1, 0.9, 9)
    for f1 in f1_scores:
        x = np.linspace(0.01, 1, 100)
        # f1 = 2 * precision * recall / (precision + recall)
        # Solve for precision in terms of recall and f1
        y = (f1 * x) / (2 * x - f1)
        valid_indices = (y >= 0) & (y <= 1)
        plt.plot(x[valid_indices], y[valid_indices], color='green', alpha=0.2, linestyle=':')
        # Add label for select contours
        # Adjust annotation position to avoid overlap
        if f1 in [0.3, 0.5, 0.7]:
             midpoint = np.argmax(valid_indices & (x > 0.3) & (x < 0.7)) # Find a point in the middle range
             if midpoint > 0:
                 plt.annotate(f'F1={f1}', xy=(x[midpoint], y[midpoint]),
                              xytext=(x[midpoint] + 0.05, y[midpoint]), # Offset text slightly
                              textcoords='data',
                              fontsize=8, alpha=0.6)


    # Customize the plot
    plt.xlim([0.0, 1.0])
    plt.ylim([0.0, 1.05])
    plt.xlabel('Recall', fontsize=12, fontweight='bold')
    plt.ylabel('Precision', fontsize=12, fontweight='bold')
    plt.title('Precision-Recall Curve', fontsize=14, fontweight='bold')
    plt.legend(loc="lower left", fontsize=10)

    # Add grid for readability
    plt.grid(True, linestyle='--', alpha=0.6)

    plt.tight_layout()
    os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
    plt.savefig(os.path.join(save_path, 'precision_recall_curve.png'), dpi=300, bbox_inches='tight')
    plt.close()


 Creates and saves a visualization of feature importance based on the weights
    of the first layer of the client model.

  Purpose:
Reveal which input features drive model decisions.
Key Features:

Weight Extraction: First layer weights from client model.

Mean Absolute Impact: Horizontal bar chart of feature influence.

Error Handling: Validates feature/model compatibility.

Interpretation Guide: Text explains weight significance.

In [27]:

def analyze_feature_importance(client_model, feature_names, save_path):
    """
    Creates and saves a visualization of feature importance based on the weights
    of the first layer of the client model.

    Args:
        client_model: Trained client PyTorch model
        feature_names: List of feature names
        save_path: Directory to save the plot
    """
    # Ensure client_model and feature_names are not None
    if client_model is None or feature_names is None or not feature_names:
        print("Cannot plot feature importance: Client model or feature names are not available.")
        return

    plt.figure(figsize=(12, 8))

    # Extract weights from the first layer of the client model
    with torch.no_grad():
        # Ensure weights are on CPU for numpy conversion
        weights = client_model.fc1.weight.cpu().numpy()

    # Calculate feature importance as the mean absolute weight per input feature
    # Ensure the number of feature names matches the input dimension of the model
    if len(feature_names) != weights.shape[1]:
        print(f"Warning: Number of feature names ({len(feature_names)}) does not match model input dimension ({weights.shape[1]}). Skipping feature importance plot.")
        plt.close() # Close the figure to prevent it from being displayed empty
        return

    feature_importance = np.mean(np.abs(weights), axis=0)

    # Sort features by importance
    indices = np.argsort(feature_importance)
    sorted_feature_names = [feature_names[i] for i in indices]
    sorted_importance = feature_importance[indices]

    # Create a horizontal bar chart
    bars = plt.barh(range(len(sorted_feature_names)), sorted_importance,
                    color=plt.cm.viridis(np.linspace(0, 1, len(sorted_feature_names))))

    # Add importance values as text
    for i, (bar, value) in enumerate(zip(bars, sorted_importance)):
        plt.text(value + 0.01, bar.get_y() + bar.get_height()/2,
                 f'{value:.4f}', va='center', fontsize=10)

    # Customize the plot
    plt.xlabel('Mean Absolute Weight', fontsize=12, fontweight='bold')
    plt.ylabel('Features', fontsize=12, fontweight='bold')
    plt.title('Feature Importance in Split Learning Model (Client Layer 1)', fontsize=14, fontweight='bold')
    plt.yticks(range(len(sorted_feature_names)), sorted_feature_names)
    plt.grid(axis='x', linestyle='--', alpha=0.7)

    # Add explanatory text
    plt.figtext(0.5, 0.01,
                "Feature importance based on mean absolute weights from the first layer of the client model.\n"
                "Higher values indicate stronger influence on the model's decisions.",
                ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    plt.tight_layout(rect=[0, 0.05, 1, 0.95]) # Adjust for the text box
    os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
    plt.savefig(os.path.join(save_path, 'feature_importance.png'), dpi=300, bbox_inches='tight')
    plt.close()


 Creates and saves a t-SNE visualization of the feature space after transformation
    by the client model.

Purpose:
Visualize how the model transforms high-dimensional data.
Key Features:

Data Subsampling: Processes first 20 batches for efficiency.

Class-Source Coding: Shapes (circles/X) for train/test, colors for classes.

Perplexity Adjustment: Avoids errors with small datasets.

Cluster Interpretation: Text explains point proximity meaning.
    

In [25]:


def visualize_feature_space(train_dataloader, test_dataloader, client_model, save_path):
    """
    Creates and saves a t-SNE visualization of the feature space after transformation
    by the client model.

    Args:
        train_dataloader: DataLoader containing training data
        test_dataloader: DataLoader containing test data
        client_model: Trained client PyTorch model
        save_path: Directory to save the plot
    """
    # Ensure dataloaders and client_model are not None
    if train_dataloader is None or test_dataloader is None or client_model is None:
        print("Cannot plot t-SNE visualization: Dataloaders or client model not available.")
        return

    # Set the model to evaluation mode
    client_model.eval()

    # Initialize lists to store activations and labels
    activations = []
    labels = []
    sources = []  # To track if the data is from training or testing set

    # Process a subset of training data (to avoid overwhelming t-SNE)
    with torch.no_grad():
        for i, (inputs, batch_labels) in enumerate(train_dataloader):
            if i >= 20:  # Increased limit for potentially better t-SNE representation
                break
            # Ensure inputs are on the correct device before passing to model
            inputs = inputs.to(next(client_model.parameters()).device)
            batch_activations = client_model(inputs).cpu().numpy()
            activations.append(batch_activations)
            labels.append(batch_labels.cpu().numpy())
            sources.extend(['train'] * len(batch_labels))

    # Process a subset of testing data
    with torch.no_grad():
        for i, (inputs, batch_labels) in enumerate(test_dataloader):
            if i >= 20:  # Increased limit for potentially better t-SNE representation
                break
            # Ensure inputs are on the correct device before passing to model
            inputs = inputs.to(next(client_model.parameters()).device)
            batch_activations = client_model(inputs).cpu().numpy()
            activations.append(batch_activations)
            labels.append(batch_labels.cpu().numpy())
            sources.extend(['test'] * len(batch_labels))

    # Combine all data
    if not activations:
        print("No data collected for t-SNE visualization. Skipping plot.")
        return

    activations = np.vstack(activations)
    labels = np.concatenate(labels)

    # Handle case with insufficient samples for t-SNE
    if len(activations) < max(30, len(np.unique(labels)) + 1) and len(activations) > 1: # t-SNE requires samples > perplexity (default 30) and > num_classes
        print(f"Insufficient data points ({len(activations)}) for t-SNE visualization. Need at least {max(30, len(np.unique(labels)) + 1)} samples. Skipping plot.")
        return
    elif len(activations) <= 1:
         print(f"Only {len(activations)} data point(s) available for t-SNE visualization. Need at least 2. Skipping plot.")
         return


    # Apply t-SNE for dimensionality reduction to 2D
    print("Applying t-SNE dimensionality reduction...")
    # Adjust perplexity based on the number of samples if needed
    perplexity_val = min(30, len(activations) - 1) if len(activations) > 1 else 1
    tsne = TSNE(n_components=2, random_state=42, perplexity=perplexity_val)
    try:
        activations_2d = tsne.fit_transform(activations)
    except ValueError as ve:
        print(f"Error during t-SNE transformation: {ve}. This might happen if data is too uniform or insufficient samples for perplexity.")
        print("Skipping t-SNE visualization.")
        return


    # Create the visualization
    plt.figure(figsize=(14, 10))

    # Set up the color scheme for classes
    # Ensure there are enough colors for the unique labels
    unique_labels = np.unique(labels)
    if len(unique_labels) > len(mcolors.CSS4_COLORS): # Fallback if too many unique labels
        colors = plt.cm.get_cmap('viridis', len(unique_labels))(np.arange(len(unique_labels)))
    else:
        # Use a fixed set of colors for up to 2 classes (Safe/Risky)
        class_colors = ['#1f77b4', '#ff7f0e'] # blue for class 0, orange for class 1
        colors = [class_colors[int(label)] for label in unique_labels]


    source_markers = {'train': 'o', 'test': 'X'}

    # Create scatter plot with different colors for classes and markers for train/test
    for label in unique_labels:
        # Ensure label is an integer index for class_colors if using the fixed set
        label_int = int(label)

        for source in ['train', 'test']:
            # Get indices matching both class and source
            indices = np.where((labels == label) & (np.array(sources) == source))[0]
            if len(indices) == 0:
                continue

            # Determine the color for this label
            color_for_label = colors[list(unique_labels).index(label)]


            plt.scatter(
                activations_2d[indices, 0],
                activations_2d[indices, 1],
                c=[color_for_label] * len(indices),
                marker=source_markers[source],
                s=50 if source == 'train' else 100,
                alpha=0.7 if source == 'train' else 0.8,
                edgecolors='w' if source == 'train' else 'k',
                linewidth=0.5 if source == 'train' else 1.0,
                label=f"Class {label_int} ({source})"
            )

    # Add title and labels
    plt.title('t-SNE Visualization of Client Model Feature Space', fontsize=16, fontweight='bold')
    plt.xlabel('t-SNE Component 1', fontsize=12, fontweight='bold')
    plt.ylabel('t-SNE Component 2', fontsize=12, fontweight='bold')

    # Add legend with custom handles (recreate based on actual unique labels and sources found)
    from matplotlib.lines import Line2D
    legend_elements = []
    for label in unique_labels:
         label_int = int(label)
         color_for_label = colors[list(unique_labels).index(label)]
         if f"Class {label_int} (Train)" in [l.get_label() for l in plt.gca().get_legend_handles_labels()[0]]:
             legend_elements.append(Line2D([0], [0], marker=source_markers['train'], color='w', markerfacecolor=color_for_label,
                                            label=f'Class {label_int} (Train)', markersize=10))
         if f"Class {label_int} (Test)" in [l.get_label() for l in plt.gca().get_legend_handles_labels()[0]]:
             legend_elements.append(Line2D([0], [0], marker=source_markers['test'], color='w', markerfacecolor=color_for_label,
                                            markeredgecolor='k', label=f'Class {label_int} (Test)', markersize=10))

    plt.legend(handles=legend_elements, loc='upper right', fontsize=10)


    # Add grid
    plt.grid(True, linestyle='--', alpha=0.3)

    # Add descriptive text
    plt.figtext(0.5, 0.01,
                "This t-SNE visualization shows how the client model transforms input features.\n"
                "Points that are close together represent similar patterns as detected by the model.",
                ha='center', fontsize=10, bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))

    plt.tight_layout(rect=[0, 0.05, 1, 0.95])  # Adjust for text box
    os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
    plt.savefig(os.path.join(save_path, 'tsne_visualization.png'), dpi=300, bbox_inches='tight')
    plt.close()


 Creates and saves a visualization summarizing model performance metrics.

 Purpose:
Consolidate key metrics into one view.
Key Features:

Grid Layout: Combines table, distributions, and bar charts.

Adaptive Tables: Handles missing classes gracefully.

Class Distribution: Bar chart with sample counts.

Metric Comparisons: Grouped bars for precision/recall/F1.

In [26]:
def create_performance_summary(y_true, y_pred, save_path):
    """
    Creates and saves a visualization summarizing model performance metrics.

    Args:
        y_true: True labels from test set
        y_pred: Predicted labels from model
        save_path: Directory to save the plot
    """
    # Ensure there are enough data points and at least two classes in y_true
    if len(y_true) < 2 or len(np.unique(y_true)) < 2:
        print("Cannot create performance summary: Insufficient data or only one class in test set.")
        # Create a placeholder figure to avoid errors
        fig, ax = plt.subplots(figsize=(6, 4))
        ax.text(0.5, 0.5, "Insufficient data for summary table.", horizontalalignment='center', verticalalignment='center')
        ax.axis('off')
        plt.title("Performance Summary (Insufficient Data)")
        os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
        plt.savefig(os.path.join(save_path, 'performance_summary.png'), dpi=300, bbox_inches='tight')
        plt.close()
        return


    report = classification_report(y_true, y_pred, output_dict=True)

    # Extract metrics for each class
    class_metrics = {}
    # Handle cases where a class might not be present in the predictions or true labels
    classes_in_report = [str(c) for c in np.unique(y_true)] # Use true labels to define classes
    for cls in classes_in_report:
         class_metrics[cls] = {
             'Precision': report[cls]['precision'] if cls in report else 0.0,
             'Recall': report[cls]['recall'] if cls in report else 0.0,
             'F1-Score': report[cls]['f1-score'] if cls in report else 0.0,
             'Support': report[cls]['support'] if cls in report else 0
         }

    # Create figure with grid layout
    fig = plt.figure(figsize=(14, 10))
    gs = GridSpec(2, 2, figure=fig)

    # 1. Create a table with model performance metrics
    ax1 = fig.add_subplot(gs[0, 0])
    ax1.axis('tight')
    ax1.axis('off')

    # Create table data
    table_data = [
        ['Metric', 'Overall', 'Class 0 (Safe)', 'Class 1 (Risky)'],
        ['Accuracy', f"{accuracy_score(y_true, y_pred):.4f}", '-', '-'], # Calculate accuracy directly
        ['Precision', f"{report['weighted avg']['precision']:.4f}" if 'weighted avg' in report else '-',
         f"{class_metrics['0']['Precision']:.4f}" if '0' in class_metrics else '-',
         f"{class_metrics['1']['Precision']:.4f}" if '1' in class_metrics else '-'],
        ['Recall', f"{report['weighted avg']['recall']:.4f}" if 'weighted avg' in report else '-',
         f"{class_metrics['0']['Recall']:.4f}" if '0' in class_metrics else '-',
         f"{class_metrics['1']['Recall']:.4f}" if '1' in class_metrics else '-'],
        ['F1-Score', f"{report['weighted avg']['f1-score']:.4f}" if 'weighted avg' in report else '-',
         f"{class_metrics['0']['F1-Score']:.4f}" if '0' in class_metrics else '-',
         f"{class_metrics['1']['F1-Score']:.4f}" if '1' in class_metrics else '-'],
        ['Support', f"{report['weighted avg']['support']:.0f}" if 'weighted avg' in report else '-',
         f"{class_metrics['0']['Support']:.0f}" if '0' in class_metrics else '-',
         f"{class_metrics['1']['Support']:.0f}" if '1' in class_metrics else '-']
    ]

    # Create the table
    table = ax1.table(cellText=table_data, cellLoc='center', loc='center')

    # Style the table
    table.auto_set_font_size(False)
    table.set_fontsize(10)
    table.scale(1.2, 1.2) # Adjust table size

    # Apply bold font to header row
    # Corrected attribute access from .cbox to .cells
    for (row, col), cell in table.get_celld().items():
        if row == 0:
            cell.set_text_props(fontweight='bold')

    ax1.set_title('Model Performance Metrics', fontsize=14, fontweight='bold')


    # 2. Create a bar chart for class distribution in the test set
    ax2 = fig.add_subplot(gs[0, 1])
    # Ensure class names match the potential unique labels
    class_names = ['Safe (0)', 'Risky (1)'] # Assuming binary classification 0/1
    class_counts = pd.Series(y_true).value_counts().sort_index()
    # Only plot for the classes that are actually present
    present_class_names = [name for i, name in enumerate(class_names) if i in class_counts.index]
    present_class_counts = [count for i, count in class_counts.items() if i in class_counts.index]
    present_colors = ['skyblue' if i == 0 else 'salmon' for i in class_counts.index]


    bars = ax2.bar(present_class_names, present_class_counts, color=present_colors)
    ax2.set_ylabel('Number of Samples', fontsize=12, fontweight='bold')
    ax2.set_title('Test Set Class Distribution', fontsize=14, fontweight='bold')
    ax2.grid(axis='y', linestyle='--', alpha=0.7)
    # Add counts on top of bars
    for bar in bars:
        yval = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2.0, yval, int(yval), va='bottom', ha='center') # va: vertical alignment

    # 3. Create a bar chart for Precision, Recall, F1-Score by Class
    ax3 = fig.add_subplot(gs[1, :]) # Span across both columns in the second row

    # Filter metrics to only include classes present in the test set
    metrics_data = {
        'Metric': ['Precision', 'Recall', 'F1-Score'],
    }
    if '0' in class_metrics:
        metrics_data['Class 0 (Safe)'] = [class_metrics['0']['Precision'], class_metrics['0']['Recall'], class_metrics['0']['F1-Score']]
    if '1' in class_metrics:
        metrics_data['Class 1 (Risky)'] = [class_metrics['1']['Precision'], class_metrics['1']['Recall'], class_metrics['1']['F1-Score']]

    metrics_df = pd.DataFrame(metrics_data)

    # Plotting grouped bar chart
    # Exclude the 'Metric' column when plotting
    columns_to_plot = [col for col in metrics_df.columns if col != 'Metric']
    if columns_to_plot: # Only plot if there are metrics columns
        metrics_df.plot(x='Metric', y=columns_to_plot, kind='bar', ax=ax3, colormap='viridis', rot=0)
        ax3.set_ylabel('Score', fontsize=12, fontweight='bold')
        ax3.set_title('Precision, Recall, and F1-Score by Class', fontsize=14, fontweight='bold')
        ax3.legend(title='Class', loc='upper left')
        ax3.grid(axis='y', linestyle='--', alpha=0.7)

        # Add score values on top of bars
        for container in ax3.containers:
            ax3.bar_label(container, fmt='%.2f', label_type='edge')
    else:
        ax3.text(0.5, 0.5, "Insufficient data for metrics bar chart.", horizontalalignment='center', verticalalignment='center')
        ax3.axis('off')


    plt.suptitle('Comprehensive Model Performance Summary', fontsize=18, fontweight='bold')
    plt.tight_layout(rect=[0, 0.03, 1, 0.97]) # Adjust layout to make room for suptitle
    os.makedirs(save_path, exist_ok=True) # Create directory if it doesn't exist
    plt.savefig(os.path.join(save_path, 'performance_summary.png'), dpi=300, bbox_inches='tight')
    plt.close()

Main Function to plot different graphs

Workflow:

Sanity Checks: Validates training results exist.

Directory Setup: Creates output folder for plots.

Plot Sequencing: Generates 6 complementary visualizations.

Error Handling: Skips incompatible plots (e.g., ROC with 1 class).

In [30]:
# --- Part 2: Visualization ---

# --- All Imports for Visualization ---
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
import pandas as pd
import os
from sklearn.metrics import roc_curve, auc, precision_recall_curve, f1_score, classification_report
from sklearn.metrics import confusion_matrix, accuracy_score
from sklearn.manifold import TSNE
import torch # Needed for client_model and checking device
from matplotlib.gridspec import GridSpec
import matplotlib.colors as mcolors






# --- Visualization Execution (Part 2) ---
# This block assumes the variables from Part 1 are available in the environment.

print("--- Starting V2V Split Learning Visualization (Part 2) ---")

# Check if the necessary variables from Part 1 are available and valid
if 'training_losses' in locals() and training_losses is not None and \
   'test_losses' in locals() and test_losses is not None and \
   'epochs_list' in locals() and epochs_list is not None and \
   'all_test_labels' in locals() and all_test_labels is not None and \
   'all_test_predictions' in locals() and all_test_predictions is not None and \
   'trained_client_model' in locals() and trained_client_model is not None and \
   'train_dataloader' in locals() and train_dataloader is not None and \
   'test_dataloader' in locals() and test_dataloader is not None and \
   'feature_names' in locals() and feature_names is not None and \
   'PLOT_SAVE_DIR' in locals() and PLOT_SAVE_DIR is not None:

    # Create directory for plots if it doesn't exist (redundant but safe)
    os.makedirs(PLOT_SAVE_DIR, exist_ok=True)
    print(f"Saving plots to: {os.path.abspath(PLOT_SAVE_DIR)}")


    # Plot learning curves
    print("Generating learning curves plot...")
    plot_learning_curves(training_losses, test_losses, epochs_list, PLOT_SAVE_DIR)
    print(f"Learning curves plot saved to {os.path.join(PLOT_SAVE_DIR, 'learning_curves.png')}")

    # Plot enhanced confusion matrix
    print("Generating enhanced confusion matrix plot...")
    plot_enhanced_confusion_matrix(all_test_labels, all_test_predictions, PLOT_SAVE_DIR)
    print(f"Confusion matrix plot saved to {os.path.join(PLOT_SAVE_DIR, 'confusion_matrix.png')}")

    # Plot ROC and Precision-Recall curves
    # Ensure there are at least two classes in the test labels before plotting ROC/PR
    if len(np.unique(all_test_labels)) >= 2:
        print("Generating ROC curve plot...")
        plot_roc_curve(all_test_labels, all_test_predictions, PLOT_SAVE_DIR)
        print(f"ROC curve plot saved to {os.path.join(PLOT_SAVE_DIR, 'roc_curve.png')}")

        print("Generating Precision-Recall curve plot...")
        plot_precision_recall_curve(all_test_labels, all_test_predictions, PLOT_SAVE_DIR)
        print(f"Precision-Recall curve plot saved to {os.path.join(PLOT_SAVE_DIR, 'precision_recall_curve.png')}")
    else:
        print("Skipping ROC and Precision-Recall plots: Test set contains only one class.")


    # Feature analysis (requires client_model and feature_names)
    print("Generating feature importance plot...")
    # Pass the trained_client_model from Part 1
    analyze_feature_importance(trained_client_model, feature_names, PLOT_SAVE_DIR)
    print(f"Feature importance plot saved to {os.path.join(PLOT_SAVE_DIR, 'feature_importance.png')}")


    # t-SNE visualization (requires dataloaders and client_model)
    print("Generating t-SNE visualization plot...")
    # Pass the dataloaders and trained_client_model from Part 1
    visualize_feature_space(train_dataloader, test_dataloader, trained_client_model, PLOT_SAVE_DIR)
    print(f"t-SNE visualization plot saved to {os.path.join(PLOT_SAVE_DIR, 'tsne_visualization.png')}")


    # Performance summary
    print("Generating performance summary plot...")
    create_performance_summary(all_test_labels, all_test_predictions, PLOT_SAVE_DIR)
    print(f"Performance summary plot saved to {os.path.join(PLOT_SAVE_DIR, 'performance_summary.png')}")

    print("\n--- Part 2: All Visualizations Generated ---")

else:
    print("\n--- Part 2: Training Results Not Available ---")
    print("Skipping visualization. Ensure Part 1 ran successfully.")




--- Starting V2V Split Learning Visualization (Part 2) ---
Saving plots to: /content/training_plots
Generating learning curves plot...
Learning curves plot saved to training_plots/learning_curves.png
Generating enhanced confusion matrix plot...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Confusion matrix plot saved to training_plots/confusion_matrix.png
Generating ROC curve plot...
ROC curve plot saved to training_plots/roc_curve.png
Generating Precision-Recall curve plot...


  y = (f1 * x) / (2 * x - f1)


Precision-Recall curve plot saved to training_plots/precision_recall_curve.png
Generating feature importance plot...
Feature importance plot saved to training_plots/feature_importance.png
Generating t-SNE visualization plot...
Applying t-SNE dimensionality reduction...
t-SNE visualization plot saved to training_plots/tsne_visualization.png
Generating performance summary plot...


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))


Performance summary plot saved to training_plots/performance_summary.png

--- Part 2: All Visualizations Generated ---
