In [1]:
# Optuna Hyperparameter Optimization for Drone Localization RL Agent
import optuna
import torch
import numpy as np
import json
import time
from datetime import datetime
from pathlib import Path
import logging
from typing import Dict, Any

# Import your existing classes (assuming they're in the same file or imported)
# from drone_localization_agent import DroneLocalizationEnvironment, PPODroneAgent, DroneLocalizationNetwork

class OptimizedDroneLocalizationNetwork(torch.nn.Module):
    """Optimizable version of the drone localization network"""
    
    def __init__(self, grid_size: int, feature_dim: int, num_heads: int, dropout_rate: float):
        super().__init__()
        self.grid_size = grid_size
        self.num_locations = grid_size * grid_size
        
        # CNN backbone for image processing
        self.tif_encoder = self._create_cnn_encoder(feature_dim, dropout_rate)
        self.crop_encoder = self._create_cnn_encoder(feature_dim, dropout_rate)
        
        # Cross-attention between TIF and crop
        self.cross_attention = torch.nn.MultiheadAttention(feature_dim, num_heads=num_heads, batch_first=True)
        
        # Spatial reasoning
        self.spatial_reasoning = torch.nn.Sequential(
            torch.nn.Linear(feature_dim * 2, feature_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout_rate),
            torch.nn.Linear(feature_dim, feature_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout_rate)
        )
        
        # Output heads
        self.location_head = torch.nn.Linear(feature_dim, self.num_locations)
        self.value_head = torch.nn.Linear(feature_dim, 1)
        
    def _create_cnn_encoder(self, feature_dim: int, dropout_rate: float):
        """Create optimizable CNN encoder"""
        import torchvision.models as models
        resnet = models.resnet18(weights='IMAGENET1K_V1')
        
        encoder = torch.nn.Sequential(
            *list(resnet.children())[:-2],
            torch.nn.AdaptiveAvgPool2d((8, 8)),
            torch.nn.Flatten(),
            torch.nn.Linear(512 * 64, feature_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout_rate)
        )
        
        return encoder
    
    def forward(self, tif_image: torch.Tensor, crop_image: torch.Tensor):
        # Encode images
        tif_features = self.tif_encoder(tif_image)
        crop_features = self.crop_encoder(crop_image)
        
        # Cross-attention
        tif_attended, _ = self.cross_attention(
            crop_features.unsqueeze(1),
            tif_features.unsqueeze(1),
            tif_features.unsqueeze(1)
        )
        
        # Combine features
        combined_features = torch.cat([
            tif_attended.squeeze(1), 
            crop_features
        ], dim=1)
        
        # Spatial reasoning
        spatial_features = self.spatial_reasoning(combined_features)
        
        # Output predictions
        location_logits = self.location_head(spatial_features)
        location_probs = torch.nn.functional.softmax(location_logits, dim=1)
        value = self.value_head(spatial_features)
        
        return location_probs, value, location_logits

class OptimizedDroneLocalizationTrainer:
    """Optimizable trainer with configurable hyperparameters"""
    
    def __init__(self, trial_params: Dict[str, Any]):
        self.params = trial_params
        
        # Initialize environment
        self.env = DroneLocalizationEnvironment(
            tif_image_path=None,
            crops_metadata_path=None, 
            grid_size=trial_params['grid_size']
        )
        
        # Initialize optimized agent
        self.agent = self._create_optimized_agent(trial_params)
        
        # Training metrics
        self.episode_rewards = []
        self.similarity_scores = []
        self.convergence_episodes = []
        
    def _create_optimized_agent(self, params: Dict[str, Any]):
        """Create agent with optimized hyperparameters"""
        
        # Create optimized network
        network = OptimizedDroneLocalizationNetwork(
            grid_size=params['grid_size'],
            feature_dim=params['feature_dim'],
            num_heads=params['num_heads'],
            dropout_rate=params['dropout_rate']
        ).to('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Create optimizer
        optimizer = torch.optim.Adam(network.parameters(), lr=params['learning_rate'])
        
        # Create agent-like object
        class OptimizedAgent:
            def __init__(self, network, optimizer, params):
                self.network = network
                self.optimizer = optimizer
                self.grid_size = params['grid_size']
                self.device = network.location_head.weight.device
                
                # PPO hyperparameters
                self.gamma = params['gamma']
                self.eps_clip = params['eps_clip']
                self.k_epochs = params['k_epochs']
                self.entropy_coef = params['entropy_coef']
                self.value_coef = params['value_coef']
                
                self.memory = []
            
            def select_top3_actions(self, tif_image, crop_image):
                # Preprocess images
                tif_tensor = self._preprocess_image(tif_image).to(self.device)
                crop_tensor = self._preprocess_image(crop_image).to(self.device)
                
                with torch.no_grad():
                    location_probs, value, logits = self.network(tif_tensor, crop_tensor)
                    
                    # Get top 3 predictions
                    top_probs, top_indices = torch.topk(location_probs.squeeze(), k=3)
                    
                    # Convert to grid coordinates
                    locations = []
                    probabilities = []
                    
                    for i in range(3):
                        idx = top_indices[i].item()
                        prob = top_probs[i].item()
                        
                        grid_y = idx // self.grid_size
                        grid_x = idx % self.grid_size
                        
                        locations.append((grid_x, grid_y))
                        probabilities.append(prob)
                
                return locations, probabilities, logits.squeeze()
            
            def _preprocess_image(self, image):
                import cv2
                resized = cv2.resize(image, (224, 224))
                tensor = torch.from_numpy(resized).float() / 255.0
                tensor = tensor.permute(2, 0, 1).unsqueeze(0)
                return tensor
            
            def store_experience(self, state, action_logits, reward, value):
                self.memory.append({
                    'state': state,
                    'action_logits': action_logits.detach().cpu(),
                    'reward': reward,
                    'value': value.detach().cpu()
                })
            
            def update_policy(self):
                if len(self.memory) < 16:  # Minimum batch size
                    return
                
                # Simplified PPO update for optimization
                returns = []
                advantages = []
                gae = 0
                
                for i in reversed(range(len(self.memory))):
                    if i == len(self.memory) - 1:
                        next_value = 0
                    else:
                        next_value = self.memory[i + 1]['value']
                    
                    reward = self.memory[i]['reward']
                    value = self.memory[i]['value']
                    
                    returns.insert(0, reward + self.gamma * next_value)
                    
                    delta = reward + self.gamma * next_value - value
                    gae = delta + self.gamma * 0.95 * gae
                    advantages.insert(0, gae)
                
                # Convert to tensors
                returns = torch.tensor(returns, dtype=torch.float32).to(self.device)
                advantages = torch.tensor(advantages, dtype=torch.float32).to(self.device)
                
                # Normalize advantages
                advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
                
                # PPO update
                for _ in range(self.k_epochs):
                    total_loss = 0
                    
                    for i, experience in enumerate(self.memory):
                        state = experience['state']
                        old_logits = experience['action_logits'].to(self.device)
                        
                        # Forward pass
                        tif_tensor = self._preprocess_image(state['tif_image']).to(self.device)
                        crop_tensor = self._preprocess_image(state['crop_image']).to(self.device)
                        
                        location_probs, value, new_logits = self.network(tif_tensor, crop_tensor)
                        
                        # Policy loss
                        old_probs = torch.nn.functional.softmax(old_logits, dim=0)
                        new_probs = torch.nn.functional.softmax(new_logits.squeeze(), dim=0)
                        
                        ratio = (new_probs + 1e-8) / (old_probs + 1e-8)
                        
                        surr1 = ratio * advantages[i]
                        surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * advantages[i]
                        policy_loss = -torch.min(surr1, surr2).mean()
                        
                        # Value loss
                        value_loss = torch.nn.functional.mse_loss(value.squeeze(), returns[i])
                        
                        # Entropy loss
                        entropy = -torch.sum(new_probs * torch.log(new_probs + 1e-8))
                        
                        total_loss += (policy_loss + 
                                     self.value_coef * value_loss - 
                                     self.entropy_coef * entropy)
                    
                    # Update
                    self.optimizer.zero_grad()
                    total_loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)
                    self.optimizer.step()
                
                # Clear memory
                self.memory = []
        
        return OptimizedAgent(network, optimizer, params)
    
    def train_and_evaluate(self, num_episodes: int = 300, update_frequency: int = 16) -> float:
        """Train and return objective score for Optuna"""
        
        print(f"   Training with params: {self.params}")
        
        episode_rewards = []
        similarity_scores = []
        
        for episode in range(num_episodes):
            # Reset environment
            tif_image, crop_image, state = self.env.reset()
            
            # Agent prediction
            locations, probabilities, action_logits = self.agent.select_top3_actions(tif_image, crop_image)
            
            # Calculate reward
            reward, similarities = self.env.calculate_reward(locations, probabilities)
            
            # Get value estimate
            with torch.no_grad():
                tif_tensor = self.agent._preprocess_image(tif_image).to(self.agent.device)
                crop_tensor = self.agent._preprocess_image(crop_image).to(self.agent.device)
                _, value, _ = self.agent.network(tif_tensor, crop_tensor)
            
            # Store experience
            self.agent.store_experience(state, action_logits, reward, value)
            
            # Update policy
            if (episode + 1) % update_frequency == 0:
                self.agent.update_policy()
            
            # Track metrics
            episode_rewards.append(reward)
            similarity_scores.append(max(similarities))
            
            # Early stopping if not improving
            if episode > 100 and episode % 50 == 0:
                recent_avg = np.mean(episode_rewards[-50:])
                older_avg = np.mean(episode_rewards[-100:-50])
                
                if recent_avg <= older_avg * 1.01:  # Less than 1% improvement
                    print(f"   Early stopping at episode {episode} (no improvement)")
                    break
        
        # Calculate final objective
        final_episodes = min(100, len(episode_rewards))
        final_reward = np.mean(episode_rewards[-final_episodes:])
        final_similarity = np.mean(similarity_scores[-final_episodes:])
        max_confidence = max([max(self.agent.select_top3_actions(
            *self.env.reset()[:2])[1]) for _ in range(10)])
        
        # Combined objective (reward + similarity + confidence)
        objective = 0.4 * final_reward + 0.4 * final_similarity + 0.2 * max_confidence
        
        print(f"   Final objective: {objective:.4f} (reward: {final_reward:.3f}, "
              f"similarity: {final_similarity:.3f}, confidence: {max_confidence:.3f})")
        
        return objective

