In [1]:
# Comprehensive Validation Framework for Drone Localization Agent
import torch
import numpy as np
import cv2
import matplotlib.pyplot as plt
import seaborn as sns
from pathlib import Path
import json
import pandas as pd
from typing import Dict, List, Tuple, Optional
from sklearn.metrics import silhouette_score
from scipy.spatial.distance import cdist
import random

class DroneLocalizationValidator:
    """
    Comprehensive validation framework for similarity-based drone localization
    """
    
    def __init__(self, trainer, validation_crops: int = 10):
        self.trainer = trainer
        self.agent = trainer.agent
        self.env = trainer.env
        self.validation_crops = validation_crops
        
        # Load ground truth data from JSON
        self.ground_truth_data = self._load_ground_truth()
        
        # Validation results
        self.validation_results = {
            'ground_truth_validation': {},  # NEW: Using center_pixel coordinates
            'visual_validation': [],
            'quantitative_metrics': {},
            'human_evaluation': [],
            'synthetic_tests': {},
            'cross_validation': {},
            'transfer_test': {}
        }
        
        print(f"🔍 Validation Framework Initialized")
        print(f"   Validation crops: {validation_crops}")
        print(f"   Ground truth coordinates: {len(self.ground_truth_data)} crops")
    
    def _load_ground_truth(self) -> Dict:
        """Load ground truth center_pixel coordinates from JSON"""
        try:
            metadata_path = self.env.crops_metadata_path
            with open(metadata_path, 'r') as f:
                metadata = json.load(f)
            
            # Create lookup dictionary: filename -> center_pixel
            ground_truth = {}
            for crop in metadata['crops']:
                ground_truth[crop['filename']] = {
                    'center_pixel': crop['center_pixel'],
                    'crop_size': crop.get('crop_size', 256),
                    'crop_index': crop.get('crop_index', 0)
                }
            
            print(f"   ✅ Loaded ground truth for {len(ground_truth)} crops")
            return ground_truth
            
        except Exception as e:
            print(f"   ❌ Could not load ground truth: {e}")
            return {}
    
    def pixel_to_grid(self, pixel_x: int, pixel_y: int) -> Tuple[int, int]:
        """Convert pixel coordinates to grid coordinates"""
        tif_h, tif_w = self.env.tif_image.shape[:2]
        
        grid_x = int((pixel_x / tif_w) * self.env.grid_size)
        grid_y = int((pixel_y / tif_h) * self.env.grid_size)
        
        # Clamp to valid grid range
        grid_x = max(0, min(self.env.grid_size - 1, grid_x))
        grid_y = max(0, min(self.env.grid_size - 1, grid_y))
        
        return grid_x, grid_y
    
    def grid_to_pixel(self, grid_x: int, grid_y: int) -> Tuple[int, int]:
        """Convert grid coordinates to pixel coordinates"""
        tif_h, tif_w = self.env.tif_image.shape[:2]
        
        pixel_x = int((grid_x / self.env.grid_size) * tif_w)
        pixel_y = int((grid_y / self.env.grid_size) * tif_h)
        
        return pixel_x, pixel_y
    
    def calculate_spatial_accuracy(self, true_pixel: Tuple[int, int], 
                                 predicted_grids: List[Tuple[int, int]]) -> Dict:
        """Calculate spatial accuracy metrics"""
        
        # Convert true pixel to grid
        true_grid = self.pixel_to_grid(true_pixel[0], true_pixel[1])
        
        # Calculate distances to all predictions
        distances_grid = []
        distances_pixel = []
        
        for pred_grid in predicted_grids:
            # Grid distance
            grid_dist = np.sqrt((pred_grid[0] - true_grid[0])**2 + (pred_grid[1] - true_grid[1])**2)
            distances_grid.append(grid_dist)
            
            # Convert back to pixel distance for real-world interpretation
            pred_pixel = self.grid_to_pixel(pred_grid[0], pred_grid[1])
            pixel_dist = np.sqrt((pred_pixel[0] - true_pixel[0])**2 + (pred_pixel[1] - true_pixel[1])**2)
            distances_pixel.append(pixel_dist)
        
        # Calculate metrics
        min_distance_grid = min(distances_grid)
        min_distance_pixel = min(distances_pixel)
        best_rank = distances_grid.index(min_distance_grid)  # 0, 1, or 2
        
        return {
            'true_grid': true_grid,
            'predicted_grids': predicted_grids,
            'distances_grid': distances_grid,
            'distances_pixel': distances_pixel,
            'min_distance_grid': min_distance_grid,
            'min_distance_pixel': min_distance_pixel,
            'best_rank': best_rank,
            'top1_accuracy': best_rank == 0,
            'top3_accuracy': True,  # Always true since we check top-3
            'within_1_grid': min_distance_grid <= 1.0,
            'within_2_grid': min_distance_grid <= 2.0,
            'within_50_pixels': min_distance_pixel <= 50,
            'within_100_pixels': min_distance_pixel <= 100
        }
    
    def run_comprehensive_validation(self):
        """Run all validation tests"""
        
        print(f"\n🧪" + "="*60 + "🧪")
        print("    COMPREHENSIVE VALIDATION SUITE")
        print(f"🧪" + "="*60 + "🧪")
        
        # 0. Ground Truth Validation (NEW - MOST IMPORTANT!)
        print(f"\n🎯 0. GROUND TRUTH VALIDATION (Using center_pixel)")
        self.ground_truth_validation()
        
        # 1. Visual Validation
        print(f"\n📸 1. VISUAL VALIDATION")
        self.visual_validation()
        
        # 2. Quantitative Metrics
        print(f"\n📊 2. QUANTITATIVE METRICS")
        self.quantitative_validation()
        
        # 3. Synthetic Test Cases
        print(f"\n🧪 3. SYNTHETIC TEST CASES")
        self.synthetic_validation()
        
        # 4. Cross-Validation
        print(f"\n🔄 4. CROSS-VALIDATION")
        self.cross_validation()
        
        # 5. Clustering Analysis
        print(f"\n🎯 5. CLUSTERING ANALYSIS")
        self.clustering_validation()
        
        # 6. Transfer Test (if multiple TIF files available)
        print(f"\n🚀 6. TRANSFER CAPABILITY TEST")
        self.transfer_validation()
        
        # 7. Generate validation report
        self.generate_validation_report()
        
        return self.validation_results
    
    def ground_truth_validation(self):
        """Validate using ground truth center_pixel coordinates"""
        
        print(f"   🎯 Testing spatial accuracy using ground truth coordinates...")
        
        if not self.ground_truth_data:
            print(f"   ❌ No ground truth data available")
            return
        
        # Test all available crops with ground truth
        all_results = []
        test_crops = list(self.ground_truth_data.keys())[:self.validation_crops]
        
        for crop_filename in test_crops:
            try:
                # Load specific crop
                crop_path = Path("realistic_drone_crops") / crop_filename
                if not crop_path.exists():
                    continue
                
                crop_image = cv2.imread(str(crop_path))
                crop_image = cv2.cvtColor(crop_image, cv2.COLOR_BGR2RGB)
                
                # Set current crop in environment
                self.env.current_crop = crop_image
                self.env.current_metadata = {'filename': crop_filename}
                
                # Get agent predictions
                locations, probabilities, _ = self.agent.select_top3_actions(self.env.tif_image, crop_image)
                
                # Get ground truth
                ground_truth = self.ground_truth_data[crop_filename]
                true_pixel = tuple(ground_truth['center_pixel'])
                
                # Calculate spatial accuracy
                spatial_metrics = self.calculate_spatial_accuracy(true_pixel, locations)
                
                # Also calculate similarity for comparison
                reward, similarities = self.env.calculate_reward(locations, probabilities)
                
                result = {
                    'crop_filename': crop_filename,
                    'true_pixel': true_pixel,
                    'predicted_locations': locations,
                    'probabilities': probabilities,
                    'similarities': similarities,
                    'spatial_metrics': spatial_metrics,
                    'reward': reward
                }
                
                all_results.append(result)
                
            except Exception as e:
                print(f"   ⚠️ Error testing {crop_filename}: {e}")
                continue
        
        if not all_results:
            print(f"   ❌ No successful ground truth tests")
            return
        
        # Aggregate results
        spatial_accuracy_metrics = {
            'total_tested': len(all_results),
            'top1_accuracy': np.mean([r['spatial_metrics']['top1_accuracy'] for r in all_results]),
            'within_1_grid': np.mean([r['spatial_metrics']['within_1_grid'] for r in all_results]),
            'within_2_grid': np.mean([r['spatial_metrics']['within_2_grid'] for r in all_results]),
            'within_50_pixels': np.mean([r['spatial_metrics']['within_50_pixels'] for r in all_results]),
            'within_100_pixels': np.mean([r['spatial_metrics']['within_100_pixels'] for r in all_results]),
            'mean_distance_grid': np.mean([r['spatial_metrics']['min_distance_grid'] for r in all_results]),
            'mean_distance_pixel': np.mean([r['spatial_metrics']['min_distance_pixel'] for r in all_results]),
            'median_distance_pixel': np.median([r['spatial_metrics']['min_distance_pixel'] for r in all_results]),
            'rank_distribution': {
                'rank_0': np.mean([r['spatial_metrics']['best_rank'] == 0 for r in all_results]),
                'rank_1': np.mean([r['spatial_metrics']['best_rank'] == 1 for r in all_results]),
                'rank_2': np.mean([r['spatial_metrics']['best_rank'] == 2 for r in all_results])
            }
        }
        
        self.validation_results['ground_truth_validation'] = {
            'metrics': spatial_accuracy_metrics,
            'detailed_results': all_results
        }
        
        # Print results
        print(f"   📈 Ground Truth Validation Results:")
        print(f"      Crops tested: {spatial_accuracy_metrics['total_tested']}")
        print(f"      Top-1 spatial accuracy: {spatial_accuracy_metrics['top1_accuracy']:.1%}")
        print(f"      Within 1 grid cell: {spatial_accuracy_metrics['within_1_grid']:.1%}")
        print(f"      Within 2 grid cells: {spatial_accuracy_metrics['within_2_grid']:.1%}")
        print(f"      Within 50 pixels: {spatial_accuracy_metrics['within_50_pixels']:.1%}")
        print(f"      Within 100 pixels: {spatial_accuracy_metrics['within_100_pixels']:.1%}")
        print(f"      Mean distance: {spatial_accuracy_metrics['mean_distance_pixel']:.1f} pixels")
        print(f"      Median distance: {spatial_accuracy_metrics['median_distance_pixel']:.1f} pixels")
        
        # Plot spatial accuracy visualization
        self._plot_spatial_accuracy(all_results)
    
    def _plot_spatial_accuracy(self, results: List[Dict]):
        """Plot spatial accuracy visualization"""
        
        fig, axes = plt.subplots(2, 3, figsize=(18, 12))
        
        # 1. Distance distribution
        distances = [r['spatial_metrics']['min_distance_pixel'] for r in results]
        axes[0,0].hist(distances, bins=15, alpha=0.7, color='blue', edgecolor='black')
        axes[0,0].axvline(x=50, color='red', linestyle='--', label='50px threshold')
        axes[0,0].axvline(x=100, color='orange', linestyle='--', label='100px threshold')
        axes[0,0].set_title('Distance to True Location (pixels)')
        axes[0,0].set_xlabel('Distance (pixels)')
        axes[0,0].set_ylabel('Count')
        axes[0,0].legend()
        
        # 2. Rank distribution
        ranks = [r['spatial_metrics']['best_rank'] for r in results]
        rank_counts = [ranks.count(i) for i in range(3)]
        axes[0,1].bar(['Top-1', 'Top-2', 'Top-3'], rank_counts, color=['gold', 'silver', 'bronze'])
        axes[0,1].set_title('Best Prediction Rank Distribution')
        axes[0,1].set_ylabel('Count')
        
        # 3. Spatial scatter plot of errors
        true_positions = np.array([r['true_pixel'] for r in results])
        predicted_positions = []
        
        for r in results:
            best_rank = r['spatial_metrics']['best_rank']
            best_grid = r['predicted_locations'][best_rank]
            best_pixel = self.grid_to_pixel(best_grid[0], best_grid[1])
            predicted_positions.append(best_pixel)
        
        predicted_positions = np.array(predicted_positions)
        
        axes[0,2].scatter(true_positions[:, 0], true_positions[:, 1], 
                         c='red', label='True positions', s=50, alpha=0.7)
        axes[0,2].scatter(predicted_positions[:, 0], predicted_positions[:, 1], 
                         c='blue', label='Best predictions', s=50, alpha=0.7)
        
        # Draw lines connecting true to predicted
        for i in range(len(true_positions)):
            axes[0,2].plot([true_positions[i,0], predicted_positions[i,0]], 
                          [true_positions[i,1], predicted_positions[i,1]], 
                          'gray', alpha=0.3, linewidth=1)
        
        axes[0,2].set_title('True vs Predicted Positions')
        axes[0,2].set_xlabel('X (pixels)')
        axes[0,2].set_ylabel('Y (pixels)')
        axes[0,2].legend()
        axes[0,2].set_aspect('equal')
        
        # 4. Distance vs similarity correlation
        similarities = [max(r['similarities']) for r in results]
        axes[1,0].scatter(distances, similarities, alpha=0.7)
        
        # Calculate correlation
        correlation = np.corrcoef(distances, similarities)[0,1]
        axes[1,0].set_title(f'Distance vs Similarity\n(Correlation: {correlation:.3f})')
        axes[1,0].set_xlabel('Distance to true location (pixels)')
        axes[1,0].set_ylabel('Best similarity score')
        
        # 5. Accuracy by threshold
        thresholds = [25, 50, 75, 100, 150, 200]
        accuracies = []
        
        for threshold in thresholds:
            accuracy = np.mean([d <= threshold for d in distances])
            accuracies.append(accuracy)
        
        axes[1,1].plot(thresholds, accuracies, 'o-', linewidth=2, markersize=8)
        axes[1,1].set_title('Accuracy vs Distance Threshold')
        axes[1,1].set_xlabel('Distance threshold (pixels)')
        axes[1,1].set_ylabel('Accuracy')
        axes[1,1].grid(True, alpha=0.3)
        
        # 6. Individual crop performance
        crop_names = [r['crop_filename'][:20] + '...' if len(r['crop_filename']) > 20 
                     else r['crop_filename'] for r in results[:10]]  # Show first 10
        crop_distances = distances[:10]
        
        colors = ['green' if d <= 50 else 'orange' if d <= 100 else 'red' for d in crop_distances]
        
        bars = axes[1,2].bar(range(len(crop_names)), crop_distances, color=colors, alpha=0.7)
        axes[1,2].set_title('Distance by Crop (First 10)')
        axes[1,2].set_xlabel('Crop')
        axes[1,2].set_ylabel('Distance (pixels)')
        axes[1,2].set_xticks(range(len(crop_names)))
        axes[1,2].set_xticklabels(crop_names, rotation=45, ha='right')
        
        # Add threshold lines
        axes[1,2].axhline(y=50, color='red', linestyle='--', alpha=0.7, label='50px')
        axes[1,2].axhline(y=100, color='orange', linestyle='--', alpha=0.7, label='100px')
        axes[1,2].legend()
        
        plt.tight_layout()
        plt.suptitle('Ground Truth Spatial Accuracy Analysis', fontsize=16, y=1.02)
        plt.show()
        
        return fig
    
    def visual_validation(self):
        """Visual validation - show predictions for human evaluation"""
        
        print(f"   📸 Generating visual validation examples...")
        
        fig, axes = plt.subplots(3, 4, figsize=(20, 15))
        
        for i in range(3):  # Test 3 different crops
            # Reset environment
            tif_image, crop_image, state = self.env.reset()
            
            # Get agent predictions
            locations, probabilities, _ = self.agent.select_top3_actions(tif_image, crop_image)
            
            # Calculate similarities for validation
            reward, similarities = self.env.calculate_reward(locations, probabilities)
            
            # Show original crop
            axes[i, 0].imshow(crop_image)
            axes[i, 0].set_title(f'Original Crop\n{state["metadata"]["filename"]}')
            axes[i, 0].axis('off')
            
            # Show top 3 predictions
            for j, (loc, prob, sim) in enumerate(zip(locations, probabilities, similarities)):
                grid_x, grid_y = loc
                predicted_area = self.env.extract_area_at_grid(grid_x, grid_y)
                
                axes[i, j+1].imshow(predicted_area)
                axes[i, j+1].set_title(f'Pred #{j+1}: Grid({grid_x},{grid_y})\n'
                                     f'Confidence: {prob:.2%}\n'
                                     f'Similarity: {sim:.3f}')
                axes[i, j+1].axis('off')
                
                # Color-code border by quality
                if sim > 0.7:
                    border_color = 'green'
                elif sim > 0.5:
                    border_color = 'orange' 
                else:
                    border_color = 'red'
                
                for spine in axes[i, j+1].spines.values():
                    spine.set_color(border_color)
                    spine.set_linewidth(3)
            
            # Store for analysis
            self.validation_results['visual_validation'].append({
                'crop_file': state["metadata"]["filename"],
                'predictions': locations,
                'probabilities': probabilities,
                'similarities': similarities,
                'reward': reward
            })
        
        plt.tight_layout()
        plt.suptitle('Visual Validation - Agent Predictions\n🟢 Good (>0.7) 🟠 OK (>0.5) 🔴 Poor (<0.5)', 
                     fontsize=16, y=0.98)
        plt.show()
        
        # Analysis
        all_similarities = [item['similarities'] for item in self.validation_results['visual_validation']]
        flat_similarities = [sim for sublist in all_similarities for sim in sublist]
        
        print(f"   📊 Visual Validation Results:")
        print(f"      Average similarity: {np.mean(flat_similarities):.3f}")
        print(f"      Best similarity: {np.max(flat_similarities):.3f}")
        print(f"      Good predictions (>0.7): {sum(1 for s in flat_similarities if s > 0.7)}/{len(flat_similarities)}")
    
    def quantitative_validation(self):
        """Quantitative metrics validation"""
        
        print(f"   📊 Computing quantitative metrics...")
        
        # Test on multiple crops
        all_results = []
        
        for i in range(self.validation_crops):
            tif_image, crop_image, state = self.env.reset()
            locations, probabilities, _ = self.agent.select_top3_actions(tif_image, crop_image)
            reward, similarities = self.env.calculate_reward(locations, probabilities)
            
            result = {
                'crop_id': i,
                'similarities': similarities,
                'probabilities': probabilities,
                'reward': reward,
                'top1_similarity': similarities[0],
                'max_similarity': max(similarities),
                'confidence_accuracy_correlation': np.corrcoef(probabilities, similarities)[0,1]
            }
            all_results.append(result)
        
        # Aggregate metrics
        metrics = {
            'mean_top1_similarity': np.mean([r['top1_similarity'] for r in all_results]),
            'mean_max_similarity': np.mean([r['max_similarity'] for r in all_results]),
            'mean_reward': np.mean([r['reward'] for r in all_results]),
            'std_reward': np.std([r['reward'] for r in all_results]),
            'top1_threshold_70': sum(1 for r in all_results if r['top1_similarity'] > 0.7) / len(all_results),
            'top3_threshold_70': sum(1 for r in all_results if r['max_similarity'] > 0.7) / len(all_results),
            'confidence_correlation': np.mean([r['confidence_accuracy_correlation'] for r in all_results if not np.isnan(r['confidence_accuracy_correlation'])])
        }
        
        self.validation_results['quantitative_metrics'] = metrics
        
        print(f"   📈 Quantitative Results:")
        print(f"      Mean Top-1 Similarity: {metrics['mean_top1_similarity']:.3f}")
        print(f"      Mean Max Similarity: {metrics['mean_max_similarity']:.3f}")
        print(f"      Top-1 Success Rate (>0.7): {metrics['top1_threshold_70']:.1%}")
        print(f"      Top-3 Success Rate (>0.7): {metrics['top3_threshold_70']:.1%}")
        print(f"      Confidence-Accuracy Correlation: {metrics['confidence_correlation']:.3f}")
        
        # Plot metrics distribution
        fig, axes = plt.subplots(2, 2, figsize=(12, 10))
        
        # Similarity distribution
        all_sims = [sim for r in all_results for sim in r['similarities']]
        axes[0,0].hist(all_sims, bins=20, alpha=0.7, color='blue')
        axes[0,0].axvline(x=0.7, color='red', linestyle='--', label='Good threshold')
        axes[0,0].set_title('Similarity Score Distribution')
        axes[0,0].set_xlabel('Similarity Score')
        axes[0,0].legend()
        
        # Reward distribution
        rewards = [r['reward'] for r in all_results]
        axes[0,1].hist(rewards, bins=20, alpha=0.7, color='green')
        axes[0,1].set_title('Reward Distribution')
        axes[0,1].set_xlabel('Reward')
        
        # Confidence vs Similarity scatter
        all_probs = [prob for r in all_results for prob in r['probabilities']]
        axes[1,0].scatter(all_probs, all_sims, alpha=0.6)
        axes[1,0].set_xlabel('Confidence')
        axes[1,0].set_ylabel('Similarity')
        axes[1,0].set_title('Confidence vs Similarity')
        
        # Success rate by ranking
        rank_success = []
        for rank in range(3):
            rank_sims = [r['similarities'][rank] for r in all_results]
            success_rate = sum(1 for s in rank_sims if s > 0.7) / len(rank_sims)
            rank_success.append(success_rate)
        
        axes[1,1].bar(['Top-1', 'Top-2', 'Top-3'], rank_success, color=['gold', 'silver', 'bronze'])
        axes[1,1].set_title('Success Rate by Ranking')
        axes[1,1].set_ylabel('Success Rate (>0.7 similarity)')
        
        plt.tight_layout()
        plt.show()
    
    def synthetic_validation(self):
        """Test with synthetic/known cases"""
        
        print(f"   🧪 Creating synthetic test cases...")
        
        # Test 1: Identity test (crop vs itself)
        identity_results = []
        
        for i in range(5):
            tif_image, crop_image, state = self.env.reset()
            
            # Create a synthetic case: place the crop in a known location
            h, w = tif_image.shape[:2]
            
            # Place crop in center of TIF
            center_x, center_y = w // 2, h // 2
            crop_h, crop_w = crop_image.shape[:2]
            
            # Insert crop into TIF at center
            synthetic_tif = tif_image.copy()
            x1 = center_x - crop_w // 2
            y1 = center_y - crop_h // 2
            x2 = x1 + crop_w
            y2 = y1 + crop_h
            
            if x1 >= 0 and y1 >= 0 and x2 < w and y2 < h:
                synthetic_tif[y1:y2, x1:x2] = crop_image
                
                # Test agent on this synthetic case
                locations, probabilities, _ = self.agent.select_top3_actions(synthetic_tif, crop_image)
                
                # Check if agent finds the center location
                center_grid_x = int((center_x / w) * self.env.grid_size)
                center_grid_y = int((center_y / h) * self.env.grid_size)
                
                # Find closest prediction to center
                distances = []
                for loc in locations:
                    dist = np.sqrt((loc[0] - center_grid_x)**2 + (loc[1] - center_grid_y)**2)
                    distances.append(dist)
                
                min_distance = min(distances)
                best_rank = distances.index(min_distance)
                
                identity_results.append({
                    'min_distance': min_distance,
                    'best_rank': best_rank,
                    'found_exact': min_distance <= 1.0  # Within 1 grid cell
                })
        
        synthetic_metrics = {
            'identity_mean_distance': np.mean([r['min_distance'] for r in identity_results]),
            'identity_success_rate': np.mean([r['found_exact'] for r in identity_results]),
            'identity_top1_rate': np.mean([r['best_rank'] == 0 for r in identity_results])
        }
        
        self.validation_results['synthetic_tests'] = synthetic_metrics
        
        print(f"   🎯 Synthetic Test Results:")
        print(f"      Identity test success rate: {synthetic_metrics['identity_success_rate']:.1%}")
        print(f"      Mean distance to planted crop: {synthetic_metrics['identity_mean_distance']:.2f} grid cells")
        print(f"      Top-1 detection rate: {synthetic_metrics['identity_top1_rate']:.1%}")
    
    def cross_validation(self):
        """Cross-validation across different crop types/altitudes"""
        
        print(f"   🔄 Cross-validation by crop characteristics...")
        
        # Group crops by altitude if available
        altitude_groups = {}
        other_crops = []
        
        for crop_data in self.env.crops_data:
            altitude = crop_data.get('altitude_meters', None)
            if altitude:
                if altitude not in altitude_groups:
                    altitude_groups[altitude] = []
                altitude_groups[altitude].append(crop_data)
            else:
                other_crops.append(crop_data)
        
        print(f"      Found {len(altitude_groups)} altitude groups: {list(altitude_groups.keys())}")
        
        # Test performance by altitude
        altitude_performance = {}
        
        for altitude, crops in altitude_groups.items():
            if len(crops) >= 3:  # Need at least 3 crops for meaningful test
                similarities = []
                
                for crop_data in crops[:5]:  # Test up to 5 crops per altitude
                    # Manually set the crop for testing
                    self.env.current_metadata = crop_data
                    crop_path = Path("realistic_drone_crops") / crop_data['filename']
                    crop_image = cv2.imread(str(crop_path))
                    self.env.current_crop = cv2.cvtColor(crop_image, cv2.COLOR_BGR2RGB)
                    
                    # Get predictions
                    locations, probabilities, _ = self.agent.select_top3_actions(self.env.tif_image, self.env.current_crop)
                    reward, sims = self.env.calculate_reward(locations, probabilities)
                    
                    similarities.extend(sims)
                
                altitude_performance[altitude] = {
                    'mean_similarity': np.mean(similarities),
                    'max_similarity': np.max(similarities),
                    'success_rate': sum(1 for s in similarities if s > 0.7) / len(similarities)
                }
        
        self.validation_results['cross_validation'] = altitude_performance
        
        print(f"   📊 Cross-validation Results:")
        for altitude, perf in altitude_performance.items():
            print(f"      {altitude}m altitude: "
                  f"Mean sim: {perf['mean_similarity']:.3f}, "
                  f"Success: {perf['success_rate']:.1%}")
    
    def clustering_validation(self):
        """Validate that agent groups similar terrains together"""
        
        print(f"   🎯 Analyzing spatial clustering of predictions...")
        
        # Collect predictions for multiple crops
        all_predictions = []
        crop_types = []
        
        for i in range(min(10, len(self.env.crops_data))):
            tif_image, crop_image, state = self.env.reset()
            locations, probabilities, _ = self.agent.select_top3_actions(tif_image, crop_image)
            
            # Store top prediction location
            all_predictions.append(locations[0])
            crop_types.append(state['metadata'].get('altitude_meters', 'unknown'))
        
        # Analyze spatial distribution
        if len(all_predictions) > 3:
            coords = np.array(all_predictions)
            
            # Calculate clustering tendency
            from sklearn.cluster import KMeans
            
            # Try different numbers of clusters
            inertias = []
            k_range = range(2, min(6, len(all_predictions)))
            
            for k in k_range:
                kmeans = KMeans(n_clusters=k, random_state=42, n_init=10)
                kmeans.fit(coords)
                inertias.append(kmeans.inertia_)
            
            # Find elbow point (optimal clusters)
            if len(inertias) > 1:
                diffs = np.diff(inertias)
                optimal_k = np.argmin(diffs) + 2  # +2 because range starts at 2
                
                print(f"      Spatial analysis: {len(all_predictions)} predictions")
                print(f"      Optimal clusters: {optimal_k}")
                print(f"      Spread: {np.std(coords, axis=0)} (grid units)")
    
    def transfer_validation(self):
        """Test transfer capability (placeholder for multiple TIF files)"""
        
        print(f"   🚀 Transfer capability test...")
        print(f"      Note: Requires multiple TIF files for full transfer test")
        print(f"      Current test: Robustness across different regions of same TIF")
        
        # Test on different regions of the same TIF
        regions_tested = 0
        region_performances = []
        
        for region in ['top-left', 'top-right', 'bottom-left', 'bottom-right', 'center']:
            # Modify the environment to focus on different TIF regions
            # This is a simplified transfer test
            performance = self._test_tif_region(region)
            if performance:
                region_performances.append(performance)
                regions_tested += 1
        
        if region_performances:
            transfer_metrics = {
                'regions_tested': regions_tested,
                'mean_region_performance': np.mean(region_performances),
                'region_consistency': 1.0 - np.std(region_performances)  # Higher = more consistent
            }
            
            self.validation_results['transfer_test'] = transfer_metrics
            
            print(f"      Regions tested: {regions_tested}")
            print(f"      Mean performance: {transfer_metrics['mean_region_performance']:.3f}")
            print(f"      Consistency: {transfer_metrics['region_consistency']:.3f}")
    
    def _test_tif_region(self, region: str) -> Optional[float]:
        """Test performance on specific TIF region"""
        # Simplified implementation - would be expanded for full transfer testing
        try:
            # Test a few crops and return average similarity
            similarities = []
            for _ in range(3):
                tif_image, crop_image, state = self.env.reset()
                locations, probabilities, _ = self.agent.select_top3_actions(tif_image, crop_image)
                reward, sims = self.env.calculate_reward(locations, probabilities)
                similarities.extend(sims)
            
            return np.mean(similarities)
        except:
            return None
    
    def generate_validation_report(self):
        """Generate comprehensive validation report"""
        
        print(f"\n📋" + "="*60 + "📋")
        print("    VALIDATION REPORT")
        print(f"📋" + "="*60 + "📋")
        
        # Primary metrics from ground truth validation
        if 'ground_truth_validation' in self.validation_results:
            gt_metrics = self.validation_results['ground_truth_validation']['metrics']
            
            print(f"\n🎯 SPATIAL ACCURACY (GROUND TRUTH):")
            print(f"   Crops tested: {gt_metrics['total_tested']}")
            print(f"   Top-1 Spatial Accuracy: {gt_metrics['top1_accuracy']:.1%}")
            print(f"   Within 1 grid cell (±{self.env.grid_size} pixels): {gt_metrics['within_1_grid']:.1%}")
            print(f"   Within 2 grid cells: {gt_metrics['within_2_grid']:.1%}")
            print(f"   Within 50 pixels: {gt_metrics['within_50_pixels']:.1%}")
            print(f"   Within 100 pixels: {gt_metrics['within_100_pixels']:.1%}")
            print(f"   Mean distance error: {gt_metrics['mean_distance_pixel']:.1f} pixels")
            print(f"   Median distance error: {gt_metrics['median_distance_pixel']:.1f} pixels")
            
            # Primary performance grade based on spatial accuracy
            spatial_accuracy_50px = gt_metrics['within_50_pixels']
            
            if spatial_accuracy_50px >= 0.8:
                grade = "🏆 EXCELLENT"
                interpretation = "Agent accurately localizes crops"
            elif spatial_accuracy_50px >= 0.6:
                grade = "🥇 GOOD"
                interpretation = "Agent shows good spatial understanding"
            elif spatial_accuracy_50px >= 0.4:
                grade = "🥈 FAIR"
                interpretation = "Agent has basic localization ability"
            else:
                grade = "🥉 NEEDS IMPROVEMENT"
                interpretation = "Agent needs more training"
            
            print(f"\n📊 SPATIAL ACCURACY GRADE: {grade}")
            print(f"   {interpretation}")
        
        # Secondary metrics from similarity validation
        if 'quantitative_metrics' in self.validation_results:
            sim_metrics = self.validation_results['quantitative_metrics']
            
            print(f"\n🎨 SIMILARITY PERFORMANCE:")
            print(f"   Mean Similarity Score: {sim_metrics['mean_top1_similarity']:.3f}")
            print(f"   Similarity Success Rate (>0.7): {sim_metrics['top1_threshold_70']:.1%}")
            print(f"   Confidence Calibration: {sim_metrics['confidence_correlation']:.3f}")
        
        # Performance insights
        print(f"\n💡 VALIDATION INSIGHTS:")
        
        if 'ground_truth_validation' in self.validation_results:
            gt_metrics = self.validation_results['ground_truth_validation']['metrics']
            
            # Spatial accuracy insights
            mean_distance = gt_metrics['mean_distance_pixel']
            median_distance = gt_metrics['median_distance_pixel']
            
            if mean_distance <= 50:
                print(f"   ✅ Excellent spatial precision (mean error: {mean_distance:.1f}px)")
            elif mean_distance <= 100:
                print(f"   ⚠️ Good spatial precision (mean error: {mean_distance:.1f}px)")
            else:
                print(f"   ❌ Poor spatial precision (mean error: {mean_distance:.1f}px)")
            
            # Consistency check
            if abs(mean_distance - median_distance) < 20:
                print(f"   ✅ Consistent performance across crops")
            else:
                print(f"   ⚠️ Inconsistent performance (some crops much harder)")
            
            # Ranking analysis
            rank_dist = gt_metrics['rank_distribution']
            if rank_dist['rank_0'] > 0.6:
                print(f"   ✅ Agent confidently picks best locations first")
            else:
                print(f"   ⚠️ Agent ranking could be improved")
        
        # Cross-validation insights
        if 'cross_validation' in self.validation_results:
            cv_results = self.validation_results['cross_validation']
            if cv_results:
                print(f"   📊 Performance varies by altitude/crop type")
                best_altitude = max(cv_results.keys(), key=lambda k: cv_results[k]['mean_similarity'])
                worst_altitude = min(cv_results.keys(), key=lambda k: cv_results[k]['mean_similarity'])
                print(f"      Best: {best_altitude}m altitude")
                print(f"      Challenging: {worst_altitude}m altitude")
        
        # Synthetic test insights
        if 'synthetic_tests' in self.validation_results:
            synthetic = self.validation_results['synthetic_tests']
            if synthetic['identity_success_rate'] > 0.8:
                print(f"   ✅ Strong identity detection (perfect match capability)")
            else:
                print(f"   ⚠️ Identity detection needs improvement")
        
        print(f"\n🔄 VALIDATION SUMMARY:")
        
        if 'ground_truth_validation' in self.validation_results:
            gt_metrics = self.validation_results['ground_truth_validation']['metrics']
            
            print(f"   🎯 SPATIAL LOCALIZATION:")
            print(f"      • Can find location within 50px: {gt_metrics['within_50_pixels']:.0%} of the time")
            print(f"      • Average error: {gt_metrics['mean_distance_pixel']:.0f} pixels")
            print(f"      • Best prediction is correct: {gt_metrics['top1_accuracy']:.0%} of the time")
            
            # Real-world interpretation
            # Assuming 10km TIF is about 1000 pixels wide, each pixel ≈ 10 meters
            tif_width = self.env.tif_image.shape[1]
            approx_meters_per_pixel = 10000 / tif_width  # 10km / width
            error_meters = gt_metrics['mean_distance_pixel'] * approx_meters_per_pixel
            
            print(f"   🌍 REAL-WORLD INTERPRETATION:")
            print(f"      • Average localization error: ~{error_meters:.0f} meters")
            print(f"      • Grid cell size: ~{approx_meters_per_pixel * (tif_width / self.env.grid_size):.0f} meters")
        
        print(f"\n📈 NEXT STEPS:")
        
        if 'ground_truth_validation' in self.validation_results:
            gt_metrics = self.validation_results['ground_truth_validation']['metrics']
            
            if gt_metrics['within_50_pixels'] < 0.5:
                print(f"   • Train for more episodes to improve spatial accuracy")
                print(f"   • Consider smaller grid size for finer localization")
                print(f"   • Review reward function balance")
            
            if gt_metrics['top1_accuracy'] < 0.4:
                print(f"   • Improve confidence calibration")
                print(f"   • Adjust ranking rewards")
            
            if gt_metrics['mean_distance_pixel'] > 200:
                print(f"   • Consider curriculum learning (easy crops first)")
                print(f"   • Increase training data diversity")
        
        print(f"   • Test on additional TIF files for generalization")
        print(f"   • Collect expert human evaluations")
        print(f"   • Consider ensemble methods for improved accuracy")
        
        # Save comprehensive validation results
        validation_file = "comprehensive_validation_results.json"
        with open(validation_file, 'w') as f:
            # Convert numpy types for JSON serialization
            json_results = {}
            for key, value in self.validation_results.items():
                if key == 'ground_truth_validation' and 'detailed_results' in value:
                    # Skip detailed results for JSON (too large), keep metrics
                    json_results[key] = {'metrics': value['metrics']}
                elif isinstance(value, dict):
                    json_results[key] = {k: float(v) if isinstance(v, np.number) else v 
                                       for k, v in value.items()}
                else:
                    json_results[key] = value
            
            json.dump(json_results, f, indent=2, default=str)
        
        print(f"\n💾 Comprehensive validation results saved to: {validation_file}")
        
        return grade, gt_metrics if 'ground_truth_validation' in self.validation_results else None

