# ML-Based Air Traffic Control: Conflict Detection & Hallucination Analysis
## Master's Thesis: Simulation and quantification of ML-based hallucination effects on safety margins in enroute control

In [1]:
# BlueSky integration
BLUESKY_AVAILABLE = True
try:
    import bluesky
    BLUESKY_AVAILABLE = True
    print("✅ BlueSky available")
except ImportError:
    !pip -qq install bluesky

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m356.8/356.8 kB[0m [31m9.4 MB/s[0m eta [36m0:00:00[0m:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.7/66.7 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cesium 0.12.4 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.[0m[31m
[0m

In [2]:
import os
import gc
import json
import warnings
import zipfile
from pathlib import Path
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
from collections import defaultdict

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler, LabelEncoder
from sklearn.metrics import confusion_matrix, classification_report, f1_score

# Transformers and NLP
from transformers import DistilBertTokenizer, DistilBertModel, get_linear_schedule_with_warmup
from torch.optim import AdamW

# Suppress warnings
warnings.filterwarnings('ignore')

# Set random seeds for reproducibility
SEED = 42
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(SEED)

2025-06-11 09:13:42.330958: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749633222.654170      35 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749633222.752755      35 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


In [3]:
# Configuration for Thesis Research
@dataclass
class ThesisConfig:
    """Configuration for ML hallucination research in ATC"""
    # Data paths
    scat_data_path: str = "/kaggle/input/swedish-civil-air-traffic-control-scat-dataset"
    output_path: str = "./thesis_results"
    model_save_path: str = "./thesis_models"
    
    # Model configuration for hallucination analysis
    model_name: str = "distilbert-base-uncased"
    max_sequence_length: int = 128
    batch_size: int = 32
    learning_rate: float = 3e-5
    num_epochs: int = 5
    warmup_steps: int = 500
    
    # Conflict detection parameters (real operational thresholds)
    conflict_threshold_nm: float = 5.0  # Nautical miles
    conflict_threshold_ft: float = 1000.0  # Feet
    time_horizon_seconds: int = 120  # Look-ahead time
    
    # Hallucination detection parameters
    uncertainty_threshold: float = 0.3
    mc_dropout_samples: int = 20
    confidence_threshold: float = 0.8
    
    # Training data envelope parameters
    altitude_min: float = 10000  # ft
    altitude_max: float = 50000  # ft
    speed_min: float = 200  # kt
    speed_max: float = 600  # kt
    
    # Device
    device: str = "cuda" if torch.cuda.is_available() else "cpu"

config = ThesisConfig()

# Create output directories
os.makedirs(config.output_path, exist_ok=True)
os.makedirs(config.model_save_path, exist_ok=True)

print(f"🎓 Thesis Configuration loaded. Using device: {config.device}")
print(f"📊 Research Focus: ML Hallucination Quantification in ATC")

🎓 Thesis Configuration loaded. Using device: cpu
📊 Research Focus: ML Hallucination Quantification in ATC


In [4]:
# Enhanced SCAT Data Loader - Real Data Only
class SCATDataLoader:
    """SCAT data loader optimized for thesis research - no synthetic data"""

    def __init__(self, config: ThesisConfig):
        self.config = config
        self.data_path = Path(config.scat_data_path)
        self.logger = self._setup_logger()
        self.airspace_data = []
        self.weather_data = []

    def _setup_logger(self):
        class SimpleLogger:
            def info(self, msg): print(f"ℹ️ {msg}")
            def warning(self, msg): print(f"⚠️ {msg}")
            def error(self, msg): print(f"❌ {msg}")
        return SimpleLogger()

    def scan_scat_folders(self) -> List[Path]:
        """Scan for SCAT data sources"""
        search_path = self.data_path
        if not search_path.exists():
            self.logger.error(f"SCAT data path does not exist: {search_path}")
            return []

        scat_folders = [p for p in search_path.iterdir()
                        if (p.is_dir() and p.name.lower().startswith("scat")) or
                           (p.suffix == ".zip" and p.name.lower().startswith("scat"))]

        self.logger.info(f"Found {len(scat_folders)} SCAT data sources")
        return sorted(scat_folders)

    def load_all_weeks(self, max_weeks: int = 12) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """Load multiple weeks of real SCAT data"""
        week_paths = self.scan_scat_folders()[:max_weeks] if max_weeks else self.scan_scat_folders()

        flights_all, clearances_all, tracks_all = [], [], []

        for wpath in tqdm(week_paths, desc="📦 Loading weeks"):
            f_df, c_df, t_df = self._load_single_week(wpath)
            flights_all.append(f_df)
            clearances_all.append(c_df)
            tracks_all.append(t_df)

        flights_df = pd.concat(flights_all, ignore_index=True) if flights_all else pd.DataFrame()
        clearances_df = pd.concat(clearances_all, ignore_index=True) if clearances_all else pd.DataFrame()
        tracks_df = pd.concat(tracks_all, ignore_index=True) if tracks_all else pd.DataFrame()

        # Clean and validate timestamps
        tracks_df = self._clean_timestamps(tracks_df)
        clearances_df = self._clean_timestamps(clearances_df)

        self.logger.info(f"✅ Loaded {len(flights_df):,} flights, "
                         f"{len(clearances_df):,} clearances, "
                         f"{len(tracks_df):,} track points from "
                         f"{len(week_paths)} week(s)")
        return flights_df, clearances_df, tracks_df

    def _clean_timestamps(self, df: pd.DataFrame) -> pd.DataFrame:
        """Clean and standardize timestamps"""
        if df.empty or 'timestamp' not in df.columns:
            return df
        
        df = df.copy()
        df['timestamp'] = pd.to_datetime(df['timestamp'], format='mixed', errors='coerce')
        df = df.dropna(subset=['timestamp'])
        df = df.sort_values('timestamp').reset_index(drop=True)
        return df

    def _load_single_week(self, week_path: Path) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:
        """Load single week of SCAT data"""
        week_flights, week_clearances, week_tracks = [], [], []

        if week_path.suffix == ".zip":
            with zipfile.ZipFile(week_path, "r") as zf:
                json_files = [f for f in zf.namelist() if f.endswith(".json")]
                for jf in tqdm(json_files, desc=f"Unzipping {week_path.name}", leave=False):
                    try:
                        with zf.open(jf) as fp:
                            data = json.load(fp)
                        self._dispatch_json(jf, data, week_flights, week_clearances, week_tracks)
                    except Exception as e:
                        continue
                self._load_auxiliary_data(zf)

        elif week_path.is_dir():
            json_paths = list(week_path.glob("*.json"))
            for jp in tqdm(json_paths, desc=f"Scanning {week_path.name}", leave=False):
                try:
                    with open(jp) as fp:
                        data = json.load(fp)
                    self._dispatch_json(jp.name, data, week_flights, week_clearances, week_tracks)
                except Exception as e:
                    continue

        return (pd.DataFrame(week_flights),
                pd.DataFrame(week_clearances),
                pd.DataFrame(week_tracks))

    def _dispatch_json(self, name, data, flights, clearances, tracks):
        """Route JSON content appropriately"""
        if name in ("airspace.json", "grib_meteo.json"):
            (self.airspace_data if "airspace" in name else self.weather_data).append(data)
        else:
            self._extract_flight_features(data, flights, clearances, tracks)

    def _extract_flight_features(self, flight_data: Dict, flights: list, clearances: list, tracks: list):
        """Extract real flight data features"""
        flight_id = flight_data.get('Id') or flight_data.get('id')
        if not flight_id:
            return
        
        # Extract flight plan data
        fpl_key = next((k for k in ['Fpl', 'fpl'] if k in flight_data), None)
        if fpl_key and 'fpl_base' in flight_data[fpl_key]:
            for base_data in flight_data[fpl_key]['fpl_base']:
                flight_record = {
                    'flight_id': flight_id,
                    'callsign': base_data.get('Callsign') or base_data.get('callsign'),
                    'aircraft_type': base_data.get('aircraft_type'),
                    'departure': base_data.get('Adep') or base_data.get('adep'),
                    'destination': base_data.get('Ades') or base_data.get('ades'),
                    'timestamp': base_data.get('time_stamp')
                }
                flights.append(flight_record)
        
        # Extract clearance data
        if fpl_key and 'fpl_clearance' in flight_data[fpl_key]:
            for clearance in flight_data[fpl_key]['fpl_clearance']:
                clearance_record = {
                    'flight_id': flight_id,
                    'timestamp': clearance.get('time_stamp'),
                    'cleared_flight_level': clearance.get('Cfl') or clearance.get('cfl'),
                    'assigned_speed': clearance.get('assigned_speed_val'),
                    'assigned_heading': clearance.get('assigned_heading_val')
                }
                clearances.append(clearance_record)
        
        # Extract track data
        plots_key = next((k for k in ['Plots', 'plots'] if k in flight_data), None)
        if plots_key:
            for plot in flight_data[plots_key]:
                if 'I062/105' in plot and 'I062/136' in plot:
                    try:
                        track_record = {
                            'flight_id': flight_id,
                            'timestamp': plot.get('time_of_track'),
                            'latitude': plot['I062/105'].get('lat'),
                            'longitude': plot['I062/105'].get('lon'),
                            'altitude': float(plot['I062/136'].get('measured_flight_level', 0)) * 100
                        }
                        
                        # Add velocity information if available
                        if 'I062/185' in plot:
                            vx = plot['I062/185'].get('vx', 0)
                            vy = plot['I062/185'].get('vy', 0)
                            track_record['ground_speed'] = np.sqrt(vx**2 + vy**2) * 1.94384  # m/s to kt
                            track_record['heading'] = np.degrees(np.arctan2(vx, vy)) % 360
                        
                        # Validate data is within operational envelope
                        if self._validate_track_data(track_record):
                            tracks.append(track_record)
                    except (ValueError, TypeError):
                        continue

    def _validate_track_data(self, track: Dict) -> bool:
        """Validate track data is within operational envelope"""
        try:
            altitude = track.get('altitude', 0)
            speed = track.get('ground_speed', 0)
            lat = track.get('latitude', 0)
            lon = track.get('longitude', 0)
            
            # Check operational envelope
            if not (self.config.altitude_min <= altitude <= self.config.altitude_max):
                return False
            if not (self.config.speed_min <= speed <= self.config.speed_max):
                return False
            if not (-90 <= lat <= 90 and -180 <= lon <= 180):
                return False
            
            return True
        except (ValueError, TypeError):
            return False

    def _load_auxiliary_data(self, zip_ref):
        """Load auxiliary airspace and weather data"""
        for aux in ("airspace.json", "grib_meteo.json"):
            try:
                if aux in zip_ref.namelist():
                    with zip_ref.open(aux) as f:
                        data = json.load(f)
                        (self.airspace_data if "airspace" in aux else self.weather_data).append(data)
            except Exception:
                pass

In [5]:
# Real Conflict Detection - No Synthetic Data
class RealConflictDetector:
    """Detect real conflicts from operational data only"""
    
    def __init__(self, config: ThesisConfig):
        self.config = config
        self.conflict_threshold_nm = config.conflict_threshold_nm
        self.conflict_threshold_ft = config.conflict_threshold_ft
        self.time_horizon = config.time_horizon_seconds
    
    def haversine_distance(self, lat1, lon1, lat2, lon2):
        """Calculate distance between two points in nautical miles"""
        R = 3440.065  # Earth radius in nautical miles
        lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
        dlat = lat2 - lat1
        dlon = lon2 - lon1
        a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
        c = 2 * np.arcsin(np.sqrt(a))
        return R * c
    
    def detect_real_conflicts(self, tracks_df: pd.DataFrame, clearances_df: pd.DataFrame) -> pd.DataFrame:
        """Detect real conflicts from operational data - no synthetic generation"""
        if tracks_df.empty:
            print("No track data available for conflict detection")
            return pd.DataFrame()
        
        conflicts = []
        tracks_df = tracks_df.copy()
        
        # Ensure timestamps are datetime
        if not pd.api.types.is_datetime64_any_dtype(tracks_df['timestamp']):
            tracks_df['timestamp'] = pd.to_datetime(tracks_df['timestamp'], errors='coerce')
        
        tracks_df = tracks_df.dropna(subset=['timestamp', 'latitude', 'longitude', 'altitude'])
        
        if tracks_df.empty:
            print("No valid track data after cleaning")
            return pd.DataFrame()
        
        # Sample time windows for analysis
        unique_times = tracks_df['timestamp'].unique()
        sample_size = min(len(unique_times), 2000)  # Reasonable sample for real analysis
        sampled_times = np.random.choice(unique_times, size=sample_size, replace=False)
        
        print(f"🔍 Analyzing {sample_size} time windows for real conflicts")
        
        for time_sample in tqdm(sampled_times, desc="Detecting real conflicts"):
            # Get aircraft at this time
            current_aircraft = tracks_df[tracks_df['timestamp'] == time_sample]
            
            if len(current_aircraft) < 2:
                continue
            
            # Group by flight to handle multiple positions
            aircraft_groups = current_aircraft.groupby('flight_id').last().reset_index()
            
            if len(aircraft_groups) < 2:
                continue
            
            # Check all pairs for conflicts
            for i in range(len(aircraft_groups)):
                for j in range(i+1, len(aircraft_groups)):
                    ac1 = aircraft_groups.iloc[i]
                    ac2 = aircraft_groups.iloc[j]
                    
                    # Calculate current separation
                    h_dist = self.haversine_distance(
                        ac1['latitude'], ac1['longitude'],
                        ac2['latitude'], ac2['longitude']
                    )
                    v_dist = abs(ac1['altitude'] - ac2['altitude'])
                    
                    # Check if this is a real conflict (within thresholds)
                    if h_dist < self.conflict_threshold_nm and v_dist < self.conflict_threshold_ft:
                        # This is a real conflict - check for associated clearances
                        clearance_info = self._find_associated_clearances(
                            time_sample, [ac1['flight_id'], ac2['flight_id']], clearances_df
                        )
                        
                        conflict_record = {
                            'timestamp': time_sample,
                            'flight_id_1': ac1['flight_id'],
                            'flight_id_2': ac2['flight_id'],
                            'horizontal_distance': h_dist,
                            'vertical_distance': v_dist,
                            'conflict': True,
                            'clearance_issued': clearance_info['issued'],
                            'clearance_type': clearance_info['type'],
                            'altitude_1': ac1['altitude'],
                            'altitude_2': ac2['altitude'],
                            'speed_1': ac1.get('ground_speed', 450),
                            'speed_2': ac2.get('ground_speed', 450),
                            'heading_1': ac1.get('heading', 90),
                            'heading_2': ac2.get('heading', 90)
                        }
                        conflicts.append(conflict_record)
        
        conflicts_df = pd.DataFrame(conflicts)
        print(f"🚨 Found {len(conflicts_df)} real conflicts")
        
        # Add non-conflict samples from same operational data
        non_conflicts = self._extract_real_non_conflicts(tracks_df, len(conflicts_df))
        
        # Combine real conflicts and non-conflicts
        all_scenarios = pd.concat([conflicts_df, non_conflicts], ignore_index=True)
        
        return all_scenarios
    
    def _find_associated_clearances(self, timestamp: pd.Timestamp, flight_ids: List[str], 
                                   clearances_df: pd.DataFrame) -> Dict:
        """Find clearances associated with a conflict"""
        if clearances_df.empty:
            return {'issued': False, 'type': 'none'}
        
        # Look for clearances within time window
        time_window = pd.Timedelta(minutes=5)
        clearance_mask = (
            (clearances_df['timestamp'] >= timestamp - time_window) &
            (clearances_df['timestamp'] <= timestamp + time_window) &
            (clearances_df['flight_id'].isin(flight_ids))
        )
        
        relevant_clearances = clearances_df[clearance_mask]
        
        if len(relevant_clearances) == 0:
            return {'issued': False, 'type': 'none'}
        
        # Determine clearance type
        clearance_type = self._classify_clearance_type(relevant_clearances)
        return {'issued': True, 'type': clearance_type}
    
    def _classify_clearance_type(self, clearances: pd.DataFrame) -> str:
        """Classify the type of clearance"""
        if clearances['cleared_flight_level'].notna().any():
            return "altitude_change"
        elif clearances['assigned_heading'].notna().any():
            return "heading_change"
        elif clearances['assigned_speed'].notna().any():
            return "speed_change"
        else:
            return "other"
    
    def _extract_real_non_conflicts(self, tracks_df: pd.DataFrame, num_conflicts: int) -> pd.DataFrame:
        """Extract real non-conflict scenarios from operational data"""
        non_conflicts = []
        
        # Sample different time windows
        unique_times = tracks_df['timestamp'].unique()
        sample_times = np.random.choice(unique_times, size=min(len(unique_times), num_conflicts * 2), replace=False)
        
        for time_sample in sample_times:
            current_aircraft = tracks_df[tracks_df['timestamp'] == time_sample]
            aircraft_groups = current_aircraft.groupby('flight_id').last().reset_index()
            
            if len(aircraft_groups) < 2:
                continue
            
            # Find pairs with safe separation
            for i in range(len(aircraft_groups)):
                for j in range(i+1, len(aircraft_groups)):
                    ac1 = aircraft_groups.iloc[i]
                    ac2 = aircraft_groups.iloc[j]
                    
                    h_dist = self.haversine_distance(
                        ac1['latitude'], ac1['longitude'],
                        ac2['latitude'], ac2['longitude']
                    )
                    v_dist = abs(ac1['altitude'] - ac2['altitude'])
                    
                    # Non-conflict: safe separation
                    if h_dist > self.conflict_threshold_nm * 1.5 or v_dist > self.conflict_threshold_ft * 1.5:
                        non_conflict_record = {
                            'timestamp': time_sample,
                            'flight_id_1': ac1['flight_id'],
                            'flight_id_2': ac2['flight_id'],
                            'horizontal_distance': h_dist,
                            'vertical_distance': v_dist,
                            'conflict': False,
                            'clearance_issued': False,
                            'clearance_type': 'none',
                            'altitude_1': ac1['altitude'],
                            'altitude_2': ac2['altitude'],
                            'speed_1': ac1.get('ground_speed', 450),
                            'speed_2': ac2.get('ground_speed', 450),
                            'heading_1': ac1.get('heading', 90),
                            'heading_2': ac2.get('heading', 90)
                        }
                        non_conflicts.append(non_conflict_record)
                        
                        if len(non_conflicts) >= num_conflicts:
                            break
                
                if len(non_conflicts) >= num_conflicts:
                    break
                    
            if len(non_conflicts) >= num_conflicts:
                break
        
        return pd.DataFrame(non_conflicts)

In [6]:
# Hallucination-Aware Dataset for Thesis
class HallucinationATCDataset(Dataset):
    """Dataset designed for hallucination analysis in ATC ML models"""
    
    def __init__(self, scenarios_df: pd.DataFrame, tokenizer, max_length: int = 128):
        self.scenarios_df = scenarios_df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Encode clearance types
        self.clearance_encoder = LabelEncoder()
        self.scenarios_df['clearance_label'] = self.clearance_encoder.fit_transform(
            self.scenarios_df['clearance_type'].fillna('none')
        )
        
        # Add envelope indicators for hallucination analysis
        self.scenarios_df['outside_training_envelope'] = self._detect_envelope_violations()
    
    def _detect_envelope_violations(self) -> pd.Series:
        """Detect scenarios outside typical training envelope"""
        violations = (
            (self.scenarios_df['altitude_1'] < config.altitude_min) |
            (self.scenarios_df['altitude_1'] > config.altitude_max) |
            (self.scenarios_df['altitude_2'] < config.altitude_min) |
            (self.scenarios_df['altitude_2'] > config.altitude_max) |
            (self.scenarios_df['speed_1'] < config.speed_min) |
            (self.scenarios_df['speed_1'] > config.speed_max) |
            (self.scenarios_df['speed_2'] < config.speed_min) |
            (self.scenarios_df['speed_2'] > config.speed_max)
        )
        return violations
    
    def __len__(self):
        return len(self.scenarios_df)
    
    def __getitem__(self, idx):
        row = self.scenarios_df.iloc[idx]
        
        # Create scenario description
        text = self._create_scenario_text(row)
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'conflict_label': torch.tensor(1 if row['conflict'] else 0, dtype=torch.long),
            'clearance_label': torch.tensor(row['clearance_label'], dtype=torch.long),
            'envelope_violation': torch.tensor(row['outside_training_envelope'], dtype=torch.float)
        }
    
    def _create_scenario_text(self, row):
        """Create natural language scenario description"""
        text = (
            f"Aircraft A at FL{int(row['altitude_1']/100):03d} "
            f"heading {int(row['heading_1']):03d}° "
            f"speed {int(row['speed_1']):03d} kt; "
            f"Aircraft B at FL{int(row['altitude_2']/100):03d} "
            f"heading {int(row['heading_2']):03d}° "
            f"speed {int(row['speed_2']):03d} kt; "
            f"horizontal separation {row['horizontal_distance']:.1f} NM; "
            f"vertical separation {row['vertical_distance']:.0f} ft."
        )
        return text

In [7]:
# Hallucination-Aware Model Architecture
class HallucinationAwareATCModel(nn.Module):
    """ATC model with built-in hallucination detection capabilities"""
    
    def __init__(self, model_name: str, num_clearance_types: int, dropout_rate: float = 0.1):
        super().__init__()
        
        self.bert = DistilBertModel.from_pretrained(model_name)
        hidden_size = self.bert.config.hidden_size
        
        # Main task heads
        self.conflict_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 2, 2)
        )
        
        self.resolution_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 2, num_clearance_types)
        )
        
        # Hallucination detection head
        self.hallucination_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 4, 1),
            nn.Sigmoid()
        )
        
        self.dropout_rate = dropout_rate
    
    def forward(self, input_ids, attention_mask, enable_dropout=False):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]
        
        if enable_dropout:
            pooled_output = F.dropout(pooled_output, p=self.dropout_rate, training=True)
        
        conflict_logits = self.conflict_head(pooled_output)
        resolution_logits = self.resolution_head(pooled_output)
        hallucination_score = self.hallucination_head(pooled_output)
        
        return conflict_logits, resolution_logits, hallucination_score
    
    def get_uncertainty_estimates(self, input_ids, attention_mask, n_samples=20):
        """Estimate model uncertainty using Monte Carlo dropout"""
        self.train()  # Enable dropout
        
        conflict_samples = []
        resolution_samples = []
        hallucination_samples = []
        
        with torch.no_grad():
            for _ in range(n_samples):
                conflict_logits, resolution_logits, hallucination_score = self.forward(
                    input_ids, attention_mask, enable_dropout=True
                )
                conflict_samples.append(F.softmax(conflict_logits, dim=-1))
                resolution_samples.append(F.softmax(resolution_logits, dim=-1))
                hallucination_samples.append(hallucination_score)
        
        # Calculate statistics
        conflict_probs = torch.stack(conflict_samples)
        resolution_probs = torch.stack(resolution_samples)
        hallucination_probs = torch.stack(hallucination_samples)
        
        results = {
            'conflict_mean': conflict_probs.mean(dim=0),
            'conflict_std': conflict_probs.std(dim=0),
            'conflict_entropy': -torch.sum(conflict_probs.mean(dim=0) * torch.log(conflict_probs.mean(dim=0) + 1e-8), dim=-1),
            'resolution_mean': resolution_probs.mean(dim=0),
            'resolution_std': resolution_probs.std(dim=0),
            'hallucination_mean': hallucination_probs.mean(dim=0),
            'hallucination_std': hallucination_probs.std(dim=0)
        }
        
        return results

