In [None]:
"""
BatteryMind - Federated Learning Hyperparameter Tuning

Advanced hyperparameter optimization for federated learning systems in battery 
management applications. This notebook provides comprehensive tuning capabilities 
for federated aggregation algorithms, privacy parameters, and client coordination.

Features:
- Multi-objective optimization for accuracy, privacy, and communication efficiency
- Bayesian optimization with federated learning-specific priors
- Privacy-utility tradeoff analysis
- Communication cost optimization
- Client heterogeneity handling
- Differential privacy parameter tuning
- Aggregation algorithm comparison and optimization

Author: BatteryMind Development Team
Version: 1.0.0
"""

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from typing import Dict, List, Optional, Tuple, Any
import logging
import json
import yaml
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

# Optimization libraries
import optuna
from optuna.pruners import MedianPruner
from optuna.samplers import TPESampler
import ray
from ray import tune
from ray.tune.schedulers import ASHAScheduler
from ray.tune.suggest.optuna import OptunaSearch

# Federated learning libraries
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import flwr as fl
from flwr.common import Parameters, Scalar
from flwr.server.strategy import FedAvg, FedProx, FedOpt

# Privacy libraries
from opacus import PrivacyEngine
from opacus.utils.batch_memory_manager import BatchMemoryManager

# BatteryMind specific imports
import sys
sys.path.append('../../')
from federated_learning.server.federated_server import FederatedServer
from federated_learning.client_models.local_trainer import LocalTrainer
from federated_learning.client_models.client_manager import ClientManager
from federated_learning.privacy_preserving.differential_privacy import DifferentialPrivacyEngine
from federated_learning.server.aggregation_algorithms import FedAvgAggregator, FedProxAggregator
from federated_learning.simulation_framework.federated_simulator import FederatedSimulator
from training_data.synthetic_datasets import generate_battery_telemetry_data
from evaluation.metrics.accuracy_metrics import BatteryHealthMetrics
from utils.config_parser import ConfigParser
from utils.logging_utils import setup_logging

# Setup logging
logger = setup_logging(__name__)

