In [1]:
# Comprehensive ML Model Testing Framework for ATC Hallucination Research
# End-to-End Model Evaluation and Validation Pipeline

# BlueSky integration
BLUESKY_AVAILABLE = True
try:
    import bluesky
    BLUESKY_AVAILABLE = True
    print("✅ BlueSky available")
except ImportError:
    !pip -qq install bluesky


import os
import gc
import json
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
from datetime import datetime, timedelta
from typing import Dict, List, Tuple, Optional, Union
from dataclasses import dataclass
from collections import defaultdict
import plotly.graph_objects as go
import plotly.express as px
from plotly.subplots import make_subplots

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from sklearn.metrics import (
    confusion_matrix, classification_report, f1_score, 
    precision_recall_curve, roc_curve, auc
)

from sklearn.calibration import calibration_curve 
from sklearn.preprocessing import StandardScaler, LabelEncoder
from transformers import DistilBertTokenizer, DistilBertModel

# Suppress warnings
warnings.filterwarnings('ignore')

# Configuration for testing
@dataclass
class TestConfig:
    """Testing configuration"""
    model_path: str = "/kaggle/input/mt-100625-scat-4/thesis_models/streaming_hallucination_aware_atc_model.pt"
    model_name: str = "distilbert-base-uncased"
    max_sequence_length: int = 128
    batch_size: int = 8
    device: str = "cuda" if torch.cuda.is_available() else "cpu"
    
    # Test parameters
    mc_dropout_samples: int = 20
    confidence_threshold: float = 0.8
    uncertainty_threshold: float = 0.3
    
    # Conflict thresholds
    conflict_threshold_nm: float = 5.0
    conflict_threshold_ft: float = 1000.0
    
    # Training envelope bounds (for hallucination testing)
    altitude_min: float = 10000  # ft
    altitude_max: float = 50000  # ft
    speed_min: float = 200  # kt
    speed_max: float = 600  # kt
    
    # Output paths
    output_path: str = "./model_test_results"

test_config = TestConfig()
os.makedirs(test_config.output_path, exist_ok=True)

print(f"🧪 Model Testing Configuration")
print(f"📁 Model path: {test_config.model_path}")
print(f"🖥️  Device: {test_config.device}")
print(f"📊 Output path: {test_config.output_path}")

