# Data Exploration: Global Wheat Dataset Analysis

**CBAM-STN-TPS-YOLO: Enhancing Agricultural Object Detection - GlobalWheat Focus**

**Authors:** Satvik Praveen, Yoonsung Jung  
**Institution:** Texas A&M University  
**Course:** Computer Vision and Deep Learning  
**Date:** April 2025

## Overview

This notebook provides comprehensive data exploration and analysis of the Global Wheat Dataset for agricultural object detection. We focus on analyzing wheat head characteristics, density variations, field conditions, and spatial distributions to optimize CBAM-STN-TPS-YOLO training for wheat detection tasks.

## Key Objectives
1. Load and analyze Global Wheat dataset structure and composition
2. Examine wheat head distribution patterns and density variations
3. Analyze bounding box characteristics specific to wheat heads
4. Explore field condition variations and environmental factors
5. Assess wheat head clustering and overlapping patterns
6. Evaluate dataset quality and identify wheat-specific challenges
7. Generate comprehensive visualizations and wheat-focused summary reports

## 1. Setup and Imports

In [None]:
"""
Enhanced setup and imports for CBAM-STN-TPS-YOLO Global Wheat Dataset Analysis
Comprehensive environment setup with error handling and device optimization
"""

# Standard imports
import os
import sys
import warnings
import logging
from pathlib import Path
from typing import Dict, List, Any, Optional, Tuple
from datetime import datetime
from collections import defaultdict, Counter

# Scientific computing
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm.auto import tqdm

# Computer vision and ML
import cv2
from scipy.spatial.distance import pdist, squareform
from sklearn.cluster import DBSCAN
from sklearn.preprocessing import StandardScaler

# PyTorch ecosystem
import torch
import torch.nn as nn
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

# Data handling
import json
from PIL import Image

# Project imports with error handling
try:
    # Core model components
    sys.path.append('..')
    from src.data.dataset import GlobalWheatDataset, create_agricultural_dataloader
    from src.utils.visualization import Visualizer, plot_training_curves, visualize_predictions
    from src.utils.evaluation import ModelEvaluator, calculate_model_complexity
    from src.utils.config_validator import load_and_validate_config
    
    print("✅ All project imports successful")
    
except ImportError as e:
    print(f"⚠️ Some project imports failed: {e}")
    print("Creating fallback implementations...")
    
    # Fallback implementations for missing modules
    class DummyVisualizer:
        def __init__(self):
            pass
        def plot_training_curves(self, *args, **kwargs):
            pass
    
    Visualizer = DummyVisualizer
    
    # Create dummy dataset class if not available
    if 'GlobalWheatDataset' not in locals():
        class GlobalWheatDataset:
            def __init__(self, data_dir, split='train'):
                self.data_dir = data_dir
                self.split = split
                self.class_names = ['wheat_head']
                self._create_dummy_data()
            
            def _create_dummy_data(self):
                # Create realistic wheat dataset simulation
                np.random.seed(42)
                self.length = 200 if self.split == 'train' else 50
            
            def __len__(self):
                return self.length
            
            def __getitem__(self, idx):
                # Generate realistic wheat field image
                image = torch.randn(3, 1024, 1024) * 0.3 + 0.5  # More realistic range
                
                # Generate wheat heads with field-like distribution
                num_heads = np.random.poisson(18)  # Average 18 heads per image
                num_heads = max(1, min(num_heads, 40))  # Constrain range
                
                targets = []
                for _ in range(num_heads):
                    # Create clusters of wheat heads (more realistic)
                    if np.random.random() < 0.7 and targets:  # 70% chance of clustering
                        # Place near existing head
                        existing_head = targets[np.random.randint(len(targets))]
                        x = existing_head[1] + np.random.normal(0, 0.1)
                        y = existing_head[2] + np.random.normal(0, 0.1)
                    else:
                        # Random placement
                        x = np.random.uniform(0.1, 0.9)
                        y = np.random.uniform(0.1, 0.9)
                    
                    # Wheat head sizes (small objects)
                    w = np.random.uniform(0.015, 0.06)  # Typical wheat head width
                    h = np.random.uniform(0.015, 0.06)  # Typical wheat head height
                    
                    # Ensure within bounds
                    x = np.clip(x, w/2, 1-w/2)
                    y = np.clip(y, h/2, 1-h/2)
                    
                    targets.append([0, x, y, w, h])  # class_id, x_center, y_center, width, height
                
                targets = torch.tensor(targets, dtype=torch.float32)
                path = f"dummy_wheat_{self.split}_{idx:06d}.jpg"
                
                return image, targets, path

# Setup logging for notebooks
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
logger = logging.getLogger(__name__)

# Enhanced plotting configuration
plt.style.use('seaborn-v0_8-darkgrid')
sns.set_palette("husl")
plt.rcParams.update({
    'figure.figsize': (12, 8),
    'font.size': 12,
    'axes.titlesize': 14,
    'axes.labelsize': 12,
    'xtick.labelsize': 10,
    'ytick.labelsize': 10,
    'legend.fontsize': 11,
    'figure.titlesize': 16,
    'savefig.dpi': 300,
    'savefig.bbox': 'tight'
})

# Device configuration with automatic detection
def setup_device():
    """Setup optimal device configuration with detailed info"""
    if torch.cuda.is_available():
        device = torch.device('cuda')
        gpu_name = torch.cuda.get_device_name()
        gpu_memory = torch.cuda.get_device_properties(0).total_memory / 1e9
        print(f"✅ CUDA available: {gpu_name}")
        print(f"   Memory: {gpu_memory:.1f} GB")
        
        # Set memory growth to avoid fragmentation
        torch.cuda.empty_cache()
        
    elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
        device = torch.device('mps')
        print("✅ MPS (Apple Silicon) available")
    else:
        device = torch.device('cpu')
        print("⚠️ Using CPU - analysis will be slower")
    
    return device

device = setup_device()

# Set random seeds for reproducibility
def set_seed(seed=42):
    """Set random seed for reproducible results"""
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        torch.backends.cudnn.deterministic = True
        torch.backends.cudnn.benchmark = False
    print(f"🎯 Random seed set to {seed}")

set_seed(42)

# Enhanced results directory setup
notebook_results_dir = Path('../results/notebooks/globalwheat_exploration')
notebook_results_dir.mkdir(parents=True, exist_ok=True)

# Create subdirectories for organized results
subdirs = ['visualizations', 'data_analysis', 'statistics', 'samples']
for subdir in subdirs:
    (notebook_results_dir / subdir).mkdir(exist_ok=True)

# Suppress warnings for cleaner output
warnings.filterwarnings('ignore', category=UserWarning)
warnings.filterwarnings('ignore', category=FutureWarning)

print("🚀 Enhanced environment setup complete!")
print(f"📁 Results directory: {notebook_results_dir}")
print(f"🔧 Device: {device}")
print(f"📊 Plotting backend: {plt.get_backend()}")

## 2. Dataset Loading and Overview

In [None]:
"""
Enhanced dataset loading with comprehensive validation and error handling
"""

def validate_dataset_structure(dataset_path):
    """Validate Global Wheat dataset structure"""
    dataset_path = Path(dataset_path)
    
    required_structure = {
        'images': ['train', 'val', 'test'],
        'labels': ['train', 'val', 'test']
    }
    
    validation_results = {
        'valid': True,
        'issues': [],
        'statistics': {}
    }
    
    for main_dir, subdirs in required_structure.items():
        main_path = dataset_path / main_dir
        if not main_path.exists():
            validation_results['valid'] = False
            validation_results['issues'].append(f"Missing {main_dir} directory")
            continue
            
        for subdir in subdirs:
            sub_path = main_path / subdir
            if sub_path.exists():
                files = list(sub_path.glob('*'))
                validation_results['statistics'][f'{main_dir}_{subdir}'] = len(files)
            else:
                validation_results['issues'].append(f"Missing {main_dir}/{subdir} directory")
    
    return validation_results

# Load Global Wheat datasets with enhanced error handling
datasets = {}
dataset_stats = {}
dataset_paths = {
    'GlobalWheat': '../data/GlobalWheat',
    'GlobalWheat2020': '../data/GlobalWheat2020',  # Alternative path
    'wheat_data': '../data/wheat'  # Another common naming
}

print("🌾 Loading Global Wheat datasets...")

dataset_loaded = False
for dataset_name, dataset_path in dataset_paths.items():
    try:
        if Path(dataset_path).exists():
            print(f"📁 Found dataset at: {dataset_path}")
            
            # Validate dataset structure
            validation = validate_dataset_structure(dataset_path)
            if not validation['valid']:
                print(f"⚠️ Dataset structure issues: {validation['issues']}")
                continue
            
            # Load dataset splits
            try:
                datasets[f'{dataset_name}_train'] = GlobalWheatDataset(dataset_path, split='train')
                datasets[f'{dataset_name}_val'] = GlobalWheatDataset(dataset_path, split='val')
                
                train_size = len(datasets[f'{dataset_name}_train'])
                val_size = len(datasets[f'{dataset_name}_val'])
                
                print(f"✅ {dataset_name} loaded successfully:")
                print(f"   Train: {train_size} images")
                print(f"   Val: {val_size} images")
                
                # Test data loading
                try:
                    sample_image, sample_targets, sample_path = datasets[f'{dataset_name}_train'][0]
                    print(f"   Sample image shape: {sample_image.shape}")
                    print(f"   Sample targets: {len(sample_targets)} wheat heads")
                    dataset_loaded = True
                    break
                    
                except Exception as e:
                    print(f"⚠️ Data loading test failed: {e}")
                    continue
                    
            except Exception as e:
                print(f"⚠️ Failed to load {dataset_name}: {e}")
                continue
                
    except Exception as e:
        print(f"⚠️ Error checking {dataset_path}: {e}")
        continue

# Fallback to enhanced dummy dataset if no real data found
if not dataset_loaded:
    print("📝 No real Global Wheat dataset found. Creating enhanced dummy dataset...")
    
    class EnhancedDummyWheatDataset:
        def __init__(self, split='train'):
            self.split = split
            self.class_names = ['wheat_head']
            self.split_sizes = {'train': 180, 'val': 45, 'test': 30}
            self.length = self.split_sizes.get(split, 100)
            
            # Create realistic wheat distribution parameters
            self.wheat_params = {
                'avg_heads_per_image': 18,
                'head_size_range': (0.015, 0.06),
                'clustering_probability': 0.7,
                'field_brightness_range': (0.3, 0.8),
                'lighting_variations': ['uniform', 'shadows', 'overexposed']
            }
            
            print(f"📊 Created {split} split with {self.length} synthetic wheat field images")
        
        def __len__(self):
            return self.length
        
        def __getitem__(self, idx):
            np.random.seed(42 + idx)  # Reproducible but varied
            
            # Generate realistic field conditions
            brightness_base = np.random.uniform(*self.wheat_params['field_brightness_range'])
            lighting_condition = np.random.choice(self.wheat_params['lighting_variations'])
            
            # Create base field image with realistic characteristics
            if lighting_condition == 'shadows':
                # Create shadow patterns
                image = self._create_shadow_field(brightness_base)
            elif lighting_condition == 'overexposed':
                # Create overexposed areas
                image = self._create_bright_field(brightness_base)
            else:
                # Uniform lighting
                image = self._create_uniform_field(brightness_base)
            
            # Generate wheat heads with realistic clustering
            num_heads = max(1, np.random.poisson(self.wheat_params['avg_heads_per_image']))
            targets = self._generate_wheat_heads(num_heads)
            
            path = f"synthetic_wheat_{self.split}_{idx:06d}_{lighting_condition}.jpg"
            
            return image, targets, path
        
        def _create_uniform_field(self, brightness):
            """Create uniform lighting field"""
            image = torch.normal(brightness, 0.1, (3, 1024, 1024))
            # Add field texture
            noise = torch.randn(3, 1024, 1024) * 0.05
            return torch.clamp(image + noise, 0, 1)
        
        def _create_shadow_field(self, brightness):
            """Create field with shadow patterns"""
            image = torch.full((3, 1024, 1024), brightness)
            
            # Add diagonal shadow patterns
            x, y = torch.meshgrid(torch.linspace(0, 1, 1024), torch.linspace(0, 1, 1024), indexing='ij')
            shadow_pattern = torch.sin(x * 10) * torch.cos(y * 8) * 0.2
            
            for c in range(3):
                image[c] = torch.clamp(image[c] + shadow_pattern, 0.2, 1.0)
            
            return image
        
        def _create_bright_field(self, brightness):
            """Create field with bright spots (overexposure)"""
            image = torch.full((3, 1024, 1024), brightness)
            
            # Add bright circular regions
            x, y = torch.meshgrid(torch.linspace(-1, 1, 1024), torch.linspace(-1, 1, 1024), indexing='ij')
            bright_spots = torch.exp(-(x**2 + y**2) * 3) * 0.3
            
            for c in range(3):
                image[c] = torch.clamp(image[c] + bright_spots, 0, 1.0)
            
            return image
        
        def _generate_wheat_heads(self, num_heads):
            """Generate realistic wheat head distributions"""
            targets = []
            cluster_centers = []
            
            # Create 2-4 cluster centers
            num_clusters = np.random.randint(2, 5)
            for _ in range(num_clusters):
                center_x = np.random.uniform(0.2, 0.8)
                center_y = np.random.uniform(0.2, 0.8)
                cluster_centers.append((center_x, center_y))
            
            for i in range(num_heads):
                if np.random.random() < self.wheat_params['clustering_probability'] and cluster_centers:
                    # Place near a cluster center
                    center_x, center_y = cluster_centers[np.random.randint(len(cluster_centers))]
                    x = center_x + np.random.normal(0, 0.1)
                    y = center_y + np.random.normal(0, 0.1)
                else:
                    # Random placement
                    x = np.random.uniform(0.1, 0.9)
                    y = np.random.uniform(0.1, 0.9)
                
                # Wheat head size with some correlation to position (perspective effect)
                size_factor = 1.0 - (y * 0.3)  # Heads appear smaller towards top
                w = np.random.uniform(*self.wheat_params['head_size_range']) * size_factor
                h = np.random.uniform(*self.wheat_params['head_size_range']) * size_factor
                
                # Ensure bounds
                x = np.clip(x, w/2, 1-w/2)
                y = np.clip(y, h/2, 1-h/2)
                
                targets.append([0, x, y, w, h])
            
            return torch.tensor(targets, dtype=torch.float32)
    
    # Create enhanced dummy datasets
    datasets['GlobalWheat_train'] = EnhancedDummyWheatDataset('train')
    datasets['GlobalWheat_val'] = EnhancedDummyWheatDataset('val')

# Calculate comprehensive dataset statistics
for name, dataset in datasets.items():
    dataset_type = name.split('_')[-1]  # train/val/test
    base_name = name.replace(f'_{dataset_type}', '')
    
    if base_name not in dataset_stats:
        dataset_stats[base_name] = {
            'splits': {},
            'total_images': 0,
            'classes': getattr(dataset, 'class_names', ['wheat_head']),
            'num_classes': len(getattr(dataset, 'class_names', ['wheat_head']))
        }
    
    dataset_stats[base_name]['splits'][dataset_type] = len(dataset)
    dataset_stats[base_name]['total_images'] += len(dataset)

# Display enhanced dataset summary
print("\n" + "="*60)
print("🌾 GLOBAL WHEAT DATASET ANALYSIS OVERVIEW")
print("="*60)

for dataset_name, stats in dataset_stats.items():
    print(f"\n📊 Dataset: {dataset_name}")
    print(f"   🎯 Domain: Wheat head detection in agricultural fields")
    print(f"   📈 Total images: {stats['total_images']}")
    print(f"   🏷️ Classes ({stats['num_classes']}): {stats['classes']}")
    
    print(f"   📁 Data splits:")
    for split, size in stats['splits'].items():
        percentage = (size / stats['total_images']) * 100
        print(f"     {split.capitalize()}: {size} images ({percentage:.1f}%)")
    
    print(f"   🌾 Expected characteristics:")
    print(f"     • High-density wheat heads (10-30 per image)")
    print(f"     • Small object detection challenge")
    print(f"     • Overlapping and clustering patterns")
    print(f"     • Field condition variations")
    print(f"     • Multiple growth stages and orientations")

print(f"\n✅ Dataset loading complete. Ready for wheat-specific analysis!")
print(f"📁 Results will be saved to: {notebook_results_dir}")

## 3. Wheat Head Distribution Analysis

In [None]:
"""
Enhanced wheat head distribution analysis with advanced metrics and visualizations
"""

def analyze_wheat_distribution_advanced(dataset, dataset_name, max_samples=500):
    """Advanced analysis of wheat head distribution patterns"""
    
    wheat_metrics = {
        'basic_stats': {
            'total_wheat_heads': 0,
            'images_analyzed': 0,
            'heads_per_image': [],
            'valid_images': 0,
            'empty_images': 0
        },
        'density_analysis': {
            'density_categories': {'very_low': 0, 'low': 0, 'medium': 0, 'high': 0, 'very_high': 0},
            'density_histogram': [],
            'density_percentiles': {}
        },
        'spatial_patterns': {
            'x_coordinates': [],
            'y_coordinates': [],
            'spatial_clusters': [],
            'edge_proximity': []
        },
        'size_analysis': {
            'areas': [],
            'widths': [],
            'heights': [],
            'aspect_ratios': [],
            'size_categories': {'tiny': 0, 'small': 0, 'medium': 0, 'large': 0}
        },
        'field_characteristics': {
            'row_patterns': [],
            'clustering_strength': [],
            'uniformity_measures': []
        }
    }
    
    # Intelligent sampling for diverse representation
    total_size = len(dataset)
    sample_size = min(total_size, max_samples)
    
    # Use stratified sampling if possible
    if sample_size < total_size:
        # Sample evenly across the dataset
        step_size = total_size // sample_size
        indices = list(range(0, total_size, step_size))[:sample_size]
    else:
        indices = list(range(total_size))
    
    print(f"🌾 Analyzing wheat distribution in {sample_size} images from {dataset_name}...")
    
    # Progress tracking
    failed_loads = 0
    processing_errors = 0
    
    for idx in tqdm(indices, desc="Processing images"):
        try:
            image, targets, path = dataset[idx]
            
            # Validate image and targets
            if image is None or targets is None:
                failed_loads += 1
                continue
                
            wheat_metrics['basic_stats']['images_analyzed'] += 1
            
            # Handle empty images
            if targets.numel() == 0 or len(targets) == 0:
                wheat_metrics['basic_stats']['heads_per_image'].append(0)
                wheat_metrics['basic_stats']['empty_images'] += 1
                wheat_metrics['density_analysis']['density_categories']['very_low'] += 1
                continue
            
            wheat_metrics['basic_stats']['valid_images'] += 1
            num_heads = len(targets)
            wheat_metrics['basic_stats']['total_wheat_heads'] += num_heads
            wheat_metrics['basic_stats']['heads_per_image'].append(num_heads)
            
            # Enhanced density categorization (wheat-specific thresholds)
            if num_heads == 0:
                wheat_metrics['density_analysis']['density_categories']['very_low'] += 1
            elif num_heads <= 5:
                wheat_metrics['density_analysis']['density_categories']['low'] += 1
            elif num_heads <= 15:
                wheat_metrics['density_analysis']['density_categories']['medium'] += 1
            elif num_heads <= 30:
                wheat_metrics['density_analysis']['density_categories']['high'] += 1
            else:
                wheat_metrics['density_analysis']['density_categories']['very_high'] += 1
            
            wheat_metrics['density_analysis']['density_histogram'].append(num_heads)
            
            # Process individual wheat heads
            wheat_positions = []
            for target in targets:
                try:
                    if len(target) >= 5:
                        cls, x_center, y_center, width, height = target[:5]
                        
                        # Spatial analysis
                        wheat_metrics['spatial_patterns']['x_coordinates'].append(float(x_center))
                        wheat_metrics['spatial_patterns']['y_coordinates'].append(float(y_center))
                        wheat_positions.append((float(x_center), float(y_center)))
                        
                        # Edge proximity analysis
                        edge_distance = min(x_center, y_center, 1-x_center, 1-y_center)
                        wheat_metrics['spatial_patterns']['edge_proximity'].append(edge_distance)
                        
                        # Size analysis
                        area = float(width * height)
                        wheat_metrics['size_analysis']['areas'].append(area)
                        wheat_metrics['size_analysis']['widths'].append(float(width))
                        wheat_metrics['size_analysis']['heights'].append(float(height))
                        
                        # Aspect ratio
                        if height > 0:
                            aspect_ratio = float(width / height)
                            wheat_metrics['size_analysis']['aspect_ratios'].append(aspect_ratio)
                        
                        # Size categorization (refined wheat-specific thresholds)
                        if area < 0.0003:  # Very tiny
                            wheat_metrics['size_analysis']['size_categories']['tiny'] += 1
                        elif area < 0.0015:  # Small
                            wheat_metrics['size_analysis']['size_categories']['small'] += 1
                        elif area < 0.006:  # Medium
                            wheat_metrics['size_analysis']['size_categories']['medium'] += 1
                        else:  # Large
                            wheat_metrics['size_analysis']['size_categories']['large'] += 1
                
                except Exception as e:
                    processing_errors += 1
                    continue
            
            # Analyze spatial clustering for this image
            if len(wheat_positions) > 2:
                try:
                    positions_array = np.array(wheat_positions)
                    
                    # Calculate clustering strength using DBSCAN
                    from sklearn.cluster import DBSCAN
                    clustering = DBSCAN(eps=0.08, min_samples=2).fit(positions_array)
                    n_clusters = len(set(clustering.labels_)) - (1 if -1 in clustering.labels_ else 0)
                    clustering_strength = n_clusters / len(wheat_positions) if len(wheat_positions) > 0 else 0
                    
                    wheat_metrics['field_characteristics']['clustering_strength'].append(clustering_strength)
                    wheat_metrics['spatial_patterns']['spatial_clusters'].append(n_clusters)
                    
                    # Uniformity measure (coefficient of variation of distances)
                    if len(positions_array) > 1:
                        distances = pdist(positions_array)
                        uniformity = np.std(distances) / np.mean(distances) if np.mean(distances) > 0 else 0
                        wheat_metrics['field_characteristics']['uniformity_measures'].append(uniformity)
                
                except Exception as e:
                    processing_errors += 1
                    continue
                    
        except Exception as e:
            failed_loads += 1
            continue
    
    # Calculate advanced statistics
    if wheat_metrics['basic_stats']['heads_per_image']:
        heads_array = np.array(wheat_metrics['basic_stats']['heads_per_image'])
        wheat_metrics['density_analysis']['density_percentiles'] = {
            'p25': float(np.percentile(heads_array, 25)),
            'p50': float(np.percentile(heads_array, 50)),
            'p75': float(np.percentile(heads_array, 75)),
            'p90': float(np.percentile(heads_array, 90)),
            'p95': float(np.percentile(heads_array, 95))
        }
    
    # Report processing statistics
    print(f"   ✅ Successfully processed: {wheat_metrics['basic_stats']['images_analyzed']} images")
    print(f"   ⚠️ Failed loads: {failed_loads}")
    print(f"   ⚠️ Processing errors: {processing_errors}")
    
    return wheat_metrics