class FederatedOptimizer:
    """
    Advanced hyperparameter optimizer for federated learning systems.
    """
    
    def __init__(self, 
                 data_config: Dict[str, Any],
                 base_fed_config: Dict[str, Any],
                 optimization_config: Dict[str, Any]):
        """
        Initialize the federated optimizer.
        
        Args:
            data_config: Configuration for data distribution
            base_fed_config: Base federated learning configuration
            optimization_config: Optimization-specific configuration
        """
        self.data_config = data_config
        self.base_fed_config = base_fed_config
        self.optimization_config = optimization_config
        
        # Initialize federated components
        self.simulator = FederatedSimulator(data_config)
        self.metrics = BatteryHealthMetrics()
        
        # Generate distributed datasets
        self.client_datasets = self._generate_client_datasets()
        
        # Optimization tracking
        self.best_params = None
        self.best_score = float('-inf')
        self.optimization_history = []
        
        logger.info("FederatedOptimizer initialized")
    
    def _generate_client_datasets(self) -> Dict[str, Any]:
        """Generate distributed datasets for federated learning."""
        logger.info("Generating distributed datasets for federated optimization...")
        
        # Generate base dataset
        base_data = generate_battery_telemetry_data(
            num_batteries=self.data_config.get('total_batteries', 1000),
            duration_days=self.data_config.get('duration_days', 30)
        )
        
        # Distribute data among clients
        num_clients = self.data_config.get('num_clients', 10)
        client_datasets = {}
        
        # Create heterogeneous data distribution
        for client_id in range(num_clients):
            # Simulate different battery types per client
            battery_types = ['lithium_ion', 'lifepo4', 'nimh']
            client_battery_type = np.random.choice(battery_types)
            
            # Filter data by battery type and add noise for heterogeneity
            client_data = base_data[base_data['battery_type'] == client_battery_type].copy()
            
            # Add client-specific noise and bias
            noise_level = np.random.uniform(0.01, 0.05)
            client_data['voltage'] += np.random.normal(0, noise_level, len(client_data))
            client_data['current'] += np.random.normal(0, noise_level * 2, len(client_data))
            
            # Simulate different data sizes per client
            sample_fraction = np.random.uniform(0.5, 1.0)
            client_data = client_data.sample(frac=sample_fraction).reset_index(drop=True)
            
            client_datasets[f'client_{client_id}'] = client_data
        
        return client_datasets
    
    def _objective_function(self, trial) -> float:
        """
        Objective function for federated learning optimization.
        
        Args:
            trial: Optuna trial object
            
        Returns:
            float: Objective value to maximize
        """
        # Sample federated learning hyperparameters
        params = {
            # Federated learning parameters
            'num_rounds': trial.suggest_int('num_rounds', 10, 100),
            'clients_per_round': trial.suggest_int('clients_per_round', 2, min(10, len(self.client_datasets))),
            'local_epochs': trial.suggest_int('local_epochs', 1, 10),
            'local_batch_size': trial.suggest_categorical('local_batch_size', [16, 32, 64, 128]),
            'local_learning_rate': trial.suggest_float('local_learning_rate', 1e-5, 1e-2, log=True),
            
            # Aggregation parameters
            'aggregation_method': trial.suggest_categorical('aggregation_method', ['fedavg', 'fedprox', 'fedopt']),
            'server_learning_rate': trial.suggest_float('server_learning_rate', 1e-3, 1e-1, log=True),
            'server_momentum': trial.suggest_float('server_momentum', 0.0, 0.9),
            
            # Privacy parameters
            'use_differential_privacy': trial.suggest_categorical('use_differential_privacy', [True, False]),
            'noise_multiplier': trial.suggest_float('noise_multiplier', 0.1, 2.0) if params.get('use_differential_privacy', False) else 0.0,
            'max_grad_norm': trial.suggest_float('max_grad_norm', 0.1, 10.0),
            
            # Communication parameters
            'compression_rate': trial.suggest_float('compression_rate', 0.1, 1.0),
            'quantization_bits': trial.suggest_int('quantization_bits', 2, 32),
            
            # Client selection parameters
            'client_selection_strategy': trial.suggest_categorical('client_selection_strategy', ['random', 'loss_based', 'contribution_based']),
            'min_available_clients': trial.suggest_int('min_available_clients', 2, min(8, len(self.client_datasets))),
            
            # Convergence parameters
            'convergence_threshold': trial.suggest_float('convergence_threshold', 1e-6, 1e-3, log=True),
            'patience': trial.suggest_int('patience', 5, 20),
            
            # Regularization parameters
            'proximal_term': trial.suggest_float('proximal_term', 0.0, 1.0) if params.get('aggregation_method') == 'fedprox' else 0.0,
            'l2_regularization': trial.suggest_float('l2_regularization', 1e-6, 1e-3, log=True),
        }
        
        # Fix conditional parameters
        if not params.get('use_differential_privacy', False):
            params['noise_multiplier'] = 0.0
        
        if params.get('aggregation_method') != 'fedprox':
            params['proximal_term'] = 0.0
        
        try:
            # Initialize federated simulator
            simulator = FederatedSimulator(
                client_datasets=self.client_datasets,
                server_config=params
            )
            
            # Run federated training
            training_results = simulator.run_federated_training(
                num_rounds=params['num_rounds'],
                clients_per_round=params['clients_per_round']
            )
            
            # Calculate composite score
            composite_score = self._calculate_federated_score(training_results, params)
            
            # Store trial results
            trial_result = {
                'trial_number': trial.number,
                'params': params,
                'score': composite_score,
                'training_results': training_results,
                'convergence_round': training_results.get('convergence_round', params['num_rounds']),
                'communication_cost': training_results.get('total_communication_cost', 0),
                'privacy_budget_used': training_results.get('privacy_budget_used', 0)
            }
            
            self.optimization_history.append(trial_result)
            
            # Update best parameters
            if composite_score > self.best_score:
                self.best_score = composite_score
                self.best_params = params.copy()
                
                # Save best model
                self._save_best_federated_model(training_results, params)
            
            return composite_score
            
        except Exception as e:
            logger.error(f"Federated trial {trial.number} failed: {str(e)}")
            return float('-inf')
    
    def _calculate_federated_score(self, training_results: Dict[str, Any], 
                                  params: Dict[str, Any]) -> float:
        """
        Calculate composite score for federated learning optimization.
        
        Args:
            training_results: Results from federated training
            params: Federated learning parameters
            
        Returns:
            float: Composite score
        """
        # Primary metrics (accuracy-based)
        global_accuracy = training_results.get('final_global_accuracy', 0) * 0.35
        convergence_speed = max(0, 1 - training_results.get('convergence_round', 100) / 100) * 0.15
        client_consistency = training_results.get('client_consistency_score', 0) * 0.15
        
        # Communication efficiency
        comm_efficiency = max(0, 1 - training_results.get('total_communication_cost', 1000) / 1000) * 0.15
        
        # Privacy preservation
        privacy_score = 0.0
        if params.get('use_differential_privacy', False):
            privacy_budget = training_results.get('privacy_budget_used', 0)
            privacy_score = max(0, 1 - privacy_budget / 10.0) * 0.1  # Reward lower privacy budget usage
        
        # Robustness metrics
        robustness_score = training_results.get('robustness_to_client_dropout', 0) * 0.1
        
        composite_score = (
            global_accuracy + convergence_speed + client_consistency + 
            comm_efficiency + privacy_score + robustness_score
        )
        
        return composite_score
    
    def _save_best_federated_model(self, training_results: Dict[str, Any], 
                                  params: Dict[str, Any]):
        """Save the best federated model and configuration."""
        save_dir = Path("../../model-artifacts/hyperparameter_tuning/federated_best")
        save_dir.mkdir(parents=True, exist_ok=True)
        
        # Save global model
        global_model = training_results.get('global_model')
        if global_model:
            torch.save(global_model.state_dict(), save_dir / "best_global_model.pt")
        
        # Save configuration
        config = {
            'parameters': params,
            'training_results': training_results,
            'optimization_timestamp': pd.Timestamp.now().isoformat()
        }
        
        with open(save_dir / "best_federated_config.json", 'w') as f:
            json.dump(config, f, indent=2, default=str)
        
        logger.info(f"Best federated model saved with score: {self.best_score:.4f}")
    
    def optimize_federated_params(self, n_trials: int = 50) -> Dict[str, Any]:
        """
        Optimize federated learning hyperparameters.
        
        Args:
            n_trials: Number of optimization trials
            
        Returns:
            Dict containing optimization results
        """
        logger.info(f"Starting federated learning optimization with {n_trials} trials")
        
        # Create study with federated-specific sampler
        study = optuna.create_study(
            direction='maximize',
            sampler=TPESampler(
                seed=42,
                n_startup_trials=10,
                n_ei_candidates=24
            ),
            pruner=MedianPruner(n_startup_trials=5, n_warmup_steps=5)
        )
        
        # Add good starting points for federated learning
        study.enqueue_trial({
            'num_rounds': 50,
            'clients_per_round': 5,
            'local_epochs': 3,
            'local_batch_size': 32,
            'local_learning_rate': 0.01,
            'aggregation_method': 'fedavg',
            'server_learning_rate': 0.1,
            'use_differential_privacy': False,
            'compression_rate': 1.0,
            'client_selection_strategy': 'random'
        })
        
        study.enqueue_trial({
            'num_rounds': 30,
            'clients_per_round': 3,
            'local_epochs': 5,
            'local_batch_size': 64,
            'local_learning_rate': 0.005,
            'aggregation_method': 'fedprox',
            'server_learning_rate': 0.05,
            'use_differential_privacy': True,
            'noise_multiplier': 0.5,
            'compression_rate': 0.5,
            'client_selection_strategy': 'loss_based'
        })
        
        # Optimize
        study.optimize(
            self._objective_function,
            n_trials=n_trials,
            timeout=self.optimization_config.get('timeout_hours', 48) * 3600,
            callbacks=[self._federated_callback]
        )
        
        # Compile results
        results = {
            'best_params': study.best_params,
            'best_score': study.best_value,
            'n_trials': len(study.trials),
            'optimization_history': self.optimization_history,
            'study_statistics': self._get_federated_study_statistics(study)
        }
        
        logger.info(f"Federated optimization completed. Best score: {study.best_value:.4f}")
        return results
    
    def _federated_callback(self, study, trial):
        """Callback for federated optimization."""
        if trial.number % 5 == 0:
            logger.info(f"Federated trial {trial.number}: Best score = {study.best_value:.4f}")
    
    def _get_federated_study_statistics(self, study) -> Dict[str, Any]:
        """Get federated-specific study statistics."""
        df = study.trials_dataframe()
        
        stats = {
            'total_trials': len(study.trials),
            'completed_trials': len([t for t in study.trials if t.state == optuna.trial.TrialState.COMPLETE]),
            'best_trial_number': study.best_trial.number,
            'parameter_importance': {}
        }
        
        # Calculate parameter importance if enough trials
        if len(study.trials) > 10:
            importance = optuna.importance.get_param_importances(study)
            stats['parameter_importance'] = importance
        
        return stats
    
    def analyze_privacy_utility_tradeoff(self) -> Dict[str, Any]:
        """
        Analyze privacy-utility tradeoff across optimization trials.
        
        Returns:
            Dict containing privacy-utility analysis
        """
        logger.info("Analyzing privacy-utility tradeoff...")
        
        if not self.optimization_history:
            return {}
        
        df = pd.DataFrame(self.optimization_history)
        
        # Filter trials with differential privacy
        dp_trials = df[df['params'].apply(lambda x: x.get('use_differential_privacy', False))]
        no_dp_trials = df[~df['params'].apply(lambda x: x.get('use_differential_privacy', False))]
        
        analysis = {
            'dp_trials_count': len(dp_trials),
            'no_dp_trials_count': len(no_dp_trials),
            'dp_average_score': dp_trials['score'].mean() if len(dp_trials) > 0 else 0,
            'no_dp_average_score': no_dp_trials['score'].mean() if len(no_dp_trials) > 0 else 0,
            'privacy_utility_correlation': 0.0,
            'optimal_privacy_params': {}
        }
        
        if len(dp_trials) > 5:
            # Calculate correlation between privacy budget and utility
            privacy_budgets = dp_trials['training_results'].apply(
                lambda x: x.get('privacy_budget_used', 0)
            )
            scores = dp_trials['score']
            
            if len(privacy_budgets) > 0 and privacy_budgets.std() > 0:
                analysis['privacy_utility_correlation'] = privacy_budgets.corr(scores)
            
            # Find optimal privacy parameters
            best_dp_trial = dp_trials.loc[dp_trials['score'].idxmax()]
            analysis['optimal_privacy_params'] = {
                'noise_multiplier': best_dp_trial['params'].get('noise_multiplier', 0),
                'max_grad_norm': best_dp_trial['params'].get('max_grad_norm', 0),
                'privacy_budget_used': best_dp_trial['training_results'].get('privacy_budget_used', 0),
                'utility_score': best_dp_trial['score']
            }
        
        return analysis
    
    def analyze_communication_efficiency(self) -> Dict[str, Any]:
        """
        Analyze communication efficiency across optimization trials.
        
        Returns:
            Dict containing communication efficiency analysis
        """
        logger.info("Analyzing communication efficiency...")
        
        if not self.optimization_history:
            return {}
        
        df = pd.DataFrame(self.optimization_history)
        
        # Extract communication costs
        comm_costs = df['training_results'].apply(
            lambda x: x.get('total_communication_cost', 0)
        )
        
        # Extract compression rates
        compression_rates = df['params'].apply(
            lambda x: x.get('compression_rate', 1.0)
        )
        
        # Extract quantization bits
        quantization_bits = df['params'].apply(
            lambda x: x.get('quantization_bits', 32)
        )
        
        analysis = {
            'average_communication_cost': comm_costs.mean(),
            'communication_cost_std': comm_costs.std(),
            'compression_efficiency': compression_rates.corr(comm_costs),
            'quantization_efficiency': quantization_bits.corr(comm_costs),
            'optimal_compression_params': {}
        }
        
        # Find optimal compression parameters
        efficiency_scores = df['score'] / (comm_costs + 1e-6)  # Score per communication cost
        best_efficiency_idx = efficiency_scores.idxmax()
        best_trial = df.loc[best_efficiency_idx]
        
        analysis['optimal_compression_params'] = {
            'compression_rate': best_trial['params'].get('compression_rate', 1.0),
            'quantization_bits': best_trial['params'].get('quantization_bits', 32),
            'communication_cost': best_trial['training_results'].get('total_communication_cost', 0),
            'efficiency_score': efficiency_scores.loc[best_efficiency_idx]
        }
        
        return analysis
    
    def analyze_client_heterogeneity_impact(self) -> Dict[str, Any]:
        """
        Analyze the impact of client heterogeneity on federated performance.
        
        Returns:
            Dict containing heterogeneity analysis
        """
        logger.info("Analyzing client heterogeneity impact...")
        
        if not self.optimization_history:
            return {}
        
        df = pd.DataFrame(self.optimization_history)
        
        # Extract client consistency scores
        consistency_scores = df['training_results'].apply(
            lambda x: x.get('client_consistency_score', 0)
        )
        
        # Extract robustness scores
        robustness_scores = df['training_results'].apply(
            lambda x: x.get('robustness_to_client_dropout', 0)
        )
        
        # Extract aggregation methods
        aggregation_methods = df['params'].apply(
            lambda x: x.get('aggregation_method', 'fedavg')
        )
        
        analysis = {
            'average_consistency': consistency_scores.mean(),
            'average_robustness': robustness_scores.mean(),
            'consistency_robustness_correlation': consistency_scores.corr(robustness_scores),
            'aggregation_method_performance': {},
            'optimal_heterogeneity_params': {}
        }
        
        # Analyze performance by aggregation method
        for method in aggregation_methods.unique():
            method_trials = df[aggregation_methods == method]
            analysis['aggregation_method_performance'][method] = {
                'average_score': method_trials['score'].mean(),
                'average_consistency': method_trials['training_results'].apply(
                    lambda x: x.get('client_consistency_score', 0)
                ).mean(),
                'count': len(method_trials)
            }
        
        # Find optimal parameters for handling heterogeneity
                # Find optimal parameters for handling heterogeneity
        best_consistency_idx = consistency_scores.idxmax()
        best_trial = df.loc[best_consistency_idx]
        
        analysis['optimal_heterogeneity_params'] = {
            'aggregation_method': best_trial['params'].get('aggregation_method', 'fedavg'),
            'local_epochs': best_trial['params'].get('local_epochs', 3),
            'clients_per_round': best_trial['params'].get('clients_per_round', 5),
            'consistency_score': best_trial['training_results'].get('client_consistency_score', 0)
        }
        
        return analysis
    
    def visualize_federated_optimization(self, results: Dict[str, Any]) -> None:
        """Create visualizations for federated optimization results."""
        if not self.optimization_history:
            logger.warning("No optimization history available for visualization")
            return
        
        df = pd.DataFrame(self.optimization_history)
        
        # Create comprehensive visualization
        fig, axes = plt.subplots(3, 3, figsize=(20, 15))
        fig.suptitle('Federated Learning Hyperparameter Optimization Results', fontsize=16)
        
        # 1. Score progression
        axes[0, 0].plot(df['trial_number'], df['score'], 'b-', alpha=0.7, label='Trial Score')
        axes[0, 0].plot(df['trial_number'], df['score'].cummax(), 'r-', linewidth=2, label='Best Score')
        axes[0, 0].set_xlabel('Trial Number')
        axes[0, 0].set_ylabel('Composite Score')
        axes[0, 0].set_title('Optimization Progress')
        axes[0, 0].legend()
        axes[0, 0].grid(True, alpha=0.3)
        
        # 2. Communication cost vs Score
        comm_costs = df['training_results'].apply(lambda x: x.get('total_communication_cost', 0))
        axes[0, 1].scatter(comm_costs, df['score'], alpha=0.6)
        axes[0, 1].set_xlabel('Communication Cost')
        axes[0, 1].set_ylabel('Score')
        axes[0, 1].set_title('Communication Efficiency')
        axes[0, 1].grid(True, alpha=0.3)
        
        # 3. Convergence rounds distribution
        conv_rounds = df['training_results'].apply(lambda x: x.get('convergence_round', 0))
        axes[0, 2].hist(conv_rounds, bins=20, alpha=0.7, color='green')
        axes[0, 2].set_xlabel('Convergence Round')
        axes[0, 2].set_ylabel('Frequency')
        axes[0, 2].set_title('Convergence Distribution')
        
        # 4. Aggregation method performance
        agg_methods = df['params'].apply(lambda x: x.get('aggregation_method', 'fedavg'))
        method_scores = df.groupby(agg_methods)['score'].mean()
        axes[1, 0].bar(method_scores.index, method_scores.values, color='orange')
        axes[1, 0].set_xlabel('Aggregation Method')
        axes[1, 0].set_ylabel('Average Score')
        axes[1, 0].set_title('Aggregation Method Performance')
        
        # 5. Privacy vs Utility (if applicable)
        dp_trials = df[df['params'].apply(lambda x: x.get('use_differential_privacy', False))]
        if len(dp_trials) > 0:
            privacy_budgets = dp_trials['training_results'].apply(lambda x: x.get('privacy_budget', 0))
            dp_scores = dp_trials['score']
            axes[1, 1].scatter(privacy_budgets, dp_scores, alpha=0.6, color='purple')
            axes[1, 1].set_xlabel('Privacy Budget (ε)')
            axes[1, 1].set_ylabel('Utility Score')
            axes[1, 1].set_title('Privacy-Utility Tradeoff')
            axes[1, 1].grid(True, alpha=0.3)
        else:
            axes[1, 1].text(0.5, 0.5, 'No Privacy Trials', ha='center', va='center', transform=axes[1, 1].transAxes)
            axes[1, 1].set_title('Privacy-Utility Tradeoff')
        
        # 6. Client participation vs performance
        client_participation = df['params'].apply(lambda x: x.get('clients_per_round', 0))
        axes[1, 2].scatter(client_participation, df['score'], alpha=0.6, color='red')
        axes[1, 2].set_xlabel('Clients per Round')
        axes[1, 2].set_ylabel('Score')
        axes[1, 2].set_title('Client Participation Impact')
        axes[1, 2].grid(True, alpha=0.3)
        
        # 7. Learning rate vs convergence
        learning_rates = df['params'].apply(lambda x: x.get('learning_rate', 0))
        axes[2, 0].scatter(learning_rates, conv_rounds, alpha=0.6, color='brown')
        axes[2, 0].set_xlabel('Learning Rate')
        axes[2, 0].set_ylabel('Convergence Round')
        axes[2, 0].set_title('Learning Rate vs Convergence')
        axes[2, 0].grid(True, alpha=0.3)
        
        # 8. Local epochs vs global accuracy
        local_epochs = df['params'].apply(lambda x: x.get('local_epochs', 0))
        global_acc = df['training_results'].apply(lambda x: x.get('final_accuracy', 0))
        axes[2, 1].scatter(local_epochs, global_acc, alpha=0.6, color='cyan')
        axes[2, 1].set_xlabel('Local Epochs')
        axes[2, 1].set_ylabel('Final Accuracy')
        axes[2, 1].set_title('Local Training vs Global Performance')
        axes[2, 1].grid(True, alpha=0.3)
        
        # 9. Parameter importance heatmap
        param_importance = self._calculate_parameter_importance(df)
        if param_importance is not None:
            im = axes[2, 2].imshow(param_importance.values.reshape(-1, 1), cmap='viridis', aspect='auto')
            axes[2, 2].set_yticks(range(len(param_importance)))
            axes[2, 2].set_yticklabels(param_importance.index)
            axes[2, 2].set_xticks([])
            axes[2, 2].set_title('Parameter Importance')
            plt.colorbar(im, ax=axes[2, 2])
        else:
            axes[2, 2].text(0.5, 0.5, 'Insufficient Data', ha='center', va='center', transform=axes[2, 2].transAxes)
            axes[2, 2].set_title('Parameter Importance')
        
        plt.tight_layout()
        plt.show()
    
    def _calculate_parameter_importance(self, df: pd.DataFrame) -> Optional[pd.Series]:
        """Calculate parameter importance using correlation with scores."""
        try:
            # Extract numeric parameters
            numeric_params = {}
            for _, row in df.iterrows():
                params = row['params']
                for key, value in params.items():
                    if isinstance(value, (int, float)):
                        if key not in numeric_params:
                            numeric_params[key] = []
                        numeric_params[key].append(value)
            
            if not numeric_params:
                return None
            
            # Calculate correlations
            importance = {}
            for param_name, values in numeric_params.items():
                if len(set(values)) > 1:  # Parameter has variation
                    correlation = np.corrcoef(values, df['score'])[0, 1]
                    importance[param_name] = abs(correlation) if not np.isnan(correlation) else 0
            
            return pd.Series(importance).sort_values(ascending=False)
        
        except Exception as e:
            logger.error(f"Error calculating parameter importance: {e}")
            return None
    
    def save_optimization_results(self, filepath: str):
        """Save optimization results to file."""
        results = {
            'optimization_history': self.optimization_history,
            'best_params': self.best_params,
            'best_score': self.best_score,
            'study_summary': self.study.trials_dataframe().to_dict() if self.study else None
        }
        
        with open(filepath, 'w') as f:
            json.dump(results, f, indent=2, default=str)
        
        logger.info(f"Optimization results saved to {filepath}")
    
    def load_optimization_results(self, filepath: str):
        """Load optimization results from file."""
        with open(filepath, 'r') as f:
            results = json.load(f)
        
        self.optimization_history = results.get('optimization_history', [])
        self.best_params = results.get('best_params', {})
        self.best_score = results.get('best_score', 0)
        
        logger.info(f"Optimization results loaded from {filepath}")