# Recreate the model architecture
class FixedHallucinationAwareATCModel(nn.Module):
    """Recreate the exact model architecture for loading"""
    
    def __init__(self, model_name: str, num_clearance_types: int, dropout_rate: float = 0.1):
        super().__init__()
        
        self.bert = DistilBertModel.from_pretrained(model_name)
        hidden_size = self.bert.config.hidden_size
        
        # Main task heads
        self.conflict_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 2, 2)
        )
        
        self.resolution_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 2),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 2, num_clearance_types)
        )
        
        # Hallucination detection head
        self.hallucination_head = nn.Sequential(
            nn.Linear(hidden_size, hidden_size // 4),
            nn.ReLU(),
            nn.Dropout(dropout_rate),
            nn.Linear(hidden_size // 4, 1)
        )
        
        self.dropout_rate = dropout_rate
    
    def forward(self, input_ids, attention_mask, enable_dropout=False):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]
        
        if enable_dropout:
            pooled_output = F.dropout(pooled_output, p=self.dropout_rate, training=True)
        
        conflict_logits = self.conflict_head(pooled_output)
        resolution_logits = self.resolution_head(pooled_output)
        hallucination_logits = self.hallucination_head(pooled_output)
        
        return conflict_logits, resolution_logits, hallucination_logits
    
    def get_uncertainty_estimates(self, input_ids, attention_mask, n_samples=20):
        """Estimate model uncertainty using Monte Carlo dropout"""
        self.train()  # Enable dropout
        
        conflict_samples = []
        resolution_samples = []
        hallucination_samples = []
        
        with torch.no_grad():
            for _ in range(n_samples):
                conflict_logits, resolution_logits, hallucination_logits = self.forward(
                    input_ids, attention_mask, enable_dropout=True
                )
                conflict_samples.append(F.softmax(conflict_logits, dim=-1))
                resolution_samples.append(F.softmax(resolution_logits, dim=-1))
                hallucination_samples.append(torch.sigmoid(hallucination_logits))
        
        # Calculate statistics
        conflict_probs = torch.stack(conflict_samples)
        resolution_probs = torch.stack(resolution_samples)
        hallucination_probs = torch.stack(hallucination_samples)
        
        results = {
            'conflict_mean': conflict_probs.mean(dim=0),
            'conflict_std': conflict_probs.std(dim=0),
            'conflict_entropy': -torch.sum(conflict_probs.mean(dim=0) * torch.log(conflict_probs.mean(dim=0) + 1e-8), dim=-1),
            'resolution_mean': resolution_probs.mean(dim=0),
            'resolution_std': resolution_probs.std(dim=0),
            'hallucination_mean': hallucination_probs.mean(dim=0),
            'hallucination_std': hallucination_probs.std(dim=0)
        }
        
        return results

# Test scenario generator
class ATCTestScenarioGenerator:
    """Generate comprehensive test scenarios for model evaluation"""
    
    def __init__(self, config: TestConfig):
        self.config = config
        
    def generate_normal_scenarios(self, n_scenarios: int = 100) -> pd.DataFrame:
        """Generate scenarios within normal operational envelope"""
        scenarios = []
        
        for i in range(n_scenarios):
            # Normal operating parameters
            alt1 = np.random.uniform(20000, 40000)  # Normal cruise altitudes
            alt2 = np.random.uniform(20000, 40000)
            speed1 = np.random.uniform(350, 550)    # Normal cruise speeds
            speed2 = np.random.uniform(350, 550)
            heading1 = np.random.uniform(0, 360)
            heading2 = np.random.uniform(0, 360)
            
            # Various separation scenarios
            if i < n_scenarios // 3:  # Safe separations
                h_dist = np.random.uniform(10, 50)
                v_dist = np.random.uniform(2000, 5000)
                conflict = False
            elif i < 2 * n_scenarios // 3:  # Close but safe
                h_dist = np.random.uniform(6, 10)
                v_dist = np.random.uniform(1200, 2000)
                conflict = False
            else:  # Conflicts
                h_dist = np.random.uniform(1, 4.9)
                v_dist = np.random.uniform(100, 900)
                conflict = True
            
            clearance_type = np.random.choice(['none', 'altitude_change', 'heading_change', 'speed_change'])
            
            scenario = {
                'scenario_id': f'normal_{i:03d}',
                'scenario_type': 'normal',
                'altitude_1': alt1,
                'altitude_2': alt2,
                'speed_1': speed1,
                'speed_2': speed2,
                'heading_1': heading1,
                'heading_2': heading2,
                'horizontal_distance': h_dist,
                'vertical_distance': v_dist,
                'conflict': conflict,
                'clearance_type': clearance_type,
                'outside_envelope': False
            }
            scenarios.append(scenario)
        
        return pd.DataFrame(scenarios)
    
    def generate_edge_case_scenarios(self, n_scenarios: int = 50) -> pd.DataFrame:
        """Generate edge case scenarios for hallucination testing"""
        scenarios = []
        
        for i in range(n_scenarios):
            scenario_type = np.random.choice(['extreme_altitude', 'extreme_speed', 'extreme_separation'])
            
            if scenario_type == 'extreme_altitude':
                # Outside normal altitude envelope
                alt1 = np.random.choice([
                    np.random.uniform(5000, 9999),    # Too low
                    np.random.uniform(50001, 60000)   # Too high
                ])
                alt2 = np.random.uniform(20000, 40000)  # Normal
                speed1 = np.random.uniform(350, 550)
                speed2 = np.random.uniform(350, 550)
                outside_envelope = True
                
            elif scenario_type == 'extreme_speed':
                # Outside normal speed envelope
                alt1 = np.random.uniform(20000, 40000)
                alt2 = np.random.uniform(20000, 40000)
                speed1 = np.random.choice([
                    np.random.uniform(100, 199),      # Too slow
                    np.random.uniform(601, 800)       # Too fast
                ])
                speed2 = np.random.uniform(350, 550)
                outside_envelope = True
                
            else:  # extreme_separation
                # Extreme separation scenarios
                alt1 = np.random.uniform(20000, 40000)
                alt2 = np.random.uniform(20000, 40000)
                speed1 = np.random.uniform(350, 550)
                speed2 = np.random.uniform(350, 550)
                outside_envelope = False
            
            heading1 = np.random.uniform(0, 360)
            heading2 = np.random.uniform(0, 360)
            
            # Extreme separations
            if scenario_type == 'extreme_separation':
                h_dist = np.random.choice([
                    np.random.uniform(0.1, 0.9),      # Very close
                    np.random.uniform(100, 200)       # Very far
                ])
                v_dist = np.random.choice([
                    np.random.uniform(10, 50),        # Very close vertically
                    np.random.uniform(10000, 20000)   # Very far vertically
                ])
            else:
                h_dist = np.random.uniform(5, 15)
                v_dist = np.random.uniform(1000, 3000)
            
            conflict = h_dist < self.config.conflict_threshold_nm and v_dist < self.config.conflict_threshold_ft
            clearance_type = np.random.choice(['none', 'altitude_change', 'heading_change', 'speed_change'])
            
            scenario = {
                'scenario_id': f'edge_{i:03d}',
                'scenario_type': scenario_type,
                'altitude_1': alt1,
                'altitude_2': alt2,
                'speed_1': speed1,
                'speed_2': speed2,
                'heading_1': heading1,
                'heading_2': heading2,
                'horizontal_distance': h_dist,
                'vertical_distance': v_dist,
                'conflict': conflict,
                'clearance_type': clearance_type,
                'outside_envelope': outside_envelope
            }
            scenarios.append(scenario)
        
        return pd.DataFrame(scenarios)
    
    def generate_stress_test_scenarios(self, n_scenarios: int = 30) -> pd.DataFrame:
        """Generate stress test scenarios with extreme parameters"""
        scenarios = []
        
        for i in range(n_scenarios):
            # Extreme parameters outside any reasonable envelope
            alt1 = np.random.choice([
                np.random.uniform(1000, 4999),       # Very low
                np.random.uniform(60001, 80000)      # Very high
            ])
            alt2 = np.random.choice([
                np.random.uniform(1000, 4999),
                np.random.uniform(60001, 80000)
            ])
            
            speed1 = np.random.choice([
                np.random.uniform(50, 99),           # Very slow
                np.random.uniform(801, 1000)         # Very fast
            ])
            speed2 = np.random.choice([
                np.random.uniform(50, 99),
                np.random.uniform(801, 1000)
            ])
            
            heading1 = np.random.uniform(0, 360)
            heading2 = np.random.uniform(0, 360)
            
            # Extreme separations
            h_dist = np.random.choice([
                np.random.uniform(0.01, 0.5),        # Collision imminent
                np.random.uniform(500, 1000)         # Extremely far
            ])
            v_dist = np.random.choice([
                np.random.uniform(1, 25),            # Collision imminent
                np.random.uniform(50000, 100000)     # Extremely far
            ])
            
            conflict = h_dist < self.config.conflict_threshold_nm and v_dist < self.config.conflict_threshold_ft
            clearance_type = np.random.choice(['none', 'altitude_change', 'heading_change', 'speed_change'])
            
            scenario = {
                'scenario_id': f'stress_{i:03d}',
                'scenario_type': 'stress_test',
                'altitude_1': alt1,
                'altitude_2': alt2,
                'speed_1': speed1,
                'speed_2': speed2,
                'heading_1': heading1,
                'heading_2': heading2,
                'horizontal_distance': h_dist,
                'vertical_distance': v_dist,
                'conflict': conflict,
                'clearance_type': clearance_type,
                'outside_envelope': True
            }
            scenarios.append(scenario)
        
        return pd.DataFrame(scenarios)

# Test dataset for model evaluation
class ModelTestDataset(Dataset):
    """Dataset for model testing with various scenario types"""
    
    def __init__(self, scenarios_df: pd.DataFrame, tokenizer, max_length: int = 128):
        self.scenarios_df = scenarios_df.reset_index(drop=True)
        self.tokenizer = tokenizer
        self.max_length = max_length
        
        # Encode clearance types (need to match training)
        unique_clearances = ['none', 'altitude_change', 'heading_change', 'speed_change', 'other']
        self.clearance_encoder = LabelEncoder()
        self.clearance_encoder.fit(unique_clearances)
        
        self.scenarios_df['clearance_label'] = self.clearance_encoder.transform(
            self.scenarios_df['clearance_type'].fillna('none')
        )
        
        print(f"📊 Test Dataset: {len(self.scenarios_df)} scenarios")
        print(f"   - Normal: {(self.scenarios_df['scenario_type'] == 'normal').sum()}")
        print(f"   - Edge cases: {(self.scenarios_df['scenario_type'] != 'normal').sum()}")
        print(f"   - Outside envelope: {self.scenarios_df['outside_envelope'].sum()}")
    
    def __len__(self):
        return len(self.scenarios_df)
    
    def __getitem__(self, idx):
        row = self.scenarios_df.iloc[idx]
        
        # Create scenario description
        text = self._create_scenario_text(row)
        
        # Tokenize
        encoding = self.tokenizer(
            text,
            truncation=True,
            padding='max_length',
            max_length=self.max_length,
            return_tensors='pt'
        )
        
        return {
            'input_ids': encoding['input_ids'].squeeze(),
            'attention_mask': encoding['attention_mask'].squeeze(),
            'conflict_label': torch.tensor(1 if row['conflict'] else 0, dtype=torch.long),
            'clearance_label': torch.tensor(row['clearance_label'], dtype=torch.long),
            'envelope_violation': torch.tensor(float(row['outside_envelope']), dtype=torch.float),
            'scenario_id': row['scenario_id'],
            'scenario_type': row['scenario_type']
        }
    
    def _create_scenario_text(self, row):
        """Create natural language scenario description"""
        text = (
            f"Aircraft A at FL{int(row['altitude_1']/100):03d} "
            f"heading {int(row['heading_1']):03d}° "
            f"speed {int(row['speed_1']):03d} kt; "
            f"Aircraft B at FL{int(row['altitude_2']/100):03d} "
            f"heading {int(row['heading_2']):03d}° "
            f"speed {int(row['speed_2']):03d} kt; "
            f"horizontal separation {row['horizontal_distance']:.1f} NM; "
            f"vertical separation {row['vertical_distance']:.0f} ft."
        )
        return text

# Comprehensive model evaluator
class ComprehensiveModelEvaluator:
    """Comprehensive evaluation of the trained ATC model"""
    
    def __init__(self, model_path: str, config: TestConfig):
        self.config = config
        self.model = self._load_model(model_path)
        self.tokenizer = DistilBertTokenizer.from_pretrained(config.model_name)
        self.results = {}
        
    def _load_model(self, model_path: str):
        """Load the trained model"""
        try:
            print(f"🔄 Loading model from {model_path}")
            
            # Initialize model with correct parameters (from training)
            model = FixedHallucinationAwareATCModel(
                model_name=self.config.model_name,
                num_clearance_types=2  # Based on training output
            )
            
            # Load state dict
            state_dict = torch.load(model_path, map_location=self.config.device)
            model.load_state_dict(state_dict)
            model.to(self.config.device)
            model.eval()
            
            print("✅ Model loaded successfully")
            return model
            
        except Exception as e:
            print(f"❌ Error loading model: {e}")
            raise
    
    def evaluate_comprehensive(self, test_scenarios: pd.DataFrame):
        """Run comprehensive evaluation"""
        print("🧪 Starting Comprehensive Model Evaluation")
        print("=" * 60)
        
        # Create test dataset
        test_dataset = ModelTestDataset(test_scenarios, self.tokenizer, self.config.max_sequence_length)
        test_loader = DataLoader(test_dataset, batch_size=self.config.batch_size, shuffle=False)
        
        # Run evaluations
        print("1️⃣ Basic Performance Evaluation...")
        basic_results = self._evaluate_basic_performance(test_loader)
        
        print("2️⃣ Uncertainty Analysis...")
        uncertainty_results = self._evaluate_uncertainty(test_loader)
        
        print("3️⃣ Hallucination Detection Analysis...")
        hallucination_results = self._evaluate_hallucination_detection(test_loader)
        
        print("4️⃣ Edge Case Analysis...")
        edge_case_results = self._evaluate_edge_cases(test_loader, test_scenarios)
        
        print("5️⃣ Calibration Analysis...")
        calibration_results = self._evaluate_calibration(test_loader)
        
        # Combine results
        self.results = {
            'basic_performance': basic_results,
            'uncertainty_analysis': uncertainty_results,
            'hallucination_detection': hallucination_results,
            'edge_case_analysis': edge_case_results,
            'calibration_analysis': calibration_results
        }
        
        print("6️⃣ Generating Comprehensive Report...")
        self._generate_comprehensive_report()
        
        print("✅ Comprehensive evaluation completed!")
        return self.results
    
    def _evaluate_basic_performance(self, test_loader):
        """Evaluate basic model performance"""
        all_conflict_preds = []
        all_conflict_labels = []
        all_clearance_preds = []
        all_clearance_labels = []
        all_scenario_types = []
        all_confidences = []
        
        self.model.eval()
        with torch.no_grad():
            for batch in test_loader:
                input_ids = batch['input_ids'].to(self.config.device)
                attention_mask = batch['attention_mask'].to(self.config.device)
                
                conflict_logits, clearance_logits, _ = self.model(input_ids, attention_mask)
                
                # Predictions
                conflict_probs = F.softmax(conflict_logits, dim=-1)
                clearance_probs = F.softmax(clearance_logits, dim=-1)
                
                conflict_preds = torch.argmax(conflict_logits, dim=-1)
                clearance_preds = torch.argmax(clearance_logits, dim=-1)
                
                # Collect results
                all_conflict_preds.extend(conflict_preds.cpu().numpy())
                all_conflict_labels.extend(batch['conflict_label'].numpy())
                all_clearance_preds.extend(clearance_preds.cpu().numpy())
                all_clearance_labels.extend(batch['clearance_label'].numpy())
                all_scenario_types.extend(batch['scenario_type'])
                all_confidences.extend(torch.max(conflict_probs, dim=-1)[0].cpu().numpy())
        
        # Calculate metrics
        results = {
            'conflict_accuracy': np.mean(np.array(all_conflict_preds) == np.array(all_conflict_labels)),
            'conflict_f1': f1_score(all_conflict_labels, all_conflict_preds, average='weighted'),
            'clearance_accuracy': np.mean(np.array(all_clearance_preds) == np.array(all_clearance_labels)),
            'clearance_f1': f1_score(all_clearance_labels, all_clearance_preds, average='weighted'),
            'average_confidence': np.mean(all_confidences),
            'predictions': {
                'conflict_preds': all_conflict_preds,
                'conflict_labels': all_conflict_labels,
                'clearance_preds': all_clearance_preds,
                'clearance_labels': all_clearance_labels,
                'scenario_types': all_scenario_types,
                'confidences': all_confidences
            }
        }
        
        print(f"   ✅ Conflict Detection Accuracy: {results['conflict_accuracy']:.3f}")
        print(f"   ✅ Conflict Detection F1: {results['conflict_f1']:.3f}")
        print(f"   ✅ Clearance Prediction Accuracy: {results['clearance_accuracy']:.3f}")
        print(f"   ✅ Average Confidence: {results['average_confidence']:.3f}")
        
        return results
    
    def _evaluate_uncertainty(self, test_loader):
        """Evaluate uncertainty estimation using Monte Carlo dropout"""
        all_uncertainties = []
        all_labels = []
        all_scenario_types = []
        
        for batch in test_loader:
            input_ids = batch['input_ids'].to(self.config.device)
            attention_mask = batch['attention_mask'].to(self.config.device)
            
            # Get uncertainty estimates
            uncertainty_estimates = self.model.get_uncertainty_estimates(
                input_ids, attention_mask, n_samples=self.config.mc_dropout_samples
            )
            
            # Extract uncertainty metrics
            conflict_entropy = uncertainty_estimates['conflict_entropy'].cpu().numpy()
            conflict_std = uncertainty_estimates['conflict_std'].cpu().numpy().mean(axis=1)
            
            all_uncertainties.extend(conflict_entropy + conflict_std)  # Combined uncertainty
            all_labels.extend(batch['conflict_label'].numpy())
            all_scenario_types.extend(batch['scenario_type'])
        
        results = {
            'uncertainties': all_uncertainties,
            'labels': all_labels,
            'scenario_types': all_scenario_types,
            'high_uncertainty_threshold': np.percentile(all_uncertainties, 90),
            'avg_uncertainty_conflict': np.mean([u for u, l in zip(all_uncertainties, all_labels) if l == 1]),
            'avg_uncertainty_safe': np.mean([u for u, l in zip(all_uncertainties, all_labels) if l == 0])
        }
        
        print(f"   ✅ Average Uncertainty (Conflicts): {results['avg_uncertainty_conflict']:.3f}")
        print(f"   ✅ Average Uncertainty (Safe): {results['avg_uncertainty_safe']:.3f}")
        
        return results
    
    def _evaluate_hallucination_detection(self, test_loader):
        """Evaluate hallucination detection capabilities"""
        all_hallucination_preds = []
        all_envelope_violations = []
        all_scenario_types = []
        
        self.model.eval()
        with torch.no_grad():
            for batch in test_loader:
                input_ids = batch['input_ids'].to(self.config.device)
                attention_mask = batch['attention_mask'].to(self.config.device)
                
                _, _, hallucination_logits = self.model(input_ids, attention_mask)
                hallucination_probs = torch.sigmoid(hallucination_logits.squeeze(-1))
                
                all_hallucination_preds.extend(hallucination_probs.cpu().numpy())
                all_envelope_violations.extend(batch['envelope_violation'].numpy())
                all_scenario_types.extend(batch['scenario_type'])
        
        # Calculate metrics
        binary_preds = np.array(all_hallucination_preds) > 0.5
        results = {
            'hallucination_accuracy': np.mean(binary_preds == np.array(all_envelope_violations)),
            'hallucination_f1': f1_score(all_envelope_violations, binary_preds) if len(set(all_envelope_violations)) > 1 else 0,
            'predictions': all_hallucination_preds,
            'labels': all_envelope_violations,
            'scenario_types': all_scenario_types
        }
        
        print(f"   ✅ Hallucination Detection Accuracy: {results['hallucination_accuracy']:.3f}")
        print(f"   ✅ Hallucination Detection F1: {results['hallucination_f1']:.3f}")
        
        return results
    
    def _evaluate_edge_cases(self, test_loader, test_scenarios):
        """Evaluate performance on edge cases"""
        # Group by scenario type
        scenario_performance = {}
        
        basic_results = self.results.get('basic_performance', {})
        if not basic_results:
            return {}
        
        preds = basic_results['predictions']
        
        for scenario_type in test_scenarios['scenario_type'].unique():
            type_mask = np.array(preds['scenario_types']) == scenario_type
            
            if np.sum(type_mask) > 0:
                type_conflict_labels = np.array(preds['conflict_labels'])[type_mask]
                type_conflict_preds = np.array(preds['conflict_preds'])[type_mask]
                type_confidences = np.array(preds['confidences'])[type_mask]
                
                scenario_performance[scenario_type] = {
                    'count': np.sum(type_mask),
                    'accuracy': np.mean(type_conflict_preds == type_conflict_labels),
                    'avg_confidence': np.mean(type_confidences),
                    'f1_score': f1_score(type_conflict_labels, type_conflict_preds, average='weighted') if len(set(type_conflict_labels)) > 1 else 0
                }
        
        print(f"   ✅ Edge Case Analysis Complete")
        for scenario_type, metrics in scenario_performance.items():
            print(f"      {scenario_type}: Accuracy={metrics['accuracy']:.3f}, Confidence={metrics['avg_confidence']:.3f}")
        
        return scenario_performance
    
    def _evaluate_calibration(self, test_loader):
        """Evaluate model calibration"""
        all_confidences = []
        all_correct = []
        
        self.model.eval()
        with torch.no_grad():
            for batch in test_loader:
                input_ids = batch['input_ids'].to(self.config.device)
                attention_mask = batch['attention_mask'].to(self.config.device)
                
                conflict_logits, _, _ = self.model(input_ids, attention_mask)
                conflict_probs = F.softmax(conflict_logits, dim=-1)
                
                max_probs, preds = torch.max(conflict_probs, dim=-1)
                correct = (preds == batch['conflict_label'].to(self.config.device))
                
                all_confidences.extend(max_probs.cpu().numpy())
                all_correct.extend(correct.cpu().numpy())
        
        # Calculate calibration
        try:
            fraction_of_positives, mean_predicted_value = calibration_curve(
                all_correct, all_confidences, n_bins=10
            )
            calibration_error = np.mean(np.abs(fraction_of_positives - mean_predicted_value))
        except:
            calibration_error = float('nan')
            fraction_of_positives = []
            mean_predicted_value = []
        
        results = {
            'calibration_error': calibration_error,
            'confidences': all_confidences,
            'correct': all_correct,
            'fraction_of_positives': fraction_of_positives,
            'mean_predicted_value': mean_predicted_value
        }
        
        print(f"   ✅ Calibration Error: {calibration_error:.3f}")
        
        return results
    
    def _generate_comprehensive_report(self):
        """Generate comprehensive evaluation report with visualizations"""
        print("📊 Generating comprehensive report...")
        
        # Create visualizations
        self._plot_performance_metrics()
        self._plot_uncertainty_analysis()
        self._plot_edge_case_analysis()
        self._plot_calibration_analysis()
        
        # Generate summary report
        self._generate_summary_report()
        
        print(f"📁 Report saved to: {self.config.output_path}")
    
    def _plot_performance_metrics(self):
        """Plot performance metrics"""
        fig, axes = plt.subplots(2, 2, figsize=(15, 12))
        
        basic_results = self.results['basic_performance']
        preds = basic_results['predictions']
        
        # Confusion matrix for conflict detection
        cm_conflict = confusion_matrix(preds['conflict_labels'], preds['conflict_preds'])
        sns.heatmap(cm_conflict, annot=True, fmt='d', ax=axes[0,0], cmap='Blues')
        axes[0,0].set_title('Conflict Detection Confusion Matrix')
        axes[0,0].set_xlabel('Predicted')
        axes[0,0].set_ylabel('Actual')
        
        # Confidence distribution
        axes[0,1].hist(preds['confidences'], bins=30, alpha=0.7, color='skyblue')
        axes[0,1].set_title('Confidence Distribution')
        axes[0,1].set_xlabel('Confidence')
        axes[0,1].set_ylabel('Frequency')
        
        # Performance by scenario type
        scenario_types = list(set(preds['scenario_types']))
        accuracies = []
        for stype in scenario_types:
            type_mask = np.array(preds['scenario_types']) == stype
            if np.sum(type_mask) > 0:
                type_accuracy = np.mean(
                    np.array(preds['conflict_preds'])[type_mask] == 
                    np.array(preds['conflict_labels'])[type_mask]
                )
                accuracies.append(type_accuracy)
            else:
                accuracies.append(0)
        
        axes[1,0].bar(scenario_types, accuracies, color='lightgreen')
        axes[1,0].set_title('Accuracy by Scenario Type')
        axes[1,0].set_xlabel('Scenario Type')
        axes[1,0].set_ylabel('Accuracy')
        axes[1,0].tick_params(axis='x', rotation=45)
        
        # Confidence vs Accuracy
        conf_bins = np.linspace(0, 1, 11)
        bin_accuracies = []
        for i in range(len(conf_bins)-1):
            mask = (np.array(preds['confidences']) >= conf_bins[i]) & (np.array(preds['confidences']) < conf_bins[i+1])
            if np.sum(mask) > 0:
                bin_acc = np.mean(
                    np.array(preds['conflict_preds'])[mask] == 
                    np.array(preds['conflict_labels'])[mask]
                )
                bin_accuracies.append(bin_acc)
            else:
                bin_accuracies.append(0)
        
        axes[1,1].plot(conf_bins[:-1], bin_accuracies, 'o-', color='red')
        axes[1,1].plot([0, 1], [0, 1], '--', color='gray', alpha=0.5)
        axes[1,1].set_title('Reliability Diagram')
        axes[1,1].set_xlabel('Confidence')
        axes[1,1].set_ylabel('Accuracy')
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.config.output_path, 'performance_metrics.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_uncertainty_analysis(self):
        """Plot uncertainty analysis"""
        uncertainty_results = self.results['uncertainty_analysis']
        
        fig, axes = plt.subplots(1, 2, figsize=(15, 6))
        
        # Uncertainty by label
        conflict_uncertainties = [u for u, l in zip(uncertainty_results['uncertainties'], uncertainty_results['labels']) if l == 1]
        safe_uncertainties = [u for u, l in zip(uncertainty_results['uncertainties'], uncertainty_results['labels']) if l == 0]
        
        axes[0].hist(safe_uncertainties, bins=30, alpha=0.7, label='Safe Scenarios', color='green')
        axes[0].hist(conflict_uncertainties, bins=30, alpha=0.7, label='Conflict Scenarios', color='red')
        axes[0].set_title('Uncertainty Distribution by Label')
        axes[0].set_xlabel('Uncertainty')
        axes[0].set_ylabel('Frequency')
        axes[0].legend()
        
        # Uncertainty by scenario type
        scenario_types = list(set(uncertainty_results['scenario_types']))
        scenario_uncertainties = []
        for stype in scenario_types:
            type_uncertainties = [u for u, s in zip(uncertainty_results['uncertainties'], uncertainty_results['scenario_types']) if s == stype]
            if type_uncertainties:
                scenario_uncertainties.append(np.mean(type_uncertainties))
            else:
                scenario_uncertainties.append(0)
        
        axes[1].bar(scenario_types, scenario_uncertainties, color='orange')
        axes[1].set_title('Average Uncertainty by Scenario Type')
        axes[1].set_xlabel('Scenario Type')
        axes[1].set_ylabel('Average Uncertainty')
        axes[1].tick_params(axis='x', rotation=45)
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.config.output_path, 'uncertainty_analysis.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_edge_case_analysis(self):
        """Plot edge case analysis"""
        edge_results = self.results['edge_case_analysis']
        
        scenario_types = list(edge_results.keys())
        accuracies = [edge_results[st]['accuracy'] for st in scenario_types]
        confidences = [edge_results[st]['avg_confidence'] for st in scenario_types]
        counts = [edge_results[st]['count'] for st in scenario_types]
        
        fig, axes = plt.subplots(1, 3, figsize=(18, 6))
        
        # Accuracy by scenario type
        bars1 = axes[0].bar(scenario_types, accuracies, color='lightblue')
        axes[0].set_title('Accuracy by Scenario Type')
        axes[0].set_xlabel('Scenario Type')
        axes[0].set_ylabel('Accuracy')
        axes[0].tick_params(axis='x', rotation=45)
        
        # Add value labels on bars
        for bar, acc in zip(bars1, accuracies):
            height = bar.get_height()
            axes[0].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{acc:.3f}', ha='center', va='bottom')
        
        # Confidence by scenario type
        bars2 = axes[1].bar(scenario_types, confidences, color='lightcoral')
        axes[1].set_title('Average Confidence by Scenario Type')
        axes[1].set_xlabel('Scenario Type')
        axes[1].set_ylabel('Average Confidence')
        axes[1].tick_params(axis='x', rotation=45)
        
        for bar, conf in zip(bars2, confidences):
            height = bar.get_height()
            axes[1].text(bar.get_x() + bar.get_width()/2., height + 0.01,
                        f'{conf:.3f}', ha='center', va='bottom')
        
        # Sample count
        bars3 = axes[2].bar(scenario_types, counts, color='lightgreen')
        axes[2].set_title('Sample Count by Scenario Type')
        axes[2].set_xlabel('Scenario Type')
        axes[2].set_ylabel('Count')
        axes[2].tick_params(axis='x', rotation=45)
        
        for bar, count in zip(bars3, counts):
            height = bar.get_height()
            axes[2].text(bar.get_x() + bar.get_width()/2., height + 1,
                        f'{count}', ha='center', va='bottom')
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.config.output_path, 'edge_case_analysis.png'), dpi=300, bbox_inches='tight')
        plt.close()
    
    def _plot_calibration_analysis(self):
        """Plot calibration analysis"""
        calib_results = self.results['calibration_analysis']
        
        if len(calib_results['fraction_of_positives']) > 0:
            fig, axes = plt.subplots(1, 2, figsize=(12, 6))
            
            # Calibration plot
            axes[0].plot(calib_results['mean_predicted_value'], calib_results['fraction_of_positives'], 'o-', label='Model')
            axes[0].plot([0, 1], [0, 1], '--', color='gray', label='Perfect Calibration')
            axes[0].set_xlabel('Mean Predicted Probability')
            axes[0].set_ylabel('Fraction of Positives')
            axes[0].set_title('Calibration Plot')
            axes[0].legend()
            axes[0].grid(True, alpha=0.3)
            
            # Confidence histogram
            axes[1].hist(calib_results['confidences'], bins=30, alpha=0.7, color='purple')
            axes[1].set_title('Confidence Distribution')
            axes[1].set_xlabel('Confidence')
            axes[1].set_ylabel('Frequency')
            
            plt.tight_layout()
            plt.savefig(os.path.join(self.config.output_path, 'calibration_analysis.png'), dpi=300, bbox_inches='tight')
            plt.close()
    
    def _generate_summary_report(self):
        """Generate summary report"""
        report = []
        report.append("# Comprehensive Model Evaluation Report")
        report.append("=" * 50)
        report.append(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}")
        report.append("")
        
        # Basic Performance
        basic = self.results['basic_performance']
        report.append("## Basic Performance Metrics")
        report.append(f"- Conflict Detection Accuracy: {basic['conflict_accuracy']:.3f}")
        report.append(f"- Conflict Detection F1 Score: {basic['conflict_f1']:.3f}")
        report.append(f"- Clearance Prediction Accuracy: {basic['clearance_accuracy']:.3f}")
        report.append(f"- Clearance Prediction F1 Score: {basic['clearance_f1']:.3f}")
        report.append(f"- Average Model Confidence: {basic['average_confidence']:.3f}")
        report.append("")
        
        # Uncertainty Analysis
        uncertainty = self.results['uncertainty_analysis']
        report.append("## Uncertainty Analysis")
        report.append(f"- Average Uncertainty (Conflicts): {uncertainty['avg_uncertainty_conflict']:.3f}")
        report.append(f"- Average Uncertainty (Safe): {uncertainty['avg_uncertainty_safe']:.3f}")
        report.append(f"- High Uncertainty Threshold (90th percentile): {uncertainty['high_uncertainty_threshold']:.3f}")
        report.append("")
        
        # Hallucination Detection
        hallucination = self.results['hallucination_detection']
        report.append("## Hallucination Detection")
        report.append(f"- Hallucination Detection Accuracy: {hallucination['hallucination_accuracy']:.3f}")
        report.append(f"- Hallucination Detection F1 Score: {hallucination['hallucination_f1']:.3f}")
        report.append("")
        
        # Edge Case Analysis
        edge_cases = self.results['edge_case_analysis']
        report.append("## Edge Case Performance")
        for scenario_type, metrics in edge_cases.items():
            report.append(f"- {scenario_type}:")
            report.append(f"  - Sample Count: {metrics['count']}")
            report.append(f"  - Accuracy: {metrics['accuracy']:.3f}")
            report.append(f"  - Average Confidence: {metrics['avg_confidence']:.3f}")
            report.append(f"  - F1 Score: {metrics['f1_score']:.3f}")
        report.append("")
        
        # Calibration
        calibration = self.results['calibration_analysis']
        report.append("## Model Calibration")
        report.append(f"- Calibration Error: {calibration['calibration_error']:.3f}")
        report.append("")
        
        # Save report
        with open(os.path.join(self.config.output_path, 'evaluation_report.md'), 'w') as f:
            f.write('\n'.join(report))