# Analyze all datasets with enhanced metrics
wheat_distribution_results = {}

for name, dataset in datasets.items():
    print(f"\n{'='*50}")
    wheat_metrics = analyze_wheat_distribution_advanced(dataset, name, max_samples=400)
    wheat_distribution_results[name] = wheat_metrics
    
    # Display comprehensive results
    basic_stats = wheat_metrics['basic_stats']
    density_stats = wheat_metrics['density_analysis']
    size_stats = wheat_metrics['size_analysis']
    
    print(f"\n🌾 {name} - COMPREHENSIVE WHEAT ANALYSIS:")
    print(f"   📊 Basic Statistics:")
    print(f"     Total wheat heads: {basic_stats['total_wheat_heads']:,}")
    print(f"     Images analyzed: {basic_stats['images_analyzed']}")
    print(f"     Valid images: {basic_stats['valid_images']}")
    print(f"     Empty images: {basic_stats['empty_images']}")
    
    if basic_stats['heads_per_image']:
        heads_array = np.array(basic_stats['heads_per_image'])
        print(f"   📈 Density Statistics:")
        print(f"     Mean heads/image: {np.mean(heads_array):.2f} ± {np.std(heads_array):.2f}")
        print(f"     Median heads/image: {np.median(heads_array):.1f}")
        print(f"     Range: {np.min(heads_array)} - {np.max(heads_array)} heads")
        
        print(f"   📊 Density Distribution:")
        total_images = sum(density_stats['density_categories'].values())
        for category, count in density_stats['density_categories'].items():
            percentage = (count / total_images * 100) if total_images > 0 else 0
            print(f"     {category.replace('_', ' ').title()}: {count} ({percentage:.1f}%)")
        
        print(f"   📏 Size Analysis:")
        if size_stats['areas']:
            areas_array = np.array(size_stats['areas'])
            print(f"     Mean area: {np.mean(areas_array):.6f} ± {np.std(areas_array):.6f}")
            print(f"     Median area: {np.median(areas_array):.6f}")
            
            total_heads = sum(size_stats['size_categories'].values())
            print(f"   🎯 Size Categories:")
            for category, count in size_stats['size_categories'].items():
                percentage = (count / total_heads * 100) if total_heads > 0 else 0
                print(f"     {category.title()}: {count} ({percentage:.1f}%)")
        
        # Spatial analysis summary
        if wheat_metrics['field_characteristics']['clustering_strength']:
            avg_clustering = np.mean(wheat_metrics['field_characteristics']['clustering_strength'])
            print(f"   🔗 Spatial Characteristics:")
            print(f"     Average clustering strength: {avg_clustering:.3f}")
            
            if wheat_metrics['field_characteristics']['uniformity_measures']:
                avg_uniformity = np.mean(wheat_metrics['field_characteristics']['uniformity_measures'])
                print(f"     Field uniformity measure: {avg_uniformity:.3f}")

# Create enhanced visualization dashboard
fig = plt.figure(figsize=(20, 16))
gs = fig.add_gridspec(4, 3, hspace=0.3, wspace=0.3)

dataset_colors = plt.cm.Set1(np.linspace(0, 1, len(wheat_distribution_results)))

for idx, (dataset_name, metrics) in enumerate(wheat_distribution_results.items()):
    color = dataset_colors[idx]
    
    # 1. Distribution of heads per image (top row, left)
    ax1 = fig.add_subplot(gs[0, 0])
    if metrics['basic_stats']['heads_per_image']:
        ax1.hist(metrics['basic_stats']['heads_per_image'], bins=30, alpha=0.7, 
                label=dataset_name, color=color, density=True)
        ax1.axvline(np.mean(metrics['basic_stats']['heads_per_image']), 
                   color=color, linestyle='--', alpha=0.8)
    ax1.set_title('Wheat Heads per Image Distribution', fontweight='bold')
    ax1.set_xlabel('Number of Wheat Heads')
    ax1.set_ylabel('Density')
    ax1.legend()
    ax1.grid(True, alpha=0.3)

    # 2. Density categories (top row, middle)
    ax2 = fig.add_subplot(gs[0, 1])
    categories = list(metrics['density_analysis']['density_categories'].keys())
    values = list(metrics['density_analysis']['density_categories'].values())
    bars = ax2.bar([cat.replace('_', '\n') for cat in categories], values, 
                   alpha=0.8, color=plt.cm.viridis(np.linspace(0, 1, len(categories))))
    ax2.set_title(f'Density Categories - {dataset_name}', fontweight='bold')
    ax2.set_ylabel('Number of Images')
    
    # Add value labels on bars
    for bar, value in zip(bars, values):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01*max(values),
                f'{value}', ha='center', va='bottom', fontweight='bold')

    # 3. Size distribution (top row, right)
    ax3 = fig.add_subplot(gs[0, 2])
    if metrics['size_analysis']['areas']:
        areas = np.array(metrics['size_analysis']['areas'])
        ax3.hist(areas, bins=30, alpha=0.7, color=color, edgecolor='black')
        ax3.axvline(np.mean(areas), color='red', linestyle='--', alpha=0.8, label='Mean')
        ax3.axvline(np.median(areas), color='green', linestyle='--', alpha=0.8, label='Median')
    ax3.set_title(f'Wheat Head Area Distribution - {dataset_name}', fontweight='bold')
    ax3.set_xlabel('Area (normalized)')
    ax3.set_ylabel('Frequency')
    ax3.legend()
    ax3.grid(True, alpha=0.3)

# 4. Spatial distribution heatmap (second row, spanning two columns)
ax4 = fig.add_subplot(gs[1, :2])
for idx, (dataset_name, metrics) in enumerate(wheat_distribution_results.items()):
    if metrics['spatial_patterns']['x_coordinates']:
        x_coords = metrics['spatial_patterns']['x_coordinates']
        y_coords = metrics['spatial_patterns']['y_coordinates']
        
        # Create 2D histogram
        hist, xedges, yedges = np.histogram2d(x_coords, y_coords, bins=25, range=[[0, 1], [0, 1]])
        
        # Overlay heatmaps with different alpha values
        im = ax4.imshow(hist.T, origin='lower', extent=[0, 1, 0, 1], 
                       cmap=plt.cm.Reds, alpha=0.6, interpolation='bilinear')

ax4.set_title('Spatial Distribution Heatmap (All Datasets)', fontweight='bold')
ax4.set_xlabel('X Coordinate (normalized)')
ax4.set_ylabel('Y Coordinate (normalized)')

# 5. Clustering analysis (second row, right)
ax5 = fig.add_subplot(gs[1, 2])
clustering_data = []
clustering_labels = []
for dataset_name, metrics in wheat_distribution_results.items():
    if metrics['field_characteristics']['clustering_strength']:
        clustering_data.extend(metrics['field_characteristics']['clustering_strength'])
        clustering_labels.extend([dataset_name] * len(metrics['field_characteristics']['clustering_strength']))

if clustering_data:
    df_clustering = pd.DataFrame({'Clustering_Strength': clustering_data, 'Dataset': clustering_labels})
    sns.boxplot(data=df_clustering, x='Dataset', y='Clustering_Strength', ax=ax5)
    ax5.set_title('Clustering Strength Distribution', fontweight='bold')
    ax5.set_ylabel('Clustering Strength')
    ax5.tick_params(axis='x', rotation=45)

# 6. Size category comparison (third row, left)
ax6 = fig.add_subplot(gs[2, 0])
size_categories = ['tiny', 'small', 'medium', 'large']
x_pos = np.arange(len(size_categories))
width = 0.8 / len(wheat_distribution_results)

for idx, (dataset_name, metrics) in enumerate(wheat_distribution_results.items()):
    values = [metrics['size_analysis']['size_categories'][cat] for cat in size_categories]
    ax6.bar(x_pos + idx*width, values, width, label=dataset_name, alpha=0.8)

ax6.set_title('Size Category Comparison', fontweight='bold')
ax6.set_xlabel('Size Category')
ax6.set_ylabel('Count')
ax6.set_xticks(x_pos + width/2)
ax6.set_xticklabels(size_categories)
ax6.legend()

# 7. Percentile analysis (third row, middle)
ax7 = fig.add_subplot(gs[2, 1])
for idx, (dataset_name, metrics) in enumerate(wheat_distribution_results.items()):
    if 'density_percentiles' in metrics['density_analysis']:
        percentiles = metrics['density_analysis']['density_percentiles']
        if percentiles:
            p_labels = list(percentiles.keys())
            p_values = list(percentiles.values())
            ax7.plot(p_labels, p_values, marker='o', label=dataset_name, linewidth=2, markersize=8)

ax7.set_title('Density Percentile Analysis', fontweight='bold')
ax7.set_xlabel('Percentile')
ax7.set_ylabel('Number of Heads')
ax7.legend()
ax7.grid(True, alpha=0.3)

# 8. Statistical summary (third row, right)
ax8 = fig.add_subplot(gs[2, 2])
ax8.axis('off')

summary_text = "📊 STATISTICAL SUMMARY\n\n"
for dataset_name, metrics in wheat_distribution_results.items():
    basic_stats = metrics['basic_stats']
    if basic_stats['heads_per_image']:
        heads_array = np.array(basic_stats['heads_per_image'])
        summary_text += f"🌾 {dataset_name}:\n"
        summary_text += f"  Images: {basic_stats['images_analyzed']}\n"
        summary_text += f"  Total heads: {basic_stats['total_wheat_heads']:,}\n"
        summary_text += f"  Mean/image: {np.mean(heads_array):.1f}\n"
        summary_text += f"  Std dev: {np.std(heads_array):.1f}\n"
        summary_text += f"  Max density: {np.max(heads_array)}\n\n"

ax8.text(0.05, 0.95, summary_text, transform=ax8.transAxes, fontsize=10,
         verticalalignment='top', fontfamily='monospace',
         bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))

plt.suptitle('🌾 COMPREHENSIVE GLOBAL WHEAT DISTRIBUTION ANALYSIS', fontsize=16, fontweight='bold')
plt.savefig(notebook_results_dir / 'visualizations' / 'comprehensive_wheat_distribution.png', 
            dpi=300, bbox_inches='tight')
plt.show()

# Save detailed results
analysis_summary = {}
for dataset_name, metrics in wheat_distribution_results.items():
    analysis_summary[dataset_name] = {
        'basic_statistics': {
            'total_wheat_heads': metrics['basic_stats']['total_wheat_heads'],
            'images_analyzed': metrics['basic_stats']['images_analyzed'],
            'valid_images': metrics['basic_stats']['valid_images'],
            'empty_images': metrics['basic_stats']['empty_images']
        },
        'density_statistics': {
            'mean_heads_per_image': float(np.mean(metrics['basic_stats']['heads_per_image'])) if metrics['basic_stats']['heads_per_image'] else 0,
            'std_heads_per_image': float(np.std(metrics['basic_stats']['heads_per_image'])) if metrics['basic_stats']['heads_per_image'] else 0,
            'median_heads_per_image': float(np.median(metrics['basic_stats']['heads_per_image'])) if metrics['basic_stats']['heads_per_image'] else 0,
            'percentiles': metrics['density_analysis']['density_percentiles'],
            'category_distribution': metrics['density_analysis']['density_categories']
        },
        'size_statistics': {
            'mean_area': float(np.mean(metrics['size_analysis']['areas'])) if metrics['size_analysis']['areas'] else 0,
            'std_area': float(np.std(metrics['size_analysis']['areas'])) if metrics['size_analysis']['areas'] else 0,
            'size_categories': metrics['size_analysis']['size_categories']
        },
        'spatial_statistics': {
            'clustering_strength': float(np.mean(metrics['field_characteristics']['clustering_strength'])) if metrics['field_characteristics']['clustering_strength'] else 0,
            'uniformity_measure': float(np.mean(metrics['field_characteristics']['uniformity_measures'])) if metrics['field_characteristics']['uniformity_measures'] else 0
        }
    }

with open(notebook_results_dir / 'data_analysis' / 'comprehensive_wheat_distribution.json', 'w') as f:
    json.dump(analysis_summary, f, indent=2)

print(f"\n💾 Comprehensive wheat distribution analysis saved!")
print(f"📁 Location: {notebook_results_dir / 'data_analysis' / 'comprehensive_wheat_distribution.json'}")

## 4. Wheat Head Clustering and Overlap Analysis

In [None]:
"""
Enhanced wheat head clustering and overlap analysis with advanced spatial algorithms
"""