def objective(trial):
    """Optuna objective function"""
    
    # Sample hyperparameters
    params = {
        # Grid and architecture
        'grid_size': trial.suggest_int('grid_size', 20, 35),
        'feature_dim': trial.suggest_categorical('feature_dim', [256, 384, 512, 768]),
        'num_heads': trial.suggest_categorical('num_heads', [4, 6, 8, 12]),
        'dropout_rate': trial.suggest_float('dropout_rate', 0.1, 0.5),
        
        # Learning parameters
        'learning_rate': trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True),
        'gamma': trial.suggest_float('gamma', 0.95, 0.999),
        
        # PPO parameters
        'eps_clip': trial.suggest_float('eps_clip', 0.1, 0.3),
        'k_epochs': trial.suggest_int('k_epochs', 2, 8),
        'entropy_coef': trial.suggest_float('entropy_coef', 0.001, 0.1, log=True),
        'value_coef': trial.suggest_float('value_coef', 0.1, 1.0),
    }
    
    try:
        # Create trainer with sampled parameters
        trainer = OptimizedDroneLocalizationTrainer(params)
        
        # Train and evaluate
        objective_score = trainer.train_and_evaluate(num_episodes=300)
        
        return objective_score
        
    except Exception as e:
        print(f"   Trial failed with error: {e}")
        return 0.0  # Return poor score for failed trials

def run_optimization(n_trials: int = 50, timeout_hours: int = 4):
    """Run Optuna optimization"""
    
    # Setup logging
    optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler())
    optuna.logging.set_verbosity(logging.INFO)
    
    # Create study
    study = optuna.create_study(
        direction='maximize',
        sampler=optuna.samplers.TPESampler(seed=42),
        pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=50)
    )
    
    print(f"🔍 Starting Optuna Optimization")
    print(f"   Trials: {n_trials}")
    print(f"   Timeout: {timeout_hours} hours")
    print(f"   Start time: {datetime.now()}")
    
    # Run optimization
    start_time = time.time()
    study.optimize(
        objective, 
        n_trials=n_trials,
        timeout=timeout_hours * 3600,
        show_progress_bar=True
    )
    
    end_time = time.time()
    
    # Results
    print(f"\n🎉 Optimization Complete!")
    print(f"   Duration: {(end_time - start_time) / 3600:.2f} hours")
    print(f"   Trials completed: {len(study.trials)}")
    print(f"   Best value: {study.best_value:.4f}")
    
    print(f"\n🏆 Best Parameters:")
    for key, value in study.best_params.items():
        print(f"   {key}: {value}")
    
    # Save results
    results = {
        'best_params': study.best_params,
        'best_value': study.best_value,
        'n_trials': len(study.trials),
        'duration_hours': (end_time - start_time) / 3600,
        'timestamp': datetime.now().isoformat()
    }
    
    results_file = f"optuna_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"   💾 Results saved to: {results_file}")
    
    # Plot optimization history (if possible)
    try:
        import matplotlib.pyplot as plt
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Optimization history
        optuna.visualization.matplotlib.plot_optimization_history(study, ax=ax1)
        ax1.set_title('Optimization History')
        
        # Parameter importances
        optuna.visualization.matplotlib.plot_param_importances(study, ax=ax2)
        ax2.set_title('Parameter Importances')
        
        plt.tight_layout()
        plt.savefig(f"optuna_plots_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png", dpi=300)
        plt.show()
        
    except ImportError:
        print("   📊 Install matplotlib and plotly for visualization")
    
    return study