# BlueSky Integration for Real-time Testing
class BlueSkyModelIntegration:
    """Integration with BlueSky simulator for real-time model testing"""
    
    def __init__(self, model_evaluator: ComprehensiveModelEvaluator):
        self.evaluator = model_evaluator
        self.simulation_log = []
        
    def create_test_scenario(self, aircraft_data: List[Dict]) -> str:
        """Create scenario text from aircraft data"""
        if len(aircraft_data) < 2:
            return ""
        
        ac1, ac2 = aircraft_data[0], aircraft_data[1]
        
        # Calculate separations
        h_dist = self._calculate_horizontal_distance(ac1, ac2)
        v_dist = abs(ac1['altitude'] - ac2['altitude'])
        
        text = (
            f"Aircraft A at FL{int(ac1['altitude']/100):03d} "
            f"heading {int(ac1['heading']):03d}° "
            f"speed {int(ac1['speed']):03d} kt; "
            f"Aircraft B at FL{int(ac2['altitude']/100):03d} "
            f"heading {int(ac2['heading']):03d}° "
            f"speed {int(ac2['speed']):03d} kt; "
            f"horizontal separation {h_dist:.1f} NM; "
            f"vertical separation {v_dist:.0f} ft."
        )
        return text
    
    def _calculate_horizontal_distance(self, ac1: Dict, ac2: Dict) -> float:
        """Calculate horizontal distance between aircraft"""
        # Simplified distance calculation (should use proper haversine for real implementation)
        lat_diff = ac1['latitude'] - ac2['latitude']
        lon_diff = ac1['longitude'] - ac2['longitude']
        return np.sqrt(lat_diff**2 + lon_diff**2) * 60  # Rough conversion to NM
    
    def real_time_conflict_detection(self, aircraft_data: List[Dict]) -> Dict:
        """Perform real-time conflict detection"""
        scenario_text = self.create_test_scenario(aircraft_data)
        
        if not scenario_text:
            return {'error': 'Insufficient aircraft data'}
        
        # Tokenize
        encoding = self.evaluator.tokenizer(
            scenario_text,
            truncation=True,
            padding='max_length',
            max_length=self.evaluator.config.max_sequence_length,
            return_tensors='pt'
        )
        
        input_ids = encoding['input_ids'].to(self.evaluator.config.device)
        attention_mask = encoding['attention_mask'].to(self.evaluator.config.device)
        
        # Get model predictions
        self.evaluator.model.eval()
        with torch.no_grad():
            conflict_logits, clearance_logits, hallucination_logits = self.evaluator.model(
                input_ids, attention_mask
            )
            
            # Get uncertainty estimates
            uncertainty_estimates = self.evaluator.model.get_uncertainty_estimates(
                input_ids, attention_mask, n_samples=10
            )
            
            conflict_probs = F.softmax(conflict_logits, dim=-1)
            clearance_probs = F.softmax(clearance_logits, dim=-1)
            hallucination_prob = torch.sigmoid(hallucination_logits)
            
            result = {
                'conflict_probability': conflict_probs[0, 1].item(),  # Probability of conflict
                'conflict_prediction': bool(conflict_probs[0, 1] > 0.5),
                'confidence': torch.max(conflict_probs[0]).item(),
                'recommended_clearance': int(torch.argmax(clearance_logits[0]).item()),
                'hallucination_risk': hallucination_prob[0].item(),
                'uncertainty': uncertainty_estimates['conflict_entropy'][0].item(),
                'scenario_text': scenario_text,
                'timestamp': datetime.now().isoformat()
            }
            
            # Log the result
            self.simulation_log.append(result)
            
            return result
    
    def simulate_test_flights(self, n_scenarios: int = 20):
        """Simulate test flights for real-time testing"""
        print(f"✈️ Simulating {n_scenarios} test flight scenarios...")
        
        results = []
        
        for i in range(n_scenarios):
            # Generate random aircraft data
            aircraft_data = [
                {
                    'id': f'TEST{i:03d}A',
                    'latitude': 59.0 + np.random.uniform(-2, 2),
                    'longitude': 18.0 + np.random.uniform(-2, 2),
                    'altitude': np.random.uniform(20000, 40000),
                    'speed': np.random.uniform(350, 550),
                    'heading': np.random.uniform(0, 360)
                },
                {
                    'id': f'TEST{i:03d}B',
                    'latitude': 59.0 + np.random.uniform(-2, 2),
                    'longitude': 18.0 + np.random.uniform(-2, 2),
                    'altitude': np.random.uniform(20000, 40000),
                    'speed': np.random.uniform(350, 550),
                    'heading': np.random.uniform(0, 360)
                }
            ]
            
            result = self.real_time_conflict_detection(aircraft_data)
            results.append(result)
            
            if result['conflict_prediction']:
                print(f"   🚨 Scenario {i:03d}: CONFLICT detected (P={result['conflict_probability']:.3f}, U={result['uncertainty']:.3f})")
            else:
                print(f"   ✅ Scenario {i:03d}: SAFE (P={result['conflict_probability']:.3f}, U={result['uncertainty']:.3f})")
        
        # Save simulation log
        import json
        with open(os.path.join(self.evaluator.config.output_path, 'simulation_log.json'), 'w') as f:
            json.dump(self.simulation_log, f, indent=2)
        
        print(f"📁 Simulation log saved to simulation_log.json")
        return results