def analyze_wheat_clustering_advanced(dataset, dataset_name, max_samples=300):
    """Advanced clustering analysis with multiple algorithms and metrics"""
    
    clustering_metrics = {
        'overlap_analysis': {
            'iou_scores': [],
            'overlap_pairs': [],
            'high_overlap_images': 0,
            'overlap_statistics': {}
        },
        'clustering_analysis': {
            'dbscan_results': [],
            'kmeans_results': [],
            'hierarchical_results': [],
            'silhouette_scores': []
        },
        'spatial_metrics': {
            'nearest_neighbor_distances': [],
            'density_gradients': [],
            'spatial_autocorrelation': [],
            'boundary_effects': []
        },
        'pattern_detection': {
            'row_patterns': 0,
            'circular_patterns': 0,
            'random_patterns': 0,
            'grid_patterns': 0
        },
        'density_analysis': {
            'local_densities': [],
            'density_hotspots': [],
            'crowding_indices': [],
            'isolation_scores': []
        },
        'geometric_features': {
            'convex_hull_areas': [],
            'bounding_box_ratios': [],
            'centroid_distances': [],
            'orientation_analyses': []
        }
    }
    
    sample_size = min(len(dataset), max_samples)
    indices = np.random.choice(len(dataset), sample_size, replace=False)
    
    print(f"🔍 Advanced clustering analysis on {sample_size} images from {dataset_name}...")
    
    processing_stats = {'success': 0, 'errors': 0, 'empty': 0}
    
    for i in tqdm(indices, desc="Analyzing clustering patterns"):
        try:
            image, targets, path = dataset[i]
            
            if targets.numel() == 0 or len(targets) < 2:
                processing_stats['empty'] += 1
                continue
                
            # Extract coordinates and create arrays
            coordinates = []
            boxes = []
            areas = []
            
            for target in targets:
                if len(target) >= 5:
                    cls, x_center, y_center, width, height = target[:5]
                    coordinates.append([float(x_center), float(y_center)])
                    
                    # Convert to corner coordinates for IoU calculation
                    x1 = float(x_center - width/2)
                    y1 = float(y_center - height/2)
                    x2 = float(x_center + width/2)
                    y2 = float(y_center + height/2)
                    boxes.append([x1, y1, x2, y2])
                    areas.append(float(width * height))
            
            if len(coordinates) < 2:
                processing_stats['empty'] += 1
                continue
                
            coordinates = np.array(coordinates)
            boxes = np.array(boxes)
            processing_stats['success'] += 1
            
            # 1. ADVANCED OVERLAP ANALYSIS
            iou_scores = []
            overlap_count = 0
            
            for j in range(len(boxes)):
                for k in range(j+1, len(boxes)):
                    box1, box2 = boxes[j], boxes[k]
                    
                    # Calculate IoU
                    x1_inter = max(box1[0], box2[0])
                    y1_inter = max(box1[1], box2[1])
                    x2_inter = min(box1[2], box2[2])
                    y2_inter = min(box1[3], box2[3])
                    
                    if x2_inter > x1_inter and y2_inter > y1_inter:
                        inter_area = (x2_inter - x1_inter) * (y2_inter - y1_inter)
                        box1_area = (box1[2] - box1[0]) * (box1[3] - box1[1])
                        box2_area = (box2[2] - box2[0]) * (box2[3] - box2[1])
                        union_area = box1_area + box2_area - inter_area
                        
                        if union_area > 0:
                            iou = inter_area / union_area
                            iou_scores.append(iou)
                            
                            if iou > 0.1:  # Significant overlap threshold
                                overlap_count += 1
            
            if iou_scores:
                clustering_metrics['overlap_analysis']['iou_scores'].extend(iou_scores)
                clustering_metrics['overlap_analysis']['overlap_pairs'].append(overlap_count)
                
                if max(iou_scores) > 0.3:  # High overlap threshold
                    clustering_metrics['overlap_analysis']['high_overlap_images'] += 1
            
            # 2. MULTI-ALGORITHM CLUSTERING ANALYSIS
            if len(coordinates) >= 3:
                
                # DBSCAN Clustering
                try:
                    from sklearn.cluster import DBSCAN
                    dbscan = DBSCAN(eps=0.08, min_samples=2)
                    dbscan_labels = dbscan.fit_predict(coordinates)
                    
                    n_clusters = len(set(dbscan_labels)) - (1 if -1 in dbscan_labels else 0)
                    n_noise = list(dbscan_labels).count(-1)
                    
                    clustering_metrics['clustering_analysis']['dbscan_results'].append({
                        'n_clusters': n_clusters,
                        'n_noise': n_noise,
                        'n_points': len(coordinates),
                        'cluster_ratio': n_clusters / len(coordinates) if len(coordinates) > 0 else 0
                    })
                    
                    # Calculate silhouette score if we have clusters
                    if n_clusters > 1 and n_noise < len(coordinates):
                        try:
                            from sklearn.metrics import silhouette_score
                            valid_labels = dbscan_labels[dbscan_labels != -1]
                            valid_coords = coordinates[dbscan_labels != -1]
                            
                            if len(set(valid_labels)) > 1 and len(valid_coords) > 1:
                                sil_score = silhouette_score(valid_coords, valid_labels)
                                clustering_metrics['clustering_analysis']['silhouette_scores'].append(sil_score)
                        except:
                            pass
                            
                except Exception as e:
                    processing_stats['errors'] += 1
                
                # K-means clustering (try different k values)
                try:
                    from sklearn.cluster import KMeans
                    optimal_k = min(5, len(coordinates)//2)
                    if optimal_k >= 2:
                        kmeans = KMeans(n_clusters=optimal_k, random_state=42, n_init=10)
                        kmeans_labels = kmeans.fit_predict(coordinates)
                        
                        clustering_metrics['clustering_analysis']['kmeans_results'].append({
                            'n_clusters': optimal_k,
                            'inertia': kmeans.inertia_,
                            'n_points': len(coordinates)
                        })
                except:
                    pass
            
            # 3. SPATIAL METRICS ANALYSIS
            
            # Nearest neighbor distances
            if len(coordinates) > 1:
                distances = pdist(coordinates)
                min_distances = []
                
                for j in range(len(coordinates)):
                    other_coords = np.delete(coordinates, j, axis=0)
                    if len(other_coords) > 0:
                        dists_to_point = np.sqrt(np.sum((other_coords - coordinates[j])**2, axis=1))
                        min_distances.append(np.min(dists_to_point))
                
                clustering_metrics['spatial_metrics']['nearest_neighbor_distances'].extend(min_distances)
            
            # Local density analysis
            if len(coordinates) >= 3:
                local_densities = []
                for j, point in enumerate(coordinates):
                    # Count points within radius
                    distances_from_point = np.sqrt(np.sum((coordinates - point)**2, axis=1))
                    neighbors_in_radius = np.sum(distances_from_point <= 0.1) - 1  # Exclude the point itself
                    local_densities.append(neighbors_in_radius)
                
                clustering_metrics['density_analysis']['local_densities'].extend(local_densities)
                
                # Identify density hotspots
                if local_densities:
                    density_threshold = np.percentile(local_densities, 75)
                    hotspots = sum(1 for d in local_densities if d >= density_threshold)
                    clustering_metrics['density_analysis']['density_hotspots'].append(hotspots)
            
            # 4. PATTERN DETECTION
            
            # Detect row patterns using line fitting
            if len(coordinates) >= 4:
                try:
                    # Try to fit lines to detect row patterns
                    from sklearn.linear_model import RANSACRegressor
                    
                    # Sort points by x-coordinate for row detection
                    sorted_indices = np.argsort(coordinates[:, 0])
                    sorted_coords = coordinates[sorted_indices]
                    
                    # Try to fit a line
                    ransac = RANSACRegressor(random_state=42)
                    ransac.fit(sorted_coords[:, 0].reshape(-1, 1), sorted_coords[:, 1])
                    
                    inlier_mask = ransac.inlier_mask_
                    inlier_ratio = np.sum(inlier_mask) / len(coordinates)
                    
                    if inlier_ratio > 0.6:  # Strong linear pattern
                        clustering_metrics['pattern_detection']['row_patterns'] += 1
                    elif inlier_ratio > 0.3:  # Some structure
                        clustering_metrics['pattern_detection']['grid_patterns'] += 1
                    else:
                        clustering_metrics['pattern_detection']['random_patterns'] += 1
                        
                except:
                    clustering_metrics['pattern_detection']['random_patterns'] += 1
            
            # 5. GEOMETRIC FEATURES
            
            if len(coordinates) >= 3:
                # Convex hull analysis
                try:
                    from scipy.spatial import ConvexHull
                    hull = ConvexHull(coordinates)
                    hull_area = hull.volume  # In 2D, volume is area
                    clustering_metrics['geometric_features']['convex_hull_areas'].append(hull_area)
                except:
                    pass
                
                # Bounding box analysis
                min_x, min_y = np.min(coordinates, axis=0)
                max_x, max_y = np.max(coordinates, axis=0)
                bbox_width = max_x - min_x
                bbox_height = max_y - min_y
                
                if bbox_height > 0:
                    bbox_ratio = bbox_width / bbox_height
                    clustering_metrics['geometric_features']['bounding_box_ratios'].append(bbox_ratio)
                
                # Centroid analysis
                centroid = np.mean(coordinates, axis=0)
                distances_to_centroid = np.sqrt(np.sum((coordinates - centroid)**2, axis=1))
                avg_distance_to_centroid = np.mean(distances_to_centroid)
                clustering_metrics['geometric_features']['centroid_distances'].append(avg_distance_to_centroid)
            
            # 6. BOUNDARY EFFECTS
            edge_threshold = 0.1
            near_edge_count = 0
            for coord in coordinates:
                if (coord[0] <= edge_threshold or coord[0] >= 1-edge_threshold or 
                    coord[1] <= edge_threshold or coord[1] >= 1-edge_threshold):
                    near_edge_count += 1
            
            boundary_ratio = near_edge_count / len(coordinates) if len(coordinates) > 0 else 0
            clustering_metrics['spatial_metrics']['boundary_effects'].append(boundary_ratio)
                
        except Exception as e:
            processing_stats['errors'] += 1
            continue
    
    # Calculate overlap statistics
    if clustering_metrics['overlap_analysis']['iou_scores']:
        iou_array = np.array(clustering_metrics['overlap_analysis']['iou_scores'])
        clustering_metrics['overlap_analysis']['overlap_statistics'] = {
            'mean_iou': float(np.mean(iou_array)),
            'std_iou': float(np.std(iou_array)),
            'max_iou': float(np.max(iou_array)),
            'high_overlap_ratio': float(np.sum(iou_array > 0.3) / len(iou_array))
        }
    
    print(f"   ✅ Processing stats: {processing_stats['success']} success, "
          f"{processing_stats['errors']} errors, {processing_stats['empty']} empty")
    
    return clustering_metrics

# Analyze clustering patterns with advanced algorithms
clustering_results = {}

for name, dataset in datasets.items():
    print(f"\n{'='*60}")
    clustering_metrics = analyze_wheat_clustering_advanced(dataset, name, max_samples=250)
    clustering_results[name] = clustering_metrics
    
    # Display comprehensive clustering analysis
    print(f"\n🔍 {name} - ADVANCED CLUSTERING ANALYSIS:")
    
    # Overlap analysis summary
    overlap_stats = clustering_metrics['overlap_analysis']['overlap_statistics']
    if overlap_stats:
        print(f"   📊 Overlap Analysis:")
        print(f"     Mean IoU: {overlap_stats['mean_iou']:.3f} ± {overlap_stats['std_iou']:.3f}")
        print(f"     Maximum IoU: {overlap_stats['max_iou']:.3f}")
        print(f"     High overlap ratio: {overlap_stats['high_overlap_ratio']:.3f}")
        print(f"     High overlap images: {clustering_metrics['overlap_analysis']['high_overlap_images']}")
    
    # Clustering analysis summary
    dbscan_results = clustering_metrics['clustering_analysis']['dbscan_results']
    if dbscan_results:
        avg_clusters = np.mean([r['n_clusters'] for r in dbscan_results])
        avg_noise = np.mean([r['n_noise'] for r in dbscan_results])
        print(f"   🔗 DBSCAN Clustering:")
        print(f"     Average clusters per image: {avg_clusters:.2f}")
        print(f"     Average noise points: {avg_noise:.2f}")
        
        if clustering_metrics['clustering_analysis']['silhouette_scores']:
            avg_silhouette = np.mean(clustering_metrics['clustering_analysis']['silhouette_scores'])
            print(f"     Average silhouette score: {avg_silhouette:.3f}")
    
    # Spatial metrics summary
    nn_distances = clustering_metrics['spatial_metrics']['nearest_neighbor_distances']
    if nn_distances:
        print(f"   📏 Spatial Metrics:")
        print(f"     Average nearest neighbor distance: {np.mean(nn_distances):.4f}")
        print(f"     Min nearest neighbor distance: {min(nn_distances):.4f}")
        
        boundary_effects = clustering_metrics['spatial_metrics']['boundary_effects']
        if boundary_effects:
            avg_boundary = np.mean(boundary_effects)
            print(f"     Average boundary effect ratio: {avg_boundary:.3f}")
    
    # Pattern detection summary
    pattern_counts = clustering_metrics['pattern_detection']
    total_patterns = sum(pattern_counts.values())
    if total_patterns > 0:
        print(f"   🎯 Pattern Detection:")
        for pattern, count in pattern_counts.items():
            percentage = (count / total_patterns) * 100
            print(f"     {pattern.replace('_', ' ').title()}: {count} ({percentage:.1f}%)")
    
    # Density analysis summary
    local_densities = clustering_metrics['density_analysis']['local_densities']
    if local_densities:
        print(f"   🌾 Density Analysis:")
        print(f"     Average local density: {np.mean(local_densities):.2f}")
        print(f"     Max local density: {max(local_densities)}")
        
        hotspots = clustering_metrics['density_analysis']['density_hotspots']
        if hotspots:
            print(f"     Average hotspots per image: {np.mean(hotspots):.2f}")

# Create comprehensive clustering visualization dashboard
fig = plt.figure(figsize=(24, 20))
gs = fig.add_gridspec(5, 4, hspace=0.4, wspace=0.3)

dataset_colors = plt.cm.Set1(np.linspace(0, 1, len(clustering_results)))

# 1. IoU Distribution (top row, first column)
ax1 = fig.add_subplot(gs[0, 0])
for idx, (dataset_name, metrics) in enumerate(clustering_results.items()):
    iou_scores = metrics['overlap_analysis']['iou_scores']
    if iou_scores:
        ax1.hist(iou_scores, bins=30, alpha=0.7, label=dataset_name, 
                color=dataset_colors[idx], density=True)
ax1.set_title('IoU Score Distribution', fontweight='bold')
ax1.set_xlabel('Intersection over Union')
ax1.set_ylabel('Density')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Cluster Count Distribution (top row, second column)
ax2 = fig.add_subplot(gs[0, 1])
for idx, (dataset_name, metrics) in enumerate(clustering_results.items()):
    dbscan_results = metrics['clustering_analysis']['dbscan_results']
    if dbscan_results:
        cluster_counts = [r['n_clusters'] for r in dbscan_results]
        ax2.hist(cluster_counts, bins=range(max(cluster_counts)+2), alpha=0.7,
                label=dataset_name, color=dataset_colors[idx])
ax2.set_title('Clusters per Image Distribution', fontweight='bold')
ax2.set_xlabel('Number of Clusters')
ax2.set_ylabel('Frequency')
ax2.legend()
ax2.grid(True, alpha=0.3)

# 3. Nearest Neighbor Distances (top row, third column)
ax3 = fig.add_subplot(gs[0, 2])
for idx, (dataset_name, metrics) in enumerate(clustering_results.items()):
    nn_distances = metrics['spatial_metrics']['nearest_neighbor_distances']
    if nn_distances:
        ax3.hist(nn_distances, bins=30, alpha=0.7, label=dataset_name,
                color=dataset_colors[idx], density=True)
ax3.set_title('Nearest Neighbor Distance Distribution', fontweight='bold')
ax3.set_xlabel('Distance')
ax3.set_ylabel('Density')
ax3.legend()
ax3.grid(True, alpha=0.3)

# 4. Pattern Detection Summary (top row, fourth column)
ax4 = fig.add_subplot(gs[0, 3])
pattern_data = defaultdict(list)
dataset_names = []
for dataset_name, metrics in clustering_results.items():
    dataset_names.append(dataset_name)
    patterns = metrics['pattern_detection']
    total = sum(patterns.values()) if sum(patterns.values()) > 0 else 1
    for pattern, count in patterns.items():
        pattern_data[pattern].append(count / total * 100)

x = np.arange(len(dataset_names))
width = 0.2
pattern_colors = plt.cm.Set2(np.linspace(0, 1, len(pattern_data)))

for idx, (pattern, percentages) in enumerate(pattern_data.items()):
    ax4.bar(x + idx*width, percentages, width, label=pattern.replace('_', ' ').title(),
           color=pattern_colors[idx], alpha=0.8)

ax4.set_title('Pattern Detection Results', fontweight='bold')
ax4.set_xlabel('Dataset')
ax4.set_ylabel('Percentage')
ax4.set_xticks(x + width * 1.5)
ax4.set_xticklabels([name.split('_')[0] for name in dataset_names], rotation=45)
ax4.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# 5. Local Density Analysis (second row, first two columns)
ax5 = fig.add_subplot(gs[1, :2])
density_data = []
density_labels = []
for dataset_name, metrics in clustering_results.items():
    local_densities = metrics['density_analysis']['local_densities']
    if local_densities:
        density_data.extend(local_densities)
        density_labels.extend([dataset_name] * len(local_densities))

if density_data:
    df_density = pd.DataFrame({'Local_Density': density_data, 'Dataset': density_labels})
    sns.violinplot(data=df_density, x='Dataset', y='Local_Density', ax=ax5)
    ax5.set_title('Local Density Distribution by Dataset', fontweight='bold')
    ax5.set_ylabel('Local Density (neighbors within radius)')
    ax5.tick_params(axis='x', rotation=45)

# 6. Overlap vs Clustering Relationship (second row, third column)
ax6 = fig.add_subplot(gs[1, 2])
for idx, (dataset_name, metrics) in enumerate(clustering_results.items()):
    overlap_pairs = metrics['overlap_analysis']['overlap_pairs']
    dbscan_results = metrics['clustering_analysis']['dbscan_results']
    
    if overlap_pairs and dbscan_results:
        cluster_counts = [r['n_clusters'] for r in dbscan_results]
        min_len = min(len(overlap_pairs), len(cluster_counts))
        
        ax6.scatter(overlap_pairs[:min_len], cluster_counts[:min_len], 
                   alpha=0.6, label=dataset_name, s=50, color=dataset_colors[idx])

ax6.set_title('Overlap vs Clustering Relationship', fontweight='bold')
ax6.set_xlabel('Number of Overlapping Pairs')
ax6.set_ylabel('Number of Clusters')
ax6.legend()
ax6.grid(True, alpha=0.3)

# 7. Boundary Effects Analysis (second row, fourth column)
ax7 = fig.add_subplot(gs[1, 3])
boundary_data = []
boundary_labels = []
for dataset_name, metrics in clustering_results.items():
    boundary_effects = metrics['spatial_metrics']['boundary_effects']
    if boundary_effects:
        boundary_data.extend(boundary_effects)
        boundary_labels.extend([dataset_name] * len(boundary_effects))

if boundary_data:
    df_boundary = pd.DataFrame({'Boundary_Ratio': boundary_data, 'Dataset': boundary_labels})
    sns.boxplot(data=df_boundary, x='Dataset', y='Boundary_Ratio', ax=ax7)
    ax7.set_title('Boundary Effects Distribution', fontweight='bold')
    ax7.set_ylabel('Ratio of Objects Near Boundary')
    ax7.tick_params(axis='x', rotation=45)

# 8. Geometric Features Analysis (third row, spanning two columns)
ax8 = fig.add_subplot(gs[2, :2])
for idx, (dataset_name, metrics) in enumerate(clustering_results.items()):
    bbox_ratios = metrics['geometric_features']['bounding_box_ratios']
    centroid_distances = metrics['geometric_features']['centroid_distances']
    
    if bbox_ratios and centroid_distances:
        min_len = min(len(bbox_ratios), len(centroid_distances))
        ax8.scatter(bbox_ratios[:min_len], centroid_distances[:min_len],
                   alpha=0.6, label=dataset_name, s=50, color=dataset_colors[idx])

ax8.set_title('Geometric Features: Bounding Box Ratio vs Centroid Distance', fontweight='bold')
ax8.set_xlabel('Bounding Box Aspect Ratio')
ax8.set_ylabel('Average Distance to Centroid')
ax8.legend()
ax8.grid(True, alpha=0.3)

# 9. Silhouette Score Analysis (third row, third column)
ax9 = fig.add_subplot(gs[2, 2])
silhouette_data = []
silhouette_labels = []
for dataset_name, metrics in clustering_results.items():
    silhouette_scores = metrics['clustering_analysis']['silhouette_scores']
    if silhouette_scores:
        silhouette_data.extend(silhouette_scores)
        silhouette_labels.extend([dataset_name] * len(silhouette_scores))

if silhouette_data:
    df_silhouette = pd.DataFrame({'Silhouette_Score': silhouette_data, 'Dataset': silhouette_labels})
    sns.violinplot(data=df_silhouette, x='Dataset', y='Silhouette_Score', ax=ax9)
    ax9.set_title('Clustering Quality (Silhouette Score)', fontweight='bold')
    ax9.set_ylabel('Silhouette Score')
    ax9.tick_params(axis='x', rotation=45)
    ax9.axhline(y=0, color='red', linestyle='--', alpha=0.5)

# 10. Comprehensive Statistics Table (third row, fourth column)
ax10 = fig.add_subplot(gs[2, 3])
ax10.axis('off')

stats_text = "📊 CLUSTERING STATISTICS\n\n"
for dataset_name, metrics in clustering_results.items():
    stats_text += f"🌾 {dataset_name}:\n"
    
    # Overlap stats
    overlap_stats = metrics['overlap_analysis']['overlap_statistics']
    if overlap_stats:
        stats_text += f"  Avg IoU: {overlap_stats['mean_iou']:.3f}\n"
        stats_text += f"  High overlap: {overlap_stats['high_overlap_ratio']:.3f}\n"
    
    # Clustering stats
    dbscan_results = metrics['clustering_analysis']['dbscan_results']
    if dbscan_results:
        avg_clusters = np.mean([r['n_clusters'] for r in dbscan_results])
        stats_text += f"  Avg clusters: {avg_clusters:.2f}\n"
    
    # Spatial stats
    nn_distances = metrics['spatial_metrics']['nearest_neighbor_distances']
    if nn_distances:
        stats_text += f"  Avg NN dist: {np.mean(nn_distances):.4f}\n"
    
    stats_text += "\n"

ax10.text(0.05, 0.95, stats_text, transform=ax10.transAxes, fontsize=9,
         verticalalignment='top', fontfamily='monospace',
         bbox=dict(boxstyle='round', facecolor='lightgreen', alpha=0.8))

plt.suptitle('🔍 COMPREHENSIVE WHEAT HEAD CLUSTERING ANALYSIS', fontsize=16, fontweight='bold')
plt.savefig(notebook_results_dir / 'visualizations' / 'advanced_clustering_analysis.png', 
            dpi=300, bbox_inches='tight')
plt.show()

# Save comprehensive clustering results
clustering_summary = {}
for dataset_name, metrics in clustering_results.items():
    clustering_summary[dataset_name] = {
        'overlap_analysis': {
            'statistics': metrics['overlap_analysis']['overlap_statistics'],
            'high_overlap_images': metrics['overlap_analysis']['high_overlap_images'],
            'total_iou_scores': len(metrics['overlap_analysis']['iou_scores'])
        },
        'clustering_analysis': {
            'dbscan_summary': {
                'avg_clusters': float(np.mean([r['n_clusters'] for r in metrics['clustering_analysis']['dbscan_results']])) if metrics['clustering_analysis']['dbscan_results'] else 0,
                'avg_noise': float(np.mean([r['n_noise'] for r in metrics['clustering_analysis']['dbscan_results']])) if metrics['clustering_analysis']['dbscan_results'] else 0,
                'total_analyses': len(metrics['clustering_analysis']['dbscan_results'])
            },
            'silhouette_summary': {
                'avg_score': float(np.mean(metrics['clustering_analysis']['silhouette_scores'])) if metrics['clustering_analysis']['silhouette_scores'] else 0,
                'total_scores': len(metrics['clustering_analysis']['silhouette_scores'])
            }
        },
        'spatial_metrics': {
            'nearest_neighbor': {
                'avg_distance': float(np.mean(metrics['spatial_metrics']['nearest_neighbor_distances'])) if metrics['spatial_metrics']['nearest_neighbor_distances'] else 0,
                'min_distance': float(min(metrics['spatial_metrics']['nearest_neighbor_distances'])) if metrics['spatial_metrics']['nearest_neighbor_distances'] else 0,
                'std_distance': float(np.std(metrics['spatial_metrics']['nearest_neighbor_distances'])) if metrics['spatial_metrics']['nearest_neighbor_distances'] else 0
            },
            'boundary_effects': {
                'avg_ratio': float(np.mean(metrics['spatial_metrics']['boundary_effects'])) if metrics['spatial_metrics']['boundary_effects'] else 0,
                'total_analyses': len(metrics['spatial_metrics']['boundary_effects'])
            }
        },
        'pattern_detection': metrics['pattern_detection'],
        'density_analysis': {
            'local_density': {
                'avg_density': float(np.mean(metrics['density_analysis']['local_densities'])) if metrics['density_analysis']['local_densities'] else 0,
                'max_density': float(max(metrics['density_analysis']['local_densities'])) if metrics['density_analysis']['local_densities'] else 0,
                'total_points': len(metrics['density_analysis']['local_densities'])
            },
            'hotspots': {
                'avg_hotspots': float(np.mean(metrics['density_analysis']['density_hotspots'])) if metrics['density_analysis']['density_hotspots'] else 0,
                'total_images': len(metrics['density_analysis']['density_hotspots'])
            }
        },
        'geometric_features': {
            'bounding_box': {
                'avg_ratio': float(np.mean(metrics['geometric_features']['bounding_box_ratios'])) if metrics['geometric_features']['bounding_box_ratios'] else 0,
                'std_ratio': float(np.std(metrics['geometric_features']['bounding_box_ratios'])) if metrics['geometric_features']['bounding_box_ratios'] else 0
            },
            'convex_hull': {
                'avg_area': float(np.mean(metrics['geometric_features']['convex_hull_areas'])) if metrics['geometric_features']['convex_hull_areas'] else 0,
                'total_hulls': len(metrics['geometric_features']['convex_hull_areas'])
            },
            'centroid': {
                'avg_distance': float(np.mean(metrics['geometric_features']['centroid_distances'])) if metrics['geometric_features']['centroid_distances'] else 0,
                'total_analyses': len(metrics['geometric_features']['centroid_distances'])
            }
        }
    }

with open(notebook_results_dir / 'data_analysis' / 'advanced_clustering_analysis.json', 'w') as f:
    json.dump(clustering_summary, f, indent=2)

print(f"\n💾 Advanced clustering analysis saved!")
print(f"📁 Location: {notebook_results_dir / 'data_analysis' / 'advanced_clustering_analysis.json'}")

## 5. Field Condition and Environmental Analysis

In [None]:
"""
Enhanced field condition and environmental analysis with computer vision techniques
"""

def analyze_field_conditions_advanced(dataset, dataset_name, max_samples=350):
    """Advanced field condition analysis using computer vision techniques"""
    
    field_metrics = {
        'illumination_analysis': {
            'brightness_values': [],
            'contrast_values': [],
            'lighting_uniformity': [],
            'shadow_detection': [],
            'glare_detection': [],
            'lighting_categories': {'uniform': 0, 'shadows': 0, 'mixed': 0, 'overexposed': 0, 'underexposed': 0}
        },
        'texture_analysis': {
            'local_binary_patterns': [],
            'texture_energy': [],
            'texture_homogeneity': [],
            'texture_contrast': [],
            'edge_density': []
        },
        'color_analysis': {
            'color_histograms': {'r': [], 'g': [], 'b': []},
            'color_moments': [],
            'saturation_levels': [],
            'hue_distributions': [],
            'vegetation_indices': []
        },
        'quality_assessment': {
            'sharpness_scores': [],
            'noise_levels': [],
            'blur_detection': [],
            'dynamic_range': [],
            'exposure_assessment': []
        },
        'environmental_factors': {
            'weather_indicators': {'clear': 0, 'cloudy': 0, 'harsh_light': 0},
            'soil_visibility': [],
            'vegetation_density': [],
            'field_maturity': [],
            'background_complexity': []
        },
        'spatial_characteristics': {
            'gradient_magnitude': [],
            'frequency_analysis': [],
            'structural_similarity': [],
            'feature_density': []
        }
    }
    
    sample_size = min(len(dataset), max_samples)
    indices = np.random.choice(len(dataset), sample_size, replace=False)
    
    print(f"🌾 Advanced field condition analysis on {sample_size} images from {dataset_name}...")
    
    processing_stats = {'success': 0, 'errors': 0, 'invalid': 0}
    
    for i in tqdm(indices, desc="Analyzing field conditions"):
        try:
            image, targets, path = dataset[i]
            
            # Convert tensor to numpy array
            if isinstance(image, torch.Tensor):
                img_np = image.permute(1, 2, 0).cpu().numpy()
                if img_np.min() < 0:  # Normalized to [-1, 1] or similar
                    img_np = (img_np + 1) / 2  # Convert to [0, 1]
                elif img_np.max() <= 1.0:
                    pass  # Already in [0, 1]
                else:
                    img_np = img_np / 255.0  # Convert from [0, 255]
                img_np = np.clip(img_np, 0, 1)
            else:
                img_np = image
                if img_np.max() > 1.0:
                    img_np = img_np / 255.0
            
            # Ensure valid image
            if img_np.shape[2] != 3:
                processing_stats['invalid'] += 1
                continue
            
            # Convert to different color spaces
            img_rgb = (img_np * 255).astype(np.uint8)
            img_gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
            img_hsv = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2HSV)
            img_lab = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2LAB)
            
            processing_stats['success'] += 1
            
            # 1. ADVANCED ILLUMINATION ANALYSIS
            
            # Basic brightness and contrast
            brightness = np.mean(img_gray) / 255.0
            contrast = np.std(img_gray) / 255.0
            field_metrics['illumination_analysis']['brightness_values'].append(brightness)
            field_metrics['illumination_analysis']['contrast_values'].append(contrast)
            
            # Lighting uniformity using coefficient of variation
            lighting_uniformity = np.std(img_gray) / max(np.mean(img_gray), 1)
            field_metrics['illumination_analysis']['lighting_uniformity'].append(lighting_uniformity)
            
            # Shadow detection using morphological operations
            kernel = np.ones((5,5), np.uint8)
            tophat = cv2.morphologyEx(img_gray, cv2.MORPH_TOPHAT, kernel)
            blackhat = cv2.morphologyEx(img_gray, cv2.MORPH_BLACKHAT, kernel)
            
            shadow_strength = np.mean(blackhat) / 255.0
            field_metrics['illumination_analysis']['shadow_detection'].append(shadow_strength)
            
            # Glare detection using brightness percentiles
            bright_pixels = np.percentile(img_gray, 95)
            glare_strength = (bright_pixels - np.mean(img_gray)) / 255.0
            field_metrics['illumination_analysis']['glare_detection'].append(glare_strength)
            
            # Categorize lighting conditions
            if brightness < 0.25:
                field_metrics['illumination_analysis']['lighting_categories']['underexposed'] += 1
            elif brightness > 0.8 and glare_strength > 0.3:
                field_metrics['illumination_analysis']['lighting_categories']['overexposed'] += 1
            elif shadow_strength > 0.1:
                field_metrics['illumination_analysis']['lighting_categories']['shadows'] += 1
            elif lighting_uniformity > 0.8:
                field_metrics['illumination_analysis']['lighting_categories']['mixed'] += 1
            else:
                field_metrics['illumination_analysis']['lighting_categories']['uniform'] += 1
            
            # 2. TEXTURE ANALYSIS
            
            # Local Binary Pattern (simplified implementation)
            def calculate_lbp_variance(image):
                """Calculate LBP variance as texture measure"""
                h, w = image.shape
                lbp_var = 0
                count = 0
                
                for i in range(1, h-1):
                    for j in range(1, w-1):
                        center = image[i, j]
                        neighbors = [
                            image[i-1, j-1], image[i-1, j], image[i-1, j+1],
                            image[i, j+1], image[i+1, j+1], image[i+1, j],
                            image[i+1, j-1], image[i, j-1]
                        ]
                        lbp_var += np.var(neighbors)
                        count += 1
                
                return lbp_var / count if count > 0 else 0
            
            lbp_variance = calculate_lbp_variance(img_gray)
            field_metrics['texture_analysis']['local_binary_patterns'].append(lbp_variance)
            
            # Haralick-inspired texture features
            # Energy (uniformity)
            hist, _ = np.histogram(img_gray, bins=256, range=(0, 256))
            normalized_hist = hist / np.sum(hist)
            energy = np.sum(normalized_hist ** 2)
            field_metrics['texture_analysis']['texture_energy'].append(energy)
            
            # Homogeneity using local patches
            patch_size = 16
            h, w = img_gray.shape
            homogeneity_values = []
            
            for i in range(0, h-patch_size, patch_size):
                for j in range(0, w-patch_size, patch_size):
                    patch = img_gray[i:i+patch_size, j:j+patch_size]
                    patch_std = np.std(patch)
                    homogeneity_values.append(1.0 / (1.0 + patch_std))
            
            avg_homogeneity = np.mean(homogeneity_values) if homogeneity_values else 0
            field_metrics['texture_analysis']['texture_homogeneity'].append(avg_homogeneity)
            
            # Edge density
            edges = cv2.Canny(img_gray, 50, 150)
            edge_density = np.sum(edges > 0) / (edges.shape[0] * edges.shape[1])
            field_metrics['texture_analysis']['edge_density'].append(edge_density)
            
            # 3. COLOR ANALYSIS
            
            # Color channel statistics
            for i, channel in enumerate(['r', 'g', 'b']):
                channel_mean = np.mean(img_rgb[:, :, i]) / 255.0
                field_metrics['color_analysis']['color_histograms'][channel].append(channel_mean)
            
            # Color moments (mean, std, skewness)
            color_moments = []
            for i in range(3):
                channel = img_rgb[:, :, i] / 255.0
                mean_val = np.mean(channel)
                std_val = np.std(channel)
                color_moments.extend([mean_val, std_val])
            
            field_metrics['color_analysis']['color_moments'].append(color_moments)
            
            # Saturation analysis
            saturation = img_hsv[:, :, 1]
            avg_saturation = np.mean(saturation) / 255.0
            field_metrics['color_analysis']['saturation_levels'].append(avg_saturation)
            
            # Vegetation index (simple NDVI approximation using RGB)
            # Approximation: (G - R) / (G + R)
            r_channel = img_rgb[:, :, 0].astype(float)
            g_channel = img_rgb[:, :, 1].astype(float)
            
            # Avoid division by zero
            denominator = g_channel + r_channel
            vegetation_mask = denominator > 10  # Avoid very dark pixels
            
            if np.sum(vegetation_mask) > 0:
                vegetation_index = np.mean((g_channel[vegetation_mask] - r_channel[vegetation_mask]) / 
                                        denominator[vegetation_mask])
                field_metrics['color_analysis']['vegetation_indices'].append(vegetation_index)
            else:
                field_metrics['color_analysis']['vegetation_indices'].append(0)
            
            # 4. QUALITY ASSESSMENT
            
            # Sharpness using Laplacian variance
            laplacian = cv2.Laplacian(img_gray, cv2.CV_64F)
            sharpness = laplacian.var()
            field_metrics['quality_assessment']['sharpness_scores'].append(sharpness)
            
            # Noise estimation using high-frequency content
            # Apply Gaussian filter and measure difference
            blurred = cv2.GaussianBlur(img_gray, (5, 5), 0)
            noise_estimate = np.mean(np.abs(img_gray.astype(float) - blurred.astype(float)))
            field_metrics['quality_assessment']['noise_levels'].append(noise_estimate)
            
            # Blur detection using gradient magnitude
            grad_x = cv2.Sobel(img_gray, cv2.CV_64F, 1, 0, ksize=3)
            grad_y = cv2.Sobel(img_gray, cv2.CV_64F, 0, 1, ksize=3)
            gradient_magnitude = np.sqrt(grad_x**2 + grad_y**2)
            avg_gradient = np.mean(gradient_magnitude)
            field_metrics['quality_assessment']['blur_detection'].append(avg_gradient)
            
            # Dynamic range
            dynamic_range = (np.max(img_gray) - np.min(img_gray)) / 255.0
            field_metrics['quality_assessment']['dynamic_range'].append(dynamic_range)
            
            # 5. ENVIRONMENTAL FACTORS
            
            # Weather indicators based on brightness distribution
            brightness_std = np.std(img_gray) / 255.0
            brightness_mean = np.mean(img_gray) / 255.0
            
            if brightness_std < 0.15 and brightness_mean > 0.7:
                field_metrics['environmental_factors']['weather_indicators']['harsh_light'] += 1
            elif brightness_std > 0.25:
                field_metrics['environmental_factors']['weather_indicators']['cloudy'] += 1
            else:
                field_metrics['environmental_factors']['weather_indicators']['clear'] += 1
            
            # Soil visibility estimation using brown/tan color detection
            # Convert to HSV and look for soil-like colors
            hue = img_hsv[:, :, 0]
            sat = img_hsv[:, :, 1]
            val = img_hsv[:, :, 2]
            
            # Soil typically has hue in brown range (10-30 in HSV) and low saturation
            soil_mask = ((hue >= 10) & (hue <= 30) & (sat < 128) & (val > 50)) | (sat < 50)
            soil_ratio = np.sum(soil_mask) / soil_mask.size
            field_metrics['environmental_factors']['soil_visibility'].append(soil_ratio)
            
            # Vegetation density using green channel dominance
            green_dominance = (g_channel > r_channel) & (g_channel > img_rgb[:, :, 2])
            vegetation_ratio = np.sum(green_dominance) / green_dominance.size
            field_metrics['environmental_factors']['vegetation_density'].append(vegetation_ratio)
            
            # Field maturity estimation using color analysis
            # Young crops: more green, mature crops: more yellow/brown
            yellow_pixels = ((hue >= 15) & (hue <= 35) & (sat > 50))
            maturity_indicator = np.sum(yellow_pixels) / yellow_pixels.size
            field_metrics['environmental_factors']['field_maturity'].append(maturity_indicator)
            
            # Background complexity using frequency analysis
            f_transform = np.fft.fft2(img_gray)
            f_shift = np.fft.fftshift(f_transform)
            magnitude_spectrum = np.log(np.abs(f_shift) + 1)
            complexity_score = np.std(magnitude_spectrum)
            field_metrics['environmental_factors']['background_complexity'].append(complexity_score)
            
            # 6. SPATIAL CHARACTERISTICS
            
            # Gradient magnitude statistics
            field_metrics['spatial_characteristics']['gradient_magnitude'].append(np.mean(gradient_magnitude))
            
            # Frequency analysis - high frequency content
            high_freq_threshold = magnitude_spectrum.shape[0] // 4
            center_x, center_y = magnitude_spectrum.shape[0] // 2, magnitude_spectrum.shape[1] // 2
            
            # Create mask for high frequencies
            y, x = np.ogrid[:magnitude_spectrum.shape[0], :magnitude_spectrum.shape[1]]
            mask = ((x - center_x)**2 + (y - center_y)**2) > high_freq_threshold**2
            
            high_freq_energy = np.mean(magnitude_spectrum[mask])
            field_metrics['spatial_characteristics']['frequency_analysis'].append(high_freq_energy)
            
        except Exception as e:
            processing_stats['errors'] += 1
            continue
    
    print(f"   ✅ Processing stats: {processing_stats['success']} success, "
          f"{processing_stats['errors']} errors, {processing_stats['invalid']} invalid")
    
    return field_metrics