# Initialize the federated optimization framework
print("BatteryMind Federated Learning Hyperparameter Optimization")
print("=" * 60)

# Create optimizer instance
optimizer = FederatedHyperparameterOptimizer(
    n_trials=100,
    direction='maximize',
    storage_url='sqlite:///federated_optimization.db',
    study_name='batterymind_federated_optimization'
)

# Define optimization search space
search_space = {
    'learning_rate': (1e-5, 1e-2, 'log'),
    'local_epochs': (1, 10, 'int'),
    'clients_per_round': (5, 50, 'int'),
    'aggregation_method': ['fedavg', 'fedprox', 'scaffold', 'fedopt'],
    'use_differential_privacy': [True, False],
    'privacy_budget': (0.1, 10.0, 'log'),
    'clip_norm': (0.1, 10.0, 'log'),
    'server_learning_rate': (0.1, 2.0, 'log'),
    'momentum': (0.0, 0.9, 'float'),
    'weight_decay': (1e-6, 1e-3, 'log'),
    'batch_size': (16, 128, 'int'),
    'model_complexity': ['simple', 'medium', 'complex'],
    'communication_rounds': (50, 500, 'int'),
    'client_dropout_rate': (0.0, 0.3, 'float'),
    'adaptive_aggregation': [True, False],
    'personalization_strength': (0.0, 1.0, 'float')
}