# Main execution function
def run_comprehensive_model_test():
    """Run the complete model testing pipeline"""
    print("🎓 Starting Comprehensive Model Testing for ML Hallucination Research")
    print("=" * 80)
    
    # Initialize scenario generator
    print("📝 Generating Test Scenarios...")
    scenario_generator = ATCTestScenarioGenerator(test_config)
    
    # Generate different types of scenarios
    normal_scenarios = scenario_generator.generate_normal_scenarios(n_scenarios=100)
    edge_scenarios = scenario_generator.generate_edge_case_scenarios(n_scenarios=50)
    stress_scenarios = scenario_generator.generate_stress_test_scenarios(n_scenarios=30)
    
    # Combine all scenarios
    all_scenarios = pd.concat([normal_scenarios, edge_scenarios, stress_scenarios], ignore_index=True)
    
    print(f"✅ Generated {len(all_scenarios)} test scenarios:")
    print(f"   - Normal scenarios: {len(normal_scenarios)}")
    print(f"   - Edge case scenarios: {len(edge_scenarios)}")
    print(f"   - Stress test scenarios: {len(stress_scenarios)}")
    
    # Initialize model evaluator
    print("🤖 Loading Trained Model...")
    try:
        evaluator = ComprehensiveModelEvaluator(test_config.model_path, test_config)
    except Exception as e:
        print(f"❌ Failed to load model: {e}")
        print("💡 Make sure the model path is correct and the model file exists")
        return
    
    # Run comprehensive evaluation
    print("🧪 Running Comprehensive Evaluation...")
    results = evaluator.evaluate_comprehensive(all_scenarios)
    
    # BlueSky integration testing
    print("✈️ Testing BlueSky Integration...")
    bluesky_integration = BlueSkyModelIntegration(evaluator)
    simulation_results = bluesky_integration.simulate_test_flights(n_scenarios=20)
    
    # Final summary
    print("\n🎯 COMPREHENSIVE TESTING SUMMARY")
    print("=" * 50)
    print(f"✅ Model Performance:")
    print(f"   - Conflict Detection Accuracy: {results['basic_performance']['conflict_accuracy']:.3f}")
    print(f"   - Average Confidence: {results['basic_performance']['average_confidence']:.3f}")
    print(f"   - Hallucination Detection Accuracy: {results['hallucination_detection']['hallucination_accuracy']:.3f}")
    
    print(f"\n📊 Edge Case Analysis:")
    for scenario_type, metrics in results['edge_case_analysis'].items():
        print(f"   - {scenario_type}: Accuracy={metrics['accuracy']:.3f}")
    
    print(f"\n🌡️ Uncertainty Analysis:")
    print(f"   - Avg Uncertainty (Conflicts): {results['uncertainty_analysis']['avg_uncertainty_conflict']:.3f}")
    print(f"   - Avg Uncertainty (Safe): {results['uncertainty_analysis']['avg_uncertainty_safe']:.3f}")
    
    print(f"\n✈️ BlueSky Simulation:")
    conflict_count = sum(1 for r in simulation_results if r['conflict_prediction'])
    print(f"   - {conflict_count}/{len(simulation_results)} scenarios predicted as conflicts")
    print(f"   - Average confidence: {np.mean([r['confidence'] for r in simulation_results]):.3f}")
    
    print(f"\n📁 All results saved to: {test_config.output_path}")
    print("🎓 Comprehensive model testing completed!")
    
    return results, simulation_results

