In [1]:
# Enhanced ML Model Testing Framework for ATC Conflict Detection and Analysis
# Focus on identifying when and why the model fails to detect/resolve conflicts
BLUESKY_AVAILABLE = True
try:
    import bluesky
    BLUESKY_AVAILABLE = True
    print("✅ BlueSky available")
except ImportError:
    !pip -qq install bluesky
import os
import gc
import json
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
from collections import defaultdict
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import (
    confusion_matrix, classification_report, f1_score, 
    precision_recall_curve, roc_curve, auc
)
from sklearn.calibration import calibration_curve 
from sklearn.preprocessing import StandardScaler, LabelEncoder
from transformers import DistilBertTokenizer, DistilBertModel

# Suppress warnings
warnings.filterwarnings('ignore')

# Enhanced configuration with proper conflict parameters
@dataclass
class EnhancedTestConfig:
    """Enhanced testing configuration with realistic conflict parameters"""
    model_path: str = "/kaggle/input/mt-100625-scat-4/thesis_models/streaming_hallucination_aware_atc_model.pt"
    model_name: str = "distilbert-base-uncased"
    max_sequence_length: int = 128
    batch_size: int = 8
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Realistic ATC conflict thresholds (ICAO standards)
    horizontal_separation_min: float = 5.0  # NM
    vertical_separation_min: float = 1000.0  # ft
    
    # Conflict severity levels
    imminent_conflict_horizontal: float = 2.0  # NM
    imminent_conflict_vertical: float = 500.0  # ft
    
    # Training envelope bounds (expanded for better testing)
    altitude_min: float = 5000   # ft (surface to space)
    altitude_max: float = 60000  # ft
    speed_min: float = 100       # kt (very slow to supersonic)
    speed_max: float = 800       # kt
    
    # Test parameters
    mc_dropout_samples: int = 50  # Increased for better uncertainty estimates
    confidence_threshold: float = 0.7
    uncertainty_threshold: float = 0.2
    
    # Scenario generation parameters
    n_normal_scenarios: int = 200
    n_conflict_scenarios: int = 150
    n_edge_cases: int = 100
    n_stress_tests: int = 50
    
    output_path: str = "./enhanced_model_analysis"

config = EnhancedTestConfig()
os.makedirs(config.output_path, exist_ok=True)

print(f"🔧 Enhanced ATC Model Testing Configuration")
print(f"📁 Model path: {config.model_path}")
print(f"🖥️  Device: {config.device}")
print(f"⚠️  Conflict thresholds: {config.horizontal_separation_min}NM / {config.vertical_separation_min}ft")

# Enhanced model architecture (same as before but with improved comments)
class HallucinationAwareATCModel(nn.Module):
    """Enhanced ATC model with hallucination detection"""
    
    def __init__(self, model_name: str, num_clearance_types: int = 5, dropout_rate: float = 0.1):
        super().__init__()
        
        self.bert = DistilBertModel.from_pretrained(model_name)
        hidden_size = self.bert.config.hidden_size
        
        # Conflict detection head (binary classification: safe/conflict)
        self.conflict_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 2, 2)  # [safe, conflict]
        )
        
        # Resolution recommendation head
        self.resolution_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 2, num_clearance_types)
        )
        
        # Hallucination detection head
        self.hallucination_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 4, 1)  # Sigmoid output for binary classification
        )
        
        self.dropout_rate = dropout_rate
    
    def forward(self, input_ids, attention_mask, enable_dropout=False):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]
        
        if enable_dropout:
            pooled_output = F.dropout(pooled_output, p=self.dropout_rate, training=True)
        
        conflict_logits = self.conflict_head(pooled_output)
        resolution_logits = self.resolution_head(pooled_output)
        hallucination_logits = self.hallucination_head(pooled_output)
        
        return conflict_logits, resolution_logits, hallucination_logits
    
    def get_uncertainty_estimates(self, input_ids, attention_mask, n_samples=50):
        """Enhanced uncertainty estimation using Monte Carlo dropout"""
        self.train()  # Enable dropout
        
        conflict_samples = []
        resolution_samples = []
        hallucination_samples = []
        
        with torch.no_grad():
            for _ in range(n_samples):
                conflict_logits, resolution_logits, hallucination_logits = self.forward(
                    input_ids, attention_mask, enable_dropout=True
                )
                conflict_samples.append(F.softmax(conflict_logits, dim=-1))
                resolution_samples.append(F.softmax(resolution_logits, dim=-1))
                hallucination_samples.append(torch.sigmoid(hallucination_logits))
        
        # Calculate comprehensive uncertainty metrics
        conflict_probs = torch.stack(conflict_samples)
        resolution_probs = torch.stack(resolution_samples)
        hallucination_probs = torch.stack(hallucination_samples)
        
        # Epistemic uncertainty (model uncertainty)
        conflict_mean = conflict_probs.mean(dim=0)
        conflict_var = conflict_probs.var(dim=0)
        
        # Mutual information (uncertainty about the model)
        conflict_entropy = -torch.sum(conflict_mean * torch.log(conflict_mean + 1e-8), dim=-1)
        
        results = {
            'conflict_mean': conflict_mean,
            'conflict_std': conflict_probs.std(dim=0),
            'conflict_variance': conflict_var,
            'conflict_entropy': conflict_entropy,
            'epistemic_uncertainty': conflict_var.sum(dim=-1),
            'resolution_mean': resolution_probs.mean(dim=0),
            'resolution_std': resolution_probs.std(dim=0),
            'hallucination_mean': hallucination_probs.mean(dim=0),
            'hallucination_std': hallucination_probs.std(dim=0),
            'total_uncertainty': conflict_entropy + conflict_var.sum(dim=-1)
        }
        
        return results