# Analyze field conditions with advanced techniques
field_condition_results = {}

for name, dataset in datasets.items():
    print(f"\n{'='*60}")
    field_metrics = analyze_field_conditions_advanced(dataset, name, max_samples=300)
    field_condition_results[name] = field_metrics
    
    # Display comprehensive field analysis
    print(f"\n🌾 {name} - ADVANCED FIELD CONDITION ANALYSIS:")
    
    # Illumination analysis summary
    illumination = field_metrics['illumination_analysis']
    if illumination['brightness_values']:
        print(f"   💡 Illumination Analysis:")
        print(f"     Brightness: {np.mean(illumination['brightness_values']):.3f} ± {np.std(illumination['brightness_values']):.3f}")
        print(f"     Contrast: {np.mean(illumination['contrast_values']):.3f} ± {np.std(illumination['contrast_values']):.3f}")
        print(f"     Lighting uniformity: {np.mean(illumination['lighting_uniformity']):.3f}")
        print(f"     Shadow strength: {np.mean(illumination['shadow_detection']):.3f}")
        print(f"     Glare strength: {np.mean(illumination['glare_detection']):.3f}")
        
        print(f"   🌤️ Lighting conditions:")
        total_images = sum(illumination['lighting_categories'].values())
        for condition, count in illumination['lighting_categories'].items():
            percentage = (count / total_images * 100) if total_images > 0 else 0
            print(f"     {condition.replace('_', ' ').title()}: {count} ({percentage:.1f}%)")
    
    # Texture analysis summary
    texture = field_metrics['texture_analysis']
    if texture['texture_energy']:
        print(f"   🎨 Texture Analysis:")
        print(f"     Texture energy: {np.mean(texture['texture_energy']):.4f}")
        print(f"     Homogeneity: {np.mean(texture['texture_homogeneity']):.4f}")
        print(f"     Edge density: {np.mean(texture['edge_density']):.4f}")
        print(f"     LBP variance: {np.mean(texture['local_binary_patterns']):.2f}")
    
    # Color analysis summary
    color = field_metrics['color_analysis']
    if color['saturation_levels']:
        print(f"   🎨 Color Analysis:")
        print(f"     Average saturation: {np.mean(color['saturation_levels']):.3f}")
        print(f"     Vegetation index: {np.mean(color['vegetation_indices']):.3f}")
        
        # Color channel analysis
        r_avg = np.mean(color['color_histograms']['r'])
        g_avg = np.mean(color['color_histograms']['g'])
        b_avg = np.mean(color['color_histograms']['b'])
        print(f"     RGB averages: R={r_avg:.3f}, G={g_avg:.3f}, B={b_avg:.3f}")
    
    # Quality assessment summary
    quality = field_metrics['quality_assessment']
    if quality['sharpness_scores']:
        print(f"   📷 Quality Assessment:")
        print(f"     Sharpness score: {np.mean(quality['sharpness_scores']):.2f}")
        print(f"     Noise level: {np.mean(quality['noise_levels']):.2f}")
        print(f"     Blur metric: {np.mean(quality['blur_detection']):.2f}")
        print(f"     Dynamic range: {np.mean(quality['dynamic_range']):.3f}")
    
    # Environmental factors summary
    environment = field_metrics['environmental_factors']
    if environment['soil_visibility']:
        print(f"   🌱 Environmental Analysis:")
        print(f"     Soil visibility: {np.mean(environment['soil_visibility']):.3f}")
        print(f"     Vegetation density: {np.mean(environment['vegetation_density']):.3f}")
        print(f"     Field maturity: {np.mean(environment['field_maturity']):.3f}")
        print(f"     Background complexity: {np.mean(environment['background_complexity']):.2f}")
        
        print(f"   ☀️ Weather indicators:")
        total_weather = sum(environment['weather_indicators'].values())
        for weather, count in environment['weather_indicators'].items():
            percentage = (count / total_weather * 100) if total_weather > 0 else 0
            print(f"     {weather.replace('_', ' ').title()}: {count} ({percentage:.1f}%)")

# Create comprehensive field condition visualization dashboard
fig = plt.figure(figsize=(24, 20))
gs = fig.add_gridspec(5, 4, hspace=0.4, wspace=0.3)

dataset_colors = plt.cm.Set1(np.linspace(0, 1, len(field_condition_results)))

# 1. Brightness and Contrast Distribution (top row, first column)
ax1 = fig.add_subplot(gs[0, 0])
for idx, (dataset_name, metrics) in enumerate(field_condition_results.items()):
    brightness = metrics['illumination_analysis']['brightness_values']
    contrast = metrics['illumination_analysis']['contrast_values']
    if brightness and contrast:
        ax1.scatter(brightness, contrast, alpha=0.6, label=dataset_name, 
                   s=30, color=dataset_colors[idx])
ax1.set_title('Brightness vs Contrast Analysis', fontweight='bold')
ax1.set_xlabel('Brightness')
ax1.set_ylabel('Contrast')
ax1.legend()
ax1.grid(True, alpha=0.3)

# 2. Lighting Conditions Distribution (top row, second column)
ax2 = fig.add_subplot(gs[0, 1])
lighting_categories = ['uniform', 'shadows', 'mixed', 'overexposed', 'underexposed']
x_pos = np.arange(len(lighting_categories))
width = 0.8 / len(field_condition_results)

for idx, (dataset_name, metrics) in enumerate(field_condition_results.items()):
    lighting_data = metrics['illumination_analysis']['lighting_categories']
    values = [lighting_data.get(cat, 0) for cat in lighting_categories]
    ax2.bar(x_pos + idx*width, values, width, label=dataset_name, 
           alpha=0.8, color=dataset_colors[idx])

ax2.set_title('Lighting Conditions Distribution', fontweight='bold')
ax2.set_xlabel('Lighting Condition')
ax2.set_ylabel('Number of Images')
ax2.set_xticks(x_pos + width/2)
ax2.set_xticklabels([cat.replace('_', '\n') for cat in lighting_categories])
ax2.legend(bbox_to_anchor=(1.05, 1), loc='upper left')

# 3. Texture Analysis (top row, third column)
ax3 = fig.add_subplot(gs[0, 2])
for idx, (dataset_name, metrics) in enumerate(field_condition_results.items()):
    energy = metrics['texture_analysis']['texture_energy']
    homogeneity = metrics['texture_analysis']['texture_homogeneity']
    if energy and homogeneity:
        ax3.scatter(energy, homogeneity, alpha=0.6, label=dataset_name,
                   s=30, color=dataset_colors[idx])
ax3.set_title('Texture Energy vs Homogeneity', fontweight='bold')
ax3.set_xlabel('Texture Energy')
ax3.set_ylabel('Texture Homogeneity')
ax3.legend()
ax3.grid(True, alpha=0.3)

# 4. Quality Assessment Radar Chart (top row, fourth column)
ax4 = fig.add_subplot(gs[0, 3], projection='polar')
quality_metrics = ['Sharpness', 'Low Noise', 'No Blur', 'Dynamic Range']
angles = np.linspace(0, 2*np.pi, len(quality_metrics), endpoint=False).tolist()
angles += angles[:1]  # Complete the circle

for idx, (dataset_name, metrics) in enumerate(field_condition_results.items()):
    quality = metrics['quality_assessment']
    if quality['sharpness_scores']:
        # Normalize metrics to 0-1 scale
        sharpness_norm = min(np.mean(quality['sharpness_scores']) / 1000, 1)
        noise_norm = max(0, 1 - np.mean(quality['noise_levels']) / 50)  # Invert noise
        blur_norm = min(np.mean(quality['blur_detection']) / 100, 1)
        range_norm = np.mean(quality['dynamic_range'])
        
        values = [sharpness_norm, noise_norm, blur_norm, range_norm]
        values += values[:1]  # Complete the circle
        
        ax4.plot(angles, values, 'o-', linewidth=2, label=dataset_name, 
                color=dataset_colors[idx])
        ax4.fill(angles, values, alpha=0.25, color=dataset_colors[idx])

ax4.set_xticks(angles[:-1])
ax4.set_xticklabels(quality_metrics)
ax4.set_ylim(0, 1)
ax4.set_title('Quality Assessment Profile', fontweight='bold', pad=20)
ax4.legend(bbox_to_anchor=(1.3, 1), loc='upper left')

# 5. Color Analysis (second row, first two columns)
ax5 = fig.add_subplot(gs[1, :2])
for idx, (dataset_name, metrics) in enumerate(field_condition_results.items()):
    color_data = metrics['color_analysis']['color_histograms']
    if all(color_data[c] for c in ['r', 'g', 'b']):
        r_vals = color_data['r']
        g_vals = color_data['g']
        b_vals = color_data['b']
        
        # Create RGB color space plot
        ax5.scatter(r_vals, g_vals, c=dataset_colors[idx], alpha=0.6, 
                   s=30, label=f'{dataset_name} (R vs G)')

ax5.set_title('Color Space Analysis (Red vs Green Channels)', fontweight='bold')
ax5.set_xlabel('Red Channel Average')
ax5.set_ylabel('Green Channel Average')
ax5.legend()
ax5.grid(True, alpha=0.3)

# 6. Vegetation Analysis (second row, third column)
ax6 = fig.add_subplot(gs[1, 2])
vegetation_data = []
vegetation_labels = []
saturation_data = []
saturation_labels = []

for dataset_name, metrics in field_condition_results.items():
    veg_indices = metrics['color_analysis']['vegetation_indices']
    sat_levels = metrics['color_analysis']['saturation_levels']
    
    if veg_indices:
        vegetation_data.extend(veg_indices)
        vegetation_labels.extend([dataset_name] * len(veg_indices))
    
    if sat_levels:
        saturation_data.extend(sat_levels)
        saturation_labels.extend([dataset_name] * len(sat_levels))

if vegetation_data and saturation_data:
    # Create combined plot
    min_len = min(len(vegetation_data), len(saturation_data))
    ax6.scatter(vegetation_data[:min_len], saturation_data[:min_len], 
               alpha=0.6, s=30, c='green')
    ax6.set_title('Vegetation Index vs Saturation', fontweight='bold')
    ax6.set_xlabel('Vegetation Index')
    ax6.set_ylabel('Saturation Level')
    ax6.grid(True, alpha=0.3)

# 7. Environmental Factors (second row, fourth column)
ax7 = fig.add_subplot(gs[1, 3])
env_factors = ['soil_visibility', 'vegetation_density', 'field_maturity']
env_data = {factor: [] for factor in env_factors}
env_labels = []

for dataset_name, metrics in field_condition_results.items():
    env_metrics = metrics['environmental_factors']
    for factor in env_factors:
        if env_metrics[factor]:
            env_data[factor].extend(env_metrics[factor])
            if factor == 'soil_visibility':  # Only add labels once
                env_labels.extend([dataset_name] * len(env_metrics[factor]))

# Create box plots for environmental factors
positions = []
box_data = []
box_labels = []

for i, (factor, data) in enumerate(env_data.items()):
    if data:
        positions.append(i)
        box_data.append(data)
        box_labels.append(factor.replace('_', '\n').title())

if box_data:
    bp = ax7.boxplot(box_data, positions=positions, labels=box_labels, patch_artist=True)
    
    # Color the boxes
    colors = plt.cm.Set3(np.linspace(0, 1, len(bp['boxes'])))
    for patch, color in zip(bp['boxes'], colors):
        patch.set_facecolor(color)
        patch.set_alpha(0.7)

ax7.set_title('Environmental Factors Distribution', fontweight='bold')
ax7.set_ylabel('Ratio/Score')
ax7.tick_params(axis='x', rotation=45)

# 8. Weather Indicators (third row, first column)
ax8 = fig.add_subplot(gs[2, 0])
weather_summary = defaultdict(int)
for dataset_name, metrics in field_condition_results.items():
    weather_data = metrics['environmental_factors']['weather_indicators']
    for weather, count in weather_data.items():
        weather_summary[weather] += count

if weather_summary:
    labels = list(weather_summary.keys())
    sizes = list(weather_summary.values())
    colors = plt.cm.Set2(np.linspace(0, 1, len(labels)))
    
    wedges, texts, autotexts = ax8.pie(sizes, labels=[l.replace('_', '\n') for l in labels], 
                                      autopct='%1.1f%%', colors=colors, startangle=90)
    ax8.set_title('Weather Conditions Distribution', fontweight='bold')

# 9. Spatial Characteristics (third row, second column)
ax9 = fig.add_subplot(gs[2, 1])
for idx, (dataset_name, metrics) in enumerate(field_condition_results.items()):
    gradient_mag = metrics['spatial_characteristics']['gradient_magnitude']
    freq_analysis = metrics['spatial_characteristics']['frequency_analysis']
    
    if gradient_mag and freq_analysis:
        ax9.scatter(gradient_mag, freq_analysis, alpha=0.6, label=dataset_name,
                   s=30, color=dataset_colors[idx])

ax9.set_title('Spatial Characteristics', fontweight='bold')
ax9.set_xlabel('Gradient Magnitude')
ax9.set_ylabel('High Frequency Energy')
ax9.legend()
ax9.grid(True, alpha=0.3)

# 10. Quality Metrics Distribution (third row, third and fourth columns)
ax10 = fig.add_subplot(gs[2, 2:])
quality_metrics_data = {}
for dataset_name, metrics in field_condition_results.items():
    quality = metrics['quality_assessment']
    quality_metrics_data[dataset_name] = {
        'sharpness': quality['sharpness_scores'],
        'noise': quality['noise_levels'],
        'blur': quality['blur_detection'],
        'dynamic_range': quality['dynamic_range']
    }

# Create subplots for each quality metric
metric_names = ['sharpness', 'noise', 'blur', 'dynamic_range']
for i, metric in enumerate(metric_names):
   for idx, (dataset_name, data) in enumerate(quality_metrics_data.items()):
       if data[metric]:
           # Normalize position for each metric group
           positions = np.random.normal(i, 0.04, len(data[metric]))
           ax10.scatter(positions, data[metric], alpha=0.6, s=20, 
                       color=dataset_colors[idx], label=dataset_name if i == 0 else "")

ax10.set_title('Quality Metrics Distribution', fontweight='bold')
ax10.set_xlabel('Quality Metric')
ax10.set_ylabel('Score')
ax10.set_xticks(range(len(metric_names)))
ax10.set_xticklabels([name.replace('_', '\n').title() for name in metric_names])
if quality_metrics_data:
   ax10.legend()
ax10.grid(True, alpha=0.3)

# 11. Comprehensive Statistics Summary (fourth and fifth rows)
ax11 = fig.add_subplot(gs[3:, :])
ax11.axis('off')

# Create comprehensive statistics table
stats_text = "📊 COMPREHENSIVE FIELD CONDITION STATISTICS\n\n"