def create_optimized_trainer(best_params: Dict[str, Any]):
    """Create trainer with best parameters found by Optuna"""
    
    print(f"🚀 Creating optimized trainer with best parameters:")
    for key, value in best_params.items():
        print(f"   {key}: {value}")
    
    return OptimizedDroneLocalizationTrainer(best_params)

def quick_optimization_overnight():
    """Quick function to run overnight optimization"""
    
    print("🌙 Starting Overnight Optimization...")
    print("   This will run for 8 hours or 50 trials (whichever comes first)")
    print("   Go to sleep! Results will be ready in the morning 😴")
    
    study = run_optimization(n_trials=50, timeout_hours=8)
    
    print(f"\n☀️ Good morning! Optimization complete.")
    print(f"   Use these parameters for your next training:")
    print(f"   {study.best_params}")
    
    return study

if __name__ == "__main__":
    print("🔍 OPTUNA HYPERPARAMETER OPTIMIZATION")
    print("="*50)
    print()
    print("🎯 WHAT IT OPTIMIZES:")
    print("   • Grid size (20-35)")
    print("   • Network architecture (feature_dim, num_heads, dropout)")
    print("   • Learning rate (1e-5 to 1e-2)")
    print("   • PPO hyperparameters (eps_clip, k_epochs, entropy)")
    print("   • Reward function weights")
    print()
    print("⏰ OVERNIGHT MODE:")
    print("   # Perfect for while you sleep!")
    print("   study = quick_optimization_overnight()")
    print()
    print("🔬 MANUAL MODE:")
    print("   # Custom trials and timeout")
    print("   study = run_optimization(n_trials=100, timeout_hours=12)")
    print()
    print("🚀 DEPLOY BEST:")
    print("   # Use best parameters found")
    print("   trainer = create_optimized_trainer(study.best_params)")
    print("   trainer.train_and_evaluate(num_episodes=1000)")
    print()
    print("😴 Sweet dreams! Wake up to optimized hyperparameters!")

  from .autonotebook import tqdm as notebook_tqdm


🔍 OPTUNA HYPERPARAMETER OPTIMIZATION

🎯 WHAT IT OPTIMIZES:
   • Grid size (20-35)
   • Network architecture (feature_dim, num_heads, dropout)
   • Learning rate (1e-5 to 1e-2)
   • PPO hyperparameters (eps_clip, k_epochs, entropy)
   • Reward function weights

⏰ OVERNIGHT MODE:
   # Perfect for while you sleep!
   study = quick_optimization_overnight()

🔬 MANUAL MODE:
   # Custom trials and timeout
   study = run_optimization(n_trials=100, timeout_hours=12)

🚀 DEPLOY BEST:
   # Use best parameters found
   trainer = create_optimized_trainer(study.best_params)
   trainer.train_and_evaluate(num_episodes=1000)

😴 Sweet dreams! Wake up to optimized hyperparameters!


# memory safe optimization

In [2]:
# Optuna Hyperparameter Optimization for Drone Localization RL Agent
import optuna
import torch
import numpy as np
import json
import time
from datetime import datetime
from pathlib import Path
import logging
from typing import Dict, Any

# Import your existing classes (assuming they're in the same file or imported)
# from drone_localization_agent import DroneLocalizationEnvironment, PPODroneAgent, DroneLocalizationNetwork

class OptimizedDroneLocalizationNetwork(torch.nn.Module):
    """Optimizable version of the drone localization network"""
    
    def __init__(self, grid_size: int, feature_dim: int, num_heads: int, dropout_rate: float):
        super().__init__()
        self.grid_size = grid_size
        self.num_locations = grid_size * grid_size
        
        # CNN backbone for image processing
        self.tif_encoder = self._create_cnn_encoder(feature_dim, dropout_rate)
        self.crop_encoder = self._create_cnn_encoder(feature_dim, dropout_rate)
        
        # Cross-attention between TIF and crop
        self.cross_attention = torch.nn.MultiheadAttention(feature_dim, num_heads=num_heads, batch_first=True)
        
        # Spatial reasoning
        self.spatial_reasoning = torch.nn.Sequential(
            torch.nn.Linear(feature_dim * 2, feature_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout_rate),
            torch.nn.Linear(feature_dim, feature_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout_rate)
        )
        
        # Output heads
        self.location_head = torch.nn.Linear(feature_dim, self.num_locations)
        self.value_head = torch.nn.Linear(feature_dim, 1)
        
    def _create_cnn_encoder(self, feature_dim: int, dropout_rate: float):
        """Create optimizable CNN encoder"""
        import torchvision.models as models
        resnet = models.resnet18(weights='IMAGENET1K_V1')
        
        encoder = torch.nn.Sequential(
            *list(resnet.children())[:-2],
            torch.nn.AdaptiveAvgPool2d((8, 8)),
            torch.nn.Flatten(),
            torch.nn.Linear(512 * 64, feature_dim),
            torch.nn.ReLU(),
            torch.nn.Dropout(dropout_rate)
        )
        
        return encoder
    
    def forward(self, tif_image: torch.Tensor, crop_image: torch.Tensor):
        # Encode images
        tif_features = self.tif_encoder(tif_image)
        crop_features = self.crop_encoder(crop_image)
        
        # Cross-attention
        tif_attended, _ = self.cross_attention(
            crop_features.unsqueeze(1),
            tif_features.unsqueeze(1),
            tif_features.unsqueeze(1)
        )
        
        # Combine features
        combined_features = torch.cat([
            tif_attended.squeeze(1), 
            crop_features
        ], dim=1)
        
        # Spatial reasoning
        spatial_features = self.spatial_reasoning(combined_features)
        
        # Output predictions
        location_logits = self.location_head(spatial_features)
        location_probs = torch.nn.functional.softmax(location_logits, dim=1)
        value = self.value_head(spatial_features)
        
        return location_probs, value, location_logits