# Enhanced conflict scenario generator with realistic physics
class RealisticATCScenarioGenerator:
    """Generate realistic ATC scenarios with proper conflict geometry"""
    
    def __init__(self, config: EnhancedTestConfig):
        self.config = config
        self.scenario_log = []
        
    def calculate_conflict_geometry(self, ac1_params: Dict, ac2_params: Dict) -> Dict:
        """Calculate realistic conflict geometry between two aircraft"""
        
        # Extract parameters
        lat1, lon1, alt1, hdg1, spd1 = ac1_params['lat'], ac1_params['lon'], ac1_params['alt'], ac1_params['hdg'], ac1_params['spd']
        lat2, lon2, alt2, hdg2, spd2 = ac2_params['lat'], ac2_params['lon'], ac2_params['alt'], ac2_params['hdg'], ac2_params['spd']
        
        # Current separation
        h_dist = self._haversine_distance(lat1, lon1, lat2, lon2)
        v_dist = abs(alt1 - alt2)
        
        # Velocity vectors (simplified to 2D)
        v1_x = spd1 * np.sin(np.radians(hdg1))  # East component
        v1_y = spd1 * np.cos(np.radians(hdg1))  # North component
        v2_x = spd2 * np.sin(np.radians(hdg2))
        v2_y = spd2 * np.cos(np.radians(hdg2))
        
        # Relative velocity
        rel_v_x = v1_x - v2_x
        rel_v_y = v1_y - v2_y
        rel_speed = np.sqrt(rel_v_x**2 + rel_v_y**2)
        
        # Time to closest point of approach (CPA)
        if rel_speed > 0.1:  # Avoid division by zero
            # Relative position
            rel_x = (lat1 - lat2) * 60  # Convert to NM approximately
            rel_y = (lon1 - lon2) * 60
            
            # Time to CPA
            t_cpa = -(rel_x * rel_v_x + rel_y * rel_v_y) / (rel_v_x**2 + rel_v_y**2)
            t_cpa = max(0, t_cpa)  # Can't have negative time
            
            # Distance at CPA
            cpa_x = rel_x + rel_v_x * t_cpa
            cpa_y = rel_y + rel_v_y * t_cpa
            cpa_distance = np.sqrt(cpa_x**2 + cpa_y**2)
        else:
            t_cpa = 0
            cpa_distance = h_dist
        
        # Determine conflict severity
        is_conflict = (h_dist < self.config.horizontal_separation_min and 
                      v_dist < self.config.vertical_separation_min)
        
        is_imminent = (cpa_distance < self.config.imminent_conflict_horizontal and 
                      v_dist < self.config.imminent_conflict_vertical and 
                      t_cpa < 5)  # Within 5 minutes
        
        # Conflict urgency level
        if is_imminent:
            urgency = "IMMINENT"
        elif is_conflict:
            urgency = "CURRENT"
        elif cpa_distance < self.config.horizontal_separation_min and t_cpa < 10:
            urgency = "PREDICTED"
        else:
            urgency = "SAFE"
        
        return {
            'current_h_separation': h_dist,
            'current_v_separation': v_dist,
            'time_to_cpa': t_cpa,
            'cpa_distance': cpa_distance,
            'relative_speed': rel_speed,
            'is_conflict': is_conflict,
            'is_imminent': is_imminent,
            'urgency_level': urgency,
            'conflict_probability': self._calculate_conflict_probability(h_dist, v_dist, cpa_distance, t_cpa)
        }
    
    def _haversine_distance(self, lat1: float, lon1: float, lat2: float, lon2: float) -> float:
        """Calculate haversine distance in nautical miles"""
        R = 3440.065  # Earth radius in nautical miles
        
        lat1, lon1, lat2, lon2 = map(np.radians, [lat1, lon1, lat2, lon2])
        dlat = lat2 - lat1
        dlon = lon2 - lon1
        
        a = np.sin(dlat/2)**2 + np.cos(lat1) * np.cos(lat2) * np.sin(dlon/2)**2
        c = 2 * np.arcsin(np.sqrt(a))
        
        return R * c
    
    def _calculate_conflict_probability(self, h_dist: float, v_dist: float, cpa_dist: float, t_cpa: float) -> float:
        """Calculate conflict probability based on geometry"""
        # Normalize factors
        h_factor = max(0, 1 - h_dist / self.config.horizontal_separation_min)
        v_factor = max(0, 1 - v_dist / self.config.vertical_separation_min)
        cpa_factor = max(0, 1 - cpa_dist / self.config.horizontal_separation_min)
        time_factor = max(0, 1 - t_cpa / 15) if t_cpa > 0 else 0  # 15 minute horizon
        
        # Combined probability
        prob = (h_factor * 0.3 + v_factor * 0.3 + cpa_factor * 0.3 + time_factor * 0.1)
        return min(1.0, prob)
    
    def generate_conflict_scenarios(self, n_scenarios: int = 150) -> pd.DataFrame:
        """Generate realistic conflict scenarios"""
        scenarios = []
        
        print(f"🚨 Generating {n_scenarios} CONFLICT scenarios...")
        
        for i in range(n_scenarios):
            # Create aircraft pairs with high conflict probability
            severity = np.random.choice(['imminent', 'current', 'predicted'], p=[0.3, 0.4, 0.3])
            
            if severity == 'imminent':
                # Very close conflicts (0-2 NM, 0-500 ft)
                h_separation = np.random.uniform(0.1, 2.0)
                v_separation = np.random.uniform(0, 500)
                # Converging headings
                hdg_diff = np.random.uniform(120, 240)  # Head-on or crossing
            elif severity == 'current':
                # Current conflicts (2-5 NM, 500-1000 ft)
                h_separation = np.random.uniform(2.0, 4.9)
                v_separation = np.random.uniform(500, 1000)
                hdg_diff = np.random.uniform(90, 180)
            else:  # predicted
                # Predicted conflicts (5-8 NM but converging)
                h_separation = np.random.uniform(5.0, 8.0)
                v_separation = np.random.uniform(800, 1500)
                hdg_diff = np.random.uniform(45, 135)
            
            # Base aircraft parameters
            base_lat = 59.0 + np.random.uniform(-5, 5)
            base_lon = 18.0 + np.random.uniform(-5, 5)
            base_alt = np.random.uniform(25000, 40000)
            
            # Aircraft 1
            ac1_params = {
                'lat': base_lat,
                'lon': base_lon,
                'alt': base_alt,
                'hdg': np.random.uniform(0, 360),
                'spd': np.random.uniform(350, 550)
            }
            
            # Aircraft 2 positioned for conflict
            bearing = np.random.uniform(0, 360)
            lat_offset = h_separation * np.cos(np.radians(bearing)) / 60
            lon_offset = h_separation * np.sin(np.radians(bearing)) / 60
            
            ac2_params = {
                'lat': base_lat + lat_offset,
                'lon': base_lon + lon_offset,
                'alt': base_alt + (v_separation if np.random.random() > 0.5 else -v_separation),
                'hdg': (ac1_params['hdg'] + hdg_diff) % 360,
                'spd': np.random.uniform(350, 550)
            }
            
            # Calculate conflict geometry
            conflict_info = self.calculate_conflict_geometry(ac1_params, ac2_params)
            
            # Determine appropriate clearance
            clearance_options = ['altitude_change', 'heading_change', 'speed_change']
            if conflict_info['urgency_level'] == 'IMMINENT':
                clearance = np.random.choice(['heading_change', 'speed_change'])  # Immediate actions
            elif conflict_info['current_v_separation'] < 500:
                clearance = 'altitude_change'  # Vertical separation
            else:
                clearance = np.random.choice(clearance_options)
            
            # Check if outside training envelope
            outside_envelope = (
                ac1_params['alt'] < self.config.altitude_min or ac1_params['alt'] > self.config.altitude_max or
                ac2_params['alt'] < self.config.altitude_min or ac2_params['alt'] > self.config.altitude_max or
                ac1_params['spd'] < self.config.speed_min or ac1_params['spd'] > self.config.speed_max or
                ac2_params['spd'] < self.config.speed_min or ac2_params['spd'] > self.config.speed_max
            )
            
            scenario = {
                'scenario_id': f'conflict_{i:03d}',
                'scenario_type': 'conflict',
                'severity': severity,
                'altitude_1': ac1_params['alt'],
                'altitude_2': ac2_params['alt'],
                'speed_1': ac1_params['spd'],
                'speed_2': ac2_params['spd'],
                'heading_1': ac1_params['hdg'],
                'heading_2': ac2_params['hdg'],
                'latitude_1': ac1_params['lat'],
                'longitude_1': ac1_params['lon'],
                'latitude_2': ac2_params['lat'],
                'longitude_2': ac2_params['lon'],
                'horizontal_distance': conflict_info['current_h_separation'],
                'vertical_distance': conflict_info['current_v_separation'],
                'time_to_cpa': conflict_info['time_to_cpa'],
                'cpa_distance': conflict_info['cpa_distance'],
                'relative_speed': conflict_info['relative_speed'],
                'conflict': True,  # All these scenarios are conflicts
                'urgency_level': conflict_info['urgency_level'],
                'conflict_probability': conflict_info['conflict_probability'],
                'clearance_type': clearance,
                'outside_envelope': outside_envelope
            }
            scenarios.append(scenario)
        
        return pd.DataFrame(scenarios)
    
    def generate_safe_scenarios(self, n_scenarios: int = 200) -> pd.DataFrame:
        """Generate safe scenarios with adequate separation"""
        scenarios = []
        
        print(f"✅ Generating {n_scenarios} SAFE scenarios...")
        
        for i in range(n_scenarios):
            # Ensure safe separations
            safety_margin = np.random.choice(['comfortable', 'adequate', 'minimum'], p=[0.5, 0.3, 0.2])
            
            if safety_margin == 'comfortable':
                h_separation = np.random.uniform(15, 50)
                v_separation = np.random.uniform(3000, 8000)
            elif safety_margin == 'adequate':
                h_separation = np.random.uniform(8, 15)
                v_separation = np.random.uniform(2000, 3000)
            else:  # minimum
                h_separation = np.random.uniform(5.1, 8)
                v_separation = np.random.uniform(1100, 2000)
            
            # Base aircraft parameters
            base_lat = 59.0 + np.random.uniform(-5, 5)
            base_lon = 18.0 + np.random.uniform(-5, 5)
            base_alt = np.random.uniform(20000, 45000)
            
            # Aircraft 1
            ac1_params = {
                'lat': base_lat,
                'lon': base_lon,
                'alt': base_alt,
                'hdg': np.random.uniform(0, 360),
                'spd': np.random.uniform(300, 600)
            }
            
            # Aircraft 2 positioned safely
            bearing = np.random.uniform(0, 360)
            lat_offset = h_separation * np.cos(np.radians(bearing)) / 60
            lon_offset = h_separation * np.sin(np.radians(bearing)) / 60
            
            ac2_params = {
                'lat': base_lat + lat_offset,
                'lon': base_lon + lon_offset,
                'alt': base_alt + (v_separation if np.random.random() > 0.5 else -v_separation),
                'hdg': np.random.uniform(0, 360),  # Random heading (non-converging)
                'spd': np.random.uniform(300, 600)
            }
            
            # Calculate geometry
            conflict_info = self.calculate_conflict_geometry(ac1_params, ac2_params)
            
            # Check if outside training envelope
            outside_envelope = (
                ac1_params['alt'] < self.config.altitude_min or ac1_params['alt'] > self.config.altitude_max or
                ac2_params['alt'] < self.config.altitude_min or ac2_params['alt'] > self.config.altitude_max or
                ac1_params['spd'] < self.config.speed_min or ac1_params['spd'] > self.config.speed_max or
                ac2_params['spd'] < self.config.speed_min or ac2_params['spd'] > self.config.speed_max
            )
            
            scenario = {
                'scenario_id': f'safe_{i:03d}',
                'scenario_type': 'safe',
                'severity': 'none',
                'altitude_1': ac1_params['alt'],
                'altitude_2': ac2_params['alt'],
                'speed_1': ac1_params['spd'],
                'speed_2': ac2_params['spd'],
                'heading_1': ac1_params['hdg'],
                'heading_2': ac2_params['hdg'],
                'latitude_1': ac1_params['lat'],
                'longitude_1': ac1_params['lon'],
                'latitude_2': ac2_params['lat'],
                'longitude_2': ac2_params['lon'],
                'horizontal_distance': conflict_info['current_h_separation'],
                'vertical_distance': conflict_info['current_v_separation'],
                'time_to_cpa': conflict_info['time_to_cpa'],
                'cpa_distance': conflict_info['cpa_distance'],
                'relative_speed': conflict_info['relative_speed'],
                'conflict': False,
                'urgency_level': 'SAFE',
                'conflict_probability': conflict_info['conflict_probability'],
                'clearance_type': 'none',
                'outside_envelope': outside_envelope
            }
            scenarios.append(scenario)
        
        return pd.DataFrame(scenarios)
    
    def generate_edge_cases(self, n_scenarios: int = 100) -> pd.DataFrame:
        """Generate edge case scenarios to test model limits"""
        scenarios = []
        
        print(f"⚠️ Generating {n_scenarios} EDGE CASE scenarios...")
        
        edge_types = ['extreme_altitude', 'extreme_speed', 'extreme_geometry', 'boundary_conditions']
        
        for i in range(n_scenarios):
            edge_type = np.random.choice(edge_types)
            
            if edge_type == 'extreme_altitude':
                # Test altitude extremes
                alt1 = np.random.choice([
                    np.random.uniform(1000, 4999),    # Very low
                    np.random.uniform(50001, 70000)   # Very high
                ])
                alt2 = np.random.uniform(20000, 40000)  # Normal
                spd1 = np.random.uniform(350, 550)
                spd2 = np.random.uniform(350, 550)
                
            elif edge_type == 'extreme_speed':
                # Test speed extremes
                alt1 = np.random.uniform(25000, 40000)
                alt2 = np.random.uniform(25000, 40000)
                spd1 = np.random.choice([
                    np.random.uniform(50, 150),       # Very slow
                    np.random.uniform(700, 900)       # Very fast
                ])
                spd2 = np.random.uniform(350, 550)
                
            elif edge_type == 'extreme_geometry':
                # Test geometric extremes
                alt1 = np.random.uniform(25000, 40000)
                alt2 = np.random.uniform(25000, 40000)
                spd1 = np.random.uniform(350, 550)
                spd2 = np.random.uniform(350, 550)
                
            else:  # boundary_conditions
                # Test exactly at boundaries
                alt1 = np.random.choice([self.config.altitude_min, self.config.altitude_max])
                alt2 = np.random.uniform(25000, 40000)
                spd1 = np.random.choice([self.config.speed_min, self.config.speed_max])
                spd2 = np.random.uniform(350, 550)
            
            # Random positioning
            base_lat = 59.0 + np.random.uniform(-10, 10)
            base_lon = 18.0 + np.random.uniform(-10, 10)
            
            # Create scenario with edge parameters
            ac1_params = {
                'lat': base_lat,
                'lon': base_lon,
                'alt': alt1,
                'hdg': np.random.uniform(0, 360),
                'spd': spd1
            }
            
            if edge_type == 'extreme_geometry':
                # Extreme separations
                h_separation = np.random.choice([
                    np.random.uniform(0.01, 0.5),     # Collision course
                    np.random.uniform(100, 200)       # Very far
                ])
                v_separation = np.random.choice([
                    np.random.uniform(0, 50),         # Same level
                    np.random.uniform(20000, 40000)   # Very different levels
                ])
            else:
                # Normal separations for other edge cases
                h_separation = np.random.uniform(3, 10)
                v_separation = np.random.uniform(500, 3000)
            
            bearing = np.random.uniform(0, 360)
            lat_offset = h_separation * np.cos(np.radians(bearing)) / 60
            lon_offset = h_separation * np.sin(np.radians(bearing)) / 60
            
            ac2_params = {
                'lat': base_lat + lat_offset,
                'lon': base_lon + lon_offset,
                'alt': alt2 + (v_separation if np.random.random() > 0.5 else -v_separation),
                'hdg': np.random.uniform(0, 360),
                'spd': spd2
            }
            
            # Calculate geometry
            conflict_info = self.calculate_conflict_geometry(ac1_params, ac2_params)
            
            # These are likely outside training envelope
            outside_envelope = True
            
            scenario = {
                'scenario_id': f'edge_{i:03d}',
                'scenario_type': edge_type,
                'severity': 'edge_case',
                'altitude_1': ac1_params['alt'],
                'altitude_2': ac2_params['alt'],
                'speed_1': ac1_params['spd'],
                'speed_2': ac2_params['spd'],
                'heading_1': ac1_params['hdg'],
                'heading_2': ac2_params['hdg'],
                'latitude_1': ac1_params['lat'],
                'longitude_1': ac1_params['lon'],
                'latitude_2': ac2_params['lat'],
                'longitude_2': ac2_params['lon'],
                'horizontal_distance': conflict_info['current_h_separation'],
                'vertical_distance': conflict_info['current_v_separation'],
                'time_to_cpa': conflict_info['time_to_cpa'],
                'cpa_distance': conflict_info['cpa_distance'],
                'relative_speed': conflict_info['relative_speed'],
                'conflict': conflict_info['is_conflict'],
                'urgency_level': conflict_info['urgency_level'],
                'conflict_probability': conflict_info['conflict_probability'],
                'clearance_type': np.random.choice(['altitude_change', 'heading_change', 'speed_change']),
                'outside_envelope': outside_envelope
            }
            scenarios.append(scenario)
        
        return pd.DataFrame(scenarios)