# Execute the comprehensive testing
if __name__ == "__main__":
    results, simulation_results = run_comprehensive_model_test()

[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m356.8/356.8 kB[0m [31m5.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m66.7/66.7 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
[?25h[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
cesium 0.12.4 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.[0m[31m
[0m

2025-06-11 15:00:26.944791: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:477] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
E0000 00:00:1749654027.246912      13 cuda_dnn.cc:8310] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
E0000 00:00:1749654027.331188      13 cuda_blas.cc:1418] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered


🧪 Model Testing Configuration
📁 Model path: /kaggle/input/mt-100625-scat-4/thesis_models/streaming_hallucination_aware_atc_model.pt
🖥️  Device: cpu
📊 Output path: ./model_test_results
🎓 Starting Comprehensive Model Testing for ML Hallucination Research
📝 Generating Test Scenarios...
✅ Generated 180 test scenarios:
   - Normal scenarios: 100
   - Edge case scenarios: 50
   - Stress test scenarios: 30
🤖 Loading Trained Model...
🔄 Loading model from /kaggle/input/mt-100625-scat-4/thesis_models/streaming_hallucination_aware_atc_model.pt


config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

✅ Model loaded successfully


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

🧪 Running Comprehensive Evaluation...
🧪 Starting Comprehensive Model Evaluation
📊 Test Dataset: 180 scenarios
   - Normal: 100
   - Edge cases: 80
   - Outside envelope: 66
1️⃣ Basic Performance Evaluation...
   ✅ Conflict Detection Accuracy: 0.750
   ✅ Conflict Detection F1: 0.643
   ✅ Clearance Prediction Accuracy: 0.206
   ✅ Average Confidence: 0.991
2️⃣ Uncertainty Analysis...
   ✅ Average Uncertainty (Conflicts): 0.075
   ✅ Average Uncertainty (Safe): 0.051
3️⃣ Hallucination Detection Analysis...
   ✅ Hallucination Detection Accuracy: 0.633
   ✅ Hallucination Detection F1: 0.000
4️⃣ Edge Case Analysis...
5️⃣ Calibration Analysis...
   ✅ Calibration Error: 0.241
6️⃣ Generating Comprehensive Report...
📊 Generating comprehensive report...
📁 Report saved to: ./model_test_results
✅ Comprehensive evaluation completed!
✈️ Testing BlueSky Integration...
✈️ Simulating 20 test flight scenarios...
   ✅ Scenario 000: SAFE (P=0.004, U=0.031)
   ✅ Scenario 001: SAFE (P=0.004, U=0.031)
   ✅ Scen