class OptimizedDroneLocalizationTrainer:
    """Optimizable trainer with configurable hyperparameters"""
    
    def __init__(self, trial_params: Dict[str, Any]):
        self.params = trial_params
        
        # Initialize environment
        self.env = DroneLocalizationEnvironment(
            tif_image_path=None,
            crops_metadata_path=None, 
            grid_size=trial_params['grid_size']
        )
        
        # Initialize optimized agent
        self.agent = self._create_optimized_agent(trial_params)
        
        # Training metrics
        self.episode_rewards = []
        self.similarity_scores = []
        self.convergence_episodes = []
        
    def _create_optimized_agent(self, params: Dict[str, Any]):
        """Create agent with optimized hyperparameters"""
        
        # Create optimized network
        network = OptimizedDroneLocalizationNetwork(
            grid_size=params['grid_size'],
            feature_dim=params['feature_dim'],
            num_heads=params['num_heads'],
            dropout_rate=params['dropout_rate']
        ).to('cuda' if torch.cuda.is_available() else 'cpu')
        
        # Create optimizer
        optimizer = torch.optim.Adam(network.parameters(), lr=params['learning_rate'])
        
        # Create agent-like object
        class OptimizedAgent:
            def __init__(self, network, optimizer, params):
                self.network = network
                self.optimizer = optimizer
                self.grid_size = params['grid_size']
                self.device = network.location_head.weight.device
                
                # PPO hyperparameters
                self.gamma = params['gamma']
                self.eps_clip = params['eps_clip']
                self.k_epochs = params['k_epochs']
                self.entropy_coef = params['entropy_coef']
                self.value_coef = params['value_coef']
                
                self.memory = []
            
            def select_top3_actions(self, tif_image, crop_image):
                # Preprocess images
                tif_tensor = self._preprocess_image(tif_image).to(self.device)
                crop_tensor = self._preprocess_image(crop_image).to(self.device)
                
                with torch.no_grad():
                    location_probs, value, logits = self.network(tif_tensor, crop_tensor)
                    
                    # Get top 3 predictions
                    top_probs, top_indices = torch.topk(location_probs.squeeze(), k=3)
                    
                    # Convert to grid coordinates
                    locations = []
                    probabilities = []
                    
                    for i in range(3):
                        idx = top_indices[i].item()
                        prob = top_probs[i].item()
                        
                        grid_y = idx // self.grid_size
                        grid_x = idx % self.grid_size
                        
                        locations.append((grid_x, grid_y))
                        probabilities.append(prob)
                
                return locations, probabilities, logits.squeeze()
            
            def _preprocess_image(self, image):
                import cv2
                resized = cv2.resize(image, (224, 224))
                tensor = torch.from_numpy(resized).float() / 255.0
                tensor = tensor.permute(2, 0, 1).unsqueeze(0)
                return tensor
            
            def store_experience(self, state, action_logits, reward, value):
                self.memory.append({
                    'state': state,
                    'action_logits': action_logits.detach().cpu(),
                    'reward': reward,
                    'value': value.detach().cpu()
                })
            
            def update_policy(self):
                if len(self.memory) < 16:  # Minimum batch size
                    return
                
                # Simplified PPO update for optimization
                returns = []
                advantages = []
                gae = 0
                
                for i in reversed(range(len(self.memory))):
                    if i == len(self.memory) - 1:
                        next_value = 0
                    else:
                        next_value = self.memory[i + 1]['value']
                    
                    reward = self.memory[i]['reward']
                    value = self.memory[i]['value']
                    
                    returns.insert(0, reward + self.gamma * next_value)
                    
                    delta = reward + self.gamma * next_value - value
                    gae = delta + self.gamma * 0.95 * gae
                    advantages.insert(0, gae)
                
                # Convert to tensors
                returns = torch.tensor(returns, dtype=torch.float32).to(self.device)
                advantages = torch.tensor(advantages, dtype=torch.float32).to(self.device)
                
                # Normalize advantages
                advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
                
                # PPO update
                for _ in range(self.k_epochs):
                    total_loss = 0
                    
                    for i, experience in enumerate(self.memory):
                        state = experience['state']
                        old_logits = experience['action_logits'].to(self.device)
                        
                        # Forward pass
                        tif_tensor = self._preprocess_image(state['tif_image']).to(self.device)
                        crop_tensor = self._preprocess_image(state['crop_image']).to(self.device)
                        
                        location_probs, value, new_logits = self.network(tif_tensor, crop_tensor)
                        
                        # Policy loss
                        old_probs = torch.nn.functional.softmax(old_logits, dim=0)
                        new_probs = torch.nn.functional.softmax(new_logits.squeeze(), dim=0)
                        
                        ratio = (new_probs + 1e-8) / (old_probs + 1e-8)
                        
                        surr1 = ratio * advantages[i]
                        surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * advantages[i]
                        policy_loss = -torch.min(surr1, surr2).mean()
                        
                        # Value loss
                        value_loss = torch.nn.functional.mse_loss(value.squeeze(), returns[i])
                        
                        # Entropy loss
                        entropy = -torch.sum(new_probs * torch.log(new_probs + 1e-8))
                        
                        total_loss += (policy_loss + 
                                     self.value_coef * value_loss - 
                                     self.entropy_coef * entropy)
                    
                    # Update
                    self.optimizer.zero_grad()
                    total_loss.backward()
                    torch.nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)
                    self.optimizer.step()
                
                # Clear memory
                self.memory = []
        
        return OptimizedAgent(network, optimizer, params)
    
    def train_and_evaluate(self, num_episodes: int = 300, update_frequency: int = 16) -> float:
        """Train and return objective score for Optuna"""
        
        print(f"   Training with params: {self.params}")
        
        episode_rewards = []
        similarity_scores = []
        
        for episode in range(num_episodes):
            # Reset environment
            tif_image, crop_image, state = self.env.reset()
            
            # Agent prediction
            locations, probabilities, action_logits = self.agent.select_top3_actions(tif_image, crop_image)
            
            # Calculate reward
            reward, similarities = self.env.calculate_reward(locations, probabilities)
            
            # Get value estimate
            with torch.no_grad():
                tif_tensor = self.agent._preprocess_image(tif_image).to(self.agent.device)
                crop_tensor = self.agent._preprocess_image(crop_image).to(self.agent.device)
                _, value, _ = self.agent.network(tif_tensor, crop_tensor)
            
            # Store experience
            self.agent.store_experience(state, action_logits, reward, value)
            
            # Update policy
            if (episode + 1) % update_frequency == 0:
                self.agent.update_policy()
            
            # Track metrics
            episode_rewards.append(reward)
            similarity_scores.append(max(similarities))
            
            # Early stopping if not improving
            if episode > 100 and episode % 50 == 0:
                recent_avg = np.mean(episode_rewards[-50:])
                older_avg = np.mean(episode_rewards[-100:-50])
                
                if recent_avg <= older_avg * 1.01:  # Less than 1% improvement
                    print(f"   Early stopping at episode {episode} (no improvement)")
                    break
        
        # Calculate final objective
        final_episodes = min(100, len(episode_rewards))
        final_reward = np.mean(episode_rewards[-final_episodes:])
        final_similarity = np.mean(similarity_scores[-final_episodes:])
        max_confidence = max([max(self.agent.select_top3_actions(
            *self.env.reset()[:2])[1]) for _ in range(10)])
        
        # Combined objective (reward + similarity + confidence)
        objective = 0.4 * final_reward + 0.4 * final_similarity + 0.2 * max_confidence
        
        print(f"   Final objective: {objective:.4f} (reward: {final_reward:.3f}, "
              f"similarity: {final_similarity:.3f}, confidence: {max_confidence:.3f})")
        
        return objective