# Enhanced dataset with realistic scenario descriptions
class EnhancedATCDataset(Dataset):
    """Enhanced dataset with realistic ATC scenario descriptions"""
    
    def __init__(self, scenarios_df: pd.DataFrame, tokenizer, max_length: int = 128):
        self.scenarios_df = scenarios_df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Enhanced clearance type encoding
        clearance_types = ['none', 'altitude_change', 'heading_change', 'speed_change', 'emergency_descent']
        self.clearance_encoder = LabelEncoder()
        self.clearance_encoder.fit(clearance_types)
        
        # Handle missing clearance types
        self.scenarios_df['clearance_type'] = self.scenarios_df['clearance_type'].fillna('none')
        
        # Map unknown clearance types to 'none'
        valid_clearances = set(clearance_types)
        self.scenarios_df['clearance_type'] = self.scenarios_df['clearance_type'].apply(
            lambda x: x if x in valid_clearances else 'none'
        )
        
        self.scenarios_df['clearance_label'] = self.clearance_encoder.transform(self.scenarios_df['clearance_type'])
        
        print(f"📊 Enhanced Dataset Created:")
        print(f"   Total scenarios: {len(self.scenarios_df)}")
        print(f"   Conflicts: {self.scenarios_df['conflict'].sum()}")
        print(f"   Safe scenarios: {(~self.scenarios_df['conflict']).sum()}")
        print(f"   Outside envelope: {self.scenarios_df['outside_envelope'].sum()}")
        
        # Print scenario type breakdown
        scenario_breakdown = self.scenarios_df['scenario_type'].value_counts()
        for scenario_type, count in scenario_breakdown.items():
            print(f"   {scenario_type}: {count}")
    
    def __len__(self):
        return len(self.scenarios_df)
    
    def __getitem__(self, idx):
        row = self.scenarios_df.iloc[idx]
        
        # Create enhanced scenario description with more context
        text = self._create_enhanced_scenario_text(row)
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'conflict_label': torch.tensor(1 if row['conflict'] else 0, dtype=torch.long),
            'clearance_label': torch.tensor(row['clearance_label'], dtype=torch.long),
            'envelope_violation': torch.tensor(float(row['outside_envelope']), dtype=torch.float),
            'scenario_id': row['scenario_id'],
            'scenario_type': row['scenario_type'],
            'urgency_level': row['urgency_level'],
            'conflict_probability': torch.tensor(row['conflict_probability'], dtype=torch.float),
            
            # Additional metadata for analysis
            'horizontal_distance': torch.tensor(row['horizontal_distance'], dtype=torch.float),
            'vertical_distance': torch.tensor(row['vertical_distance'], dtype=torch.float),
            'time_to_cpa': torch.tensor(row.get('time_to_cpa', 0), dtype=torch.float),
            'cpa_distance': torch.tensor(row.get('cpa_distance', 0), dtype=torch.float)
        }
    
    def _create_enhanced_scenario_text(self, row):
        """Create detailed scenario description with context"""
        # Basic aircraft information
        text_parts = [
            f"Aircraft A: FL{int(row['altitude_1']/100):03d}, "
            f"heading {int(row['heading_1']):03d}°, "
            f"speed {int(row['speed_1']):03d} knots.",
            
            f"Aircraft B: FL{int(row['altitude_2']/100):03d}, "
            f"heading {int(row['heading_2']):03d}°, "
            f"speed {int(row['speed_2']):03d} knots.",
            
            f"Current separation: {row['horizontal_distance']:.1f} nautical miles horizontal, "
            f"{row['vertical_distance']:.0f} feet vertical."
        ]
        
        # Add time-critical information if available
        if 'time_to_cpa' in row and row['time_to_cpa'] > 0:
            text_parts.append(f"Time to closest approach: {row['time_to_cpa']:.1f} minutes.")
        
        if 'cpa_distance' in row:
            text_parts.append(f"Closest approach distance: {row['cpa_distance']:.1f} nautical miles.")
        
        # Add urgency context
        if hasattr(row, 'urgency_level'):
            urgency_text = {
                'IMMINENT': 'URGENT: Immediate conflict resolution required.',
                'CURRENT': 'CONFLICT: Separation standards violated.',
                'PREDICTED': 'CAUTION: Potential conflict developing.',
                'SAFE': 'NORMAL: Aircraft maintaining safe separation.'
            }
            text_parts.append(urgency_text.get(row['urgency_level'], ''))
        
        return ' '.join(text_parts)