print("Search Space Configuration:")
for param, config in search_space.items():
    print(f"  {param}: {config}")

# Run hyperparameter optimization
print("\nStarting Hyperparameter Optimization...")
print("This may take several hours depending on the number of trials and complexity.")

start_time = time.time()
best_params, best_score = optimizer.optimize(search_space)
end_time = time.time()

print(f"\nOptimization completed in {(end_time - start_time)/3600:.2f} hours")
print(f"Best Score: {best_score:.6f}")
print("Best Parameters:")
for param, value in best_params.items():
    print(f"  {param}: {value}")

# Analyze optimization results
print("\nAnalyzing Optimization Results...")
analysis = optimizer.analyze_optimization_results()

print("\nOptimization Analysis Summary:")
print(f"Total Trials: {analysis['total_trials']}")
print(f"Successful Trials: {analysis['successful_trials']}")
print(f"Failed Trials: {analysis['failed_trials']}")
print(f"Success Rate: {analysis['success_rate']:.2%}")

if 'convergence_analysis' in analysis:
    conv_analysis = analysis['convergence_analysis']
    print(f"\nConvergence Analysis:")
    print(f"  Average Convergence Round: {conv_analysis['mean_convergence_round']:.1f}")
    print(f"  Std Convergence Round: {conv_analysis['std_convergence_round']:.1f}")
    print(f"  Fastest Convergence: {conv_analysis['min_convergence_round']}")
    print(f"  Slowest Convergence: {conv_analysis['max_convergence_round']}")