def objective(trial):
    """Optuna objective function"""
    
    # Sample hyperparameters
    params = {
        # Grid and architecture
        'grid_size': trial.suggest_int('grid_size', 20, 35),
        'feature_dim': trial.suggest_categorical('feature_dim', [256, 384, 512, 768]),
        'num_heads': trial.suggest_categorical('num_heads', [4, 6, 8, 12]),
        'dropout_rate': trial.suggest_float('dropout_rate', 0.1, 0.5),
        
        # Learning parameters
        'learning_rate': trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True),
        'gamma': trial.suggest_float('gamma', 0.95, 0.999),
        
        # PPO parameters
        'eps_clip': trial.suggest_float('eps_clip', 0.1, 0.3),
        'k_epochs': trial.suggest_int('k_epochs', 2, 8),
        'entropy_coef': trial.suggest_float('entropy_coef', 0.001, 0.1, log=True),
        'value_coef': trial.suggest_float('value_coef', 0.1, 1.0),
    }
    
    try:
        # Create trainer with sampled parameters
        trainer = OptimizedDroneLocalizationTrainer(params)
        
        # Train and evaluate
        objective_score = trainer.train_and_evaluate(num_episodes=300)
        
        return objective_score
        
    except Exception as e:
        print(f"   Trial failed with error: {e}")
        return 0.0  # Return poor score for failed trials

def run_optimization(n_trials: int = 50, timeout_hours: int = 8):
    """Run Optuna optimization"""
    
    # Setup logging
    optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler())
    optuna.logging.set_verbosity(logging.INFO)
    
    # Create study
    study = optuna.create_study(
        direction='maximize',
        sampler=optuna.samplers.TPESampler(seed=42),
        pruner=optuna.pruners.MedianPruner(n_startup_trials=5, n_warmup_steps=50)
    )
    
    print(f"🔍 Starting Optuna Optimization")
    print(f"   Trials: {n_trials}")
    print(f"   Timeout: {timeout_hours} hours")
    print(f"   Start time: {datetime.now()}")
    
    # Run optimization
    start_time = time.time()
    study.optimize(
        objective, 
        n_trials=n_trials,
        timeout=timeout_hours * 3600,
        show_progress_bar=True
    )
    
    end_time = time.time()
    
    # Results
    print(f"\n🎉 Optimization Complete!")
    print(f"   Duration: {(end_time - start_time) / 3600:.2f} hours")
    print(f"   Trials completed: {len(study.trials)}")
    print(f"   Best value: {study.best_value:.4f}")
    
    print(f"\n🏆 Best Parameters:")
    for key, value in study.best_params.items():
        print(f"   {key}: {value}")
    
    # Save results
    results = {
        'best_params': study.best_params,
        'best_value': study.best_value,
        'n_trials': len(study.trials),
        'duration_hours': (end_time - start_time) / 3600,
        'timestamp': datetime.now().isoformat()
    }
    
    results_file = f"optuna_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"   💾 Results saved to: {results_file}")
    
    # Plot optimization history (if possible)
    try:
        import matplotlib.pyplot as plt
        
        fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12, 5))
        
        # Optimization history
        optuna.visualization.matplotlib.plot_optimization_history(study, ax=ax1)
        ax1.set_title('Optimization History')
        
        # Parameter importances
        optuna.visualization.matplotlib.plot_param_importances(study, ax=ax2)
        ax2.set_title('Parameter Importances')
        
        plt.tight_layout()
        plt.savefig(f"optuna_plots_{datetime.now().strftime('%Y%m%d_%H%M%S')}.png", dpi=300)
        plt.show()
        
    except ImportError:
        print("   📊 Install matplotlib and plotly for visualization")
    
    return study

def create_optimized_trainer(best_params: Dict[str, Any]):
    """Create trainer with best parameters found by Optuna"""
    
    print(f"🚀 Creating optimized trainer with best parameters:")
    for key, value in best_params.items():
        print(f"   {key}: {value}")
    
    return OptimizedDroneLocalizationTrainer(best_params)

# Optuna Hyperparameter Optimization for Drone Localization RL Agent
import optuna
import torch
import numpy as np
import json
import time
import gc
import psutil
import os
from datetime import datetime
from pathlib import Path
import logging
from typing import Dict, Any
import traceback

# Memory management utilities
class MemoryManager:
    """Manages memory and prevents OOM during optimization"""
    
    def __init__(self, max_gpu_memory_gb: float = 10.0, max_cpu_memory_gb: float = 16.0):
        self.max_gpu_memory = max_gpu_memory_gb * 1024 * 1024 * 1024  # Convert to bytes
        self.max_cpu_memory = max_cpu_memory_gb * 1024 * 1024 * 1024
        
    def check_memory_limits(self) -> bool:
        """Check if we're approaching memory limits"""
        try:
            # Check CPU memory
            cpu_memory = psutil.virtual_memory()
            cpu_used = cpu_memory.used
            
            if cpu_used > self.max_cpu_memory:
                print(f"   ⚠️ CPU memory limit exceeded: {cpu_used / 1024**3:.1f}GB")
                return False
            
            # Check GPU memory if available
            if torch.cuda.is_available():
                gpu_memory = torch.cuda.memory_allocated()
                if gpu_memory > self.max_gpu_memory:
                    print(f"   ⚠️ GPU memory limit exceeded: {gpu_memory / 1024**3:.1f}GB")
                    return False
            
            return True
            
        except Exception as e:
            print(f"   ⚠️ Memory check failed: {e}")
            return True  # Continue if we can't check
    
    def cleanup_memory(self):
        """Aggressive memory cleanup"""
        try:
            # Clear Python garbage
            gc.collect()
            
            # Clear CUDA cache if available
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
                torch.cuda.synchronize()
            
            # Force garbage collection again
            gc.collect()
            
        except Exception as e:
            print(f"   ⚠️ Memory cleanup failed: {e}")
    
    def get_memory_info(self) -> Dict[str, float]:
        """Get current memory usage"""
        info = {}
        
        try:
            # CPU memory
            cpu_memory = psutil.virtual_memory()
            info['cpu_used_gb'] = cpu_memory.used / 1024**3
            info['cpu_percent'] = cpu_memory.percent
            
            # GPU memory
            if torch.cuda.is_available():
                gpu_memory = torch.cuda.memory_allocated()
                gpu_reserved = torch.cuda.memory_reserved()
                info['gpu_used_gb'] = gpu_memory / 1024**3
                info['gpu_reserved_gb'] = gpu_reserved / 1024**3
            
        except Exception as e:
            print(f"   ⚠️ Memory info failed: {e}")
        
        return info