# Enhanced model evaluator with failure analysis
class ComprehensiveFailureAnalyzer:
    """Comprehensive analysis of model failures and edge cases"""
    
    def __init__(self, model_path: str, config: EnhancedTestConfig):
        self.config = config
        self.model = self._load_model(model_path)
        self.tokenizer = DistilBertTokenizer.from_pretrained(config.model_name)
        self.results = {}
        self.failure_cases = []
        
    def _load_model(self, model_path: str):
        """Load the trained model"""
        try:
            print(f"🔄 Loading model from {model_path}")
            
            # Initialize model
            model = HallucinationAwareATCModel(
                model_name=self.config.model_name,
                num_clearance_types=5  # Enhanced clearance types
            )
            
            # Load state dict
            state_dict = torch.load(model_path, map_location=self.config.device)
            model.load_state_dict(state_dict)
            model.to(self.config.device)
            model.eval()
            
            print("✅ Model loaded successfully")
            return model
            
        except Exception as e:
            print(f"❌ Error loading model: {e}")
            print("💡 Attempting to load with compatibility mode...")
            
            # Try with fewer clearance types if the model was trained differently
            try:
                model = HallucinationAwareATCModel(
                    model_name=self.config.model_name,
                    num_clearance_types=2  # Fallback
                )
                state_dict = torch.load(model_path, map_location=self.config.device)
                model.load_state_dict(state_dict)
                model.to(self.config.device)
                model.eval()
                print("✅ Model loaded in compatibility mode")
                return model
            except Exception as e2:
                print(f"❌ Failed to load model: {e2}")
                raise
    
    def analyze_comprehensive_failures(self, test_scenarios: pd.DataFrame):
        """Comprehensive failure analysis"""
        print("🔍 Starting Comprehensive Failure Analysis")
        print("=" * 60)
        
        # Create dataset
        test_dataset = EnhancedATCDataset(test_scenarios, self.tokenizer, self.config.max_sequence_length)
        test_loader = DataLoader(test_dataset, batch_size=self.config.batch_size, shuffle=False)
        
        print("1️⃣ Collecting Model Predictions...")
        predictions = self._collect_all_predictions(test_loader)
        
        print("2️⃣ Analyzing Failure Patterns...")
        failure_analysis = self._analyze_failure_patterns(predictions, test_scenarios)
        
        print("3️⃣ Analyzing Edge Case Performance...")
        edge_case_analysis = self._analyze_edge_case_performance(predictions, test_scenarios)
        
        print("4️⃣ Analyzing Uncertainty Patterns...")
        uncertainty_analysis = self._analyze_uncertainty_patterns(predictions, test_scenarios)
        
        print("5️⃣ Analyzing Conflict Severity Impact...")
        severity_analysis = self._analyze_conflict_severity_impact(predictions, test_scenarios)
        
        print("6️⃣ Generating Failure Report...")
        self._generate_failure_report(failure_analysis, edge_case_analysis, uncertainty_analysis, severity_analysis)
        
        # Store results
        self.results = {
            'predictions': predictions,
            'failure_analysis': failure_analysis,
            'edge_case_analysis': edge_case_analysis,
            'uncertainty_analysis': uncertainty_analysis,
            'severity_analysis': severity_analysis
        }
        
        print("✅ Comprehensive failure analysis completed!")
        return self.results
    
    def _collect_all_predictions(self, test_loader):
        """Collect all model predictions with detailed metadata"""
        all_predictions = []
        
        self.model.eval()
        
        for batch_idx, batch in enumerate(test_loader):
            input_ids = batch['input_ids'].to(self.config.device)
            attention_mask = batch['attention_mask'].to(self.config.device)
            
            # Standard prediction
            with torch.no_grad():
                conflict_logits, clearance_logits, hallucination_logits = self.model(input_ids, attention_mask)
                
                conflict_probs = F.softmax(conflict_logits, dim=-1)
                clearance_probs = F.softmax(clearance_logits, dim=-1)
                hallucination_probs = torch.sigmoid(hallucination_logits.squeeze(-1))
                
                conflict_predictions = torch.argmax(conflict_logits, dim=-1)
                clearance_predictions = torch.argmax(clearance_logits, dim=-1)
                
                # Get uncertainty estimates
                uncertainty_estimates = self.model.get_uncertainty_estimates(
                    input_ids, attention_mask, n_samples=self.config.mc_dropout_samples
                )
            
            # Collect batch results
            batch_size = len(batch['scenario_id'])
            for i in range(batch_size):
                prediction_data = {
                    # Basic info
                    'scenario_id': batch['scenario_id'][i],
                    'scenario_type': batch['scenario_type'][i],
                    'urgency_level': batch['urgency_level'][i],
                    
                    # Ground truth
                    'true_conflict': bool(batch['conflict_label'][i].item()),
                    'true_clearance': batch['clearance_label'][i].item(),
                    'envelope_violation': bool(batch['envelope_violation'][i].item()),
                    'true_conflict_probability': batch['conflict_probability'][i].item(),
                    
                    # Geometric data
                    'horizontal_distance': batch['horizontal_distance'][i].item(),
                    'vertical_distance': batch['vertical_distance'][i].item(),
                    'time_to_cpa': batch['time_to_cpa'][i].item(),
                    'cpa_distance': batch['cpa_distance'][i].item(),
                    
                    # Model predictions
                    'pred_conflict': bool(conflict_predictions[i].item()),
                    'pred_clearance': clearance_predictions[i].item(),
                    'conflict_confidence': torch.max(conflict_probs[i]).item(),
                    'conflict_prob': conflict_probs[i, 1].item(),  # Probability of conflict
                    'hallucination_score': hallucination_probs[i].item(),
                    
                    # Uncertainty metrics
                    'epistemic_uncertainty': uncertainty_estimates['epistemic_uncertainty'][i].item(),
                    'total_uncertainty': uncertainty_estimates['total_uncertainty'][i].item(),
                    'conflict_entropy': uncertainty_estimates['conflict_entropy'][i].item(),
                    
                    # Performance metrics
                    'conflict_correct': bool(conflict_predictions[i].item()) == bool(batch['conflict_label'][i].item()),
                    'clearance_correct': clearance_predictions[i].item() == batch['clearance_label'][i].item(),
                }
                
                all_predictions.append(prediction_data)
        
        return all_predictions
    
    def _analyze_failure_patterns(self, predictions: List[Dict], scenarios: pd.DataFrame):
        """Analyze patterns in model failures"""
        
        # Separate different types of failures
        false_positives = [p for p in predictions if not p['true_conflict'] and p['pred_conflict']]
        false_negatives = [p for p in predictions if p['true_conflict'] and not p['pred_conflict']]
        correct_predictions = [p for p in predictions if p['conflict_correct']]
        
        analysis = {
            'total_predictions': len(predictions),
            'correct_predictions': len(correct_predictions),
            'false_positives': len(false_positives),
            'false_negatives': len(false_negatives),
            'accuracy': len(correct_predictions) / len(predictions),
            'false_positive_rate': len(false_positives) / len([p for p in predictions if not p['true_conflict']]),
            'false_negative_rate': len(false_negatives) / len([p for p in predictions if p['true_conflict']]),
        }
        
        # Analyze failure characteristics
        if false_positives:
            fp_analysis = {
                'avg_confidence': np.mean([fp['conflict_confidence'] for fp in false_positives]),
                'avg_uncertainty': np.mean([fp['total_uncertainty'] for fp in false_positives]),
                'avg_horizontal_distance': np.mean([fp['horizontal_distance'] for fp in false_positives]),
                'scenario_types': pd.Series([fp['scenario_type'] for fp in false_positives]).value_counts().to_dict()
            }
            analysis['false_positive_analysis'] = fp_analysis
            
            print(f"   🚨 FALSE POSITIVES: {len(false_positives)} cases")
            print(f"      Average confidence: {fp_analysis['avg_confidence']:.3f}")
            print(f"      Average separation: {fp_analysis['avg_horizontal_distance']:.1f} NM")
        
        if false_negatives:
            fn_analysis = {
                'avg_confidence': np.mean([fn['conflict_confidence'] for fn in false_negatives]),
                'avg_uncertainty': np.mean([fn['total_uncertainty'] for fn in false_negatives]),
                'avg_horizontal_distance': np.mean([fn['horizontal_distance'] for fn in false_negatives]),
                'urgency_levels': pd.Series([fn['urgency_level'] for fn in false_negatives]).value_counts().to_dict(),
                'scenario_types': pd.Series([fn['scenario_type'] for fn in false_negatives]).value_counts().to_dict()
            }
            analysis['false_negative_analysis'] = fn_analysis
            
            print(f"   🚨 FALSE NEGATIVES: {len(false_negatives)} cases")
            print(f"      Average confidence: {fn_analysis['avg_confidence']:.3f}")
            print(f"      Average separation: {fn_analysis['avg_horizontal_distance']:.1f} NM")
            print(f"      Urgency breakdown: {fn_analysis['urgency_levels']}")
        
        # Store failure cases for detailed analysis
        self.failure_cases = {
            'false_positives': false_positives,
            'false_negatives': false_negatives
        }
        
        return analysis
    
    def _analyze_edge_case_performance(self, predictions: List[Dict], scenarios: pd.DataFrame):
        """Analyze performance on edge cases"""
        
        # Group by scenario type
        scenario_performance = {}
        
        for scenario_type in scenarios['scenario_type'].unique():
            type_predictions = [p for p in predictions if p['scenario_type'] == scenario_type]
            
            if type_predictions:
                correct = [p for p in type_predictions if p['conflict_correct']]
                envelope_violations = [p for p in type_predictions if p['envelope_violation']]
                
                performance = {
                    'total_cases': len(type_predictions),
                    'correct_cases': len(correct),
                    'accuracy': len(correct) / len(type_predictions),
                    'avg_confidence': np.mean([p['conflict_confidence'] for p in type_predictions]),
                    'avg_uncertainty': np.mean([p['total_uncertainty'] for p in type_predictions]),
                    'envelope_violations': len(envelope_violations),
                    'avg_hallucination_score': np.mean([p['hallucination_score'] for p in type_predictions])
                }
                
                scenario_performance[scenario_type] = performance
                
                print(f"   📊 {scenario_type}: {performance['accuracy']:.1%} accuracy, "
                      f"{performance['avg_confidence']:.3f} confidence, "
                      f"{performance['envelope_violations']} envelope violations")
        
        return scenario_performance
    
    def _analyze_uncertainty_patterns(self, predictions: List[Dict], scenarios: pd.DataFrame):
        """Analyze uncertainty patterns in predictions"""
        
        # Group by correctness
        correct_preds = [p for p in predictions if p['conflict_correct']]
        incorrect_preds = [p for p in predictions if not p['conflict_correct']]
        
        # Group by envelope status
        inside_envelope = [p for p in predictions if not p['envelope_violation']]
        outside_envelope = [p for p in predictions if p['envelope_violation']]
        
        analysis = {
            'correct_predictions': {
                'avg_uncertainty': np.mean([p['total_uncertainty'] for p in correct_preds]),
                'avg_confidence': np.mean([p['conflict_confidence'] for p in correct_preds]),
                'count': len(correct_preds)
            },
            'incorrect_predictions': {
                'avg_uncertainty': np.mean([p['total_uncertainty'] for p in incorrect_preds]),
                'avg_confidence': np.mean([p['conflict_confidence'] for p in incorrect_preds]),
                'count': len(incorrect_preds)
            },
            'inside_envelope': {
                'avg_uncertainty': np.mean([p['total_uncertainty'] for p in inside_envelope]),
                'avg_hallucination_score': np.mean([p['hallucination_score'] for p in inside_envelope]),
                'count': len(inside_envelope)
            },
            'outside_envelope': {
                'avg_uncertainty': np.mean([p['total_uncertainty'] for p in outside_envelope]),
                'avg_hallucination_score': np.mean([p['hallucination_score'] for p in outside_envelope]),
                'count': len(outside_envelope)
            }
        }
        
        print(f"   🎯 UNCERTAINTY ANALYSIS:")
        print(f"      Correct predictions: {analysis['correct_predictions']['avg_uncertainty']:.3f} uncertainty")
        print(f"      Incorrect predictions: {analysis['incorrect_predictions']['avg_uncertainty']:.3f} uncertainty")
        print(f"      Inside envelope: {analysis['inside_envelope']['avg_hallucination_score']:.3f} hallucination score")
        print(f"      Outside envelope: {analysis['outside_envelope']['avg_hallucination_score']:.3f} hallucination score")
        
        return analysis
    
    def _analyze_conflict_severity_impact(self, predictions: List[Dict], scenarios: pd.DataFrame):
        """Analyze how conflict severity affects model performance"""
        
        severity_performance = {}
        
        for urgency in ['SAFE', 'PREDICTED', 'CURRENT', 'IMMINENT']:
            urgency_predictions = [p for p in predictions if p['urgency_level'] == urgency]
            
            if urgency_predictions:
                correct = [p for p in urgency_predictions if p['conflict_correct']]
                
                performance = {
                    'total_cases': len(urgency_predictions),
                    'correct_cases': len(correct),
                    'accuracy': len(correct) / len(urgency_predictions),
                    'avg_confidence': np.mean([p['conflict_confidence'] for p in urgency_predictions]),
                    'avg_uncertainty': np.mean([p['total_uncertainty'] for p in urgency_predictions]),
                    'false_negatives': len([p for p in urgency_predictions if p['true_conflict'] and not p['pred_conflict']]),
                    'false_positives': len([p for p in urgency_predictions if not p['true_conflict'] and p['pred_conflict']])
                }
                
                severity_performance[urgency] = performance
                
                print(f"   ⚡ {urgency}: {performance['accuracy']:.1%} accuracy, "
                      f"FN: {performance['false_negatives']}, FP: {performance['false_positives']}")
        
        return severity_performance
    
    def _generate_failure_report(self, failure_analysis, edge_case_analysis, uncertainty_analysis, severity_analysis):
        """Generate comprehensive failure analysis report"""
        
        report_lines = [
            "# 🔍 COMPREHENSIVE MODEL FAILURE ANALYSIS REPORT",
            "=" * 70,
            f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}",
            "",
            "## 📊 OVERALL PERFORMANCE SUMMARY",
            f"- Total Predictions: {failure_analysis['total_predictions']}",
            f"- Overall Accuracy: {failure_analysis['accuracy']:.1%}",
            f"- False Positive Rate: {failure_analysis['false_positive_rate']:.1%}",
            f"- False Negative Rate: {failure_analysis['false_negative_rate']:.1%}",
            "",
            "## 🚨 CRITICAL FAILURE ANALYSIS",
        ]
        
        # False Negatives (Most Critical)
        if 'false_negative_analysis' in failure_analysis:
            fn_analysis = failure_analysis['false_negative_analysis']
            report_lines.extend([
                f"### ❌ FALSE NEGATIVES (Missed Conflicts): {failure_analysis['false_negatives']} cases",
                f"- Average Model Confidence: {fn_analysis['avg_confidence']:.3f}",
                f"- Average Separation: {fn_analysis['avg_horizontal_distance']:.1f} NM",
                f"- Urgency Level Breakdown:",
            ])
            for urgency, count in fn_analysis['urgency_levels'].items():
                report_lines.append(f"  - {urgency}: {count} cases")
            report_lines.append("")
        
        # False Positives
        if 'false_positive_analysis' in failure_analysis:
            fp_analysis = failure_analysis['false_positive_analysis']
            report_lines.extend([
                f"### ⚠️ FALSE POSITIVES (False Alarms): {failure_analysis['false_positives']} cases",
                f"- Average Model Confidence: {fp_analysis['avg_confidence']:.3f}",
                f"- Average Separation: {fp_analysis['avg_horizontal_distance']:.1f} NM",
                ""
            ])
        
        # Edge Case Performance
        report_lines.extend([
            "## 🎯 EDGE CASE PERFORMANCE",
        ])
        for scenario_type, performance in edge_case_analysis.items():
            report_lines.extend([
                f"### {scenario_type.upper()}",
                f"- Cases: {performance['total_cases']}",
                f"- Accuracy: {performance['accuracy']:.1%}",
                f"- Average Confidence: {performance['avg_confidence']:.3f}",
                f"- Envelope Violations: {performance['envelope_violations']}",
                f"- Hallucination Score: {performance['avg_hallucination_score']:.3f}",
                ""
            ])
        
        # Conflict Severity Analysis
        report_lines.extend([
            "## ⚡ CONFLICT SEVERITY IMPACT",
        ])
        for severity, performance in severity_analysis.items():
            report_lines.extend([
                f"### {severity}",
                f"- Accuracy: {performance['accuracy']:.1%} ({performance['correct_cases']}/{performance['total_cases']})",
                f"- False Negatives: {performance['false_negatives']}",
                f"- False Positives: {performance['false_positives']}",
                f"- Average Confidence: {performance['avg_confidence']:.3f}",
                ""
            ])
        
        # Critical Findings
        report_lines.extend([
            "## 🚨 CRITICAL FINDINGS FOR THESIS",
            "",
            "### Most Dangerous Failure Modes:",
        ])
        
        # Identify most dangerous failures
        if 'false_negative_analysis' in failure_analysis:
            fn_analysis = failure_analysis['false_negative_analysis']
            high_confidence_fn = [fn for fn in self.failure_cases['false_negatives'] 
                                if fn['conflict_confidence'] > 0.8]
            if high_confidence_fn:
                report_lines.append(f"- {len(high_confidence_fn)} HIGH-CONFIDENCE false negatives (model very sure but wrong)")
        
        # Hallucination detection effectiveness
        inside_score = uncertainty_analysis['inside_envelope']['avg_hallucination_score']
        outside_score = uncertainty_analysis['outside_envelope']['avg_hallucination_score']
        
        if outside_score <= inside_score:
            report_lines.append(f"- Hallucination detection NOT working (inside: {inside_score:.3f}, outside: {outside_score:.3f})")
        else:
            report_lines.append(f"- Hallucination detection working (inside: {inside_score:.3f}, outside: {outside_score:.3f})")
        
        report_lines.extend([
            "",
            "### Recommendations for Thesis:",
            "1. Focus on false negative reduction (safety-critical)",
            "2. Improve hallucination detection mechanism",
            "3. Add uncertainty-based safety margins",
            "4. Implement ensemble methods for critical decisions",
            ""
        ])
        
        # Save report
        with open(os.path.join(self.config.output_path, 'failure_analysis_report.md'), 'w') as f:
            f.write('\n'.join(report_lines))
        
        # Save detailed failure cases
        import json
        with open(os.path.join(self.config.output_path, 'failure_cases_detailed.json'), 'w') as f:
            json.dump(self.failure_cases, f, indent=2, default=str)