def validate_drone_agent(trainer):
    """
    Main validation function
    """
    
    print(f"🔍 Starting comprehensive validation...")
    
    validator = DroneLocalizationValidator(trainer, validation_crops=15)
    results = validator.run_comprehensive_validation()
    
    return validator, results

if __name__ == "__main__":
    print("🔍 DRONE LOCALIZATION VALIDATION FRAMEWORK")
    print("="*50)
    print()
    print("🎯 VALIDATION METHODS:")
    print("   🥇 Ground truth validation (using center_pixel coordinates)")
    print("   • Visual validation (human-interpretable)")
    print("   • Quantitative metrics (success rates, correlations)")
    print("   • Synthetic test cases (known ground truth)")
    print("   • Cross-validation (by altitude/crop type)")
    print("   • Clustering analysis (spatial coherence)")
    print("   • Transfer testing (generalization)")
    print()
    print("📊 GROUND TRUTH METRICS:")
    print("   • Spatial accuracy (distance to true location)")
    print("   • Top-1/Top-3 localization success")
    print("   • Real-world error in meters")
    print("   • Ranking quality analysis")
    print()
    print("📊 TO RUN:")
    print("   # After training your agent:")
    print("   validator, results = validate_drone_agent(trainer)")
    print()
    print("🏆 Provides definitive spatial accuracy assessment!")

🔍 DRONE LOCALIZATION VALIDATION FRAMEWORK

🎯 VALIDATION METHODS:
   🥇 Ground truth validation (using center_pixel coordinates)
   • Visual validation (human-interpretable)
   • Quantitative metrics (success rates, correlations)
   • Synthetic test cases (known ground truth)
   • Cross-validation (by altitude/crop type)
   • Clustering analysis (spatial coherence)
   • Transfer testing (generalization)

📊 GROUND TRUTH METRICS:
   • Spatial accuracy (distance to true location)
   • Top-1/Top-3 localization success
   • Real-world error in meters
   • Ranking quality analysis

📊 TO RUN:
   # After training your agent:
   validator, results = validate_drone_agent(trainer)

🏆 Provides definitive spatial accuracy assessment!


In [2]:
validator, results = validate_drone_agent(trainer)

NameError: name 'trainer' is not defined