for dataset_name, metrics in field_condition_results.items():
   stats_text += f"🌾 {dataset_name}:\n"
   
   # Illumination stats
   illum = metrics['illumination_analysis']
   if illum['brightness_values']:
       stats_text += f"  💡 Illumination:\n"
       stats_text += f"    Brightness: {np.mean(illum['brightness_values']):.3f} ± {np.std(illum['brightness_values']):.3f}\n"
       stats_text += f"    Contrast: {np.mean(illum['contrast_values']):.3f} ± {np.std(illum['contrast_values']):.3f}\n"
       stats_text += f"    Shadow strength: {np.mean(illum['shadow_detection']):.3f}\n"
       stats_text += f"    Glare strength: {np.mean(illum['glare_detection']):.3f}\n"
   
   # Texture stats
   texture = metrics['texture_analysis']
   if texture['texture_energy']:
       stats_text += f"  🎨 Texture:\n"
       stats_text += f"    Energy: {np.mean(texture['texture_energy']):.4f}\n"
       stats_text += f"    Homogeneity: {np.mean(texture['texture_homogeneity']):.4f}\n"
       stats_text += f"    Edge density: {np.mean(texture['edge_density']):.4f}\n"
   
   # Color stats
   color = metrics['color_analysis']
   if color['saturation_levels']:
       stats_text += f"  🌈 Color:\n"
       stats_text += f"    Saturation: {np.mean(color['saturation_levels']):.3f}\n"
       stats_text += f"    Vegetation index: {np.mean(color['vegetation_indices']):.3f}\n"
   
   # Quality stats
   quality = metrics['quality_assessment']
   if quality['sharpness_scores']:
       stats_text += f"  📷 Quality:\n"
       stats_text += f"    Sharpness: {np.mean(quality['sharpness_scores']):.1f}\n"
       stats_text += f"    Noise: {np.mean(quality['noise_levels']):.2f}\n"
       stats_text += f"    Dynamic range: {np.mean(quality['dynamic_range']):.3f}\n"
   
   # Environmental stats
   env = metrics['environmental_factors']
   if env['soil_visibility']:
       stats_text += f"  🌱 Environment:\n"
       stats_text += f"    Soil visibility: {np.mean(env['soil_visibility']):.3f}\n"
       stats_text += f"    Vegetation density: {np.mean(env['vegetation_density']):.3f}\n"
       stats_text += f"    Field maturity: {np.mean(env['field_maturity']):.3f}\n"
   
   stats_text += "\n"

ax11.text(0.02, 0.98, stats_text, transform=ax11.transAxes, fontsize=9,
        verticalalignment='top', fontfamily='monospace',
        bbox=dict(boxstyle='round', facecolor='lightcyan', alpha=0.8))

plt.suptitle('🌾 COMPREHENSIVE FIELD CONDITION & ENVIRONMENTAL ANALYSIS', 
            fontsize=16, fontweight='bold')
plt.savefig(notebook_results_dir / 'visualizations' / 'comprehensive_field_analysis.png', 
           dpi=300, bbox_inches='tight')
plt.show()

# Save comprehensive field condition results
field_condition_summary = {}
for dataset_name, metrics in field_condition_results.items():
   field_condition_summary[dataset_name] = {
       'illumination_analysis': {
           'brightness_stats': {
               'mean': float(np.mean(metrics['illumination_analysis']['brightness_values'])) if metrics['illumination_analysis']['brightness_values'] else 0,
               'std': float(np.std(metrics['illumination_analysis']['brightness_values'])) if metrics['illumination_analysis']['brightness_values'] else 0,
               'min': float(min(metrics['illumination_analysis']['brightness_values'])) if metrics['illumination_analysis']['brightness_values'] else 0,
               'max': float(max(metrics['illumination_analysis']['brightness_values'])) if metrics['illumination_analysis']['brightness_values'] else 0
           },
           'contrast_stats': {
               'mean': float(np.mean(metrics['illumination_analysis']['contrast_values'])) if metrics['illumination_analysis']['contrast_values'] else 0,
               'std': float(np.std(metrics['illumination_analysis']['contrast_values'])) if metrics['illumination_analysis']['contrast_values'] else 0
           },
           'lighting_uniformity': {
               'mean': float(np.mean(metrics['illumination_analysis']['lighting_uniformity'])) if metrics['illumination_analysis']['lighting_uniformity'] else 0,
               'std': float(np.std(metrics['illumination_analysis']['lighting_uniformity'])) if metrics['illumination_analysis']['lighting_uniformity'] else 0
           },
           'shadow_detection': {
               'mean': float(np.mean(metrics['illumination_analysis']['shadow_detection'])) if metrics['illumination_analysis']['shadow_detection'] else 0,
               'std': float(np.std(metrics['illumination_analysis']['shadow_detection'])) if metrics['illumination_analysis']['shadow_detection'] else 0
           },
           'glare_detection': {
               'mean': float(np.mean(metrics['illumination_analysis']['glare_detection'])) if metrics['illumination_analysis']['glare_detection'] else 0,
               'std': float(np.std(metrics['illumination_analysis']['glare_detection'])) if metrics['illumination_analysis']['glare_detection'] else 0
           },
           'lighting_categories': metrics['illumination_analysis']['lighting_categories']
       },
       'texture_analysis': {
           'texture_energy': {
               'mean': float(np.mean(metrics['texture_analysis']['texture_energy'])) if metrics['texture_analysis']['texture_energy'] else 0,
               'std': float(np.std(metrics['texture_analysis']['texture_energy'])) if metrics['texture_analysis']['texture_energy'] else 0
           },
           'texture_homogeneity': {
               'mean': float(np.mean(metrics['texture_analysis']['texture_homogeneity'])) if metrics['texture_analysis']['texture_homogeneity'] else 0,
               'std': float(np.std(metrics['texture_analysis']['texture_homogeneity'])) if metrics['texture_analysis']['texture_homogeneity'] else 0
           },
           'edge_density': {
               'mean': float(np.mean(metrics['texture_analysis']['edge_density'])) if metrics['texture_analysis']['edge_density'] else 0,
               'std': float(np.std(metrics['texture_analysis']['edge_density'])) if metrics['texture_analysis']['edge_density'] else 0
           },
           'lbp_variance': {
               'mean': float(np.mean(metrics['texture_analysis']['local_binary_patterns'])) if metrics['texture_analysis']['local_binary_patterns'] else 0,
               'std': float(np.std(metrics['texture_analysis']['local_binary_patterns'])) if metrics['texture_analysis']['local_binary_patterns'] else 0
           }
       },
       'color_analysis': {
           'saturation_stats': {
               'mean': float(np.mean(metrics['color_analysis']['saturation_levels'])) if metrics['color_analysis']['saturation_levels'] else 0,
               'std': float(np.std(metrics['color_analysis']['saturation_levels'])) if metrics['color_analysis']['saturation_levels'] else 0
           },
           'vegetation_index': {
               'mean': float(np.mean(metrics['color_analysis']['vegetation_indices'])) if metrics['color_analysis']['vegetation_indices'] else 0,
               'std': float(np.std(metrics['color_analysis']['vegetation_indices'])) if metrics['color_analysis']['vegetation_indices'] else 0
           },
           'rgb_averages': {
               'red': float(np.mean(metrics['color_analysis']['color_histograms']['r'])) if metrics['color_analysis']['color_histograms']['r'] else 0,
               'green': float(np.mean(metrics['color_analysis']['color_histograms']['g'])) if metrics['color_analysis']['color_histograms']['g'] else 0,
               'blue': float(np.mean(metrics['color_analysis']['color_histograms']['b'])) if metrics['color_analysis']['color_histograms']['b'] else 0
           }
       },
       'quality_assessment': {
           'sharpness_stats': {
               'mean': float(np.mean(metrics['quality_assessment']['sharpness_scores'])) if metrics['quality_assessment']['sharpness_scores'] else 0,
               'std': float(np.std(metrics['quality_assessment']['sharpness_scores'])) if metrics['quality_assessment']['sharpness_scores'] else 0
           },
           'noise_stats': {
               'mean': float(np.mean(metrics['quality_assessment']['noise_levels'])) if metrics['quality_assessment']['noise_levels'] else 0,
               'std': float(np.std(metrics['quality_assessment']['noise_levels'])) if metrics['quality_assessment']['noise_levels'] else 0
           },
           'blur_stats': {
               'mean': float(np.mean(metrics['quality_assessment']['blur_detection'])) if metrics['quality_assessment']['blur_detection'] else 0,
               'std': float(np.std(metrics['quality_assessment']['blur_detection'])) if metrics['quality_assessment']['blur_detection'] else 0
           },
           'dynamic_range_stats': {
               'mean': float(np.mean(metrics['quality_assessment']['dynamic_range'])) if metrics['quality_assessment']['dynamic_range'] else 0,
               'std': float(np.std(metrics['quality_assessment']['dynamic_range'])) if metrics['quality_assessment']['dynamic_range'] else 0
           }
       },
       'environmental_factors': {
           'soil_visibility': {
               'mean': float(np.mean(metrics['environmental_factors']['soil_visibility'])) if metrics['environmental_factors']['soil_visibility'] else 0,
               'std': float(np.std(metrics['environmental_factors']['soil_visibility'])) if metrics['environmental_factors']['soil_visibility'] else 0
           },
           'vegetation_density': {
               'mean': float(np.mean(metrics['environmental_factors']['vegetation_density'])) if metrics['environmental_factors']['vegetation_density'] else 0,
               'std': float(np.std(metrics['environmental_factors']['vegetation_density'])) if metrics['environmental_factors']['vegetation_density'] else 0
           },
           'field_maturity': {
               'mean': float(np.mean(metrics['environmental_factors']['field_maturity'])) if metrics['environmental_factors']['field_maturity'] else 0,
               'std': float(np.std(metrics['environmental_factors']['field_maturity'])) if metrics['environmental_factors']['field_maturity'] else 0
           },
           'background_complexity': {
               'mean': float(np.mean(metrics['environmental_factors']['background_complexity'])) if metrics['environmental_factors']['background_complexity'] else 0,
               'std': float(np.std(metrics['environmental_factors']['background_complexity'])) if metrics['environmental_factors']['background_complexity'] else 0
           },
           'weather_indicators': metrics['environmental_factors']['weather_indicators']
       },
       'spatial_characteristics': {
           'gradient_magnitude': {
               'mean': float(np.mean(metrics['spatial_characteristics']['gradient_magnitude'])) if metrics['spatial_characteristics']['gradient_magnitude'] else 0,
               'std': float(np.std(metrics['spatial_characteristics']['gradient_magnitude'])) if metrics['spatial_characteristics']['gradient_magnitude'] else 0
           },
           'frequency_analysis': {
               'mean': float(np.mean(metrics['spatial_characteristics']['frequency_analysis'])) if metrics['spatial_characteristics']['frequency_analysis'] else 0,
               'std': float(np.std(metrics['spatial_characteristics']['frequency_analysis'])) if metrics['spatial_characteristics']['frequency_analysis'] else 0
           }
       }
   }

with open(notebook_results_dir / 'data_analysis' / 'comprehensive_field_conditions.json', 'w') as f:
   json.dump(field_condition_summary, f, indent=2)

print(f"\n💾 Comprehensive field condition analysis saved!")
print(f"📁 Location: {notebook_results_dir / 'data_analysis' / 'comprehensive_field_conditions.json'}")

## 6. Wheat Head Sample Visualization

In [None]:
"""
Enhanced wheat head sample visualization with detailed analysis and annotations
"""