# Main execution function
def run_enhanced_conflict_analysis():
    """Run enhanced conflict detection analysis"""
    print("🎓 ENHANCED ATC MODEL TESTING FOR THESIS RESEARCH")
    print("🎯 Focus: Identifying Model Failures in Conflict Detection")
    print("=" * 80)
    
    # Generate realistic test scenarios
    print("📝 Generating Realistic Test Scenarios...")
    scenario_generator = RealisticATCScenarioGenerator(config)
    
    # Generate different scenario types
    conflict_scenarios = scenario_generator.generate_conflict_scenarios(config.n_conflict_scenarios)
    safe_scenarios = scenario_generator.generate_safe_scenarios(config.n_normal_scenarios)
    edge_scenarios = scenario_generator.generate_edge_cases(config.n_edge_cases)
    
    # Combine all scenarios
    all_scenarios = pd.concat([conflict_scenarios, safe_scenarios, edge_scenarios], ignore_index=True)
    
    print(f"✅ Generated {len(all_scenarios)} realistic test scenarios:")
    print(f"   🚨 Conflict scenarios: {len(conflict_scenarios)}")
    print(f"   ✅ Safe scenarios: {len(safe_scenarios)}")
    print(f"   ⚠️ Edge case scenarios: {len(edge_scenarios)}")
    
    # Print conflict severity breakdown
    if 'urgency_level' in all_scenarios.columns:
        urgency_counts = all_scenarios['urgency_level'].value_counts()
        print(f"   📊 Urgency breakdown:")
        for urgency, count in urgency_counts.items():
            print(f"      {urgency}: {count}")
    
    # Save scenarios for reference
    all_scenarios.to_csv(os.path.join(config.output_path, 'test_scenarios.csv'), index=False)
    
    # Initialize failure analyzer
    print("🤖 Loading Model for Failure Analysis...")
    try:
        analyzer = ComprehensiveFailureAnalyzer(config.model_path, config)
    except Exception as e:
        print(f"❌ Failed to load model: {e}")
        return None
    
    # Run comprehensive failure analysis
    print("🔍 Running Comprehensive Failure Analysis...")
    results = analyzer.analyze_comprehensive_failures(all_scenarios)
    
    # Summary of critical findings
    print("\n🚨 CRITICAL FINDINGS SUMMARY")
    print("=" * 50)
    
    failure_analysis = results['failure_analysis']
    print(f"📊 Overall Performance:")
    print(f"   - Accuracy: {failure_analysis['accuracy']:.1%}")
    print(f"   - False Negatives (Missed Conflicts): {failure_analysis['false_negatives']}")
    print(f"   - False Positives (False Alarms): {failure_analysis['false_positives']}")
    
    # Most critical: False negatives
    if failure_analysis['false_negatives'] > 0:
        print(f"\n🚨 SAFETY CONCERN: {failure_analysis['false_negatives']} missed conflicts!")
        print(f"   - False Negative Rate: {failure_analysis['false_negative_rate']:.1%}")
        
        if analyzer.failure_cases['false_negatives']:
            # Find the most dangerous false negatives
            dangerous_fn = [fn for fn in analyzer.failure_cases['false_negatives'] 
                          if fn['urgency_level'] in ['IMMINENT', 'CURRENT']]
            if dangerous_fn:
                print(f"   - {len(dangerous_fn)} missed CRITICAL conflicts")
                avg_confidence = np.mean([fn['conflict_confidence'] for fn in dangerous_fn])
                print(f"   - Model was {avg_confidence:.1%} confident (but wrong!)")
    
    # Edge case performance
    edge_results = results['edge_case_analysis']
    print(f"\n⚠️ EDGE CASE PERFORMANCE:")
    for scenario_type, performance in edge_results.items():
        if performance['accuracy'] < 0.7:  # Flag poor performance
            print(f"   🚨 {scenario_type}: {performance['accuracy']:.1%} accuracy (POOR)")
        else:
            print(f"   ✅ {scenario_type}: {performance['accuracy']:.1%} accuracy")
    
    # Hallucination detection effectiveness
    uncertainty_results = results['uncertainty_analysis']
    inside_score = uncertainty_results['inside_envelope']['avg_hallucination_score']
    outside_score = uncertainty_results['outside_envelope']['avg_hallucination_score']
    
    print(f"\n🧠 HALLUCINATION DETECTION:")
    if outside_score > inside_score + 0.1:  # Threshold for working detection
        print(f"   ✅ Working (Inside: {inside_score:.3f}, Outside: {outside_score:.3f})")
    else:
        print(f"   ❌ NOT working (Inside: {inside_score:.3f}, Outside: {outside_score:.3f})")
        print(f"   💡 Major thesis concern: Model can't detect out-of-envelope scenarios")
    
    print(f"\n📁 Detailed analysis saved to: {config.output_path}")
    print(f"📋 Failure report: failure_analysis_report.md")
    print(f"📊 Test scenarios: test_scenarios.csv")
    print(f"🔍 Detailed failure cases: failure_cases_detailed.json")
    
    print("\n🎓 THESIS IMPLICATIONS:")
    print("1. ✅ Comprehensive test framework created")
    print("2. 🚨 Critical failure modes identified")
    print("3. 📊 Edge case performance quantified")
    print("4. 🧠 Hallucination detection evaluated")
    print("5. 📝 Detailed analysis for thesis defense ready")
    
    return results