In [8]:
# Training Pipeline with Hallucination Loss
class HallucinationAwareTrainer:
    """Training pipeline incorporating hallucination detection"""
    
    def __init__(self, model, train_loader, val_loader, config):
        self.model = model.to(config.device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.config = config
        
        self.optimizer = AdamW(model.parameters(), lr=config.learning_rate, weight_decay=1e-5)
        total_steps = len(train_loader) * config.num_epochs
        self.scheduler = get_linear_schedule_with_warmup(
            self.optimizer, num_warmup_steps=config.warmup_steps, num_training_steps=total_steps
        )
        
        # Loss functions
        self.conflict_loss_fn = nn.CrossEntropyLoss()
        self.resolution_loss_fn = nn.CrossEntropyLoss()
        self.hallucination_loss_fn = nn.BCELoss()
        
        # Loss weights
        self.conflict_weight = 1.0
        self.resolution_weight = 1.0
        self.hallucination_weight = 0.5
    
    def train(self):
        """Training loop with hallucination awareness"""
        for epoch in range(self.config.num_epochs):
            self.model.train()
            total_loss = 0
            
            for batch in tqdm(self.train_loader, desc=f"Epoch {epoch+1}"):
                # Move to device
                input_ids = batch['input_ids'].to(self.config.device)
                attention_mask = batch['attention_mask'].to(self.config.device)
                conflict_labels = batch['conflict_label'].to(self.config.device)
                clearance_labels = batch['clearance_label'].to(self.config.device)
                envelope_violations = batch['envelope_violation'].to(self.config.device)
                
                # Forward pass
                conflict_logits, resolution_logits, hallucination_scores = self.model(
                    input_ids, attention_mask
                )
                
                # Calculate losses
                conflict_loss = self.conflict_loss_fn(conflict_logits, conflict_labels)
                resolution_loss = self.resolution_loss_fn(resolution_logits, clearance_labels)
                hallucination_loss = self.hallucination_loss_fn(
                    hallucination_scores.squeeze(), envelope_violations
                )
                
                # Combined loss
                total_batch_loss = (
                    self.conflict_weight * conflict_loss +
                    self.resolution_weight * resolution_loss +
                    self.hallucination_weight * hallucination_loss
                )
                
                # Backward pass
                total_batch_loss.backward()
                torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
                self.optimizer.step()
                self.scheduler.step()
                self.optimizer.zero_grad()
                
                total_loss += total_batch_loss.item()
            
            avg_loss = total_loss / len(self.train_loader)
            print(f"Epoch {epoch+1}/{self.config.num_epochs}, Average Loss: {avg_loss:.4f}")
            
            # Validation
            self.validate()
    
    def validate(self):
        """Validation with hallucination analysis"""
        self.model.eval()
        total_loss = 0
        
        conflict_preds, conflict_labels = [], []
        clearance_preds, clearance_labels = [], []
        hallucination_preds, hallucination_labels = [], []
        
        with torch.no_grad():
            for batch in self.val_loader:
                input_ids = batch['input_ids'].to(self.config.device)
                attention_mask = batch['attention_mask'].to(self.config.device)
                conflict_true = batch['conflict_label'].to(self.config.device)
                clearance_true = batch['clearance_label'].to(self.config.device)
                envelope_true = batch['envelope_violation'].to(self.config.device)
                
                conflict_logits, resolution_logits, hallucination_scores = self.model(
                    input_ids, attention_mask
                )
                
                # Calculate losses
                conflict_loss = self.conflict_loss_fn(conflict_logits, conflict_true)
                resolution_loss = self.resolution_loss_fn(resolution_logits, clearance_true)
                hallucination_loss = self.hallucination_loss_fn(
                    hallucination_scores.squeeze(), envelope_true
                )
                
                total_loss += conflict_loss.item() + resolution_loss.item() + hallucination_loss.item()
                
                # Collect predictions
                conflict_preds.extend(torch.argmax(conflict_logits, dim=1).cpu().numpy())
                conflict_labels.extend(conflict_true.cpu().numpy())
                clearance_preds.extend(torch.argmax(resolution_logits, dim=1).cpu().numpy())
                clearance_labels.extend(clearance_true.cpu().numpy())
                hallucination_preds.extend((hallucination_scores.squeeze() > 0.5).cpu().numpy())
                hallucination_labels.extend(envelope_true.cpu().numpy())
        
        print(f"Validation Loss: {total_loss / len(self.val_loader):.4f}")
        print("\n📊 Conflict Detection Report:")
        print(classification_report(conflict_labels, conflict_preds))
        print("\n📊 Clearance Prediction Report:")
        print(classification_report(clearance_labels, clearance_preds))
        print("\n📊 Hallucination Detection Report:")
        print(classification_report(hallucination_labels, hallucination_preds))

In [9]:
# Fixed BlueSky Integration and Visualization
class EnhancedBlueSkySimulator:
    """Enhanced BlueSky simulator with proper timestamp handling"""
    
    def __init__(self):
        self.state = {}
        self.detector = RealConflictDetector(config)
        print("✈️ Enhanced BlueSky Simulator initialized")
    
    def update_state(self, tracks_df: pd.DataFrame):
        """Update simulator state with validated track data"""
        if not tracks_df.empty:
            # Ensure proper timestamp format
            tracks_df = tracks_df.copy()
            if not pd.api.types.is_datetime64_any_dtype(tracks_df['timestamp']):
                tracks_df['timestamp'] = pd.to_datetime(tracks_df['timestamp'], errors='coerce')
            tracks_df = tracks_df.dropna(subset=['timestamp'])
            
        self.state['tracks'] = tracks_df
        print(f"📡 Simulator updated with {len(tracks_df)} track points")
    
    def get_separation(self, flight_id_1: str, flight_id_2: str) -> Tuple[float, float]:
        """Get current separation between aircraft"""
        if 'tracks' not in self.state or self.state['tracks'].empty:
            return 0.0, 0.0
        
        tracks = self.state['tracks']
        ac1_tracks = tracks[tracks['flight_id'] == flight_id_1]
        ac2_tracks = tracks[tracks['flight_id'] == flight_id_2]
        
        if ac1_tracks.empty or ac2_tracks.empty:
            return 0.0, 0.0
        
        # Get latest positions
        ac1 = ac1_tracks.iloc[-1]
        ac2 = ac2_tracks.iloc[-1]
        
        h_dist = self.detector.haversine_distance(
            ac1['latitude'], ac1['longitude'], 
            ac2['latitude'], ac2['longitude']
        )
        v_dist = abs(ac1.get('altitude', 0) - ac2.get('altitude', 0))
        
        return h_dist, v_dist
    
    def validate_clearance(self, flight_id_1: str, flight_id_2: str, clearance_type: str) -> bool:
        """Validate if clearance maintains safe separation"""
        h_dist, v_dist = self.get_separation(flight_id_1, flight_id_2)
        safe = h_dist > config.conflict_threshold_nm or v_dist > config.conflict_threshold_ft
        print(f"🔍 Clearance validation: H={h_dist:.1f}NM, V={v_dist:.0f}ft, Safe={safe}")
        return safe

def plot_trajectories(simulator):
    """Plot aircraft trajectories with proper error handling"""
    if 'tracks' not in simulator.state or simulator.state['tracks'].empty:
        print("⚠️ No tracks available for plotting")
        return
    
    tracks_df = simulator.state['tracks']
    
    plt.figure(figsize=(14, 10))
    flight_ids = tracks_df['flight_id'].unique()[:10]  # Limit for readability
    
    for flight_id in flight_ids:
        flight_tracks = tracks_df[tracks_df['flight_id'] == flight_id].sort_values('timestamp')
        if len(flight_tracks) > 1:
            plt.plot(flight_tracks['longitude'], flight_tracks['latitude'], 
                    marker='o', markersize=2, alpha=0.7, label=f'Flight {flight_id}')
    
    plt.xlabel('Longitude (°)')
    plt.ylabel('Latitude (°)')
    plt.title('Real Aircraft Trajectories from SCAT Data')
    plt.legend(bbox_to_anchor=(1.05, 1), loc='upper left')
    plt.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(os.path.join(config.output_path, 'real_trajectories.png'), dpi=300, bbox_inches='tight')
    plt.close()
    print("📈 Trajectory plot saved")

def plot_separation_over_time(simulator, pair: Tuple[str, str]):
    """Plot separation with fixed timestamp handling"""
    if 'tracks' not in simulator.state or simulator.state['tracks'].empty:
        print("⚠️ No tracks available for separation plotting")
        return
    
    tracks_df = simulator.state['tracks']
    ac1_tracks = tracks_df[tracks_df['flight_id'] == pair[0]].copy()
    ac2_tracks = tracks_df[tracks_df['flight_id'] == pair[1]].copy()
    
    if ac1_tracks.empty or ac2_tracks.empty:
        print(f"⚠️ No data for aircraft pair {pair}")
        return
    
    # Ensure timestamps are properly formatted
    ac1_tracks['timestamp'] = pd.to_datetime(ac1_tracks['timestamp'], errors='coerce')
    ac2_tracks['timestamp'] = pd.to_datetime(ac2_tracks['timestamp'], errors='coerce')
    
    ac1_tracks = ac1_tracks.dropna(subset=['timestamp']).sort_values('timestamp')
    ac2_tracks = ac2_tracks.dropna(subset=['timestamp']).sort_values('timestamp')
    
    if ac1_tracks.empty or ac2_tracks.empty:
        print(f"⚠️ No valid timestamps for pair {pair}")
        return
    
    # Create time-aligned data using merge_asof with proper datetime index
    try:
        merged = pd.merge_asof(
            ac1_tracks[['timestamp', 'latitude', 'longitude', 'altitude']],
            ac2_tracks[['timestamp', 'latitude', 'longitude', 'altitude']],
            on='timestamp',
            suffixes=('_1', '_2'),
            direction='nearest'
        )
        
        if merged.empty:
            print(f"⚠️ No aligned data for pair {pair}")
            return
        
        # Calculate separations
        detector = RealConflictDetector(config)
        merged['h_dist'] = merged.apply(
            lambda row: detector.haversine_distance(
                row['latitude_1'], row['longitude_1'],
                row['latitude_2'], row['longitude_2']
            ), axis=1
        )
        merged['v_dist'] = abs(merged['altitude_1'] - merged['altitude_2'])
        
        # Plot
        plt.figure(figsize=(14, 8))
        plt.subplot(2, 1, 1)
        plt.plot(merged['timestamp'], merged['h_dist'], 'b-', linewidth=2, label='Horizontal Distance')
        plt.axhline(y=config.conflict_threshold_nm, color='r', linestyle='--', 
                   label=f'Conflict Threshold ({config.conflict_threshold_nm} NM)')
        plt.ylabel('Distance (NM)')
        plt.title(f'Separation Analysis: {pair[0]} vs {pair[1]}')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.subplot(2, 1, 2)
        plt.plot(merged['timestamp'], merged['v_dist']/1000, 'g-', linewidth=2, label='Vertical Distance')
        plt.axhline(y=config.conflict_threshold_ft/1000, color='r', linestyle='--', 
                   label=f'Conflict Threshold ({config.conflict_threshold_ft/1000:.1f} kft)')
        plt.xlabel('Time')
        plt.ylabel('Distance (kft)')
        plt.legend()
        plt.grid(True, alpha=0.3)
        
        plt.tight_layout()
        plt.savefig(os.path.join(config.output_path, f'separation_{pair[0]}_{pair[1]}.png'), 
                   dpi=300, bbox_inches='tight')
        plt.close()
        print(f"📊 Separation plot saved for {pair}")
        
    except Exception as e:
        print(f"❌ Error plotting separation for {pair}: {e}")

In [None]:
# Main Research Pipeline
def main():
    """Main pipeline for thesis research on ML hallucination in ATC"""
    print("🎓 Starting ML Hallucination Research Pipeline")
    print("=" * 60)
    
    # Load real SCAT data
    print("📦 Loading SCAT Data...")
    data_loader = SCATDataLoader(config)
    flights_df, clearances_df, tracks_df = data_loader.load_all_weeks(max_weeks=3)
    
    if tracks_df.empty:
        print("❌ No track data loaded. Cannot proceed with analysis.")
        return
    
    # Detect real conflicts only
    print("🔍 Detecting Real Conflicts...")
    detector = RealConflictDetector(config)
    scenarios_df = detector.detect_real_conflicts(tracks_df, clearances_df)
    
    if scenarios_df.empty:
        print("❌ No scenarios generated. Cannot proceed with training.")
        return
    
    print(f"✅ Generated {len(scenarios_df)} real scenarios ({scenarios_df['conflict'].sum()} conflicts)")
    
    # Prepare ML pipeline
    print("🤖 Initializing ML Pipeline...")
    tokenizer = DistilBertTokenizer.from_pretrained(config.model_name)
    
    # Create dataset
    dataset = HallucinationATCDataset(scenarios_df, tokenizer, config.max_sequence_length)
    
    # Split data
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
    
    # Create data loaders
    train_loader = DataLoader(train_dataset, batch_size=config.batch_size, shuffle=True)
    val_loader = DataLoader(val_dataset, batch_size=config.batch_size)
    
    # Initialize model
    num_clearance_types = scenarios_df['clearance_type'].nunique()
    model = HallucinationAwareATCModel(config.model_name, num_clearance_types)
    
    # Train model
    print("🎯 Training Hallucination-Aware Model...")
    trainer = HallucinationAwareTrainer(model, train_loader, val_loader, config)
    trainer.train()
    
    # Save model
    model_path = os.path.join(config.model_save_path, 'hallucination_aware_atc_model.pt')
    torch.save(model.state_dict(), model_path)
    print(f"💾 Model saved to {model_path}")
    
    # Initialize simulator and test
    print("✈️ Testing with BlueSky Simulator...")
    simulator = EnhancedBlueSkySimulator()
    simulator.update_state(tracks_df)
    
    # Test validation on real conflicts
    if not scenarios_df[scenarios_df['conflict']].empty:
        sample_conflict = scenarios_df[scenarios_df['conflict']].iloc[0]
        flight_pair = (sample_conflict['flight_id_1'], sample_conflict['flight_id_2'])
        
        clearance_valid = simulator.validate_clearance(
            flight_pair[0], flight_pair[1], sample_conflict['clearance_type']
        )
        print(f"✅ Clearance validation result: {clearance_valid}")
        
        # Generate visualizations
        print("📊 Generating Visualizations...")
        plot_trajectories(simulator)
        plot_separation_over_time(simulator, flight_pair)
    
    print("🎓 Research pipeline completed successfully!")
    print(f"📁 Results saved to: {config.output_path}")

if __name__ == "__main__":
    main()

🎓 Starting ML Hallucination Research Pipeline
📦 Loading SCAT Data...
ℹ️ Found 13 SCAT data sources


📦 Loading weeks:   0%|          | 0/3 [00:00<?, ?it/s]

Scanning scat20161015_20161021:   0%|          | 0/13140 [00:00<?, ?it/s]

Scanning scat20161112_20161118:   0%|          | 0/12250 [00:00<?, ?it/s]