if 'communication_analysis' in analysis:
    comm_analysis = analysis['communication_analysis']
    print(f"\nCommunication Analysis:")
    print(f"  Average Communication Cost: {comm_analysis['mean_communication_cost']:.2f}")
    print(f"  Communication Efficiency Score: {comm_analysis['efficiency_score']:.4f}")

if 'privacy_analysis' in analysis:
    privacy_analysis = analysis['privacy_analysis']
    print(f"\nPrivacy Analysis:")
    print(f"  Privacy-Preserving Trials: {privacy_analysis['privacy_trials']}")
    print(f"  Average Privacy Budget: {privacy_analysis['mean_privacy_budget']:.2f}")
    print(f"  Privacy-Utility Tradeoff: {privacy_analysis['privacy_utility_tradeoff']:.4f}")

if 'optimal_heterogeneity_params' in analysis:
    het_params = analysis['optimal_heterogeneity_params']
    print(f"\nOptimal Heterogeneity Handling:")
    print(f"  Aggregation Method: {het_params['aggregation_method']}")
    print(f"  Local Epochs: {het_params['local_epochs']}")
    print(f"  Clients per Round: {het_params['clients_per_round']}")
    print(f"  Consistency Score: {het_params['consistency_score']:.4f}")

# Create and display comprehensive visualizations
print("\nGenerating Optimization Visualizations...")
optimizer.visualize_federated_optimization(analysis)