# Execute the enhanced analysis
if __name__ == "__main__":
    results = run_enhanced_conflict_analysis()

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m356.8/356.8 kB[0m [31m5.8 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.7/66.7 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cesium 0.12.4 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.[0m[31m
[0m

2025-06-11 15:25:17.091358: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749655517.381810      13 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749655517.469039      13 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


🔧 Enhanced ATC Model Testing Configuration
📁 Model path: /kaggle/input/mt-100625-scat-4/thesis_models/streaming_hallucination_aware_atc_model.pt
🖥️  Device: cpu
⚠️  Conflict thresholds: 5.0NM / 1000.0ft
🎓 ENHANCED ATC MODEL TESTING FOR THESIS RESEARCH
🎯 Focus: Identifying Model Failures in Conflict Detection
📝 Generating Realistic Test Scenarios...
🚨 Generating 150 CONFLICT scenarios...
✅ Generating 200 SAFE scenarios...
⚠️ Generating 100 EDGE CASE scenarios...
✅ Generated 450 realistic test scenarios:
   🚨 Conflict scenarios: 150
   ✅ Safe scenarios: 200
   ⚠️ Edge case scenarios: 100
   📊 Urgency breakdown:
      SAFE: 271
      PREDICTED: 70
      CURRENT: 67
      IMMINENT: 42
🤖 Loading Model for Failure Analysis...
🔄 Loading model from /kaggle/input/mt-100625-scat-4/thesis_models/streaming_hallucination_aware_atc_model.pt


config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

❌ Error loading model: Error(s) in loading state_dict for HallucinationAwareATCModel:
	size mismatch for resolution_head.3.weight: copying a param with shape torch.Size([2, 384]) from checkpoint, the shape in current model is torch.Size([5, 384]).
	size mismatch for resolution_head.3.bias: copying a param with shape torch.Size([2]) from checkpoint, the shape in current model is torch.Size([5]).
💡 Attempting to load with compatibility mode...
✅ Model loaded in compatibility mode


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

🔍 Running Comprehensive Failure Analysis...
🔍 Starting Comprehensive Failure Analysis
📊 Enhanced Dataset Created:
   Total scenarios: 450
   Conflicts: 154
   Safe scenarios: 296
   Outside envelope: 100
   safe: 200
   conflict: 150
   extreme_geometry: 29
   boundary_conditions: 26
   extreme_altitude: 23
   extreme_speed: 22
1️⃣ Collecting Model Predictions...
2️⃣ Analyzing Failure Patterns...
   🚨 FALSE NEGATIVES: 154 cases
      Average confidence: 0.991
      Average separation: 3.1 NM
      Urgency breakdown: {'CURRENT': 67, 'IMMINENT': 42, 'SAFE': 30, 'PREDICTED': 15}
3️⃣ Analyzing Edge Case Performance...
   📊 conflict: 0.0% accuracy, 0.991 confidence, 0 envelope violations
   📊 safe: 100.0% accuracy, 0.993 confidence, 0 envelope violations
   📊 extreme_altitude: 100.0% accuracy, 0.990 confidence, 23 envelope violations
   📊 boundary_conditions: 100.0% accuracy, 0.990 confidence, 26 envelope violations
   📊 extreme_geometry: 89.7% accuracy, 0.993 confidence, 29 envelope violat