class SafeOptimizedDroneLocalizationTrainer:
    """Memory-safe trainer with proper cleanup"""
    
    def __init__(self, trial_params: Dict[str, Any]):
        self.params = trial_params
        self.memory_manager = MemoryManager()
        self.env = None
        self.agent = None
        
        try:
            # Check memory before initialization
            if not self.memory_manager.check_memory_limits():
                raise MemoryError("Memory limits exceeded before initialization")
            
            # Initialize with memory monitoring
            self._safe_initialize()
            
        except Exception as e:
            self.cleanup()
            raise e
    
    def _safe_initialize(self):
        """Initialize with memory safety"""
        
        # Limit batch sizes based on grid size to prevent OOM
        grid_size = self.params['grid_size']
        if grid_size > 30:
            max_memory_per_episode = 100  # MB
        else:
            max_memory_per_episode = 200  # MB
        
        # Initialize environment
        self.env = DroneLocalizationEnvironment(
            tif_image_path=None,
            crops_metadata_path=None, 
            grid_size=grid_size
        )
        
        # Check memory after env creation
        if not self.memory_manager.check_memory_limits():
            raise MemoryError("Memory limits exceeded after environment creation")
        
        # Initialize agent with memory limits
        self.agent = self._create_memory_safe_agent(self.params)
        
        # Final memory check
        if not self.memory_manager.check_memory_limits():
            raise MemoryError("Memory limits exceeded after agent creation")
    
    def _create_memory_safe_agent(self, params: Dict[str, Any]):
        """Create agent with memory limitations"""
        
        # Reduce feature dimensions if memory is tight
        available_memory = psutil.virtual_memory().available / 1024**3
        if available_memory < 8:  # Less than 8GB available
            params['feature_dim'] = min(params['feature_dim'], 256)
            params['num_heads'] = min(params['num_heads'], 4)
        
        # Create network with memory monitoring
        try:
            network = OptimizedDroneLocalizationNetwork(
                grid_size=params['grid_size'],
                feature_dim=params['feature_dim'],
                num_heads=params['num_heads'],
                dropout_rate=params['dropout_rate']
            )
            
            # Move to GPU only if we have enough memory
            device = 'cpu'
            if torch.cuda.is_available():
                try:
                    network = network.to('cuda')
                    device = 'cuda'
                    # Test allocation
                    test_tensor = torch.randn(1, 3, 224, 224).to('cuda')
                    _ = network.tif_encoder(test_tensor)
                    del test_tensor
                    torch.cuda.empty_cache()
                except RuntimeError as e:
                    if "out of memory" in str(e).lower():
                        print(f"   ⚠️ GPU OOM, falling back to CPU")
                        network = network.to('cpu')
                        device = 'cpu'
                        torch.cuda.empty_cache()
                    else:
                        raise e
            
        except Exception as e:
            self.memory_manager.cleanup_memory()
            raise e
        
        # Create optimizer
        optimizer = torch.optim.Adam(network.parameters(), lr=params['learning_rate'])
        
        # Create memory-safe agent
        class MemorySafeAgent:
            def __init__(self, network, optimizer, params, memory_manager):
                self.network = network
                self.optimizer = optimizer
                self.grid_size = params['grid_size']
                self.device = device
                self.memory_manager = memory_manager
                
                # PPO hyperparameters
                self.gamma = params['gamma']
                self.eps_clip = params['eps_clip']
                self.k_epochs = params['k_epochs']
                self.entropy_coef = params['entropy_coef']
                self.value_coef = params['value_coef']
                
                # Memory-limited storage
                max_memory_size = min(32, 64 // (params['grid_size'] // 20))
                self.memory = []
                self.max_memory_size = max_memory_size
            
            def select_top3_actions(self, tif_image, crop_image):
                try:
                    # Check memory before processing
                    if not self.memory_manager.check_memory_limits():
                        raise MemoryError("Memory limit exceeded during action selection")
                    
                    # Preprocess images
                    tif_tensor = self._preprocess_image(tif_image).to(self.device)
                    crop_tensor = self._preprocess_image(crop_image).to(self.device)
                    
                    with torch.no_grad():
                        location_probs, value, logits = self.network(tif_tensor, crop_tensor)
                        
                        # Get top 3 predictions
                        top_probs, top_indices = torch.topk(location_probs.squeeze(), k=3)
                        
                        # Convert to grid coordinates
                        locations = []
                        probabilities = []
                        
                        for i in range(3):
                            idx = top_indices[i].item()
                            prob = top_probs[i].item()
                            
                            grid_y = idx // self.grid_size
                            grid_x = idx % self.grid_size
                            
                            locations.append((grid_x, grid_y))
                            probabilities.append(prob)
                    
                    # Clean up tensors
                    del tif_tensor, crop_tensor
                    if self.device == 'cuda':
                        torch.cuda.empty_cache()
                    
                    return locations, probabilities, logits.squeeze()
                    
                except Exception as e:
                    self.memory_manager.cleanup_memory()
                    raise e
            
            def _preprocess_image(self, image):
                import cv2
                resized = cv2.resize(image, (224, 224))
                tensor = torch.from_numpy(resized).float() / 255.0
                tensor = tensor.permute(2, 0, 1).unsqueeze(0)
                return tensor
            
            def store_experience(self, state, action_logits, reward, value):
                # Limit memory size to prevent OOM
                if len(self.memory) >= self.max_memory_size:
                    self.memory.pop(0)  # Remove oldest experience
                
                self.memory.append({
                    'state': state,
                    'action_logits': action_logits.detach().cpu(),
                    'reward': reward,
                    'value': value.detach().cpu()
                })
            
            def update_policy(self):
                try:
                    if len(self.memory) < 8:  # Reduced minimum batch size
                        return
                    
                    # Check memory before update
                    if not self.memory_manager.check_memory_limits():
                        print("   ⚠️ Skipping policy update due to memory limits")
                        return
                    
                    # Simplified PPO update with memory management
                    returns = []
                    advantages = []
                    gae = 0
                    
                    for i in reversed(range(len(self.memory))):
                        if i == len(self.memory) - 1:
                            next_value = 0
                        else:
                            next_value = self.memory[i + 1]['value']
                        
                        reward = self.memory[i]['reward']
                        value = self.memory[i]['value']
                        
                        returns.insert(0, reward + self.gamma * next_value)
                        
                        delta = reward + self.gamma * next_value - value
                        gae = delta + self.gamma * 0.95 * gae
                        advantages.insert(0, gae)
                    
                    # Convert to tensors
                    returns = torch.tensor(returns, dtype=torch.float32).to(self.device)
                    advantages = torch.tensor(advantages, dtype=torch.float32).to(self.device)
                    
                    # Normalize advantages
                    advantages = (advantages - advantages.mean()) / (advantages.std() + 1e-8)
                    
                    # Reduced PPO epochs for memory safety
                    safe_k_epochs = min(self.k_epochs, 3)
                    
                    # PPO update with memory monitoring
                    for epoch in range(safe_k_epochs):
                        total_loss = 0
                        
                        # Process in smaller batches if needed
                        batch_size = min(len(self.memory), 8)
                        
                        for i in range(0, len(self.memory), batch_size):
                            batch_end = min(i + batch_size, len(self.memory))
                            batch_loss = 0
                            
                            for j in range(i, batch_end):
                                experience = self.memory[j]
                                state = experience['state']
                                old_logits = experience['action_logits'].to(self.device)
                                
                                # Forward pass
                                tif_tensor = self._preprocess_image(state['tif_image']).to(self.device)
                                crop_tensor = self._preprocess_image(state['crop_image']).to(self.device)
                                
                                location_probs, value, new_logits = self.network(tif_tensor, crop_tensor)
                                
                                # Policy loss
                                old_probs = torch.nn.functional.softmax(old_logits, dim=0)
                                new_probs = torch.nn.functional.softmax(new_logits.squeeze(), dim=0)
                                
                                ratio = (new_probs + 1e-8) / (old_probs + 1e-8)
                                
                                surr1 = ratio * advantages[j]
                                surr2 = torch.clamp(ratio, 1 - self.eps_clip, 1 + self.eps_clip) * advantages[j]
                                policy_loss = -torch.min(surr1, surr2).mean()
                                
                                # Value loss
                                value_loss = torch.nn.functional.mse_loss(value.squeeze(), returns[j])
                                
                                # Entropy loss
                                entropy = -torch.sum(new_probs * torch.log(new_probs + 1e-8))
                                
                                batch_loss += (policy_loss + 
                                             self.value_coef * value_loss - 
                                             self.entropy_coef * entropy)
                                
                                # Clean up tensors
                                del tif_tensor, crop_tensor
                            
                            # Update with batch
                            if batch_loss != 0:
                                self.optimizer.zero_grad()
                                batch_loss.backward()
                                torch.nn.utils.clip_grad_norm_(self.network.parameters(), 0.5)
                                self.optimizer.step()
                            
                            # Memory cleanup between batches
                            if self.device == 'cuda':
                                torch.cuda.empty_cache()
                        
                        # Check memory between epochs
                        if not self.memory_manager.check_memory_limits():
                            print(f"   ⚠️ Stopping PPO early at epoch {epoch} due to memory")
                            break
                    
                    # Clear memory after update
                    self.memory = []
                    self.memory_manager.cleanup_memory()
                    
                except Exception as e:
                    print(f"   ⚠️ Policy update failed: {e}")
                    self.memory_manager.cleanup_memory()
                    self.memory = []  # Clear memory on failure
        
        return MemorySafeAgent(network, optimizer, params, self.memory_manager)
    
    def train_and_evaluate(self, num_episodes: int = 200, update_frequency: int = 16) -> float:
        """Memory-safe training with monitoring"""
        
        try:
            print(f"   🔧 Training with params: {self.params}")
            mem_info = self.memory_manager.get_memory_info()
            print(f"   💾 Memory: CPU {mem_info.get('cpu_used_gb', 0):.1f}GB, "
                  f"GPU {mem_info.get('gpu_used_gb', 0):.1f}GB")
            
            episode_rewards = []
            similarity_scores = []
            memory_failures = 0
            
            for episode in range(num_episodes):
                try:
                    # Check memory before each episode
                    if not self.memory_manager.check_memory_limits():
                        print(f"   ⚠️ Memory limit reached at episode {episode}")
                        break
                    
                    # Reset environment
                    tif_image, crop_image, state = self.env.reset()
                    
                    # Agent prediction
                    locations, probabilities, action_logits = self.agent.select_top3_actions(tif_image, crop_image)
                    
                    # Calculate reward
                    reward, similarities = self.env.calculate_reward(locations, probabilities)
                    
                    # Get value estimate
                    with torch.no_grad():
                        tif_tensor = self.agent._preprocess_image(tif_image).to(self.agent.device)
                        crop_tensor = self.agent._preprocess_image(crop_image).to(self.agent.device)
                        _, value, _ = self.agent.network(tif_tensor, crop_tensor)
                        
                        # Clean up
                        del tif_tensor, crop_tensor
                        if self.agent.device == 'cuda':
                            torch.cuda.empty_cache()
                    
                    # Store experience
                    self.agent.store_experience(state, action_logits, reward, value)
                    
                    # Update policy
                    if (episode + 1) % update_frequency == 0:
                        self.agent.update_policy()
                    
                    # Track metrics
                    episode_rewards.append(reward)
                    similarity_scores.append(max(similarities))
                    
                    # Periodic memory cleanup
                    if episode % 50 == 0:
                        self.memory_manager.cleanup_memory()
                    
                    # Early stopping if not improving
                    if episode > 100 and episode % 50 == 0:
                        recent_avg = np.mean(episode_rewards[-25:])
                        older_avg = np.mean(episode_rewards[-50:-25])
                        
                        if recent_avg <= older_avg * 1.01:
                            print(f"   📈 Early stopping at episode {episode} (no improvement)")
                            break
                    
                except MemoryError as e:
                    memory_failures += 1
                    print(f"   ⚠️ Memory error at episode {episode}: {e}")
                    self.memory_manager.cleanup_memory()
                    
                    if memory_failures > 3:
                        print(f"   ❌ Too many memory failures, stopping trial")
                        break
                    
                except Exception as e:
                    print(f"   ⚠️ Episode {episode} failed: {e}")
                    self.memory_manager.cleanup_memory()
                    continue
            
            # Calculate final objective
            if len(episode_rewards) < 10:
                print(f"   ❌ Trial failed - insufficient episodes completed")
                return 0.0
            
            final_episodes = min(50, len(episode_rewards))
            final_reward = np.mean(episode_rewards[-final_episodes:])
            final_similarity = np.mean(similarity_scores[-final_episodes:])
            
            # Test confidence
            try:
                confidence_scores = []
                for _ in range(5):
                    tif_img, crop_img, _ = self.env.reset()
                    _, probs, _ = self.agent.select_top3_actions(tif_img, crop_img)
                    confidence_scores.extend(probs)
                max_confidence = max(confidence_scores) if confidence_scores else 0.0
            except:
                max_confidence = 0.0
            
            # Combined objective
            objective = 0.4 * final_reward + 0.4 * final_similarity + 0.2 * max_confidence
            
            print(f"   ✅ Final objective: {objective:.4f} (reward: {final_reward:.3f}, "
                  f"similarity: {final_similarity:.3f}, confidence: {max_confidence:.3f})")
            
            return objective
            
        except Exception as e:
            print(f"   ❌ Training failed: {e}")
            return 0.0
        
        finally:
            # Always cleanup
            self.cleanup()
    
    def cleanup(self):
        """Cleanup resources"""
        try:
            if hasattr(self, 'agent') and self.agent:
                if hasattr(self.agent, 'memory'):
                    self.agent.memory = []
                del self.agent
            
            if hasattr(self, 'env') and self.env:
                del self.env
            
            self.memory_manager.cleanup_memory()
            
        except Exception as e:
            print(f"   ⚠️ Cleanup failed: {e}")

def safe_objective(trial):
    """Memory-safe Optuna objective function with proper error handling"""
    
    trainer = None
    try:
        # Sample hyperparameters with memory considerations
        params = {
            # Grid and architecture - conservative limits
            'grid_size': trial.suggest_int('grid_size', 20, 35),
            'feature_dim': trial.suggest_categorical('feature_dim', [256, 384, 512]),  # Removed 768
            'num_heads': trial.suggest_categorical('num_heads', [4, 6, 8]),  # Removed 12
            'dropout_rate': trial.suggest_float('dropout_rate', 0.1, 0.4),
            
            # Learning parameters
            'learning_rate': trial.suggest_float('learning_rate', 1e-5, 1e-2, log=True),
            'gamma': trial.suggest_float('gamma', 0.95, 0.999),
            
            # PPO parameters
            'eps_clip': trial.suggest_float('eps_clip', 0.1, 0.3),
            'k_epochs': trial.suggest_int('k_epochs', 2, 6),  # Reduced max
            'entropy_coef': trial.suggest_float('entropy_coef', 0.001, 0.1, log=True),
            'value_coef': trial.suggest_float('value_coef', 0.1, 1.0),
        }
        
        # Check available memory before starting
        available_memory = psutil.virtual_memory().available / 1024**3
        if available_memory < 4:  # Less than 4GB available
            print(f"   ⚠️ Low memory ({available_memory:.1f}GB), pruning trial")
            raise optuna.TrialPruned("Insufficient memory")
        
        # Create trainer with memory safety
        trainer = SafeOptimizedDroneLocalizationTrainer(params)
        
        # Train and evaluate with shorter episodes for safety
        objective_score = trainer.train_and_evaluate(num_episodes=200)
        
        # Prune trials with poor early performance
        if objective_score < 0.1:
            raise optuna.TrialPruned("Poor performance")
        
        return objective_score
        
    except MemoryError as e:
        print(f"   💥 Memory error: {e}")
        raise optuna.TrialPruned("Out of memory")
        
    except RuntimeError as e:
        if "out of memory" in str(e).lower():
            print(f"   💥 CUDA OOM: {e}")
            raise optuna.TrialPruned("CUDA out of memory")
        else:
            print(f"   ❌ Runtime error: {e}")
            return 0.0
            
    except Exception as e:
        print(f"   ❌ Trial failed: {e}")
        print(f"   Stack trace: {traceback.format_exc()}")
        return 0.0
        
    finally:
        # Always cleanup
        if trainer:
            trainer.cleanup()
        gc.collect()
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

def run_safe_optimization(n_trials: int = 30, timeout_hours: int = 8):
    """Run memory-safe Optuna optimization"""
    
    # Setup logging
    optuna.logging.get_logger("optuna").addHandler(logging.StreamHandler())
    optuna.logging.set_verbosity(logging.INFO)
    
    # Create study with aggressive pruning
    study = optuna.create_study(
        direction='maximize',
        sampler=optuna.samplers.TPESampler(seed=42),
        pruner=optuna.pruners.MedianPruner(
            n_startup_trials=3,  # Start pruning early
            n_warmup_steps=25,   # Prune quickly
            interval_steps=10    # Check frequently
        )
    )
    
    print(f"🔍 Starting Memory-Safe Optuna Optimization")
    print(f"   Trials: {n_trials}")
    print(f"   Timeout: {timeout_hours} hours")
    print(f"   Memory limits: GPU 10GB, CPU 16GB")
    print(f"   Start time: {datetime.now()}")
    
    # Check system memory
    total_memory = psutil.virtual_memory().total / 1024**3
    available_memory = psutil.virtual_memory().available / 1024**3
    print(f"   System memory: {available_memory:.1f}GB available / {total_memory:.1f}GB total")
    
    if available_memory < 8:
        print(f"   ⚠️ Warning: Low available memory, consider closing other applications")
    
    # Run optimization
    start_time = time.time()
    study.optimize(
        safe_objective, 
        n_trials=n_trials,
        timeout=timeout_hours * 3600,
        show_progress_bar=True,
        gc_after_trial=True  # Force garbage collection after each trial
    )
    
    end_time = time.time()
    
    # Results
    print(f"\n🎉 Optimization Complete!")
    print(f"   Duration: {(end_time - start_time) / 3600:.2f} hours")
    print(f"   Trials completed: {len(study.trials)}")
    print(f"   Trials pruned: {len([t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED])}")
    print(f"   Best value: {study.best_value:.4f}")
    
    print(f"\n🏆 Best Parameters:")
    for key, value in study.best_params.items():
        print(f"   {key}: {value}")
    
    # Save results
    results = {
        'best_params': study.best_params,
        'best_value': study.best_value,
        'n_trials': len(study.trials),
        'n_pruned': len([t for t in study.trials if t.state == optuna.trial.TrialState.PRUNED]),
        'duration_hours': (end_time - start_time) / 3600,
        'timestamp': datetime.now().isoformat()
    }
    
    results_file = f"safe_optuna_results_{datetime.now().strftime('%Y%m%d_%H%M%S')}.json"
    with open(results_file, 'w') as f:
        json.dump(results, f, indent=2)
    
    print(f"   💾 Results saved to: {results_file}")
    
    return study

def safe_overnight_optimization():
    """Memory-safe overnight optimization"""
    
    print("🌙 Starting Safe Overnight Optimization...")
    print("   Memory-safe with aggressive pruning and cleanup")
    print("   Will automatically handle OOM and other errors")
    print("   Sweet dreams! 😴")
    
    study = run_safe_optimization(n_trials=30, timeout_hours=8)
    
    print(f"\n☀️ Good morning! Safe optimization complete.")
    print(f"   Best parameters found: {study.best_params}")
    
    return study

if __name__ == "__main__":
    print("🔍 OPTUNA HYPERPARAMETER OPTIMIZATION")
    print("="*50)
    print()
    print("🎯 WHAT IT OPTIMIZES:")
    print("   • Grid size (20-35)")
    print("   • Network architecture (feature_dim, num_heads, dropout)")
    print("   • Learning rate (1e-5 to 1e-2)")
    print("   • PPO hyperparameters (eps_clip, k_epochs, entropy)")
    print("   • Reward function weights")
    print()
    print("⏰ OVERNIGHT MODE:")
    print("   # Perfect for while you sleep!")
    print("   study = quick_optimization_overnight()")
    print()
    print("🔬 MANUAL MODE:")
    print("   # Custom trials and timeout")
    print("   study = run_optimization(n_trials=100, timeout_hours=12)")
    print()
    print("🚀 DEPLOY BEST:")
    print("   # Use best parameters found")
    print("   trainer = create_optimized_trainer(study.best_params)")
    print("   trainer.train_and_evaluate(num_episodes=1000)")
    print()
    print("😴 Sweet dreams! Wake up to optimized hyperparameters!")

🔍 OPTUNA HYPERPARAMETER OPTIMIZATION

🎯 WHAT IT OPTIMIZES:
   • Grid size (20-35)
   • Network architecture (feature_dim, num_heads, dropout)
   • Learning rate (1e-5 to 1e-2)
   • PPO hyperparameters (eps_clip, k_epochs, entropy)
   • Reward function weights

⏰ OVERNIGHT MODE:
   # Perfect for while you sleep!
   study = quick_optimization_overnight()

🔬 MANUAL MODE:
   # Custom trials and timeout
   study = run_optimization(n_trials=100, timeout_hours=12)

🚀 DEPLOY BEST:
   # Use best parameters found
   trainer = create_optimized_trainer(study.best_params)
   trainer.train_and_evaluate(num_episodes=1000)

😴 Sweet dreams! Wake up to optimized hyperparameters!