# Additional analysis: Parameter correlations
print("\nParameter Correlation Analysis:")
if optimizer.optimization_history:
    df = pd.DataFrame(optimizer.optimization_history)
    
    # Extract numeric parameters for correlation analysis
    numeric_params = {}
    for _, row in df.iterrows():
        params = row['params']
        for key, value in params.items():
            if isinstance(value, (int, float)):
                if key not in numeric_params:
                    numeric_params[key] = []
                numeric_params[key].append(value)
    
    if numeric_params:
        # Add scores
        numeric_params['score'] = df['score'].tolist()
        
        # Create correlation DataFrame
        param_df = pd.DataFrame(numeric_params)
        correlation_matrix = param_df.corr()
        
        # Display correlation with score
        score_correlations = correlation_matrix['score'].drop('score').sort_values(key=abs, ascending=False)
        print("\nParameter-Score Correlations:")
        for param, corr in score_correlations.items():
            print(f"  {param}: {corr:.4f}")
        
        # Create correlation heatmap
        plt.figure(figsize=(12, 10))
        sns.heatmap(correlation_matrix, annot=True, cmap='coolwarm', center=0, fmt='.2f')
        plt.title('Parameter Correlation Matrix')
        plt.tight_layout()
        plt.show()

# Performance comparison with baseline methods
print("\nPerformance Comparison with Baseline Methods:")