def visualize_wheat_samples_advanced(dataset, dataset_name, num_samples=12):
    """Advanced visualization of wheat samples with detailed annotations and analysis"""
    
    if not hasattr(dataset, '__getitem__'):
        print(f"Skipping {dataset_name} - incompatible dataset format")
        return
    
    # Intelligent sample selection for diverse representation
    total_samples = len(dataset)
    
    # Strategy: Select samples with different wheat head densities
    sample_indices = []
    density_targets = [0, 5, 10, 15, 20, 25, 30, 40]  # Target different densities
    
    # Try to find samples with target densities
    attempts = 0
    max_attempts = min(total_samples, 500)
    
    while len(sample_indices) < num_samples and attempts < max_attempts:
        idx = np.random.randint(0, total_samples)
        
        try:
            _, targets, _ = dataset[idx]
            num_heads = len(targets) if targets.numel() > 0 else 0
            
            # Check if this sample fills a desired density category
            target_density = density_targets[len(sample_indices) % len(density_targets)]
            density_diff = abs(num_heads - target_density)
            
            # Accept if it's close to target or if we need any sample
            if idx not in sample_indices and (density_diff <= 5 or len(sample_indices) >= 8):
                sample_indices.append(idx)
                
        except:
            pass
        
        attempts += 1
    
    # Fill remaining slots with random samples if needed
    while len(sample_indices) < num_samples:
        idx = np.random.randint(0, total_samples)
        if idx not in sample_indices:
            sample_indices.append(idx)
    
    # Create advanced visualization layout
    cols = 4
    rows = (num_samples + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(6*cols, 5*rows))
    if rows == 1:
        axes = axes.reshape(1, -1)
    elif cols == 1:
        axes = axes.reshape(-1, 1)
    
    axes_flat = axes.flatten()
    
    # Color map for different wheat head sizes
    size_colors = {'tiny': 'red', 'small': 'orange', 'medium': 'yellow', 'large': 'green'}
    
    sample_statistics = []
    
    for i, idx in enumerate(sample_indices):
        if i >= len(axes_flat):
            break
            
        try:
            image, targets, path = dataset[idx]
            
            # Convert and process image
            if isinstance(image, torch.Tensor):
                img_np = image.permute(1, 2, 0).cpu().numpy()
                if img_np.min() < 0:
                    img_np = (img_np + 1) / 2
                elif img_np.max() <= 1.0:
                    pass
                else:
                    img_np = img_np / 255.0
                img_np = np.clip(img_np, 0, 1)
            else:
                img_np = image
                if img_np.max() > 1.0:
                    img_np = img_np / 255.0
            
            # Handle multi-channel images
            if img_np.shape[-1] > 3:
                img_np = img_np[:, :, :3]
            
            axes_flat[i].imshow(img_np)
            
            # Advanced annotation and analysis
            num_heads = len(targets) if targets.numel() > 0 else 0
            sample_stats = {
                'num_heads': num_heads,
                'sizes': [],
                'positions': [],
                'clustering_score': 0,
                'density_category': '',
                'spatial_distribution': '',
                'overlap_detected': False
            }
            
            if num_heads > 0:
                h, w = img_np.shape[:2]
                positions = []
                areas = []
                overlaps = 0
                
                # Process each wheat head
                for j, target in enumerate(targets):
                    if len(target) >= 5:
                        cls, x_center, y_center, width, height = target[:5]
                        
                        # Convert to pixel coordinates
                        x1 = (x_center - width/2) * w
                        y1 = (y_center - height/2) * h
                        x2 = (x_center + width/2) * w
                        y2 = (y_center + height/2) * h
                        
                        # Store data for analysis
                        area = width * height
                        areas.append(area)
                        positions.append((x_center, y_center))
                        
                        # Determine size category
                        if area < 0.0003:
                            size_cat = 'tiny'
                        elif area < 0.0015:
                            size_cat = 'small'
                        elif area < 0.006:
                            size_cat = 'medium'
                        else:
                            size_cat = 'large'
                        
                        sample_stats['sizes'].append(size_cat)
                        
                        # Draw bounding box with size-based color
                        from matplotlib.patches import Rectangle
                        color = size_colors[size_cat]
                        
                        rect = Rectangle((x1, y1), x2-x1, y2-y1, 
                                       linewidth=1.5, edgecolor=color, facecolor='none', alpha=0.8)
                        axes_flat[i].add_patch(rect)
                        
                        # Add size label
                        axes_flat[i].text(x1, y1-5, f'{j+1}', fontsize=8, color=color, 
                                        fontweight='bold', bbox=dict(boxstyle='round,pad=0.2', 
                                        facecolor='white', alpha=0.7))
                
                # Analyze spatial distribution
                if len(positions) >= 2:
                    positions_array = np.array(positions)
                    
                    # Calculate clustering score using coefficient of variation of distances
                    distances = pdist(positions_array)
                    if len(distances) > 0:
                        cv_distances = np.std(distances) / np.mean(distances) if np.mean(distances) > 0 else 0
                        sample_stats['clustering_score'] = cv_distances
                    
                    # Detect spatial patterns
                    x_coords = positions_array[:, 0]
                    y_coords = positions_array[:, 1]
                    
                    x_spread = np.std(x_coords)
                    y_spread = np.std(y_coords)
                    
                    if x_spread < 0.2 and y_spread < 0.2:
                        sample_stats['spatial_distribution'] = 'Clustered'
                    elif max(x_spread, y_spread) / min(x_spread, y_spread) > 2:
                        sample_stats['spatial_distribution'] = 'Linear'
                    else:
                        sample_stats['spatial_distribution'] = 'Distributed'
                    
                    # Simple overlap detection
                    for j in range(len(positions)):
                        for k in range(j+1, len(positions)):
                            dist = np.sqrt((positions[j][0] - positions[k][0])**2 + 
                                         (positions[j][1] - positions[k][1])**2)
                            if dist < 0.05:  # Close proximity threshold
                                overlaps += 1
                    
                    sample_stats['overlap_detected'] = overlaps > 0
                
                sample_stats['positions'] = positions
            
            # Categorize density
            if num_heads == 0:
                sample_stats['density_category'] = 'Empty'
            elif num_heads <= 5:
                sample_stats['density_category'] = 'Low'
            elif num_heads <= 15:
                sample_stats['density_category'] = 'Medium'
            elif num_heads <= 30:
                sample_stats['density_category'] = 'High'
            else:
                sample_stats['density_category'] = 'Very High'
            
            sample_statistics.append(sample_stats)
            
            # Create comprehensive title with analysis
            size_summary = dict(Counter(sample_stats['sizes']))
            size_text = ', '.join([f"{count} {size}" for size, count in size_summary.items()])
            
            title_lines = [
                f'Sample {i+1}: {num_heads} wheat heads',
                f'Density: {sample_stats["density_category"]}',
                f'Distribution: {sample_stats["spatial_distribution"]}',
                f'Sizes: {size_text if size_text else "None"}',
                f'Clustering: {sample_stats["clustering_score"]:.2f}' if sample_stats["clustering_score"] > 0 else 'Clustering: N/A',
                f'Overlaps: {"Yes" if sample_stats["overlap_detected"] else "No"}',
                f'{Path(path).name}'
            ]
            
            axes_flat[i].set_title('\n'.join(title_lines), fontsize=8, ha='left')
            axes_flat[i].axis('off')
            
        except Exception as e:
            axes_flat[i].text(0.5, 0.5, f'Error loading\nsample {i+1}\n{str(e)[:30]}...', 
                             ha='center', va='center', transform=axes_flat[i].transAxes,
                             bbox=dict(boxstyle='round', facecolor='pink', alpha=0.8))
            axes_flat[i].axis('off')
    
    # Hide unused subplots
    for i in range(len(sample_indices), len(axes_flat)):
        axes_flat[i].axis('off')
    
    # Add legend for size categories
    legend_elements = [plt.Rectangle((0,0),1,1, facecolor='none', edgecolor=color, 
                                   linewidth=2, label=f'{size.title()} wheat heads') 
                      for size, color in size_colors.items()]
    
    if axes_flat:
        axes_flat[0].legend(handles=legend_elements, loc='upper right', 
                          bbox_to_anchor=(1, 1), fontsize=8)
    
    plt.suptitle(f'🌾 Advanced Wheat Head Sample Analysis - {dataset_name}', 
                 fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(notebook_results_dir / 'samples' / f'advanced_wheat_samples_{dataset_name.lower()}.png', 
                dpi=300, bbox_inches='tight')
    plt.show()
    
    # Create summary statistics visualization
    if sample_statistics:
        create_sample_analysis_summary(sample_statistics, dataset_name)
    
    return sample_statistics

def create_sample_analysis_summary(sample_statistics, dataset_name):
    """Create summary analysis of the visualized samples"""
    
    fig, axes = plt.subplots(2, 3, figsize=(18, 12))
    
    # Extract data for analysis
    densities = [s['num_heads'] for s in sample_statistics]
    density_categories = [s['density_category'] for s in sample_statistics]
    spatial_distributions = [s['spatial_distribution'] for s in sample_statistics if s['spatial_distribution']]
    clustering_scores = [s['clustering_score'] for s in sample_statistics if s['clustering_score'] > 0]
    
    # Size distribution across all samples
    all_sizes = []
    for s in sample_statistics:
        all_sizes.extend(s['sizes'])
    
    # 1. Density distribution
    axes[0, 0].hist(densities, bins=15, alpha=0.7, color='skyblue', edgecolor='black')
    axes[0, 0].set_title('Sample Density Distribution', fontweight='bold')
    axes[0, 0].set_xlabel('Number of Wheat Heads')
    axes[0, 0].set_ylabel('Frequency')
    axes[0, 0].grid(True, alpha=0.3)
    
    # Add statistics
    if densities:
        mean_density = np.mean(densities)
        axes[0, 0].axvline(mean_density, color='red', linestyle='--', 
                          label=f'Mean: {mean_density:.1f}')
        axes[0, 0].legend()
    
    # 2. Density categories pie chart
    density_counts = Counter(density_categories)
    if density_counts:
        labels = list(density_counts.keys())
        sizes = list(density_counts.values())
        colors = plt.cm.Set3(np.linspace(0, 1, len(labels)))
        
        axes[0, 1].pie(sizes, labels=labels, autopct='%1.1f%%', colors=colors, startangle=90)
        axes[0, 1].set_title('Density Categories Distribution', fontweight='bold')
    
    # 3. Size distribution
    if all_sizes:
        size_counts = Counter(all_sizes)
        size_labels = list(size_counts.keys())
        size_values = list(size_counts.values())
        
        bars = axes[0, 2].bar(size_labels, size_values, alpha=0.8, 
                             color=['red', 'orange', 'yellow', 'green'])
        axes[0, 2].set_title('Wheat Head Size Distribution', fontweight='bold')
        axes[0, 2].set_xlabel('Size Category')
        axes[0, 2].set_ylabel('Count')
        
        # Add value labels
        for bar, value in zip(bars, size_values):
            height = bar.get_height()
            axes[0, 2].text(bar.get_x() + bar.get_width()/2., height + 0.1,
                           f'{value}', ha='center', va='bottom', fontweight='bold')
    
    # 4. Spatial distribution patterns
    if spatial_distributions:
        spatial_counts = Counter(spatial_distributions)
        spatial_labels = list(spatial_counts.keys())
        spatial_values = list(spatial_counts.values())
        
        axes[1, 0].bar(spatial_labels, spatial_values, alpha=0.8, 
                      color=plt.cm.Set2(np.linspace(0, 1, len(spatial_labels))))
        axes[1, 0].set_title('Spatial Distribution Patterns', fontweight='bold')
        axes[1, 0].set_xlabel('Pattern Type')
        axes[1, 0].set_ylabel('Count')
        axes[1, 0].tick_params(axis='x', rotation=45)
    
    # 5. Clustering scores
    if clustering_scores:
        axes[1, 1].hist(clustering_scores, bins=10, alpha=0.7, color='lightgreen', edgecolor='black')
        axes[1, 1].set_title('Clustering Score Distribution', fontweight='bold')
        axes[1, 1].set_xlabel('Clustering Score (CV of distances)')
        axes[1, 1].set_ylabel('Frequency')
        axes[1, 1].grid(True, alpha=0.3)
        
        # Add interpretation guide
        axes[1, 1].axvline(0.5, color='orange', linestyle='--', alpha=0.7, label='Moderate clustering')
        axes[1, 1].axvline(1.0, color='red', linestyle='--', alpha=0.7, label='High clustering')
        axes[1, 1].legend()
    
    # 6. Overlap analysis
    overlap_counts = {'With Overlaps': 0, 'No Overlaps': 0}
    for s in sample_statistics:
        if s['overlap_detected']:
            overlap_counts['With Overlaps'] += 1
        else:
            overlap_counts['No Overlaps'] += 1
    
    if sum(overlap_counts.values()) > 0:
        labels = list(overlap_counts.keys())
        sizes = list(overlap_counts.values())
        colors = ['lightcoral', 'lightblue']
        
        axes[1, 2].pie(sizes, labels=labels, autopct='%1.1f%%', colors=colors, startangle=90)
        axes[1, 2].set_title('Overlap Detection Results', fontweight='bold')
    
    plt.suptitle(f'📊 Sample Analysis Summary - {dataset_name}', fontsize=16, fontweight='bold')
    plt.tight_layout()
    plt.savefig(notebook_results_dir / 'samples' / f'sample_analysis_summary_{dataset_name.lower()}.png', 
                dpi=300, bbox_inches='tight')
    plt.show()

def analyze_wheat_head_sizes_comprehensive(dataset, dataset_name, max_samples=500):
    """Comprehensive analysis of wheat head sizes and morphological characteristics"""
    
    size_metrics = {
        'morphological_features': {
            'areas': [],
            'widths': [],
            'heights': [],
            'perimeters': [],
            'aspect_ratios': [],
            'elongation_indices': [],
            'compactness_scores': []
        },
        'size_categories': {
            'micro': 0,      # < 0.0002
            'tiny': 0,       # 0.0002 - 0.0005
            'small': 0,      # 0.0005 - 0.002
            'medium': 0,     # 0.002 - 0.008
            'large': 0,      # 0.008 - 0.02
            'extra_large': 0 # > 0.02
        },
        'shape_analysis': {
            'circular_heads': 0,
            'elongated_heads': 0,
            'irregular_heads': 0
        },
        'spatial_context': {
            'position_size_correlation': {'x_coords': [], 'y_coords': [], 'sizes': []},
            'edge_effects': [],
            'center_bias': []
        },
        'statistical_measures': {
            'size_distribution': [],
            'shape_distribution': [],
            'outlier_detection': []
        }
    }
    
    sample_size = min(len(dataset), max_samples)
    indices = np.random.choice(len(dataset), sample_size, replace=False)
    
    print(f"🔍 Comprehensive wheat head size analysis on {sample_size} images from {dataset_name}...")
    
    for i in tqdm(indices, desc="Analyzing wheat head morphology"):
        try:
            _, targets, _ = dataset[i]
            
            if targets.numel() == 0:
                continue
                
            for target in targets:
                if len(target) >= 5:
                    cls, x_center, y_center, width, height = target[:5]
                    
                    # Basic measurements
                    area = float(width * height)
                    perimeter = 2 * (width + height)  # Approximation for rectangle
                    
                    size_metrics['morphological_features']['areas'].append(area)
                    size_metrics['morphological_features']['widths'].append(float(width))
                    size_metrics['morphological_features']['heights'].append(float(height))
                    size_metrics['morphological_features']['perimeters'].append(perimeter)
                    
                    # Shape analysis
                    if height > 0 and width > 0:
                        aspect_ratio = width / height
                        size_metrics['morphological_features']['aspect_ratios'].append(aspect_ratio)
                        
                        # Elongation index (how far from square)
                        elongation = abs(aspect_ratio - 1.0)
                        size_metrics['morphological_features']['elongation_indices'].append(elongation)
                        
                        # Compactness score (area to perimeter ratio)
                        compactness = (4 * np.pi * area) / (perimeter ** 2) if perimeter > 0 else 0
                        size_metrics['morphological_features']['compactness_scores'].append(compactness)
                        
                        # Shape categorization
                        if 0.8 <= aspect_ratio <= 1.2:  # Nearly square
                            size_metrics['shape_analysis']['circular_heads'] += 1
                        elif aspect_ratio > 1.5 or aspect_ratio < 0.67:  # Significantly elongated
                            size_metrics['shape_analysis']['elongated_heads'] += 1
                        else:
                            size_metrics['shape_analysis']['irregular_heads'] += 1
                    
                    # Size categorization with refined thresholds
                    if area < 0.0002:
                        size_metrics['size_categories']['micro'] += 1
                    elif area < 0.0005:
                        size_metrics['size_categories']['tiny'] += 1
                    elif area < 0.002:
                        size_metrics['size_categories']['small'] += 1
                    elif area < 0.008:
                        size_metrics['size_categories']['medium'] += 1
                    elif area < 0.02:
                        size_metrics['size_categories']['large'] += 1
                    else:
                        size_metrics['size_categories']['extra_large'] += 1
                    
                    # Spatial context analysis
                    size_metrics['spatial_context']['position_size_correlation']['x_coords'].append(float(x_center))
                    size_metrics['spatial_context']['position_size_correlation']['y_coords'].append(float(y_center))
                    size_metrics['spatial_context']['position_size_correlation']['sizes'].append(area)
                    
                    # Edge effect analysis
                    edge_distance = min(x_center, y_center, 1-x_center, 1-y_center)
                    size_metrics['spatial_context']['edge_effects'].append(edge_distance)
                    
                    # Center bias analysis
                    center_distance = np.sqrt((x_center - 0.5)**2 + (y_center - 0.5)**2)
                    size_metrics['spatial_context']['center_bias'].append(center_distance)
                    
        except Exception as e:
            continue
    
    # Statistical analysis
    areas = size_metrics['morphological_features']['areas']
    if areas:
        # Outlier detection using IQR method
        q1, q3 = np.percentile(areas, [25, 75])
        iqr = q3 - q1
        lower_bound = q1 - 1.5 * iqr
        upper_bound = q3 + 1.5 * iqr
        
        outliers = [a for a in areas if a < lower_bound or a > upper_bound]
        size_metrics['statistical_measures']['outlier_detection'] = len(outliers)
        size_metrics['statistical_measures']['outlier_ratio'] = len(outliers) / len(areas)
    
    return size_metrics

# Visualize samples and perform comprehensive analysis
all_sample_statistics = {}
all_size_metrics = {}

for name, dataset in datasets.items():
    print(f"\n{'='*60}")
    print(f"🖼️ Advanced wheat sample visualization for {name}...")
    
    # Visualize samples with advanced annotations
    sample_stats = visualize_wheat_samples_advanced(dataset, name, num_samples=12)
    all_sample_statistics[name] = sample_stats
    
    # Comprehensive size analysis
    print(f"\n🔍 Comprehensive size analysis for {name}...")
    size_metrics = analyze_wheat_head_sizes_comprehensive(dataset, name, max_samples=400)
    all_size_metrics[name] = size_metrics
    
    # Display size analysis results
    morph_features = size_metrics['morphological_features']
    if morph_features['areas']:
        areas = np.array(morph_features['areas'])
        widths = np.array(morph_features['widths'])
        heights = np.array(morph_features['heights'])
        aspect_ratios = np.array(morph_features['aspect_ratios'])
        
        print(f"\n📏 {name} - MORPHOLOGICAL ANALYSIS:")
        print(f"   📊 Area Statistics:")
        print(f"     Mean: {np.mean(areas):.6f} ± {np.std(areas):.6f}")
        print(f"     Median: {np.median(areas):.6f}")
        print(f"     Range: {np.min(areas):.6f} - {np.max(areas):.6f}")
        
        print(f"   📐 Dimension Statistics:")
        print(f"     Width: {np.mean(widths):.4f} ± {np.std(widths):.4f}")
        print(f"     Height: {np.mean(heights):.4f} ± {np.std(heights):.4f}")
        print(f"     Aspect ratio: {np.mean(aspect_ratios):.3f} ± {np.std(aspect_ratios):.3f}")
        
        print(f"   🎯 Size Categories:")
        total_heads = sum(size_metrics['size_categories'].values())
        for category, count in size_metrics['size_categories'].items():
            percentage = (count / total_heads * 100) if total_heads > 0 else 0
            print(f"     {category.replace('_', ' ').title()}: {count} ({percentage:.1f}%)")
        
        print(f"   🔍 Shape Analysis:")
        total_shapes = sum(size_metrics['shape_analysis'].values())
        for shape, count in size_metrics['shape_analysis'].items():
            percentage = (count / total_shapes * 100) if total_shapes > 0 else 0
            print(f"     {shape.replace('_', ' ').title()}: {count} ({percentage:.1f}%)")
        
        # Statistical measures
        outlier_info = size_metrics['statistical_measures']
        if 'outlier_detection' in outlier_info:
            print(f"   📈 Statistical Analysis:")
            print(f"     Outliers detected: {outlier_info['outlier_detection']}")
            print(f"     Outlier ratio: {outlier_info.get('outlier_ratio', 0):.3f}")

# Save comprehensive sample and size analysis results
sample_analysis_summary = {}
for dataset_name, stats in all_sample_statistics.items():
    sample_analysis_summary[dataset_name] = {
        'total_samples_analyzed': len(stats),
        'density_statistics': {
            'mean_heads_per_sample': float(np.mean([s['num_heads'] for s in stats])),
            'density_categories': dict(Counter([s['density_category'] for s in stats])),
            'spatial_patterns': dict(Counter([s['spatial_distribution'] for s in stats if s['spatial_distribution']]))
        },
        'size_analysis': all_size_metrics.get(dataset_name, {}),
        'overlap_detection': {
            'samples_with_overlap': sum(1 for s in stats if s['overlap_detected']),
            'overlap_ratio': sum(1 for s in stats if s['overlap_detected']) / len(stats) if stats else 0
        }
    }

with open(notebook_results_dir / 'samples' / 'comprehensive_sample_analysis.json', 'w') as f:
    json.dump(sample_analysis_summary, f, indent=2, default=str)

print(f"\n💾 Comprehensive sample analysis saved!")
print(f"📁 Sample visualizations: {notebook_results_dir / 'samples'}")
print(f"📄 Analysis data: {notebook_results_dir / 'samples' / 'comprehensive_sample_analysis.json'}")

## 7. Wheat-Specific Challenge Assessment

In [None]:
"""
Enhanced wheat-specific challenge assessment with computer vision and ML techniques
"""

def assess_wheat_challenges_comprehensive(dataset, dataset_name, max_samples=300):
    """Comprehensive assessment of wheat detection challenges using advanced techniques"""
    
    challenge_metrics = {
        'occlusion_analysis': {
            'occlusion_levels': {'none': 0, 'light': 0, 'moderate': 0, 'heavy': 0, 'severe': 0},
            'occlusion_patterns': [],
            'visibility_scores': [],
            'partial_visibility': []
        },
        'scale_challenges': {
            'scale_variations': [],
            'multi_scale_complexity': [],
            'size_consistency': [],
            'perspective_effects': []
        },
        'density_challenges': {
            'crowding_indices': [],
            'object_separation': [],
            'detection_difficulty': {'trivial': 0, 'easy': 0, 'medium': 0, 'hard': 0, 'extreme': 0},
            'spatial_interference': []
        },
        'visual_challenges': {
            'contrast_issues': [],
            'illumination_problems': [],
            'background_interference': [],
            'texture_similarity': [],
            'color_discrimination': []
        },
        'geometric_challenges': {
            'shape_variations': [],
            'orientation_challenges': [],
            'deformation_analysis': [],
            'boundary_clarity': []
        },
        'field_specific_issues': {
            'growth_stage_variations': [],
            'weather_impact_scores': [],
            'soil_interference': [],
            'vegetation_confusion': [],
            'maturity_inconsistency': []
        },
        'detection_complexity': {
            'edge_cases': 0,
            'ambiguous_objects': 0,
            'false_positive_risks': [],
            'annotation_challenges': []
        }
    }
    
    sample_size = min(len(dataset), max_samples)
    indices = np.random.choice(len(dataset), sample_size, replace=False)
    
    print(f"⚠️ Comprehensive wheat challenge assessment on {sample_size} images from {dataset_name}...")
    
    processing_stats = {'success': 0, 'errors': 0, 'skipped': 0}
    
    for i in tqdm(indices, desc="Assessing detection challenges"):
        try:
            image, targets, path = dataset[i]
            
            # Convert and validate image
            if isinstance(image, torch.Tensor):
                img_np = image.permute(1, 2, 0).cpu().numpy()
                if img_np.min() < 0:
                    img_np = (img_np + 1) / 2
                elif img_np.max() <= 1.0:
                    pass
                else:
                    img_np = img_np / 255.0
                img_np = np.clip(img_np, 0, 1)
            else:
                img_np = image
                if img_np.max() > 1.0:
                    img_np = img_np / 255.0
            
            if img_np.shape[-1] != 3:
                processing_stats['skipped'] += 1
                continue
            
            # Convert to analysis formats
            img_rgb = (img_np * 255).astype(np.uint8)
            img_gray = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2GRAY)
            img_hsv = cv2.cvtColor(img_rgb, cv2.COLOR_RGB2HSV)
            
            processing_stats['success'] += 1
            
            # Analyze targets if available
            if targets.numel() == 0:
                challenge_metrics['detection_complexity']['edge_cases'] += 1
                continue
            
            num_heads = len(targets)
            target_data = []
            
            # Extract target information
            for target in targets:
                if len(target) >= 5:
                    cls, x_center, y_center, width, height = target[:5]
                    area = width * height
                    target_data.append({
                        'center': (float(x_center), float(y_center)),
                        'size': (float(width), float(height)),
                        'area': float(area)
                    })
            
            if not target_data:
                processing_stats['skipped'] += 1
                continue
            
            # 1. OCCLUSION ANALYSIS
            
            # Calculate object separation distances
            if len(target_data) > 1:
                separations = []
                overlaps = 0
                
                for i in range(len(target_data)):
                    for j in range(i+1, len(target_data)):
                        center1 = target_data[i]['center']
                        center2 = target_data[j]['center']
                        size1 = target_data[i]['size']
                        size2 = target_data[j]['size']
                        
                        # Calculate distance between centers
                        distance = np.sqrt((center1[0] - center2[0])**2 + (center1[1] - center2[1])**2)
                        separations.append(distance)
                        
                        # Check for potential overlap
                        avg_size = (size1[0] + size1[1] + size2[0] + size2[1]) / 4
                        if distance < avg_size * 0.8:  # Close proximity
                            overlaps += 1
                
                avg_separation = np.mean(separations) if separations else 1.0
                challenge_metrics['occlusion_analysis']['occlusion_patterns'].append(overlaps)
                
                # Categorize occlusion level
                if overlaps == 0:
                    challenge_metrics['occlusion_analysis']['occlusion_levels']['none'] += 1
                elif overlaps <= 2:
                    challenge_metrics['occlusion_analysis']['occlusion_levels']['light'] += 1
                elif overlaps <= 5:
                    challenge_metrics['occlusion_analysis']['occlusion_levels']['moderate'] += 1
                elif overlaps <= 10:
                    challenge_metrics['occlusion_analysis']['occlusion_levels']['heavy'] += 1
                else:
                    challenge_metrics['occlusion_analysis']['occlusion_levels']['severe'] += 1
                
                challenge_metrics['density_challenges']['object_separation'].append(avg_separation)
            
            # 2. SCALE CHALLENGES
            
            # Analyze size variations
            areas = [td['area'] for td in target_data]
            if areas:
                area_cv = np.std(areas) / np.mean(areas) if np.mean(areas) > 0 else 0
                challenge_metrics['scale_challenges']['scale_variations'].append(area_cv)
                
                # Multi-scale complexity (ratio of largest to smallest)
                if len(areas) > 1:
                    scale_ratio = max(areas) / min(areas) if min(areas) > 0 else 1
                    challenge_metrics['scale_challenges']['multi_scale_complexity'].append(scale_ratio)
                
                # Size consistency analysis
                median_area = np.median(areas)
                consistency_score = sum(1 for a in areas if 0.5 * median_area <= a <= 2 * median_area) / len(areas)
                challenge_metrics['scale_challenges']['size_consistency'].append(consistency_score)
            
            # 3. DENSITY CHALLENGES
            
            # Crowding index based on local density
            if len(target_data) >= 2:
                crowding_scores = []
                for target in target_data:
                    center = target['center']
                    neighbors_in_radius = 0
                    radius = 0.15  # Analysis radius
                    
                    for other_target in target_data:
                        if other_target != target:
                            other_center = other_target['center']
                            distance = np.sqrt((center[0] - other_center[0])**2 + 
                                             (center[1] - other_center[1])**2)
                            if distance <= radius:
                                neighbors_in_radius += 1
                    
                    crowding_scores.append(neighbors_in_radius)
                
                avg_crowding = np.mean(crowding_scores)
                challenge_metrics['density_challenges']['crowding_indices'].append(avg_crowding)
                
                # Detection difficulty assessment
                difficulty_score = 0
                
                # Add difficulty for high density
                if num_heads > 30:
                    difficulty_score += 4
                elif num_heads > 20:
                    difficulty_score += 3
                elif num_heads > 15:
                    difficulty_score += 2
                elif num_heads > 10:
                    difficulty_score += 1
                
                # Add difficulty for small objects
                small_objects = sum(1 for a in areas if a < 0.001)
                small_ratio = small_objects / len(areas) if areas else 0
                if small_ratio > 0.7:
                    difficulty_score += 3
                elif small_ratio > 0.5:
                    difficulty_score += 2
                elif small_ratio > 0.3:
                    difficulty_score += 1
                
                # Add difficulty for high crowding
                if avg_crowding > 5:
                    difficulty_score += 2
                elif avg_crowding > 3:
                    difficulty_score += 1
                
                # Categorize difficulty
                if difficulty_score >= 8:
                    challenge_metrics['density_challenges']['detection_difficulty']['extreme'] += 1
                elif difficulty_score >= 6:
                    challenge_metrics['density_challenges']['detection_difficulty']['hard'] += 1
                elif difficulty_score >= 4:
                    challenge_metrics['density_challenges']['detection_difficulty']['medium'] += 1
                elif difficulty_score >= 2:
                    challenge_metrics['density_challenges']['detection_difficulty']['easy'] += 1
                else:
                    challenge_metrics['density_challenges']['detection_difficulty']['trivial'] += 1
            
            # 4. VISUAL CHALLENGES
            
            # Contrast analysis
            local_contrasts = []
            for target in target_data:
                center = target['center']
                x_pixel = int(center[0] * img_gray.shape[1])
                y_pixel = int(center[1] * img_gray.shape[0])
                
                # Extract local region
                region_size = 32
                x1 = max(0, x_pixel - region_size // 2)
                x2 = min(img_gray.shape[1], x_pixel + region_size // 2)
                y1 = max(0, y_pixel - region_size // 2)
                y2 = min(img_gray.shape[0], y_pixel + region_size // 2)
                
                if x2 > x1 and y2 > y1:
                    local_region = img_gray[y1:y2, x1:x2]
                    local_contrast = np.std(local_region) / 255.0
                    local_contrasts.append(local_contrast)
            
            if local_contrasts:
                avg_local_contrast = np.mean(local_contrasts)
                challenge_metrics['visual_challenges']['contrast_issues'].append(avg_local_contrast)
            
            # Global illumination analysis
            brightness = np.mean(img_gray) / 255.0
            brightness_std = np.std(img_gray) / 255.0
            
            illumination_problem_score = 0
            if brightness < 0.3 or brightness > 0.8:  # Too dark or bright
                illumination_problem_score += 1
            if brightness_std > 0.25:  # High variation (shadows/glare)
                illumination_problem_score += 1
            
            challenge_metrics['visual_challenges']['illumination_problems'].append(illumination_problem_score)
            
            # Background interference using edge density
            edges = cv2.Canny(img_gray, 50, 150)
            edge_density = np.sum(edges > 0) / (edges.shape[0] * edges.shape[1])
            challenge_metrics['visual_challenges']['background_interference'].append(edge_density)
            
            # Texture similarity analysis using LBP-like measure
            def calculate_texture_uniformity(image):
                """Calculate texture uniformity in the image"""
                h, w = image.shape
                texture_vars = []
                patch_size = 16
                
                for i in range(0, h-patch_size, patch_size):
                    for j in range(0, w-patch_size, patch_size):
                        patch = image[i:i+patch_size, j:j+patch_size]
                        texture_vars.append(np.var(patch))
                
                return np.mean(texture_vars) if texture_vars else 0
            
            texture_uniformity = calculate_texture_uniformity(img_gray)
            challenge_metrics['visual_challenges']['texture_similarity'].append(texture_uniformity)
            
            # 5. GEOMETRIC CHALLENGES
            
            # Shape variation analysis
            if len(target_data) > 1:
                aspect_ratios = []
                for target in target_data:
                    width, height = target['size']
                    if height > 0:
                        aspect_ratios.append(width / height)
                
                if aspect_ratios:
                    shape_variation = np.std(aspect_ratios)
                    challenge_metrics['geometric_challenges']['shape_variations'].append(shape_variation)
            
            # Boundary clarity assessment using gradient analysis
            boundary_scores = []
            for target in target_data:
                center = target['center']
                size = target['size']
                
                # Calculate bounding box in pixels
                x_pixel = int(center[0] * img_gray.shape[1])
                y_pixel = int(center[1] * img_gray.shape[0])
                w_pixel = int(size[0] * img_gray.shape[1])
                h_pixel = int(size[1] * img_gray.shape[0])
                
                # Extract region around object
                x1 = max(0, x_pixel - w_pixel // 2)
                x2 = min(img_gray.shape[1], x_pixel + w_pixel // 2)
                y1 = max(0, y_pixel - h_pixel // 2)
                y2 = min(img_gray.shape[0], y_pixel + h_pixel // 2)
                
                if x2 > x1 and y2 > y1:
                    region = img_gray[y1:y2, x1:x2]
                    
                    # Calculate gradient magnitude
                    grad_x = cv2.Sobel(region, cv2.CV_64F, 1, 0, ksize=3)
                    grad_y = cv2.Sobel(region, cv2.CV_64F, 0, 1, ksize=3)
                    gradient_mag = np.sqrt(grad_x**2 + grad_y**2)
                    
                    boundary_clarity = np.mean(gradient_mag)
                    boundary_scores.append(boundary_clarity)
            
            if boundary_scores:
                avg_boundary_clarity = np.mean(boundary_scores)
                challenge_metrics['geometric_challenges']['boundary_clarity'].append(avg_boundary_clarity)
            
            # 6. FIELD-SPECIFIC ISSUES
            
            # Growth stage analysis using color properties
            hue_channel = img_hsv[:, :, 0]
            sat_channel = img_hsv[:, :, 1]
            val_channel = img_hsv[:, :, 2]
            
            # Analyze dominant colors to infer growth stage
            green_mask = (hue_channel >= 35) & (hue_channel <= 85) & (sat_channel > 50)
            yellow_mask = (hue_channel >= 15) & (hue_channel <= 35) & (sat_channel > 50)
            brown_mask = (hue_channel >= 5) & (hue_channel <= 25) & (sat_channel > 30)
            
            green_ratio = np.sum(green_mask) / green_mask.size
            yellow_ratio = np.sum(yellow_mask) / yellow_mask.size
            brown_ratio = np.sum(brown_mask) / brown_mask.size
            
            # Growth stage inference
            if green_ratio > 0.4:
                growth_stage = 'early'
            elif yellow_ratio > 0.3:
                growth_stage = 'mature'
            elif brown_ratio > 0.2:
                growth_stage = 'late'
            else:
                growth_stage = 'mixed'
            
            challenge_metrics['field_specific_issues']['growth_stage_variations'].append(growth_stage)
            
            # Weather impact assessment
            weather_impact = 0
            if brightness < 0.25:  # Very dark (cloudy/stormy)
                weather_impact += 2
            elif brightness > 0.85:  # Very bright (harsh sun)
                weather_impact += 2
            if brightness_std > 0.3:  # High variation (mixed lighting)
                weather_impact += 1
            
            challenge_metrics['field_specific_issues']['weather_impact_scores'].append(weather_impact)
            
            # Soil interference analysis
            # Look for brown/tan pixels that might interfere with detection
            soil_mask = ((hue_channel >= 10) & (hue_channel <= 30) & 
                        (sat_channel < 100) & (val_channel > 30))
            soil_ratio = np.sum(soil_mask) / soil_mask.size
            challenge_metrics['field_specific_issues']['soil_interference'].append(soil_ratio)
            
            # 7. DETECTION COMPLEXITY ASSESSMENT
            
            # Edge case detection
            edge_distance_threshold = 0.1
            edge_objects = 0
            for target in target_data:
                center = target['center']
                if (center[0] <= edge_distance_threshold or center[0] >= 1-edge_distance_threshold or
                    center[1] <= edge_distance_threshold or center[1] >= 1-edge_distance_threshold):
                    edge_objects += 1
            
            if edge_objects > 0:
                challenge_metrics['detection_complexity']['edge_cases'] += 1
            
            # Ambiguous object detection based on size and isolation
            ambiguous_count = 0
            for target in target_data:
                area = target['area']
                center = target['center']
                
                # Check if object is unusually small or large
                if area < 0.0002 or area > 0.02:
                    ambiguous_count += 1
                    continue
                
                # Check isolation (might be noise if very isolated)
                neighbors = 0
                for other_target in target_data:
                    if other_target != target:
                        other_center = other_target['center']
                        distance = np.sqrt((center[0] - other_center[0])**2 + 
                                         (center[1] - other_center[1])**2)
                        if distance <= 0.2:  # Within reasonable distance
                            neighbors += 1
                
                if neighbors == 0 and len(target_data) > 3:  # Isolated in dense scene
                    ambiguous_count += 1
            
            if ambiguous_count > 0:
                challenge_metrics['detection_complexity']['ambiguous_objects'] += ambiguous_count
            
            # False positive risk assessment
            fp_risk_score = 0
            
            # High background complexity increases FP risk
            if edge_density > 0.15:
                fp_risk_score += 1
            
            # Low contrast increases FP risk
            if avg_local_contrast < 0.1:
                fp_risk_score += 1
            
            # High soil visibility increases FP risk
            if soil_ratio > 0.3:
                fp_risk_score += 1
            
            challenge_metrics['detection_complexity']['false_positive_risks'].append(fp_risk_score)
            
        except Exception as e:
            processing_stats['errors'] += 1
            continue
    
    print(f"   ✅ Processing stats: {processing_stats['success']} success, "
          f"{processing_stats['errors']} errors, {processing_stats['skipped']} skipped")
    
    return challenge_metrics

# Comprehensive challenge assessment
challenge_results = {}

for name, dataset in datasets.items():
    print(f"\n{'='*70}")
    challenge_metrics = assess_wheat_challenges_comprehensive(dataset, name, max_samples=250)
    challenge_results[name] = challenge_metrics
    
    # Display comprehensive challenge analysis
    print(f"\n⚠️ {name} - COMPREHENSIVE CHALLENGE ASSESSMENT:")
    
    # Occlusion analysis
    occlusion_data = challenge_metrics['occlusion_analysis']
    print(f"   🔍 Occlusion Analysis:")
    total_occlusion = sum(occlusion_data['occlusion_levels'].values())
    for level, count in occlusion_data['occlusion_levels'].items():
        percentage = (count / total_occlusion * 100) if total_occlusion > 0 else 0
        print(f"     {level.title()}: {count} ({percentage:.1f}%)")
    
    if occlusion_data['occlusion_patterns']:
        avg_overlaps = np.mean(occlusion_data['occlusion_patterns'])
        print(f"     Average overlaps per image: {avg_overlaps:.2f}")
    
    # Scale challenges
    scale_data = challenge_metrics['scale_challenges']
    if scale_data['scale_variations']:
        print(f"   📏 Scale Challenges:")
        print(f"     Scale variation (CV): {np.mean(scale_data['scale_variations']):.3f}")
        if scale_data['multi_scale_complexity']:
            print(f"     Multi-scale complexity: {np.mean(scale_data['multi_scale_complexity']):.2f}")
        if scale_data['size_consistency']:
            print(f"     Size consistency: {np.mean(scale_data['size_consistency']):.3f}")
    
    # Density challenges
    density_data = challenge_metrics['density_challenges']
    print(f"   🌾 Density Challenges:")
    total_difficulty = sum(density_data['detection_difficulty'].values())
    for difficulty, count in density_data['detection_difficulty'].items():
        percentage = (count / total_difficulty * 100) if total_difficulty > 0 else 0
        print(f"     {difficulty.title()}: {count} ({percentage:.1f}%)")
    
    if density_data['crowding_indices']:
        avg_crowding = np.mean(density_data['crowding_indices'])
        print(f"     Average crowding index: {avg_crowding:.2f}")
    
    # Visual challenges
    visual_data = challenge_metrics['visual_challenges']
    if visual_data['contrast_issues']:
        print(f"   👁️ Visual Challenges:")
        print(f"     Average local contrast: {np.mean(visual_data['contrast_issues']):.3f}")
        print(f"     Illumination problems: {np.mean(visual_data['illumination_problems']):.2f}")
        print(f"     Background interference: {np.mean(visual_data['background_interference']):.4f}")
    
    # Field-specific issues
    field_data = challenge_metrics['field_specific_issues']
    if field_data['growth_stage_variations']:
        print(f"   🌱 Field-Specific Issues:")
        growth_stages = Counter(field_data['growth_stage_variations'])
        for stage, count in growth_stages.items():
            print(f"     {stage.title()} growth: {count}")
        
        if field_data['weather_impact_scores']:
            avg_weather_impact = np.mean(field_data['weather_impact_scores'])
            print(f"     Average weather impact: {avg_weather_impact:.2f}")
        
        if field_data['soil_interference']:
            avg_soil_interference = np.mean(field_data['soil_interference'])
            print(f"     Average soil interference: {avg_soil_interference:.3f}")
    
    # Detection complexity
    complexity_data = challenge_metrics['detection_complexity']
    print(f"   🎯 Detection Complexity:")
    print(f"     Edge cases: {complexity_data['edge_cases']}")
    print(f"     Ambiguous objects: {complexity_data['ambiguous_objects']}")
    
    if complexity_data['false_positive_risks']:
        avg_fp_risk = np.mean(complexity_data['false_positive_risks'])
        print(f"     Average FP risk score: {avg_fp_risk:.2f}")

# Create comprehensive challenge assessment visualization
fig = plt.figure(figsize=(24, 20))
gs = fig.add_gridspec(5, 4, hspace=0.4, wspace=0.3)

dataset_colors = plt.cm.Set1(np.linspace(0, 1, len(challenge_results)))

# 1. Occlusion Levels Distribution (top row, first column)
ax1 = fig.add_subplot(gs[0, 0])
occlusion_summary = defaultdict(int)
for dataset_name, metrics in challenge_results.items():
    occlusion_data = metrics['occlusion_analysis']['occlusion_levels']
    for level, count in occlusion_data.items():
        occlusion_summary[level] += count

if occlusion_summary:
    levels = list(occlusion_summary.keys())
    counts = list(occlusion_summary.values())
    colors = plt.cm.Reds(np.linspace(0.3, 1, len(levels)))
    
    bars = ax1.bar(levels, counts, color=colors, alpha=0.8)
    ax1.set_title('Occlusion Levels Distribution', fontweight='bold')
    ax1.set_xlabel('Occlusion Level')
    ax1.set_ylabel('Number of Images')
    ax1.tick_params(axis='x', rotation=45)
    
    # Add value labels
    for bar, count in zip(bars, counts):
        height = bar.get_height()
        ax1.text(bar.get_x() + bar.get_width()/2., height + 0.01*max(counts),
                f'{count}', ha='center', va='bottom', fontweight='bold')

# 2. Detection Difficulty Distribution (top row, second column)
ax2 = fig.add_subplot(gs[0, 1])
difficulty_summary = defaultdict(int)
for dataset_name, metrics in challenge_results.items():
    difficulty_data = metrics['density_challenges']['detection_difficulty']
    for level, count in difficulty_data.items():
        difficulty_summary[level] += count

if difficulty_summary:
    difficulties = list(difficulty_summary.keys())
    counts = list(difficulty_summary.values())
    colors = plt.cm.YlOrRd(np.linspace(0.3, 1, len(difficulties)))
    
    bars = ax2.bar(range(len(difficulties)), counts, color=colors, alpha=0.8)
    ax2.set_title('Detection Difficulty Distribution', fontweight='bold')
    ax2.set_xlabel('Difficulty Level')
    ax2.set_ylabel('Number of Images')
    ax2.set_xticks(range(len(difficulties)))
    ax2.set_xticklabels([d.replace('_', '\n') for d in difficulties])
    
    for bar, count in zip(bars, counts):
        height = bar.get_height()
        ax2.text(bar.get_x() + bar.get_width()/2., height + 0.01*max(counts),
                f'{count}', ha='center', va='bottom', fontweight='bold')

# 3. Scale Challenges Analysis (top row, third column)
ax3 = fig.add_subplot(gs[0, 2])
for idx, (dataset_name, metrics) in enumerate(challenge_results.items()):
    scale_variations = metrics['scale_challenges']['scale_variations']
    size_consistency = metrics['scale_challenges']['size_consistency']
    
    if scale_variations and size_consistency:
        ax3.scatter(scale_variations, size_consistency, alpha=0.6, 
                   label=dataset_name, s=50, color=dataset_colors[idx])

ax3.set_title('Scale Variation vs Size Consistency', fontweight='bold')
ax3.set_xlabel('Scale Variation (CV)')
ax3.set_ylabel('Size Consistency Score')
ax3.legend()
ax3.grid(True, alpha=0.3)

# 4. Visual Challenges Radar Chart (top row, fourth column)
ax4 = fig.add_subplot(gs[0, 3], projection='polar')
visual_metrics = ['Contrast', 'Illumination', 'Background', 'Texture']
angles = np.linspace(0, 2*np.pi, len(visual_metrics), endpoint=False).tolist()
angles += angles[:1]

for idx, (dataset_name, metrics) in enumerate(challenge_results.items()):
    visual_data = metrics['visual_challenges']
    
    # Normalize metrics to 0-1 scale
    contrast_norm = 1 - min(np.mean(visual_data['contrast_issues']) if visual_data['contrast_issues'] else 0, 1)
    illum_norm = min(np.mean(visual_data['illumination_problems']) if visual_data['illumination_problems'] else 0, 1) / 3
    bg_norm = min(np.mean(visual_data['background_interference']) if visual_data['background_interference'] else 0, 1)
    texture_norm = min(np.mean(visual_data['texture_similarity']) if visual_data['texture_similarity'] else 0, 1) / 1000
    
    values = [contrast_norm, illum_norm, bg_norm, texture_norm]
    values += values[:1]
    
    ax4.plot(angles, values, 'o-', linewidth=2, label=dataset_name, 
            color=dataset_colors[idx])
    ax4.fill(angles, values, alpha=0.25, color=dataset_colors[idx])

ax4.set_xticks(angles[:-1])
ax4.set_xticklabels(visual_metrics)
ax4.set_ylim(0, 1)
ax4.set_title('Visual Challenge Profile', fontweight='bold', pad=20)
ax4.legend(bbox_to_anchor=(1.3, 1), loc='upper left')

# 5. Field-Specific Issues (second row, first two columns)
ax5 = fig.add_subplot(gs[1, :2])
field_data = []
field_labels = []
for dataset_name, metrics in challenge_results.items():
    weather_scores = metrics['field_specific_issues']['weather_impact_scores']
    soil_interference = metrics['field_specific_issues']['soil_interference']
    
    if weather_scores:
        field_data.extend(weather_scores)
        field_labels.extend([f'{dataset_name}_weather'] * len(weather_scores))
    
    if soil_interference:
        # Normalize soil interference to same scale as weather
        normalized_soil = [s * 3 for s in soil_interference]  # Scale up for visibility
        field_data.extend(normalized_soil)
        field_labels.extend([f'{dataset_name}_soil'] * len(normalized_soil))

if field_data:
    df_field = pd.DataFrame({'Score': field_data, 'Issue_Type': field_labels})
    df_field['Dataset'] = df_field['Issue_Type'].str.split('_').str[0]
    df_field['Issue'] = df_field['Issue_Type'].str.split('_').str[1]
    
    sns.boxplot(data=df_field, x='Dataset', y='Score', hue='Issue', ax=ax5)
    ax5.set_title('Field-Specific Issues Analysis', fontweight='bold')
    ax5.set_ylabel('Impact Score')
    ax5.tick_params(axis='x', rotation=45)

# 6. Growth Stage Distribution (second row, third column)
ax6 = fig.add_subplot(gs[1, 2])
growth_stages_all = []
for dataset_name, metrics in challenge_results.items():
    growth_stages = metrics['field_specific_issues']['growth_stage_variations']
    growth_stages_all.extend(growth_stages)

if growth_stages_all:
    stage_counts = Counter(growth_stages_all)
    stages = list(stage_counts.keys())
    counts = list(stage_counts.values())
    colors = plt.cm.Greens(np.linspace(0.4, 1, len(stages)))
    
    wedges, texts, autotexts = ax6.pie(counts, labels=stages, autopct='%1.1f%%', 
                                      colors=colors, startangle=90)
    ax6.set_title('Growth Stage Distribution', fontweight='bold')

# 7. Complexity Factors (second row, fourth column)
ax7 = fig.add_subplot(gs[1, 3])
complexity_factors = ['Edge Cases', 'Ambiguous Objects', 'Avg FP Risk']
complexity_data = {factor: [] for factor in complexity_factors}

for dataset_name, metrics in challenge_results.items():
    complexity = metrics['detection_complexity']
    complexity_data['Edge Cases'].append(complexity['edge_cases'])
    complexity_data['Ambiguous Objects'].append(complexity['ambiguous_objects'])
    
    fp_risks = complexity['false_positive_risks']
    avg_fp_risk = np.mean(fp_risks) if fp_risks else 0
    complexity_data['Avg FP Risk'].append(avg_fp_risk * 10)  # Scale for visibility

# Create grouped bar chart
x = np.arange(len(challenge_results))
width = 0.25

for i, (factor, data) in enumerate(complexity_data.items()):
    ax7.bar(x + i*width, data, width, label=factor, alpha=0.8)

ax7.set_title('Detection Complexity Factors', fontweight='bold')
ax7.set_xlabel('Dataset')
ax7.set_ylabel('Count/Score')
ax7.set_xticks(x + width)
ax7.set_xticklabels([name.split('_')[0] for name in challenge_results.keys()], rotation=45)
ax7.legend()

# 8-12. Detailed Analysis Plots (remaining rows)
# Crowding vs Difficulty Relationship
ax8 = fig.add_subplot(gs[2, 0])
for idx, (dataset_name, metrics) in enumerate(challenge_results.items()):
    crowding_indices = metrics['density_challenges']['crowding_indices']
    difficulty_data = metrics['density_challenges']['detection_difficulty']
    
    if crowding_indices:
        # Calculate average difficulty score
        difficulty_weights = {'trivial': 1, 'easy': 2, 'medium': 3, 'hard': 4, 'extreme': 5}
        total_images = sum(difficulty_data.values())
        if total_images > 0:
            avg_difficulty = sum(difficulty_weights[level] * count for level, count in difficulty_data.items()) / total_images
            avg_crowding = np.mean(crowding_indices)
            
            ax8.scatter(avg_crowding, avg_difficulty, s=100, alpha=0.7, 
                       label=dataset_name, color=dataset_colors[idx])

ax8.set_title('Crowding vs Detection Difficulty', fontweight='bold')
ax8.set_xlabel('Average Crowding Index')
ax8.set_ylabel('Average Difficulty Score')
ax8.legend()
ax8.grid(True, alpha=0.3)

# Boundary Clarity Analysis
ax9 = fig.add_subplot(gs[2, 1])
boundary_data = []
boundary_labels = []
for dataset_name, metrics in challenge_results.items():
    boundary_clarity = metrics['geometric_challenges']['boundary_clarity']
    if boundary_clarity:
        boundary_data.extend(boundary_clarity)
        boundary_labels.extend([dataset_name] * len(boundary_clarity))

if boundary_data:
    df_boundary = pd.DataFrame({'Boundary_Clarity': boundary_data, 'Dataset': boundary_labels})
    sns.violinplot(data=df_boundary, x='Dataset', y='Boundary_Clarity', ax=ax9)
    ax9.set_title('Boundary Clarity Distribution', fontweight='bold')
    ax9.set_ylabel('Gradient Magnitude')
    ax9.tick_params(axis='x', rotation=45)

# Object Separation Analysis
ax10 = fig.add_subplot(gs[2, 2])
for idx, (dataset_name, metrics) in enumerate(challenge_results.items()):
    separations = metrics['density_challenges']['object_separation']
    if separations:
        ax10.hist(separations, bins=20, alpha=0.6, label=dataset_name,
                 color=dataset_colors[idx], density=True)

ax10.set_title('Object Separation Distribution', fontweight='bold')
ax10.set_xlabel('Average Separation Distance')
ax10.set_ylabel('Density')
ax10.legend()
ax10.grid(True, alpha=0.3)

# Challenge Correlation Matrix
ax11 = fig.add_subplot(gs[2, 3])
# Create correlation matrix for different challenge metrics
challenge_matrix_data = []
for dataset_name, metrics in challenge_results.items():
    row_data = []
    
    # Occlusion severity (weighted average)
    occlusion_weights = {'none': 0, 'light': 1, 'moderate': 2, 'heavy': 3, 'severe': 4}
    occlusion_data = metrics['occlusion_analysis']['occlusion_levels']
    total_occ = sum(occlusion_data.values())
    if total_occ > 0:
        avg_occlusion = sum(occlusion_weights[level] * count for level, count in occlusion_data.items()) / total_occ
    else:
        avg_occlusion = 0
    row_data.append(avg_occlusion)
    
    # Scale complexity
    scale_vars = metrics['scale_challenges']['scale_variations']
    avg_scale_var = np.mean(scale_vars) if scale_vars else 0
    row_data.append(avg_scale_var)
    
    # Visual challenge score
    visual = metrics['visual_challenges']
    visual_score = (
        np.mean(visual['illumination_problems']) if visual['illumination_problems'] else 0
    ) + (
        (1 - np.mean(visual['contrast_issues'])) if visual['contrast_issues'] else 0
    )
    row_data.append(visual_score)
    
    # Density challenge
    crowding = metrics['density_challenges']['crowding_indices']
    avg_crowding = np.mean(crowding) if crowding else 0
    row_data.append(avg_crowding)
    
    challenge_matrix_data.append(row_data)

if challenge_matrix_data:
    challenge_matrix = np.array(challenge_matrix_data)
    correlation_matrix = np.corrcoef(challenge_matrix.T)
    
    labels = ['Occlusion', 'Scale Var', 'Visual', 'Density']
    im = ax11.imshow(correlation_matrix, cmap='coolwarm', vmin=-1, vmax=1)
    ax11.set_title('Challenge Correlation Matrix', fontweight='bold')
    ax11.set_xticks(range(len(labels)))
    ax11.set_yticks(range(len(labels)))
    ax11.set_xticklabels(labels, rotation=45)
    ax11.set_yticklabels(labels)
    
    # Add correlation values
    for i in range(len(labels)):
        for j in range(len(labels)):
            text = ax11.text(j, i, f'{correlation_matrix[i, j]:.2f}',
                           ha="center", va="center", color="black")
    
    plt.colorbar(im, ax=ax11)

# 13. Comprehensive Statistics Summary (bottom rows)
ax12 = fig.add_subplot(gs[3:, :])
ax12.axis('off')

summary_text = "⚠️ COMPREHENSIVE WHEAT DETECTION CHALLENGE ASSESSMENT\n\n"

for dataset_name, metrics in challenge_results.items():
    summary_text += f"🌾 {dataset_name}:\n"
    
    # Occlusion summary
    occlusion_data = metrics['occlusion_analysis']['occlusion_levels']
    total_occ = sum(occlusion_data.values())
    if total_occ > 0:
        severe_ratio = (occlusion_data.get('heavy', 0) + occlusion_data.get('severe', 0)) / total_occ
        summary_text += f"  🔍 Occlusion: {severe_ratio*100:.1f}% severe cases\n"
    
    # Difficulty summary
    difficulty_data = metrics['density_challenges']['detection_difficulty']
    total_diff = sum(difficulty_data.values())
    if total_diff > 0:
        hard_ratio = (difficulty_data.get('hard', 0) + difficulty_data.get('extreme', 0)) / total_diff
        summary_text += f"  🎯 Difficulty: {hard_ratio*100:.1f}% hard/extreme cases\n"
    
    # Scale challenges
    scale_data = metrics['scale_challenges']
    if scale_data['scale_variations']:
        avg_scale_var = np.mean(scale_data['scale_variations'])
        summary_text += f"  📏 Scale variation: {avg_scale_var:.3f}\n"
    
    # Visual challenges
    visual_data = metrics['visual_challenges']
    if visual_data['contrast_issues']:
        avg_contrast = np.mean(visual_data['contrast_issues'])
        summary_text += f"  👁️ Average contrast: {avg_contrast:.3f}\n"
    
    # Complexity factors
    complexity_data = metrics['detection_complexity']
    summary_text += f"  🎲 Edge cases: {complexity_data['edge_cases']}\n"
    summary_text += f"  ❓ Ambiguous objects: {complexity_data['ambiguous_objects']}\n"
    
    # Field-specific issues
    field_data = metrics['field_specific_issues']
    if field_data['weather_impact_scores']:
        avg_weather = np.mean(field_data['weather_impact_scores'])
        summary_text += f"  🌤️ Weather impact: {avg_weather:.2f}\n"
    
    summary_text += "\n"

ax12.text(0.02, 0.98, summary_text, transform=ax12.transAxes, fontsize=9,
         verticalalignment='top', fontfamily='monospace',
         bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))

plt.suptitle('⚠️ COMPREHENSIVE WHEAT DETECTION CHALLENGE ASSESSMENT', 
             fontsize=16, fontweight='bold')
plt.savefig(notebook_results_dir / 'visualizations' / 'comprehensive_challenge_assessment.png', 
            dpi=300, bbox_inches='tight')
plt.show()

# Save comprehensive challenge assessment results
challenge_summary = {}
for dataset_name, metrics in challenge_results.items():
    challenge_summary[dataset_name] = {
        'occlusion_analysis': {
            'levels_distribution': metrics['occlusion_analysis']['occlusion_levels'],
            'average_overlaps': float(np.mean(metrics['occlusion_analysis']['occlusion_patterns'])) if metrics['occlusion_analysis']['occlusion_patterns'] else 0,
            'severe_cases_ratio': float((metrics['occlusion_analysis']['occlusion_levels'].get('heavy', 0) + 
                                       metrics['occlusion_analysis']['occlusion_levels'].get('severe', 0)) / 
                                      max(sum(metrics['occlusion_analysis']['occlusion_levels'].values()), 1))
        },
        'scale_challenges': {
            'scale_variation': float(np.mean(metrics['scale_challenges']['scale_variations'])) if metrics['scale_challenges']['scale_variations'] else 0,
            'multi_scale_complexity': float(np.mean(metrics['scale_challenges']['multi_scale_complexity'])) if metrics['scale_challenges']['multi_scale_complexity'] else 0,
            'size_consistency': float(np.mean(metrics['scale_challenges']['size_consistency'])) if metrics['scale_challenges']['size_consistency'] else 0
        },
        'density_challenges': {
            'difficulty_distribution': metrics['density_challenges']['detection_difficulty'],
            'average_crowding': float(np.mean(metrics['density_challenges']['crowding_indices'])) if metrics['density_challenges']['crowding_indices'] else 0,
            'hard_cases_ratio': float((metrics['density_challenges']['detection_difficulty'].get('hard', 0) + 
                                     metrics['density_challenges']['detection_difficulty'].get('extreme', 0)) / 
                                    max(sum(metrics['density_challenges']['detection_difficulty'].values()), 1))
        },
        'visual_challenges': {
            'contrast_issues': float(np.mean(metrics['visual_challenges']['contrast_issues'])) if metrics['visual_challenges']['contrast_issues'] else 0,
            'illumination_problems': float(np.mean(metrics['visual_challenges']['illumination_problems'])) if metrics['visual_challenges']['illumination_problems'] else 0,
            'background_interference': float(np.mean(metrics['visual_challenges']['background_interference'])) if metrics['visual_challenges']['background_interference'] else 0,
            'texture_similarity': float(np.mean(metrics['visual_challenges']['texture_similarity'])) if metrics['visual_challenges']['texture_similarity'] else 0
        },
        'geometric_challenges': {
            'shape_variations': float(np.mean(metrics['geometric_challenges']['shape_variations'])) if metrics['geometric_challenges']['shape_variations'] else 0,
            'boundary_clarity': float(np.mean(metrics['geometric_challenges']['boundary_clarity'])) if metrics['geometric_challenges']['boundary_clarity'] else 0
        },
        'field_specific_issues': {
            'growth_stage_distribution': dict(Counter(metrics['field_specific_issues']['growth_stage_variations'])),
            'weather_impact': float(np.mean(metrics['field_specific_issues']['weather_impact_scores'])) if metrics['field_specific_issues']['weather_impact_scores'] else 0,
            'soil_interference': float(np.mean(metrics['field_specific_issues']['soil_interference'])) if metrics['field_specific_issues']['soil_interference'] else 0
        },
        'detection_complexity': {
            'edge_cases': metrics['detection_complexity']['edge_cases'],
            'ambiguous_objects': metrics['detection_complexity']['ambiguous_objects'],
            'false_positive_risk': float(np.mean(metrics['detection_complexity']['false_positive_risks'])) if metrics['detection_complexity']['false_positive_risks'] else 0
        }
    }

with open(notebook_results_dir / 'data_analysis' / 'comprehensive_challenge_assessment.json', 'w') as f:
    json.dump(challenge_summary, f, indent=2)

print(f"\n💾 Comprehensive challenge assessment saved!")
print(f"📁 Location: {notebook_results_dir / 'data_analysis' / 'comprehensive_challenge_assessment.json'}")

## 🌾 Global Wheat Head Detection - Comprehensive Analysis Summary
## Enhanced CBAM-STN-TPS-YOLO Dataset Exploration

---

## Executive Overview

This comprehensive analysis presents the complete evaluation of the Global Wheat Head Detection dataset, focusing on dense object detection challenges and optimization strategies for the CBAM-STN-TPS-YOLO architecture. The analysis encompasses distribution patterns, clustering characteristics, field conditions, and detection challenges specific to agricultural wheat head identification.

### Analysis Metadata
- **Analysis Version**: 2.0_enhanced
- **Dataset Focus**: Global Wheat Head Detection - Comprehensive Analysis
- **Domain Characteristics**: Agricultural Computer Vision - Dense Object Detection
- **Analysis Scope**: Multi-dimensional wheat detection challenge assessment
- **Detection Paradigm**: Dense small object detection in agricultural environments

---

## Dataset Overview and Characteristics

### Core Statistics
- **Total Images Processed**: Comprehensive wheat field image collection
- **Domain Specialization**: Wheat head detection in natural field environments
- **Detection Paradigm**: Dense small object detection with high overlap scenarios
- **Complexity Level**: High-density small objects with significant clustering

### Primary Detection Challenges

#### Density and Clustering Challenges
- **High-Density Clustering**: Dense wheat head arrangements requiring sophisticated detection
- **Small Object Detection**: Wheat heads typically occupy 0.001-0.008 normalized image area
- **Overlapping Wheat Heads**: Significant occlusion patterns in mature wheat fields
- **Multi-Scale Detection**: Varying wheat head sizes within single field images

#### Environmental and Field Conditions
- **Field Condition Variations**: Diverse agricultural environments and growing conditions
- **Illumination Challenges**: Variable lighting from shadows to overexposure
- **Background Interference**: Complex soil, vegetation, and field texture variations
- **Growth Stage Adaptations**: Different wheat maturity levels affecting appearance

#### Technical Detection Complexities
- **Occlusion Management**: Heavy, partial, and minimal occlusion scenarios
- **Scale Variation Handling**: Multiple wheat head sizes requiring multi-scale processing
- **Background Interference Filtering**: Distinguishing wheat heads from complex backgrounds
- **Environmental Robustness**: Handling weather, soil, and seasonal variations

---

## Comprehensive Analysis Results

### Wheat Head Distribution Analysis

#### Density Characteristics
- **Average Wheat Heads per Image**: 15-30+ objects per field image
- **Density Categories**: Low, medium, high, and extreme density classifications
- **Spatial Distribution**: Non-uniform clustering patterns across field images
- **Density Variation**: High coefficient of variation in wheat head counts

#### Clustering Patterns
- **Clustering Strength**: Quantified spatial clustering using DBSCAN analysis
- **Uniformity Measures**: Assessment of wheat head spatial distribution consistency
- **Overlap Analysis**: IoU-based overlap detection and severity quantification
- **Spatial Entropy**: Measurement of wheat head placement randomness

### Field Condition Assessment

#### Illumination Profile Analysis
- **Brightness Stability**: Variation in lighting conditions across images
- **Lighting Categories**: Classification of easy, moderate, and challenging lighting
- **Shadow Effects**: Impact of shadows on wheat head visibility and detection
- **Contrast Adequacy**: Assessment of image contrast for effective detection

#### Environmental Factor Analysis
- **Vegetation Density**: Ratio of wheat vegetation to background elements
- **Soil Visibility**: Interference from exposed soil and non-wheat vegetation
- **Background Complexity**: Quantified complexity of field backgrounds
- **Weather Impact**: Effects of environmental conditions on image quality

### Detection Challenge Assessment

#### Difficulty Classification
- **Easy Cases**: Clear, well-separated wheat heads with good contrast
- **Moderate Cases**: Some overlapping or challenging lighting conditions
- **Hard Cases**: Significant occlusion or poor environmental conditions
- **Extreme Cases**: Maximum density with severe overlapping and poor conditions

#### Specific Challenge Quantification
- **Occlusion Severity**: Average number of overlapping wheat heads per image
- **Scale Complexity**: Variation in wheat head sizes requiring multi-scale detection
- **Visual Challenges**: Illumination, contrast, and visibility issues
- **Boundary Detection**: Edge effects and partial wheat heads at image borders

---

## Advanced Insights and Pattern Discovery

### Key Statistical Insights

#### Density Statistics
- **Mean Density**: Average wheat heads per image across all datasets
- **Density Standard Deviation**: Variability in wheat head counts
- **Density Range**: Minimum to maximum wheat heads observed
- **High-Density Prevalence**: Percentage of images with challenging densities

#### Difficulty Distribution Analysis
- **Mean Difficulty**: Average challenge level across all images
- **Hard Case Prevalence**: Percentage of images classified as difficult
- **Occlusion Consistency**: Reliability of overlap patterns across datasets
- **Challenge Correlation**: Relationship between different difficulty factors

### Pattern Discoveries

#### Spatial Organization Patterns
- Wheat heads exhibit strong clustering patterns in dense field conditions
- Spatial distribution follows agricultural row planting patterns
- Cluster sizes correlate with field maturity and growing conditions
- Edge effects create detection challenges at image boundaries

#### Morphological Patterns
- Scale variations correlate with field perspective and camera distance
- Growth stage variations create morphological detection challenges
- Aspect ratio consistency across different wheat varieties
- Shape deformation patterns under wind and environmental stress

#### Environmental Impact Patterns
- Illumination challenges significantly impact detection difficulty
- Background interference increases with soil visibility
- Seasonal variations affect wheat head appearance and detectability
- Weather conditions create systematic detection challenges

### Optimization Opportunities

#### Architecture-Specific Optimizations
- Multi-scale feature extraction critical for varying wheat head sizes
- Attention mechanisms essential for dense clustering scenarios
- Spatial transformation networks beneficial for perspective variations
- Advanced NMS required for overlapping object management
- Data augmentation crucial for illumination robustness

#### Training Strategy Optimizations
- Progressive training from simple to complex field scenarios
- Density-aware loss functions for imbalanced object distributions
- Multi-scale training for handling size variations
- Extensive augmentation for environmental robustness
- Specialized evaluation metrics for dense object scenarios

---

## CBAM-STN-TPS-YOLO Architecture Alignment

### Component Justification and Benefits

#### CBAM (Convolutional Block Attention Module)
**Rationale**: High background interference and dense clustering require attention mechanisms

**Specific Benefits**:
- Channel attention for wheat-soil discrimination
- Spatial attention for crowded scene focus
- Illumination invariance through adaptive attention
- Multi-level attention for different growth stages

**Configuration Recommendations**:
- Reduction ratio 16 for optimal wheat feature extraction
- Spatial kernel size 7 for wheat head receptive field
- Multi-level attention integration across network layers

#### STN (Spatial Transformer Network)
**Rationale**: Field perspective variations and camera angle diversity require spatial transformation

**Specific Benefits**:
- Perspective normalization for consistent detection
- Rotation handling for varying field orientations
- Scale compensation for distance variations
- Viewpoint invariance for different camera positions

**Configuration Recommendations**:
- Affine transformation for perspective correction
- Localization network with field-specific features
- Progressive transformation for training stability

#### TPS (Thin Plate Spline)
**Rationale**: Irregular wheat head shapes and wind deformation require non-rigid transformation

**Specific Benefits**:
- Non-rigid deformation handling for wind effects
- Irregular boundary adaptation for growth variations
- Shape normalization for consistent feature extraction
- Flexible transformation for natural object variations

**Configuration Recommendations**:
- 20-24 control points for wheat head complexity
- Regularization lambda 0.01-0.1 for shape preservation
- Uniform grid initialization for field structure alignment

#### YOLO Optimization
**Rationale**: Dense object detection with real-time requirements for agricultural applications

**Specific Benefits**:
- Efficient dense detection for high wheat head counts
- Multi-scale processing for size variations
- End-to-end optimization for field deployment
- Real-time inference for agricultural monitoring

**Configuration Recommendations**:
- Wheat-specific anchor sizes: [8,8], [16,16], [32,32], [48,48], [64,64]
- Dense prediction layers for high-density scenarios
- NMS threshold 0.3-0.4 for overlap management
- Maximum detections set to 100 for extreme density cases

### Synergy Optimization Strategies

#### CBAM-STN Integration
- CBAM attention guides STN localization network focus
- STN normalized features enhance CBAM effectiveness
- Joint training for optimal transformation learning
- Attention-guided geometric transformation

#### STN-TPS Coordination
- STN global transformation followed by TPS local refinement
- Hierarchical transformation from coarse to fine adjustment
- Shared feature extraction for computational efficiency
- Progressive deformation handling

#### Attention-Guided Detection
- CBAM features inform YOLO detection heads
- Attention maps guide anchor placement optimization
- Multi-level attention for multi-scale detection
- Feature enhancement for dense object scenarios

---

## Training Strategy Recommendations

### Data Preparation and Preprocessing

#### Preprocessing Pipeline
- Normalize to [0,1] range with ImageNet statistics
- Resize with aspect ratio preservation for field structure
- Multi-scale training with progressive sizing (416→640→832)
- Quality enhancement for low-contrast field images

#### Annotation Optimization
- Verify dense annotation consistency across all images
- Handle overlapping bounding boxes with IoU analysis
- Quality control for small object annotations
- Validate wheat head boundary accuracy

#### Data Splitting Strategy
- Stratified split by density categories for balanced training
- Temporal split for growth stage diversity representation
- Geographic split for field condition variety
- Cross-validation for robust performance assessment

### Augmentation Strategy

#### Geometric Augmentations
- Random rotation (±15°) for field orientation diversity
- Random perspective transform for camera angle variation
- Random scale (0.8-1.2) for distance simulation
- Horizontal/vertical flipping for spatial variation

#### Photometric Augmentations
- Color jittering for illumination robustness
- Random brightness/contrast for lighting condition simulation
- HSV augmentation for growth stage appearance variation
- Gaussian noise addition for sensor variation simulation

#### Wheat-Specific Augmentations
- Mosaic augmentation for density increase simulation
- CutMix for realistic occlusion pattern creation
- Random erasing for missing wheat head simulation
- Copy-paste augmentation for rare density scenarios

#### Advanced Field-Aware Augmentations
- Field-structure preserving augmentation maintaining row patterns
- Density-preserving augmentation maintaining wheat head counts
- Weather simulation through atmospheric effect modeling
- Growth stage interpolation for temporal consistency

### Training Schedule and Methodology

#### Progressive Training Phases
1. **Phase 1**: Basic detection on simple, clear images (20 epochs)
2. **Phase 2**: Multi-scale training introduction with moderate complexity (30 epochs)
3. **Phase 3**: Full augmentation pipeline with challenging scenarios (50 epochs)

#### Learning Rate Schedule
- **Warm-up**: Linear increase for 5 epochs to stable learning rate
- **Main Training**: Cosine annealing with restarts for optimization
- **Fine-tuning**: Reduced learning rate for final convergence

#### Component Training Strategy
- Pre-train backbone on ImageNet for feature extraction foundation
- Freeze-unfreeze strategy for CBAM integration stability
- Joint fine-tuning for STN-TPS coordination optimization
- Progressive component activation for training stability

---

## Comprehensive Evaluation Framework

### Multi-Dimensional Metric Suite

#### Core Detection Metrics
- **mAP@0.5**: General detection performance assessment
- **mAP@0.75**: Precise localization capability evaluation
- **mAP@[0.5:0.95]**: Comprehensive performance across IoU thresholds
- **Precision/Recall Curves**: Threshold sensitivity analysis

#### Density-Specific Evaluation
- **Small Object AP**: Performance on area < 32² wheat heads
- **Medium Object AP**: Performance on 32² ≤ area < 96² wheat heads
- **Large Object AP**: Performance on area ≥ 96² wheat heads
- **Crowded Scene AP**: Performance on images with >20 objects

#### Robustness Assessment Metrics
- **Illumination Invariance Score**: Performance across lighting conditions
- **Scale Robustness Measure**: Consistency across size variations
- **Occlusion Handling Capability**: Performance with overlapping objects
- **Background Interference Resistance**: Robustness to complex backgrounds

#### Efficiency and Deployment Metrics
- **Inference Time**: Milliseconds per image processing
- **Frames Per Second**: Real-time application suitability
- **Memory Usage**: GPU and CPU resource requirements
- **Model Complexity**: Parameters and computational requirements

### Benchmark Protocol

#### Comprehensive Test Scenarios
- **Standard Test Set**: Baseline performance evaluation
- **Cross-Field Generalization**: Performance on unseen field conditions
- **Growth Stage Robustness**: Consistency across wheat maturity levels
- **Weather Condition Stress Test**: Performance under challenging conditions

#### Detailed Ablation Studies
- **Component-wise Ablation**: Individual CBAM, STN, TPS contribution analysis
- **Loss Function Ablation**: Impact of different loss formulations
- **Augmentation Strategy Ablation**: Effectiveness of augmentation techniques
- **Architecture Variant Comparison**: Alternative design choices evaluation

#### Competitive Baseline Comparisons
- Standard YOLO variants (YOLOv5, YOLOv8)
- RetinaNet with Feature Pyramid Networks
- Faster R-CNN with ResNet backbone
- EfficientDet architecture variants
- Specialized agricultural detection systems

---

## Deployment Considerations

### Hardware Requirements and Optimization

#### Computational Specifications
- **GPU Requirements**: NVIDIA RTX 3060 or equivalent for real-time inference
- **CPU Specifications**: 8-core processor for preprocessing pipeline support
- **Memory Requirements**: 16GB RAM for batch processing capabilities
- **Storage Solutions**: SSD for fast dataset access and model loading

#### Mobile and Edge Deployment
- **Edge Device Compatibility**: NVIDIA Jetson series optimization
- **Model Quantization**: INT8 quantization for mobile inference acceleration
- **Optimized Inference Engines**: TensorRT and ONNX runtime integration
- **Power Consumption Optimization**: Efficient inference for field deployment

#### Field Equipment Integration
- **Ruggedized Camera Systems**: Weather-resistant imaging equipment
- **Robust Computing Units**: Agricultural environment computing solutions
- **Reliable Connectivity**: Field-to-cloud data transmission systems
- **Power Management**: Sustainable power solutions for remote deployment

### Software Optimization Strategies

#### Model Compression Techniques
- **Pruning**: Parameter reduction while maintaining performance
- **Quantization**: Inference speedup through reduced precision
- **Knowledge Distillation**: Efficiency improvement through teacher-student training
- **Dynamic Inference**: Adaptive computation based on scene complexity

#### Runtime Performance Optimization
- **Batch Processing**: Throughput optimization for multiple images
- **Pipeline Parallelization**: Concurrent preprocessing and inference
- **Memory Management**: Efficient GPU memory utilization
- **Kernel Optimization**: Custom CUDA kernels for specific operations

---

## Implementation Roadmap

### Development Phases

#### Phase 1: Architecture Implementation
- Implement CBAM-STN-TPS-YOLO architecture with wheat-specific parameters
- Configure component integration and synergy optimization
- Establish baseline performance with standard training
- Validate architecture stability and convergence

#### Phase 2: Advanced Training Pipeline
- Design comprehensive augmentation pipeline for field condition robustness
- Implement progressive training strategy with density-aware scheduling
- Develop wheat-specific loss functions and evaluation metrics
- Optimize hyperparameters for agricultural detection scenarios

#### Phase 3: Evaluation and Validation
- Implement comprehensive evaluation framework with wheat-specific metrics
- Conduct extensive ablation studies for component validation
- Perform cross-field generalization testing
- Evaluate robustness under diverse environmental conditions

#### Phase 4: Optimization and Deployment
- Optimize model for deployment with hardware-specific considerations
- Implement mobile and edge device optimizations
- Develop field deployment protocols and monitoring systems
- Create agricultural integration interfaces and APIs

#### Phase 5: Field Testing and Refinement
- Conduct real-world field testing across diverse agricultural environments
- Collect performance feedback and identify improvement opportunities
- Iterative refinement based on agricultural user requirements
- Scale deployment across different wheat growing regions

#### Phase 6: Production and Maintenance
- Production system deployment with monitoring and maintenance protocols
- Continuous learning integration for model improvement
- Agricultural stakeholder training and support systems
- Long-term performance monitoring and optimization

---

## Key Findings and Recommendations

### Critical Success Factors

#### Technical Requirements
- Multi-scale feature extraction essential for wheat head size variations
- Attention mechanisms critical for dense clustering scenarios
- Spatial transformations necessary for field perspective variations
- Advanced post-processing required for overlapping object management

#### Training Optimization
- Progressive training strategy from simple to complex scenarios
- Extensive augmentation pipeline for environmental robustness
- Density-aware loss functions for imbalanced object distributions
- Multi-scale training approach for consistent performance

#### Deployment Considerations
- Edge device optimization crucial for field applications
- Real-time inference requirements for agricultural monitoring
- Robust performance under challenging environmental conditions
- Integration with existing agricultural workflow systems

### Risk Mitigation Strategies

#### Technical Risks
- **False Positive Control**: Robust background discrimination training
- **False Negative Minimization**: Comprehensive occlusion handling strategies
- **Scale Sensitivity Management**: Multi-scale training and evaluation
- **Environmental Robustness**: Extensive augmentation and field testing

#### Deployment Risks
- **Hardware Compatibility**: Thorough testing across target devices
- **Performance Consistency**: Robust evaluation under diverse conditions
- **Integration Challenges**: Careful API design and documentation
- **Maintenance Requirements**: Automated monitoring and update systems

---

## Conclusion

The comprehensive analysis of the Global Wheat Head Detection dataset reveals significant challenges and opportunities for CBAM-STN-TPS-YOLO architecture optimization. Key achievements include:

### Analysis Completeness
- **Multi-dimensional Assessment**: Distribution, clustering, field conditions, and challenge analysis
- **Pattern Discovery**: Identification of wheat-specific detection patterns and challenges
- **Architecture Alignment**: Detailed justification for CBAM-STN-TPS-YOLO components
- **Implementation Guidance**: Comprehensive training and deployment strategies

### Technical Insights
- Dense object detection challenges requiring specialized attention mechanisms
- Multi-scale variations necessitating adaptive feature extraction
- Environmental robustness critical for agricultural deployment success
- Real-time performance requirements achievable with proper optimization

### Implementation Readiness
The analysis provides a solid foundation for advancing to CBAM-STN-TPS-YOLO implementation with comprehensive understanding of dataset characteristics, training strategies, and deployment considerations tailored specifically for wheat head detection challenges.

---

*Analysis completed with enhanced pipeline providing comprehensive multi-dimensional wheat detection assessment*