# ML-Based Air Traffic Control: Conflict Detection & Hallucination Analysis
## Master's Thesis: Simulation and quantification of ML-based hallucination effects on safety margins in enroute control
### Enhanced with k-NN based conflict detection for improved performance

In [1]:
# BlueSky integration
BLUESKY_AVAILABLE = True
try:
    import bluesky
    BLUESKY_AVAILABLE = True
    print("✅ BlueSky available")
except ImportError:
    !pip -qq install bluesky

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m356.8/356.8 kB[0m [31m8.6 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.7/66.7 kB[0m [31m2.8 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cesium 0.12.4 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.[0m[31m
[0m

In [2]:
# Enhanced LLM-Based Air Traffic Control with k-NN optimization
# Fixed tensor dimensions and accelerated conflict detection

import os
import gc
import json
import warnings
import zipfile
from pathlib import Path
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import confusion_matrix, classification_report, f1_score
from sklearn.neighbors import NearestNeighbors

# Transformers and NLP
from transformers import DistilBertTokenizer, DistilBertModel, get_linear_schedule_with_warmup
from torch.optim import AdamW

# Suppress warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

2025-06-11 10:20:36.571779: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749637236.847399      13 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749637236.926716      13 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
# Enhanced Configuration for Thesis Research
@dataclass
class ThesisConfig:
    """Enhanced configuration for ML hallucination research in ATC with k-NN optimization"""
    # Data paths
    scat_data_path: str = "/kaggle/input/swedish-civil-air-traffic-control-scat-dataset"
    output_path: str = "./thesis_results"
    model_save_path: str = "./thesis_models"
    
    # Model configuration for hallucination analysis
    model_name: str = "distilbert-base-uncased"
    max_sequence_length: int = 128
    batch_size: int = 16  # Reduced for stability
    learning_rate: float = 2e-5  # More conservative
    num_epochs: int = 3  # Reduced for faster iteration
    warmup_steps: int = 100
    
    # k-NN based conflict detection parameters
    k_neighbors: int = 3  # Check only 3 nearest neighbors
    conflict_threshold_nm: float = 5.0  # Nautical miles
    conflict_threshold_ft: float = 1000.0  # Feet
    time_horizon_seconds: int = 120  # Look-ahead time
    spatial_grid_size: float = 10.0  # NM for spatial partitioning
    
    # Enhanced conflict detection parameters - memory optimized
    sample_multiplier: int = 2  # Reduced for streaming approach
    max_time_samples: int = 6000  # Distributed across all weeks
    
    # Hallucination detection parameters
    uncertainty_threshold: float = 0.3
    mc_dropout_samples: int = 10  # Reduced for speed
    confidence_threshold: float = 0.8
    
    # Training data envelope parameters
    altitude_min: float = 10000  # ft
    altitude_max: float = 50000  # ft
    speed_min: float = 200  # kt
    speed_max: float = 600  # kt
    
    # Device
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

config = ThesisConfig()

# Create output directories
os.makedirs(config.output_path, exist_ok=True)
os.makedirs(config.model_save_path, exist_ok=True)

print(f"🎓 Enhanced Thesis Configuration loaded. Using device: {config.device}")
print(f"📊 Research Focus: ML Hallucination Quantification in ATC with k-NN optimization")
print(f"🔍 k-NN neighbors: {config.k_neighbors}, Max samples: {config.max_time_samples}")
print(f"🌊 STREAMING MODE: Processing ALL available weeks one-by-one")
print(f"💾 Memory-Optimized: <5GB RAM usage (vs >50GB for batch processing)")
print(f"⏱️  Processing time: Longer but RAM-safe for complete dataset")

🎓 Enhanced Thesis Configuration loaded. Using device: cpu
📊 Research Focus: ML Hallucination Quantification in ATC with k-NN optimization
🔍 k-NN neighbors: 3, Max samples: 6000
🌊 STREAMING MODE: Processing ALL available weeks one-by-one
💾 Memory-Optimized: <5GB RAM usage (vs >50GB for batch processing)
⏱️  Processing time: Longer but RAM-safe for complete dataset


In [4]:
# Enhanced SCAT Data Loader - Optimized for k-NN
class OptimizedSCATDataLoader:
    """SCAT data loader optimized for k-NN based conflict detection"""

    def __init__(self, config: ThesisConfig):
        self.config = config
        self.data_path = Path(config.scat_data_path)
        self.logger = self._setup_logger()
        self.airspace_data = []
        self.weather_data = []

    def _setup_logger(self):
        class SimpleLogger:
            def info(self, msg): print(f"ℹ️ {msg}")
            def warning(self, msg): print(f"⚠️ {msg}")
            def error(self, msg): print(f"❌ {msg}")
        return SimpleLogger()

    def scan_scat_folders(self) -> List[Path]:
        """Scan for SCAT data sources"""
        search_path = self.data_path
        if not search_path.exists():
            self.logger.error(f"SCAT data path does not exist: {search_path}")
            return []

        scat_folders = [p for p in search_path.iterdir()
                        if (p.is_dir() and p.name.lower().startswith("scat")) or
                           (p.suffix == ".zip" and p.name.lower().startswith("scat"))]

        self.logger.info(f"Found {len(scat_folders)} SCAT data sources")
        return sorted(scat_folders)

    def get_week_paths(self, max_weeks: int = None) -> List[Path]:
        """Get list of week paths for streaming processing"""
        week_paths = self.scan_scat_folders()[:max_weeks] if max_weeks else self.scan_scat_folders()
        return week_paths
    
    def load_single_week_processed(self, week_path: Path) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """Load and process a single week with memory optimization"""
        print(f"🗂️  Loading week: {week_path.name}")
        
        # Load single week
        f_df, c_df, t_df = self._load_single_week(week_path)
        
        # Clean and optimize immediately
        t_df = self._clean_and_optimize_tracks(t_df)
        c_df = self._clean_timestamps(c_df)
        
        print(f"   📊 Week stats: {len(f_df):,} flights, {len(c_df):,} clearances, {len(t_df):,} tracks")
        
        return f_df, c_df, t_df

    def load_all_weeks(self, max_weeks: int = None) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """Load multiple weeks with enhanced preprocessing for k-NN - DEPRECATED for memory efficiency"""
        print("⚠️  This method loads all weeks at once - use streaming approach instead")
        week_paths = self.scan_scat_folders()[:max_weeks] if max_weeks else self.scan_scat_folders()
        
        total_weeks = len(week_paths)
        print(f"📅 Processing {total_weeks} weeks of SCAT data...")
        print(f"⏱️  Estimated processing time: {total_weeks * 2}-{total_weeks * 4} minutes")

        flights_all, clearances_all, tracks_all = [], [], []

        for i, wpath in enumerate(tqdm(week_paths, desc="📦 Loading weeks")):
            print(f"🗂️  Processing week {i+1}/{total_weeks}: {wpath.name}")
            f_df, c_df, t_df = self._load_single_week(wpath)
            flights_all.append(f_df)
            clearances_all.append(c_df)
            tracks_all.append(t_df)
            
            # Memory management for large datasets
            if i % 5 == 4:  # Every 5 weeks, force garbage collection
                gc.collect()
                print(f"🧹 Memory cleanup after {i+1} weeks")

        print("🔄 Concatenating all weeks...")
        flights_df = pd.concat(flights_all, ignore_index=True) if flights_all else pd.DataFrame()
        clearances_df = pd.concat(clearances_all, ignore_index=True) if clearances_all else pd.DataFrame()
        tracks_df = pd.concat(tracks_all, ignore_index=True) if tracks_all else pd.DataFrame()

        # Final memory cleanup
        del flights_all, clearances_all, tracks_all
        gc.collect()

        # Enhanced cleaning and preprocessing
        tracks_df = self._clean_and_optimize_tracks(tracks_df)
        clearances_df = self._clean_timestamps(clearances_df)

        self.logger.info(f"✅ Loaded {len(flights_df):,} flights, "
                         f"{len(clearances_df):,} clearances, "
                         f"{len(tracks_df):,} track points from "
                         f"{len(week_paths)} week(s)")
        return flights_df, clearances_df, tracks_df

    def _clean_and_optimize_tracks(self, df: pd.DataFrame) -> pd.DataFrame:
        """Enhanced track cleaning with spatial indexing preparation"""
        if df.empty or 'timestamp' not in df.columns:
            return df
        
        df = df.copy()
        df['timestamp'] = pd.to_datetime(df['timestamp'], format='mixed', errors='coerce')
        df = df.dropna(subset=['timestamp', 'latitude', 'longitude', 'altitude'])
        
        # Remove outliers and invalid data
        df = df[
            (df['latitude'].between(-90, 90)) &
            (df['longitude'].between(-180, 180)) &
            (df['altitude'].between(self.config.altitude_min, self.config.altitude_max))
        ]
        
        # Add spatial grid indices for faster k-NN
        df['lat_grid'] = (df['latitude'] // (self.config.spatial_grid_size / 60)).astype(int)
        df['lon_grid'] = (df['longitude'] // (self.config.spatial_grid_size / 60)).astype(int)
        
        df = df.sort_values(['timestamp', 'flight_id']).reset_index(drop=True)
        return df

    def _clean_timestamps(self, df: pd.DataFrame) -> pd.DataFrame:
        """Clean and standardize timestamps"""
        if df.empty or 'timestamp' not in df.columns:
            return df
        
        df = df.copy()
        df['timestamp'] = pd.to_datetime(df['timestamp'], format='mixed', errors='coerce')
        df = df.dropna(subset=['timestamp'])
        df = df.sort_values('timestamp').reset_index(drop=True)
        return df

    def _load_single_week(self, week_path: Path) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """Load single week with error handling"""
        week_flights, week_clearances, week_tracks = [], [], []

        if week_path.suffix == ".zip":
            with zipfile.ZipFile(week_path, "r") as zf:
                json_files = [f for f in zf.namelist() if f.endswith(".json")]
                for jf in tqdm(json_files, desc=f"Processing {week_path.name}", leave=False, disable=True):
                    try:
                        with zf.open(jf) as fp:
                            data = json.load(fp)
                        self._dispatch_json(jf, data, week_flights, week_clearances, week_tracks)
                    except Exception as e:
                        continue
                self._load_auxiliary_data(zf)

        elif week_path.is_dir():
            json_paths = list(week_path.glob("*.json"))
            for jp in tqdm(json_paths, desc=f"Processing {week_path.name}", leave=False, disable=True):
                try:
                    with open(jp) as fp:
                        data = json.load(fp)
                    self._dispatch_json(jp.name, data, week_flights, week_clearances, week_tracks)
                except Exception as e:
                    continue

        return (pd.DataFrame(week_flights),
                pd.DataFrame(week_clearances),
                pd.DataFrame(week_tracks))

    def _dispatch_json(self, name, data, flights, clearances, tracks):
        """Route JSON content appropriately"""
        if name in ("airspace.json", "grib_meteo.json"):
            (self.airspace_data if "airspace" in name else self.weather_data).append(data)
        else:
            self._extract_flight_features(data, flights, clearances, tracks)

    def _extract_flight_features(self, flight_data: Dict, flights: list, clearances: list, tracks: list):
        """Extract flight data with validation"""
        flight_id = flight_data.get('Id') or flight_data.get('id')
        if not flight_id:
            return
        
        # Extract flight plan data
        fpl_key = next((k for k in ['Fpl', 'fpl'] if k in flight_data), None)
        if fpl_key and 'fpl_base' in flight_data[fpl_key]:
            for base_data in flight_data[fpl_key]['fpl_base']:
                flight_record = {
                    'flight_id': flight_id,
                    'callsign': base_data.get('Callsign') or base_data.get('callsign'),
                    'aircraft_type': base_data.get('aircraft_type'),
                    'departure': base_data.get('Adep') or base_data.get('adep'),
                    'destination': base_data.get('Ades') or base_data.get('ades'),
                    'timestamp': base_data.get('time_stamp')
                }
                flights.append(flight_record)
        
        # Extract clearance data
        if fpl_key and 'fpl_clearance' in flight_data[fpl_key]:
            for clearance in flight_data[fpl_key]['fpl_clearance']:
                clearance_record = {
                    'flight_id': flight_id,
                    'timestamp': clearance.get('time_stamp'),
                    'cleared_flight_level': clearance.get('Cfl') or clearance.get('cfl'),
                    'assigned_speed': clearance.get('assigned_speed_val'),
                    'assigned_heading': clearance.get('assigned_heading_val')
                }
                clearances.append(clearance_record)
        
        # Extract track data with validation
        plots_key = next((k for k in ['Plots', 'plots'] if k in flight_data), None)
        if plots_key:
            for plot in flight_data[plots_key]:
                if 'I062/105' in plot and 'I062/136' in plot:
                    try:
                        track_record = {
                            'flight_id': flight_id,
                            'timestamp': plot.get('time_of_track'),
                            'latitude': plot['I062/105'].get('lat'),
                            'longitude': plot['I062/105'].get('lon'),
                            'altitude': float(plot['I062/136'].get('measured_flight_level', 0)) * 100
                        }
                        
                        # Add velocity information
                        if 'I062/185' in plot:
                            vx = plot['I062/185'].get('vx', 0)
                            vy = plot['I062/185'].get('vy', 0)
                            track_record['ground_speed'] = np.sqrt(vx**2 + vy**2) * 1.94384
                            track_record['heading'] = np.degrees(np.arctan2(vx, vy)) % 360
                        else:
                            track_record['ground_speed'] = 450.0  # Default speed
                            track_record['heading'] = 90.0      # Default heading
                        
                        # Validate and add
                        if self._validate_track_data(track_record):
                            tracks.append(track_record)
                    except (ValueError, TypeError):
                        continue

    def _validate_track_data(self, track: Dict) -> bool:
        """Validate track data is within operational envelope"""
        try:
            altitude = track.get('altitude', 0)
            speed = track.get('ground_speed', 0)
            lat = track.get('latitude', 0)
            lon = track.get('longitude', 0)
            
            # Check operational envelope
            if not (self.config.altitude_min <= altitude <= self.config.altitude_max):
                return False
            if not (self.config.speed_min <= speed <= self.config.speed_max):
                return False
            if not (-90 <= lat <= 90 and -180 <= lon <= 180):
                return False
            
            return True
        except (ValueError, TypeError):
            return False

    def _load_auxiliary_data(self, zip_ref):
        """Load auxiliary data"""
        for aux in ("airspace.json", "grib_meteo.json"):
            try:
                if aux in zip_ref.namelist():
                    with zip_ref.open(aux) as f:
                        data = json.load(f)
                        (self.airspace_data if "airspace" in aux else self.weather_data).append(data)
            except Exception:
                pass

In [5]:
# k-NN Enhanced Conflict Detector
class KNNConflictDetector:
    """k-NN based conflict detector for improved performance and extensive detection"""
    
    def __init__(self, config: ThesisConfig):
        self.config = config
        self.k_neighbors = config.k_neighbors
        self.conflict_threshold_nm = config.conflict_threshold_nm
        self.conflict_threshold_ft = config.conflict_threshold_ft
        self.time_horizon = config.time_horizon_seconds
    
    def haversine_distance(self, lat1, lon1, lat2, lon2):
        """Calculate distance between two points in nautical miles"""
        R = 3440.065  # Earth radius in nautical miles
        lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
        dlat = lat2 - lat1
        dlon = lon2 - lon1
        a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
        c = 2 * np.arcsin(np.sqrt(a))
        return R * c
    
    def detect_conflicts_streaming(self, data_loader: 'OptimizedSCATDataLoader', 
                                  max_weeks: int = None) -> pd.DataFrame:
        """Memory-efficient streaming conflict detection - one week at a time"""
        
        week_paths = data_loader.get_week_paths(max_weeks)
        total_weeks = len(week_paths)
        
        print(f"🌊 Streaming conflict detection across {total_weeks} weeks")
        print(f"💾 Memory-efficient: Processing one week at a time")
        print(f"⚡ Expected RAM usage: <5GB per week vs >{total_weeks*2}GB for all weeks")
        
        all_scenarios = []
        total_conflicts = 0
        total_scenarios = 0
        
        for week_idx, week_path in enumerate(tqdm(week_paths, desc="🔄 Streaming weeks")):
            print(f"\n📅 Week {week_idx+1}/{total_weeks}: {week_path.name}")
            
            try:
                # Load single week
                flights_df, clearances_df, tracks_df = data_loader.load_single_week_processed(week_path)
                
                if tracks_df.empty:
                    print(f"   ⚠️  No track data in {week_path.name}, skipping...")
                    continue
                
                # Detect conflicts for this week only
                week_scenarios = self._detect_week_conflicts(tracks_df, clearances_df, week_idx)
                
                if not week_scenarios.empty:
                    week_conflicts = week_scenarios['conflict'].sum()
                    total_conflicts += week_conflicts
                    total_scenarios += len(week_scenarios)
                    
                    print(f"   🚨 Found {week_conflicts} conflicts in {len(week_scenarios)} scenarios")
                    all_scenarios.append(week_scenarios)
                else:
                    print(f"   ℹ️  No scenarios generated for {week_path.name}")
                
                # Aggressive memory cleanup after each week
                del flights_df, clearances_df, tracks_df, week_scenarios
                gc.collect()
                
                # Memory status
                if week_idx % 3 == 2:  # Every 3 weeks
                    print(f"   🧹 Memory cleanup - processed {week_idx+1}/{total_weeks} weeks")
                
            except Exception as e:
                print(f"   ❌ Error processing {week_path.name}: {e}")
                continue
        
        # Combine all scenarios
        if all_scenarios:
            print(f"\n🔄 Combining scenarios from {len(all_scenarios)} weeks...")
            final_scenarios = pd.concat(all_scenarios, ignore_index=True)
            
            # Final cleanup
            del all_scenarios
            gc.collect()
            
            print(f"✅ Streaming detection complete!")
            print(f"   📊 Total scenarios: {len(final_scenarios):,}")
            print(f"   🚨 Total conflicts: {total_conflicts:,}")
            print(f"   ✅ Non-conflicts: {len(final_scenarios) - total_conflicts:,}")
            
            return final_scenarios
        else:
            print("❌ No scenarios generated from any week")
            return pd.DataFrame()
    
    def _detect_week_conflicts(self, tracks_df: pd.DataFrame, clearances_df: pd.DataFrame, 
                              week_idx: int) -> pd.DataFrame:
        """Detect conflicts for a single week with memory optimization"""
        
        if tracks_df.empty:
            return pd.DataFrame()
        
        # Reduce sample size per week to manage memory
        week_sample_size = min(len(tracks_df['timestamp'].unique()), 
                              self.config.max_time_samples // 10)  # Distribute across weeks
        
        # Enhanced sampling for this week
        unique_times = tracks_df['timestamp'].unique()
        time_aircraft_counts = tracks_df.groupby('timestamp')['flight_id'].nunique()
        
        # Weight by aircraft density
        weights = time_aircraft_counts.values
        weights = weights / weights.sum()
        
        sampled_times = np.random.choice(
            unique_times, 
            size=week_sample_size, 
            replace=False,
            p=weights
        )
        
        conflicts = []
        
        for time_sample in sampled_times:
            current_aircraft = tracks_df[tracks_df['timestamp'] == time_sample]
            
            if len(current_aircraft) < 2:
                continue
            
            aircraft_groups = current_aircraft.groupby('flight_id').last().reset_index()
            
            if len(aircraft_groups) < 2:
                continue
            
            # k-NN conflict detection for this time
            conflicts_at_time = self._detect_conflicts_knn_at_time(
                aircraft_groups, time_sample, clearances_df
            )
            conflicts.extend(conflicts_at_time)
        
        week_conflicts_df = pd.DataFrame(conflicts)
        
        # Add non-conflicts for this week
        if not week_conflicts_df.empty:
            num_conflicts = len(week_conflicts_df)
            non_conflicts = self._extract_week_non_conflicts(
                tracks_df, num_conflicts * 2  # Reduced multiplier for memory
            )
            week_conflicts_df = pd.concat([week_conflicts_df, non_conflicts], ignore_index=True)
        
        return week_conflicts_df
    
    def _extract_week_non_conflicts(self, tracks_df: pd.DataFrame, num_samples: int) -> pd.DataFrame:
        """Extract non-conflict scenarios for a single week"""
        non_conflicts = []
        
        # Limit samples for memory efficiency
        max_samples = min(num_samples, 1000)  # Cap per week
        
        unique_times = tracks_df['timestamp'].unique()
        sample_times = np.random.choice(
            unique_times, 
            size=min(len(unique_times), max_samples), 
            replace=False
        )
        
        for time_sample in sample_times:
            current_aircraft = tracks_df[tracks_df['timestamp'] == time_sample]
            aircraft_groups = current_aircraft.groupby('flight_id').last().reset_index()
            
            if len(aircraft_groups) < 2:
                continue
            
            # Use k-NN for safe separations
            coords = aircraft_groups[['latitude', 'longitude']].values
            if len(coords) < 2:
                continue
                
            try:
                knn = NearestNeighbors(n_neighbors=min(3, len(aircraft_groups)), metric='haversine')
                knn.fit(np.radians(coords))
                distances, indices = knn.kneighbors(np.radians(coords))
                distances_nm = distances * 3440.065
                
                processed_pairs = set()
                
                for i, aircraft_idx in enumerate(aircraft_groups.index):
                    if len(non_conflicts) >= max_samples:
                        break
                        
                    aircraft = aircraft_groups.loc[aircraft_idx]
                    
                    for j in range(1, min(len(indices[i]), 3)):
                        neighbor_pos = indices[i][j]
                        neighbor_idx = aircraft_groups.index[neighbor_pos]
                        neighbor = aircraft_groups.loc[neighbor_idx]
                        
                        pair_key = tuple(sorted([aircraft['flight_id'], neighbor['flight_id']]))
                        if pair_key in processed_pairs:
                            continue
                        processed_pairs.add(pair_key)
                        
                        h_dist = distances_nm[i][j]
                        v_dist = abs(aircraft['altitude'] - neighbor['altitude'])
                        
                        # Safe separation
                        if h_dist > self.config.conflict_threshold_nm * 1.5 or v_dist > self.config.conflict_threshold_ft * 1.5:
                            non_conflict_record = {
                                'timestamp': time_sample,
                                'flight_id_1': aircraft['flight_id'],
                                'flight_id_2': neighbor['flight_id'],
                                'horizontal_distance': h_dist,
                                'vertical_distance': v_dist,
                                'conflict': False,
                                'clearance_issued': False,
                                'clearance_type': 'none',
                                'altitude_1': aircraft['altitude'],
                                'altitude_2': neighbor['altitude'],
                                'speed_1': aircraft.get('ground_speed', 450),
                                'speed_2': neighbor.get('ground_speed', 450),
                                'heading_1': aircraft.get('heading', 90),
                                'heading_2': neighbor.get('heading', 90)
                            }
                            non_conflicts.append(non_conflict_record)
                            
                            if len(non_conflicts) >= max_samples:
                                break
                    
                    if len(non_conflicts) >= max_samples:
                        break
                        
            except Exception as e:
                continue
                
            if len(non_conflicts) >= max_samples:
                break
        
        return pd.DataFrame(non_conflicts)
    
    def detect_real_conflicts_knn(self, tracks_df: pd.DataFrame, clearances_df: pd.DataFrame) -> pd.DataFrame:
        """Enhanced conflict detection using k-NN optimization - LEGACY METHOD"""
        print("⚠️  Using legacy batch processing - consider streaming for large datasets")
        
        if tracks_df.empty:
            print("No track data available for conflict detection")
            return pd.DataFrame()
        
        conflicts = []
        tracks_df = tracks_df.copy()
        
        # Ensure timestamps are datetime
        if not pd.api.types.is_datetime64_any_dtype(tracks_df['timestamp']):
            tracks_df['timestamp'] = pd.to_datetime(tracks_df['timestamp'], errors='coerce')
        
        tracks_df = tracks_df.dropna(subset=['timestamp', 'latitude', 'longitude', 'altitude'])
        
        if tracks_df.empty:
            print("No valid track data after cleaning")
            return pd.DataFrame()
        
        # Enhanced sampling strategy
        unique_times = tracks_df['timestamp'].unique()
        sample_size = min(len(unique_times), self.config.max_time_samples)
        
        # Weighted sampling - prefer times with more aircraft
        time_aircraft_counts = tracks_df.groupby('timestamp')['flight_id'].nunique()
        weights = time_aircraft_counts.values
        weights = weights / weights.sum()  # Normalize
        
        sampled_times = np.random.choice(
            unique_times, 
            size=sample_size, 
            replace=False,
            p=weights
        )
        
        print(f"🔍 Enhanced k-NN analysis of {sample_size} time windows for conflicts")
        
        for time_sample in tqdm(sampled_times, desc="k-NN conflict detection"):
            # Get aircraft at this time
            current_aircraft = tracks_df[tracks_df['timestamp'] == time_sample]
            
            if len(current_aircraft) < 2:
                continue
            
            # Group by flight and get latest position
            aircraft_groups = current_aircraft.groupby('flight_id').last().reset_index()
            
            if len(aircraft_groups) < 2:
                continue
            
            # Use k-NN for efficient conflict detection
            conflicts_at_time = self._detect_conflicts_knn_at_time(
                aircraft_groups, time_sample, clearances_df
            )
            conflicts.extend(conflicts_at_time)
        
        conflicts_df = pd.DataFrame(conflicts)
        print(f"🚨 k-NN detection found {len(conflicts_df)} scenarios ({conflicts_df['conflict'].sum() if not conflicts_df.empty else 0} conflicts)")
        
        # Add extensive non-conflict samples
        if not conflicts_df.empty:
            non_conflicts = self._extract_extensive_non_conflicts(
                tracks_df, len(conflicts_df) * self.config.sample_multiplier
            )
            conflicts_df = pd.concat([conflicts_df, non_conflicts], ignore_index=True)
        
        return conflicts_df
        """Enhanced conflict detection using k-NN optimization"""
        if tracks_df.empty:
            print("No track data available for conflict detection")
            return pd.DataFrame()
        
        conflicts = []
        tracks_df = tracks_df.copy()
        
        # Ensure timestamps are datetime
        if not pd.api.types.is_datetime64_any_dtype(tracks_df['timestamp']):
            tracks_df['timestamp'] = pd.to_datetime(tracks_df['timestamp'], errors='coerce')
        
        tracks_df = tracks_df.dropna(subset=['timestamp', 'latitude', 'longitude', 'altitude'])
        
        if tracks_df.empty:
            print("No valid track data after cleaning")
            return pd.DataFrame()
        
        # Enhanced sampling strategy
        unique_times = tracks_df['timestamp'].unique()
        sample_size = min(len(unique_times), self.config.max_time_samples)
        
        # Weighted sampling - prefer times with more aircraft
        time_aircraft_counts = tracks_df.groupby('timestamp')['flight_id'].nunique()
        weights = time_aircraft_counts.values
        weights = weights / weights.sum()  # Normalize
        
        sampled_times = np.random.choice(
            unique_times, 
            size=sample_size, 
            replace=False,
            p=weights
        )
        
        print(f"🔍 Enhanced k-NN analysis of {sample_size} time windows for conflicts")
        
        for time_sample in tqdm(sampled_times, desc="k-NN conflict detection"):
            # Get aircraft at this time
            current_aircraft = tracks_df[tracks_df['timestamp'] == time_sample]
            
            if len(current_aircraft) < 2:
                continue
            
            # Group by flight and get latest position
            aircraft_groups = current_aircraft.groupby('flight_id').last().reset_index()
            
            if len(aircraft_groups) < 2:
                continue
            
            # Use k-NN for efficient conflict detection
            conflicts_at_time = self._detect_conflicts_knn_at_time(
                aircraft_groups, time_sample, clearances_df
            )
            conflicts.extend(conflicts_at_time)
        
        conflicts_df = pd.DataFrame(conflicts)
        print(f"🚨 k-NN detection found {len(conflicts_df)} scenarios ({conflicts_df['conflict'].sum() if not conflicts_df.empty else 0} conflicts)")
        
        # Add extensive non-conflict samples
        if not conflicts_df.empty:
            non_conflicts = self._extract_extensive_non_conflicts(
                tracks_df, len(conflicts_df) * self.config.sample_multiplier
            )
            conflicts_df = pd.concat([conflicts_df, non_conflicts], ignore_index=True)
        
        return conflicts_df
    
    def _detect_conflicts_knn_at_time(self, aircraft_df: pd.DataFrame, timestamp: pd.Timestamp, 
                                     clearances_df: pd.DataFrame) -> List[Dict]:
        """Detect conflicts at specific time using k-NN"""
        conflicts = []
        
        if len(aircraft_df) < 2:
            return conflicts
        
        # Prepare coordinates for k-NN
        coords = aircraft_df[['latitude', 'longitude']].values
        
        # Build k-NN model
        knn = NearestNeighbors(n_neighbors=min(self.k_neighbors + 1, len(aircraft_df)), 
                              metric='haversine')
        knn.fit(np.radians(coords))
        
        # Find neighbors for each aircraft
        distances, indices = knn.kneighbors(np.radians(coords))
        
        # Convert distances from radians to nautical miles
        distances_nm = distances * 3440.065
        
        processed_pairs = set()
        
        for i, aircraft_idx in enumerate(aircraft_df.index):
            aircraft = aircraft_df.loc[aircraft_idx]
            
            # Check k nearest neighbors (skip self at index 0)
            for j in range(1, min(len(indices[i]), self.k_neighbors + 1)):
                neighbor_pos = indices[i][j]
                neighbor_idx = aircraft_df.index[neighbor_pos]
                neighbor = aircraft_df.loc[neighbor_idx]
                
                # Avoid duplicate pairs
                pair_key = tuple(sorted([aircraft['flight_id'], neighbor['flight_id']]))
                if pair_key in processed_pairs:
                    continue
                processed_pairs.add(pair_key)
                
                # Calculate separations
                h_dist = distances_nm[i][j]
                v_dist = abs(aircraft['altitude'] - neighbor['altitude'])
                
                # Determine if this is a conflict or safe scenario
                is_conflict = (h_dist < self.config.conflict_threshold_nm and 
                             v_dist < self.config.conflict_threshold_ft)
                
                # For non-conflicts, add some that are close but safe
                include_scenario = is_conflict or (
                    h_dist < self.config.conflict_threshold_nm * 2 and
                    np.random.random() < 0.3  # Sample 30% of near-misses
                )
                
                if include_scenario:
                    # Find associated clearances
                    clearance_info = self._find_associated_clearances(
                        timestamp, [aircraft['flight_id'], neighbor['flight_id']], clearances_df
                    )
                    
                    conflict_record = {
                        'timestamp': timestamp,
                        'flight_id_1': aircraft['flight_id'],
                        'flight_id_2': neighbor['flight_id'],
                        'horizontal_distance': h_dist,
                        'vertical_distance': v_dist,
                        'conflict': is_conflict,
                        'clearance_issued': clearance_info['issued'],
                        'clearance_type': clearance_info['type'],
                        'altitude_1': aircraft['altitude'],
                        'altitude_2': neighbor['altitude'],
                        'speed_1': aircraft.get('ground_speed', 450),
                        'speed_2': neighbor.get('ground_speed', 450),
                        'heading_1': aircraft.get('heading', 90),
                        'heading_2': neighbor.get('heading', 90)
                    }
                    conflicts.append(conflict_record)
        
        return conflicts
    
    def _find_associated_clearances(self, timestamp: pd.Timestamp, flight_ids: List[str], 
                                   clearances_df: pd.DataFrame) -> Dict:
        """Find clearances associated with a conflict"""
        if clearances_df.empty:
            return {'issued': False, 'type': 'none'}
        
        # Look for clearances within time window
        time_window = pd.Timedelta(minutes=5)
        clearance_mask = (
            (clearances_df['timestamp'] >= timestamp - time_window) &
            (clearances_df['timestamp'] <= timestamp + time_window) &
            (clearances_df['flight_id'].isin(flight_ids))
        )
        
        relevant_clearances = clearances_df[clearance_mask]
        
        if len(relevant_clearances) == 0:
            return {'issued': False, 'type': 'none'}
        
        # Determine clearance type
        clearance_type = self._classify_clearance_type(relevant_clearances)
        return {'issued': True, 'type': clearance_type}
    
    def _classify_clearance_type(self, clearances: pd.DataFrame) -> str:
        """Classify the type of clearance"""
        if clearances['cleared_flight_level'].notna().any():
            return "altitude_change"
        elif clearances['assigned_heading'].notna().any():
            return "heading_change"  
        elif clearances['assigned_speed'].notna().any():
            return "speed_change"
        else:
            return "other"
    
    def _extract_extensive_non_conflicts(self, tracks_df: pd.DataFrame, num_samples: int) -> pd.DataFrame:
        """Extract extensive non-conflict scenarios using k-NN"""
        non_conflicts = []
        
        # Sample different time windows
        unique_times = tracks_df['timestamp'].unique()
        sample_times = np.random.choice(
            unique_times, 
            size=min(len(unique_times), num_samples), 
            replace=False
        )
        
        for time_sample in tqdm(sample_times, desc="Generating non-conflicts", disable=True):
            current_aircraft = tracks_df[tracks_df['timestamp'] == time_sample]
            aircraft_groups = current_aircraft.groupby('flight_id').last().reset_index()
            
            if len(aircraft_groups) < 2:
                continue
            
            # Use k-NN to find safe separations
            coords = aircraft_groups[['latitude', 'longitude']].values
            if len(coords) < 2:
                continue
                
            knn = NearestNeighbors(n_neighbors=min(3, len(aircraft_groups)), metric='haversine')
            knn.fit(np.radians(coords))
            distances, indices = knn.kneighbors(np.radians(coords))
            distances_nm = distances * 3440.065
            
            processed_pairs = set()
            
            for i, aircraft_idx in enumerate(aircraft_groups.index):
                aircraft = aircraft_groups.loc[aircraft_idx]
                
                for j in range(1, min(len(indices[i]), 3)):
                    neighbor_pos = indices[i][j]
                    neighbor_idx = aircraft_groups.index[neighbor_pos]
                    neighbor = aircraft_groups.loc[neighbor_idx]
                    
                    pair_key = tuple(sorted([aircraft['flight_id'], neighbor['flight_id']]))
                    if pair_key in processed_pairs:
                        continue
                    processed_pairs.add(pair_key)
                    
                    h_dist = distances_nm[i][j]
                    v_dist = abs(aircraft['altitude'] - neighbor['altitude'])
                    
                    # Non-conflict: safe separation
                    if h_dist > self.config.conflict_threshold_nm * 1.5 or v_dist > self.config.conflict_threshold_ft * 1.5:
                        non_conflict_record = {
                            'timestamp': time_sample,
                            'flight_id_1': aircraft['flight_id'],
                            'flight_id_2': neighbor['flight_id'],
                            'horizontal_distance': h_dist,
                            'vertical_distance': v_dist,
                            'conflict': False,
                            'clearance_issued': False,
                            'clearance_type': 'none',
                            'altitude_1': aircraft['altitude'],
                            'altitude_2': neighbor['altitude'],
                            'speed_1': aircraft.get('ground_speed', 450),
                            'speed_2': neighbor.get('ground_speed', 450),
                            'heading_1': aircraft.get('heading', 90),
                            'heading_2': neighbor.get('heading', 90)
                        }
                        non_conflicts.append(non_conflict_record)
                        
                        if len(non_conflicts) >= num_samples:
                            break
                
                if len(non_conflicts) >= num_samples:
                    break
                    
            if len(non_conflicts) >= num_samples:
                break
        
        return pd.DataFrame(non_conflicts)

In [6]:
# Fixed Hallucination-Aware Dataset
class FixedHallucinationATCDataset(Dataset):
    """Fixed dataset with proper tensor dimensions"""
    
    def __init__(self, scenarios_df: pd.DataFrame, tokenizer, max_length: int = 128):
        self.scenarios_df = scenarios_df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Encode clearance types
        self.clearance_encoder = LabelEncoder()
        self.scenarios_df['clearance_label'] = self.clearance_encoder.fit_transform(
            self.scenarios_df['clearance_type'].fillna('none')
        )
        
        # Add envelope indicators for hallucination analysis
        self.scenarios_df['outside_training_envelope'] = self._detect_envelope_violations()
        
        print(f"📊 Dataset created: {len(self.scenarios_df)} scenarios")
        print(f"   - Conflicts: {self.scenarios_df['conflict'].sum()}")
        print(f"   - Non-conflicts: {(~self.scenarios_df['conflict']).sum()}")
        print(f"   - Envelope violations: {self.scenarios_df['outside_training_envelope'].sum()}")
    
    def _detect_envelope_violations(self) -> pd.Series:
        """Detect scenarios outside typical training envelope"""
        violations = (
            (self.scenarios_df['altitude_1'] < config.altitude_min) |
            (self.scenarios_df['altitude_1'] > config.altitude_max) |
            (self.scenarios_df['altitude_2'] < config.altitude_min) |
            (self.scenarios_df['altitude_2'] > config.altitude_max) |
            (self.scenarios_df['speed_1'] < config.speed_min) |
            (self.scenarios_df['speed_1'] > config.speed_max) |
            (self.scenarios_df['speed_2'] < config.speed_min) |
            (self.scenarios_df['speed_2'] > config.speed_max)
        )
        return violations
    
    def __len__(self):
        return len(self.scenarios_df)
    
    def __getitem__(self, idx):
        row = self.scenarios_df.iloc[idx]
        
        # Create scenario description
        text = self._create_scenario_text(row)
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'conflict_label': torch.tensor(1 if row['conflict'] else 0, dtype=torch.long),
            'clearance_label': torch.tensor(row['clearance_label'], dtype=torch.long),
            'envelope_violation': torch.tensor(float(row['outside_training_envelope']), dtype=torch.float)
        }
    
    def _create_scenario_text(self, row):
        """Create natural language scenario description"""
        text = (
            f"Aircraft A at FL{int(row['altitude_1']/100):03d} "
            f"heading {int(row['heading_1']):03d}° "
            f"speed {int(row['speed_1']):03d} kt; "
            f"Aircraft B at FL{int(row['altitude_2']/100):03d} "
            f"heading {int(row['heading_2']):03d}° "
            f"speed {int(row['speed_2']):03d} kt; "
            f"horizontal separation {row['horizontal_distance']:.1f} NM; "
            f"vertical separation {row['vertical_distance']:.0f} ft."
        )
        return text

In [7]:
# Fixed Hallucination-Aware Model Architecture  
class FixedHallucinationAwareATCModel(nn.Module):
    """Fixed ATC model with proper tensor handling"""
    
    def __init__(self, model_name: str, num_clearance_types: int, dropout_rate: float = 0.1):
        super().__init__()
        
        self.bert = DistilBertModel.from_pretrained(model_name)
        hidden_size = self.bert.config.hidden_size
        
        # Main task heads
        self.conflict_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 2, 2)
        )
        
        self.resolution_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 2, num_clearance_types)
        )
        
        # Fixed hallucination detection head
        self.hallucination_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 4, 1)  # No sigmoid here
        )
        
        self.dropout_rate = dropout_rate
    
    def forward(self, input_ids, attention_mask, enable_dropout=False):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]
        
        if enable_dropout:
            pooled_output = F.dropout(pooled_output, p=self.dropout_rate, training=True)
        
        conflict_logits = self.conflict_head(pooled_output)
        resolution_logits = self.resolution_head(pooled_output)
        hallucination_logits = self.hallucination_head(pooled_output)  # Raw logits
        
        return conflict_logits, resolution_logits, hallucination_logits
    
    def get_uncertainty_estimates(self, input_ids, attention_mask, n_samples=10):
        """Estimate model uncertainty using Monte Carlo dropout"""
        self.train()  # Enable dropout
        
        conflict_samples = []
        resolution_samples = []
        hallucination_samples = []
        
        with torch.no_grad():
            for _ in range(n_samples):
                conflict_logits, resolution_logits, hallucination_logits = self.forward(
                    input_ids, attention_mask, enable_dropout=True
                )
                conflict_samples.append(F.softmax(conflict_logits, dim=-1))
                resolution_samples.append(F.softmax(resolution_logits, dim=-1))
                hallucination_samples.append(torch.sigmoid(hallucination_logits))
        
        # Calculate statistics
        conflict_probs = torch.stack(conflict_samples)
        resolution_probs = torch.stack(resolution_samples)
        hallucination_probs = torch.stack(hallucination_samples)
        
        results = {
            'conflict_mean': conflict_probs.mean(dim=0),
            'conflict_std': conflict_probs.std(dim=0),
            'conflict_entropy': -torch.sum(conflict_probs.mean(dim=0) * torch.log(conflict_probs.mean(dim=0) + 1e-8), dim=-1),
            'resolution_mean': resolution_probs.mean(dim=0),
            'resolution_std': resolution_probs.std(dim=0),
            'hallucination_mean': hallucination_probs.mean(dim=0),
            'hallucination_std': hallucination_probs.std(dim=0)
        }
        
        return results

In [8]:
# Fixed Training Pipeline
class FixedHallucinationAwareTrainer:
    """Fixed training pipeline with proper tensor handling"""
    
    def __init__(self, model, train_loader, val_loader, config):
        self.model = model.to(config.device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        
        self.optimizer = AdamW(model.parameters(), lr=config.learning_rate, weight_decay=1e-5)
        total_steps = len(train_loader) * config.num_epochs
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=total_steps
        )
        
        # Loss functions
        self.conflict_loss_fn = nn.CrossEntropyLoss()
        self.resolution_loss_fn = nn.CrossEntropyLoss()
        self.hallucination_loss_fn = nn.BCEWithLogitsLoss()  # Use BCEWithLogitsLoss for stability
        
        # Loss weights
        self.conflict_weight = 1.0
        self.resolution_weight = 1.0
        self.hallucination_weight = 0.3  # Reduced weight
    
    def train(self):
        """Training loop with proper tensor handling"""
        for epoch in range(self.config.num_epochs):
            self.model.train()
            total_loss = 0
            
            for batch_idx, batch in enumerate(tqdm(self.train_loader, desc=f"Epoch {epoch+1}")):
                try:
                    # Move to device
                    input_ids = batch['input_ids'].to(self.config.device)
                    attention_mask = batch['attention_mask'].to(self.config.device)
                    conflict_labels = batch['conflict_label'].to(self.config.device)
                    clearance_labels = batch['clearance_label'].to(self.config.device)
                    envelope_violations = batch['envelope_violation'].to(self.config.device)
                    
                    # Forward pass
                    conflict_logits, resolution_logits, hallucination_logits = self.model(
                        input_ids, attention_mask
                    )
                    
                    # Calculate losses with proper tensor handling
                    conflict_loss = self.conflict_loss_fn(conflict_logits, conflict_labels)
                    resolution_loss = self.resolution_loss_fn(resolution_logits, clearance_labels)
                    
                    # Fix hallucination loss - ensure proper dimensions
                    hallucination_logits_flat = hallucination_logits.squeeze(-1)  # Remove last dim if size 1
                    if hallucination_logits_flat.dim() == 0:  # If scalar, make it 1D
                        hallucination_logits_flat = hallucination_logits_flat.unsqueeze(0)
                    
                    hallucination_loss = self.hallucination_loss_fn(
                        hallucination_logits_flat, envelope_violations
                    )
                    
                    # Combined loss
                    total_batch_loss = (
                        self.conflict_weight * conflict_loss +
                        self.resolution_weight * resolution_loss +
                        self.hallucination_weight * hallucination_loss
                    )
                    
                    # Backward pass
                    total_batch_loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                    self.optimizer.step()
                    self.scheduler.step()
                    self.optimizer.zero_grad()
                    
                    total_loss += total_batch_loss.item()
                    
                except Exception as e:
                    print(f"Error in batch {batch_idx}: {e}")
                    continue
            
            avg_loss = total_loss / len(self.train_loader)
            print(f"Epoch {epoch+1}/{self.config.num_epochs}, Average Loss: {avg_loss:.4f}")
            
            # Validation
            self.validate()
    
    def validate(self):
        """Validation with proper error handling"""
        self.model.eval()
        total_loss = 0
        
        conflict_preds, conflict_labels = [], []
        clearance_preds, clearance_labels = [], []
        hallucination_preds, hallucination_labels = [], []
        
        with torch.no_grad():
            for batch in self.val_loader:
                try:
                    input_ids = batch['input_ids'].to(self.config.device)
                    attention_mask = batch['attention_mask'].to(self.config.device)
                    conflict_true = batch['conflict_label'].to(self.config.device)
                    clearance_true = batch['clearance_label'].to(self.config.device)
                    envelope_true = batch['envelope_violation'].to(self.config.device)
                    
                    conflict_logits, resolution_logits, hallucination_logits = self.model(
                        input_ids, attention_mask
                    )
                    
                    # Calculate losses
                    conflict_loss = self.conflict_loss_fn(conflict_logits, conflict_true)
                    resolution_loss = self.resolution_loss_fn(resolution_logits, clearance_true)
                    
                    hallucination_logits_flat = hallucination_logits.squeeze(-1)
                    if hallucination_logits_flat.dim() == 0:
                        hallucination_logits_flat = hallucination_logits_flat.unsqueeze(0)
                    
                    hallucination_loss = self.hallucination_loss_fn(
                        hallucination_logits_flat, envelope_true
                    )
                    
                    total_loss += conflict_loss.item() + resolution_loss.item() + hallucination_loss.item()
                    
                    # Collect predictions
                    conflict_preds.extend(torch.argmax(conflict_logits, dim=1).cpu().numpy())
                    conflict_labels.extend(conflict_true.cpu().numpy())
                    clearance_preds.extend(torch.argmax(resolution_logits, dim=1).cpu().numpy())
                    clearance_labels.extend(clearance_true.cpu().numpy())
                    hallucination_preds.extend((torch.sigmoid(hallucination_logits_flat) > 0.5).cpu().numpy())
                    hallucination_labels.extend(envelope_true.cpu().numpy())
                    
                except Exception as e:
                    continue
        
        print(f"Validation Loss: {total_loss / len(self.val_loader):.4f}")
        
        if len(conflict_labels) > 0:
            print("\n📊 Conflict Detection Report:")
            print(classification_report(conflict_labels, conflict_preds, zero_division=0))
            print("\n📊 Clearance Prediction Report:")
            print(classification_report(clearance_labels, clearance_preds, zero_division=0))
            print("\n📊 Hallucination Detection Report:")
            print(classification_report(hallucination_labels, hallucination_preds, zero_division=0))

In [9]:
# Enhanced BlueSky Integration with k-NN
class OptimizedBlueSkySimulator:
    """Optimized BlueSky simulator with k-NN and proper error handling"""
    
    def __init__(self):
        self.state = {}
        self.detector = KNNConflictDetector(config)
        print("✈️ Enhanced k-NN BlueSky Simulator initialized")
    
    def update_state(self, tracks_df: pd.DataFrame):
        """Update simulator state with validated track data"""
        if not tracks_df.empty:
            tracks_df = tracks_df.copy()
            if not pd.api.types.is_datetime64_any_dtype(tracks_df['timestamp']):
                tracks_df['timestamp'] = pd.to_datetime(tracks_df['timestamp'], errors='coerce')
            tracks_df = tracks_df.dropna(subset=['timestamp'])
            
        self.state['tracks'] = tracks_df
        print(f"📡 Simulator updated with {len(tracks_df)} track points")
    
    def get_separation(self, flight_id_1: str, flight_id_2: str) -> Tuple[float, float]:
        """Get current separation between aircraft using k-NN"""
        if 'tracks' not in self.state or self.state['tracks'].empty:
            return 0.0, 0.0
        
        tracks = self.state['tracks']
        ac1_tracks = tracks[tracks['flight_id'] == flight_id_1]
        ac2_tracks = tracks[tracks['flight_id'] == flight_id_2]
        
        if ac1_tracks.empty or ac2_tracks.empty:
            return 0.0, 0.0
        
        # Get latest positions
        ac1 = ac1_tracks.iloc[-1]
        ac2 = ac2_tracks.iloc[-1]
        
        h_dist = self.detector.haversine_distance(
            ac1['latitude'], ac1['longitude'], 
            ac2['latitude'], ac2['longitude']
        )
        v_dist = abs(ac1.get('altitude', 0) - ac2.get('altitude', 0))
        
        return h_dist, v_dist
    
    def validate_clearance(self, flight_id_1: str, flight_id_2: str, clearance_type: str) -> bool:
        """Validate clearance maintains safe separation"""
        h_dist, v_dist = self.get_separation(flight_id_1, flight_id_2)
        safe = h_dist > config.conflict_threshold_nm or v_dist > config.conflict_threshold_ft
        print(f"🔍 Clearance validation: H={h_dist:.1f}NM, V={v_dist:.0f}ft, Safe={safe}")
        return safe

def plot_enhanced_trajectories(simulator):
    """Enhanced trajectory plotting"""
    if 'tracks' not in simulator.state or simulator.state['tracks'].empty:
        print("⚠️ No tracks available for plotting")
        return
    
    tracks_df = simulator.state['tracks']
    
    plt.figure(figsize=(16, 12))
    
    # Plot sample of trajectories
    flight_ids = tracks_df['flight_id'].unique()[:15]  # Show more flights
    
    colors = plt.cm.tab20(np.linspace(0, 1, len(flight_ids)))
    
    for i, flight_id in enumerate(flight_ids):
        flight_tracks = tracks_df[tracks_df['flight_id'] == flight_id].sort_values('timestamp')
        if len(flight_tracks) > 1:
            plt.plot(flight_tracks['longitude'], flight_tracks['latitude'], 
                    color=colors[i], marker='o', markersize=1, alpha=0.7, 
                    label=f'Flight {flight_id}', linewidth=1.5)
    
    plt.xlabel('Longitude (°)', fontsize=12)
    plt.ylabel('Latitude (°)', fontsize=12)
    plt.title('Enhanced k-NN Aircraft Trajectories from SCAT Data', fontsize=14)
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left', fontsize=8)
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(config.output_path, 'enhanced_trajectories.png'), 
               dpi=300, bbox_inches='tight')
    plt.close()
    print("📈 Enhanced trajectory plot saved")

def plot_enhanced_separation(simulator, pair: Tuple[str, str]):
    """Enhanced separation plotting with error handling"""
    if 'tracks' not in simulator.state or simulator.state['tracks'].empty:
        print("⚠️ No tracks available for separation plotting")
        return
    
    tracks_df = simulator.state['tracks']
    ac1_tracks = tracks_df[tracks_df['flight_id'] == pair[0]].copy()
    ac2_tracks = tracks_df[tracks_df['flight_id'] == pair[1]].copy()
    
    if ac1_tracks.empty or ac2_tracks.empty:
        print(f"⚠️ No data for aircraft pair {pair}")
        return
    
    try:
        # Enhanced timestamp handling
        ac1_tracks['timestamp'] = pd.to_datetime(ac1_tracks['timestamp'], errors='coerce')
        ac2_tracks['timestamp'] = pd.to_datetime(ac2_tracks['timestamp'], errors='coerce')
        
        ac1_tracks = ac1_tracks.dropna(subset=['timestamp']).sort_values('timestamp')
        ac2_tracks = ac2_tracks.dropna(subset=['timestamp']).sort_values('timestamp')
        
        if ac1_tracks.empty or ac2_tracks.empty:
            print(f"⚠️ No valid timestamps for pair {pair}")
            return
        
        # Enhanced merge with better error handling
        merged = pd.merge_asof(
            ac1_tracks[['timestamp', 'latitude', 'longitude', 'altitude']].reset_index(drop=True),
            ac2_tracks[['timestamp', 'latitude', 'longitude', 'altitude']].reset_index(drop=True),
            on='timestamp',
            suffixes=('_1', '_2'),
            direction='nearest',
            tolerance=pd.Timedelta(minutes=5)
        )
        
        if merged.empty:
            print(f"⚠️ No aligned data for pair {pair}")
            return
        
        # Calculate separations
        detector = KNNConflictDetector(config)
        merged['h_dist'] = merged.apply(
            lambda row: detector.haversine_distance(
                row['latitude_1'], row['longitude_1'],
                row['latitude_2'], row['longitude_2']
            ) if pd.notna(row['latitude_1']) and pd.notna(row['latitude_2']) else np.nan, 
            axis=1
        )
        merged['v_dist'] = abs(merged['altitude_1'] - merged['altitude_2'])
        
        # Remove NaN values
        merged = merged.dropna(subset=['h_dist', 'v_dist'])
        
        if merged.empty:
            print(f"⚠️ No valid separation data for pair {pair}")
            return
        
        # Enhanced plotting
        fig, (ax1, ax2) = plt.subplots(2, 1, figsize=(16, 10))
        
        # Horizontal separation
        ax1.plot(merged['timestamp'], merged['h_dist'], 'b-', linewidth=2, label='Horizontal Distance')
        ax1.axhline(y=config.conflict_threshold_nm, color='r', linestyle='--', linewidth=2,
                   label=f'Conflict Threshold ({config.conflict_threshold_nm} NM)')
        ax1.set_ylabel('Distance (NM)', fontsize=12)
        ax1.set_title(f'k-NN Enhanced Separation Analysis: {pair[0]} vs {pair[1]}', fontsize=14)
        ax1.legend(fontsize=10)
        ax1.grid(True, alpha=0.3)
        
        # Vertical separation
        ax2.plot(merged['timestamp'], merged['v_dist']/1000, 'g-', linewidth=2, label='Vertical Distance')
        ax2.axhline(y=config.conflict_threshold_ft/1000, color='r', linestyle='--', linewidth=2,
                   label=f'Conflict Threshold ({config.conflict_threshold_ft/1000:.1f} kft)')
        ax2.set_xlabel('Time', fontsize=12)
        ax2.set_ylabel('Distance (kft)', fontsize=12)
        ax2.legend(fontsize=10)
        ax2.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(config.output_path, f'enhanced_separation_{pair[0]}_{pair[1]}.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
        print(f"📊 Enhanced separation plot saved for {pair}")
        
    except Exception as e:
        print(f"❌ Error plotting enhanced separation for {pair}: {e}")

In [10]:
# Main Enhanced Research Pipeline
def main():
    """Memory-efficient main pipeline with streaming week-by-week processing"""
    print("🎓 Starting Enhanced ML Hallucination Research Pipeline")
    print("=" * 70)
    print("🌍 STREAMING PROCESSING - ALL AVAILABLE SCAT WEEKS")
    print("💾 Memory-Optimized: One week at a time (RAM-friendly)")
    print("=" * 70)
    
    # Initialize components
    print("📦 Initializing SCAT Data Loader...")
    data_loader = OptimizedSCATDataLoader(config)
    
    # Get week count for estimation
    week_paths = data_loader.get_week_paths(max_weeks=None)
    total_weeks = len(week_paths)
    
    print(f"📅 Found {total_weeks} weeks to process")
    print(f"⚡ Estimated RAM usage: <5GB (vs >{total_weeks*2}GB for batch processing)")
    print(f"⏱️  Estimated time: {total_weeks * 3}-{total_weeks * 6} minutes")
    
    # Streaming conflict detection
    print("🌊 Starting Streaming k-NN Conflict Detection...")
    detector = KNNConflictDetector(config)
    scenarios_df = detector.detect_conflicts_streaming(data_loader, max_weeks=None)
    
    if scenarios_df.empty:
        print("❌ No scenarios generated from streaming processing. Cannot proceed.")
        return
    
    print(f"\n📊 Final Dataset Statistics:")
    print(f"   - Total scenarios: {len(scenarios_df):,}")
    print(f"   - Conflicts: {scenarios_df['conflict'].sum():,}")
    print(f"   - Non-conflicts: {(~scenarios_df['conflict']).sum():,}")
    print(f"   - Conflict rate: {scenarios_df['conflict'].mean()*100:.2f}%")
    
    # Memory cleanup before ML pipeline
    gc.collect()
    print("🧹 Memory cleanup before ML training")
    
    # Initialize ML pipeline
    print("🤖 Initializing Enhanced ML Pipeline...")
    tokenizer = DistilBertTokenizer.from_pretrained(config.model_name)
    
    # Create dataset with memory monitoring
    print("📚 Creating hallucination-aware dataset...")
    dataset = FixedHallucinationATCDataset(scenarios_df, tokenizer, config.max_sequence_length)
    
    # Memory cleanup after dataset creation
    del scenarios_df  # Free the original dataframe
    gc.collect()
    
    # Split data efficiently
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    print(f"   - Training samples: {train_size:,}")
    print(f"   - Validation samples: {val_size:,}")
    
    # Create data loaders with memory-friendly settings
    train_loader = DataLoader(
        train_dataset, 
        batch_size=config.batch_size, 
        shuffle=True, 
        num_workers=0,  # Avoid multiprocessing overhead
        pin_memory=False  # Reduce memory pressure
    )
    val_loader = DataLoader(
        val_dataset, 
        batch_size=config.batch_size, 
        num_workers=0,
        pin_memory=False
    )
    
    # Initialize model
    num_clearance_types = len(dataset.clearance_encoder.classes_)
    model = FixedHallucinationAwareATCModel(config.model_name, num_clearance_types)
    
    print(f"🤖 Model initialized with {num_clearance_types} clearance types")
    
    # Training with memory monitoring
    print("🎯 Training Enhanced Hallucination-Aware Model...")
    trainer = FixedHallucinationAwareTrainer(model, train_loader, val_loader, config)
    trainer.train()
    
    # Save model
    model_path = os.path.join(config.model_save_path, 'streaming_hallucination_aware_atc_model.pt')
    torch.save(model.state_dict(), model_path)
    print(f"💾 Model saved to {model_path}")
    
    # Memory cleanup before testing
    del train_loader, val_loader, train_dataset, val_dataset
    gc.collect()
    
    # Testing phase with lightweight data
    print("✈️ Testing with Sample Data (Memory-Efficient)...")
    
    # Load a small sample for testing
    test_week_path = week_paths[0]  # Use first week for testing
    test_flights, test_clearances, test_tracks = data_loader.load_single_week_processed(test_week_path)
    
    if not test_tracks.empty:
        simulator = OptimizedBlueSkySimulator()
        simulator.update_state(test_tracks)
        
        # Generate sample visualizations
        print("📊 Generating Sample Visualizations...")
        plot_enhanced_trajectories(simulator)
        
        # Test clearance validation if we have data
        if len(test_tracks['flight_id'].unique()) >= 2:
            flight_ids = test_tracks['flight_id'].unique()[:2]
            clearance_valid = simulator.validate_clearance(
                flight_ids[0], flight_ids[1], "altitude_change"
            )
            print(f"✅ Sample clearance validation result: {clearance_valid}")
    
    # Final cleanup
    del test_flights, test_clearances, test_tracks
    gc.collect()
    
    print("🎓 Streaming research pipeline completed successfully!")
    print(f"📁 Results saved to: {config.output_path}")
    print("🌊 Streaming approach enabled processing of complete dataset within RAM limits")
    print("🌍 All SCAT weeks analyzed for comprehensive hallucination research")
    print("=" * 70)

if __name__ == "__main__":
    main()

🎓 Starting Enhanced ML Hallucination Research Pipeline
🌍 STREAMING PROCESSING - ALL AVAILABLE SCAT WEEKS
💾 Memory-Optimized: One week at a time (RAM-friendly)
📦 Initializing SCAT Data Loader...
ℹ️ Found 13 SCAT data sources
📅 Found 13 weeks to process
⚡ Estimated RAM usage: <5GB (vs >26GB for batch processing)
⏱️  Estimated time: 39-78 minutes
🌊 Starting Streaming k-NN Conflict Detection...
ℹ️ Found 13 SCAT data sources
🌊 Streaming conflict detection across 13 weeks
💾 Memory-efficient: Processing one week at a time
⚡ Expected RAM usage: <5GB per week vs >26GB for all weeks


🔄 Streaming weeks:   0%|          | 0/13 [00:00<?, ?it/s]


📅 Week 1/13: scat20161015_20161021
🗂️  Loading week: scat20161015_20161021
   📊 Week stats: 13,208 flights, 93,407 clearances, 4,571,496 tracks
   🚨 Found 0 conflicts in 177 scenarios

📅 Week 2/13: scat20161112_20161118
🗂️  Loading week: scat20161112_20161118
   📊 Week stats: 12,355 flights, 88,803 clearances, 4,328,981 tracks
   🚨 Found 1 conflicts in 108 scenarios

📅 Week 3/13: scat20161210_20161216
🗂️  Loading week: scat20161210_20161216
   📊 Week stats: 12,180 flights, 87,428 clearances, 4,344,936 tracks
   🚨 Found 2 conflicts in 129 scenarios
   🧹 Memory cleanup - processed 3/13 weeks

📅 Week 4/13: scat20170107_20170113
🗂️  Loading week: scat20170107_20170113
   📊 Week stats: 11,281 flights, 76,232 clearances, 3,940,857 tracks
   🚨 Found 0 conflicts in 129 scenarios

📅 Week 5/13: scat20170215_20170221
🗂️  Loading week: scat20170215_20170221
   📊 Week stats: 11,681 flights, 81,891 clearances, 4,169,377 tracks
   🚨 Found 0 conflicts in 111 scenarios

📅 Week 6/13: scat20170304_20170

tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

📚 Creating hallucination-aware dataset...
📊 Dataset created: 2028 scenarios
   - Conflicts: 11
   - Non-conflicts: 2017
   - Envelope violations: 0
   - Training samples: 1,622
   - Validation samples: 406


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

🤖 Model initialized with 2 clearance types
🎯 Training Enhanced Hallucination-Aware Model...


Epoch 1:   0%|          | 0/102 [00:00<?, ?it/s]

Epoch 1/3, Average Loss: 0.8724
Validation Loss: 0.6034

📊 Conflict Detection Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       405
           1       0.00      0.00      0.00         1

    accuracy                           1.00       406
   macro avg       0.50      0.50      0.50       406
weighted avg       1.00      1.00      1.00       406


📊 Clearance Prediction Report:
              precision    recall  f1-score   support

           0       0.00      0.00      0.00        93
           1       0.77      1.00      0.87       313

    accuracy                           0.77       406
   macro avg       0.39      0.50      0.44       406
weighted avg       0.59      0.77      0.67       406


📊 Hallucination Detection Report:
              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00       406

    accuracy                           1.00       406
   macro avg       1.00      

Epoch 2:   0%|          | 0/102 [00:00<?, ?it/s]

Epoch 2/3, Average Loss: 0.3910
Validation Loss: 0.3286

📊 Conflict Detection Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       405
           1       0.00      0.00      0.00         1

    accuracy                           1.00       406
   macro avg       0.50      0.50      0.50       406
weighted avg       1.00      1.00      1.00       406


📊 Clearance Prediction Report:
              precision    recall  f1-score   support

           0       0.61      0.94      0.74        93
           1       0.98      0.82      0.89       313

    accuracy                           0.85       406
   macro avg       0.79      0.88      0.81       406
weighted avg       0.89      0.85      0.86       406


📊 Hallucination Detection Report:
              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00       406

    accuracy                           1.00       406
   macro avg       1.00      

Epoch 3:   0%|          | 0/102 [00:00<?, ?it/s]

Epoch 3/3, Average Loss: 0.3138
Validation Loss: 0.3092

📊 Conflict Detection Report:
              precision    recall  f1-score   support

           0       1.00      1.00      1.00       405
           1       0.00      0.00      0.00         1

    accuracy                           1.00       406
   macro avg       0.50      0.50      0.50       406
weighted avg       1.00      1.00      1.00       406


📊 Clearance Prediction Report:
              precision    recall  f1-score   support

           0       0.63      0.77      0.70        93
           1       0.93      0.87      0.90       313

    accuracy                           0.84       406
   macro avg       0.78      0.82      0.80       406
weighted avg       0.86      0.84      0.85       406


📊 Hallucination Detection Report:
              precision    recall  f1-score   support

         0.0       1.00      1.00      1.00       406

    accuracy                           1.00       406
   macro avg       1.00      