baseline_methods = {
    'Standard FedAvg': {
        'learning_rate': 0.001,
        'local_epochs': 3,
        'clients_per_round': 10,
        'aggregation_method': 'fedavg',
        'use_differential_privacy': False
    },
    'FedProx': {
        'learning_rate': 0.001,
        'local_epochs': 5,
        'clients_per_round': 10,
        'aggregation_method': 'fedprox',
        'use_differential_privacy': False
    },
    'Private FedAvg': {
        'learning_rate': 0.001,
        'local_epochs': 3,
        'clients_per_round': 10,
        'aggregation_method': 'fedavg',
        'use_differential_privacy': True,
        'privacy_budget': 1.0
    }
}

print("Baseline Method Performance:")
baseline_results = {}
for method_name, params in baseline_methods.items():
    # Simulate baseline performance (in practice, this would run actual training)
    simulated_score = optimizer._simulate_federated_training(params)
    baseline_results[method_name] = simulated_score
    print(f"  {method_name}: {simulated_score:.4f}")

print(f"\nOptimized Method: {best_score:.4f}")
print(f"Best Baseline: {max(baseline_results.values()):.4f}")
print(f"Improvement: {((best_score - max(baseline_results.values())) / max(baseline_results.values()) * 100):.2f}%")

# Create performance comparison visualization
plt.figure(figsize=(12, 6))
methods = list(baseline_results.keys()) + ['Optimized']
scores = list(baseline_results.values()) + [best_score]
colors = ['lightblue'] * len(baseline_results) + ['red']

bars = plt.bar(methods, scores, color=colors, alpha=0.7)
plt.ylabel('Performance Score')
plt.title('Federated Learning Method Performance Comparison')
plt.xticks(rotation=45)

# Add value labels on bars
for bar, score in zip(bars, scores):
    plt.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01,
             f'{score:.4f}', ha='center', va='bottom', fontweight='bold')

plt.tight_layout()
plt.show()

# Save optimization results
results_file = 'federated_optimization_results.json'
optimizer.save_optimization_results(results_file)
print(f"\nOptimization results saved to {results_file}")

# Generate recommendations for production deployment
print("\nRecommendations for Production Deployment:")
print("=" * 50)

print("\n1. Optimal Hyperparameters:")
optimal_config = {
    'learning_rate': best_params.get('learning_rate', 0.001),
    'local_epochs': best_params.get('local_epochs', 3),
    'clients_per_round': best_params.get('clients_per_round', 10),
    'aggregation_method': best_params.get('aggregation_method', 'fedavg'),
    'communication_rounds': best_params.get('communication_rounds', 100),
    'batch_size': best_params.get('batch_size', 32)
}

for param, value in optimal_config.items():
    print(f"  {param}: {value}")

print("\n2. Privacy Configuration:")
if best_params.get('use_differential_privacy', False):
    print(f"  Enable Differential Privacy: Yes")
    print(f"  Privacy Budget (ε): {best_params.get('privacy_budget', 1.0):.2f}")
    print(f"  Clip Norm: {best_params.get('clip_norm', 1.0):.2f}")
else:
    print(f"  Enable Differential Privacy: No")

print("\n3. Communication Optimization:")
print(f"  Clients per Round: {best_params.get('clients_per_round', 10)}")
print(f"  Client Dropout Rate: {best_params.get('client_dropout_rate', 0.1):.2f}")
print(f"  Adaptive Aggregation: {best_params.get('adaptive_aggregation', False)}")

print("\n4. Model Configuration:")
print(f"  Model Complexity: {best_params.get('model_complexity', 'medium')}")
print(f"  Weight Decay: {best_params.get('weight_decay', 1e-4):.2e}")
print(f"  Momentum: {best_params.get('momentum', 0.9):.2f}")

print("\n5. Personalization Settings:")
print(f"  Personalization Strength: {best_params.get('personalization_strength', 0.0):.2f}")

print("\n6. Monitoring and Alerting:")
print("  - Monitor client participation rates")
print("  - Track communication costs per round")
print("  - Set up alerts for convergence failures")
print("  - Monitor privacy budget consumption")

print("\n7. Scalability Considerations:")
print("  - Plan for client heterogeneity")
print("  - Implement fault tolerance mechanisms")
print("  - Consider asynchronous aggregation for large deployments")
print("  - Set up load balancing for server infrastructure")

print("\n8. Security Recommendations:")
print("  - Implement secure aggregation protocols")
print("  - Use authenticated communication channels")
print("  - Regular security audits of federated components")
print("  - Consider homomorphic encryption for sensitive data")

# Final summary
print("\n" + "=" * 60)
print("FEDERATED LEARNING OPTIMIZATION COMPLETE")
print("=" * 60)
print(f"Best Configuration Score: {best_score:.6f}")
print(f"Total Optimization Time: {(end_time - start_time)/3600:.2f} hours")
print(f"Trials Evaluated: {len(optimizer.optimization_history)}")
print(f"Success Rate: {analysis['success_rate']:.2%}")
print("=" * 60)

# Export configuration for deployment
deployment_config = {
    'model_type': 'federated_battery_health',
    'version': '1.0.0',
    'optimized_hyperparameters': best_params,
    'performance_metrics': {
        'optimization_score': best_score,
        'convergence_rounds': analysis.get('convergence_analysis', {}).get('mean_convergence_round', 100),
        'communication_efficiency': analysis.get('communication_analysis', {}).get('efficiency_score', 0.5)
    },
    'deployment_recommendations': {
        'min_clients': 10,
        'recommended_clients': best_params.get('clients_per_round', 10),
        'max_clients': 100,
        'monitoring_frequency': 'every_round',
        'backup_strategy': 'checkpoint_every_10_rounds'
    }
}

with open('federated_deployment_config.json', 'w') as f:
    json.dump(deployment_config, f, indent=2)

print("Deployment configuration saved to: federated_deployment_config.json")
print("Ready for production deployment!")
