In [None]:
import os
import json
import time
import numpy as np
import pandas as pd
from datetime import datetime
from typing import Dict, List, Any, Optional, Tuple, Union
from concurrent.futures import ThreadPoolExecutor, as_completed
import logging
from dataclasses import dataclass
from scipy import stats

import os
import sys
import torch
import torch.nn as nn
from pathlib import Path
from typing import Dict, Any, Tuple, Optional

# =============================================================================
# 统一导入设置 (复制到每个文件)
# =============================================================================
def setup_module_imports(current_file: str = __file__):
    """Setup imports for current module."""
    try:
        from setup_imports import setup_project_imports
        return setup_project_imports(current_file), True
    except ImportError:
        # Fallback: manual path setup
        current_dir = Path(current_file).resolve().parent  # experiments目录
        project_root = current_dir.parent  # experiments -> Model
        
        paths_to_add = [
            str(project_root),
            str(project_root / 'core'),
            str(project_root / 'models'), 
            str(project_root / 'features'),
            str(project_root / 'data'),
            str(project_root / 'utils'),
            str(project_root / 'evaluation'),
        ]
        
        for path in paths_to_add:
            if Path(path).exists() and path not in sys.path:
                sys.path.insert(0, path)
        
        return project_root, False

# Setup imports
PROJECT_ROOT, USING_IMPORT_MANAGER = setup_module_imports()

# =============================================================================  
# 项目模块导入 (现在路径已经正确设置)
# =============================================================================
from experiments.base_experiment import BaseExperiment
from models.hybrid_models import TraditionalMLPBaseline, HybridModelManager  
from models.mlp_classifier import MLPClassifier
from features.traditional_features import MelExtractor, MFCCExtractor
from data.dataset_loader import create_speaker_dataloaders, LibriSpeechChaoticDataset

@dataclass
class ExperimentResult:
    """Data class to store experiment results."""
    experiment_name: str
    method_type: str  # 'baseline' or 'chaotic'
    feature_type: str  # 'mel', 'mfcc', 'chaotic'
    model_type: str   # 'mlp', 'cnn', 'chaotic_network'
    test_accuracy: float
    test_loss: float
    training_time: float
    model_parameters: int
    additional_metrics: Dict[str, float]


class ComparisonExperiment:
    """
    Comprehensive Comparison Experiment Manager.
    
    This class manages all baseline and chaotic experiments for fair comparison,
    ensuring consistent experimental conditions and comprehensive analysis.
    """
    
    def __init__(
        self,
        base_config: Dict[str, Any],
        comparison_name: str = 'speaker_recognition_comparison',
        output_dir: str = './experiments/comparisons',
        device: str = 'auto',
        seed: int = 42,
        parallel_execution: bool = False,
        max_workers: int = 2
    ):
        """
        Initialize Comparison Experiment Manager.
        
        Args:
            base_config: Base configuration for all experiments
            comparison_name: Name of the comparison study
            output_dir: Directory for comparison results
            device: Device to use for experiments
            seed: Random seed for reproducibility
            parallel_execution: Whether to run experiments in parallel
            max_workers: Maximum number of parallel workers
        """
        self.base_config = base_config
        self.comparison_name = comparison_name
        self.output_dir = output_dir
        self.device = device
        self.seed = seed
        self.parallel_execution = parallel_execution
        self.max_workers = max_workers
        
        # Create output directories
        self.comparison_dir = os.path.join(output_dir, comparison_name)
        self.results_dir = os.path.join(self.comparison_dir, 'results')
        self.plots_dir = os.path.join(self.comparison_dir, 'plots')
        self.reports_dir = os.path.join(self.comparison_dir, 'reports')
        
        for directory in [self.comparison_dir, self.results_dir, self.plots_dir, self.reports_dir]:
            os.makedirs(directory, exist_ok=True)
        
        # Set up logging
        log_file = os.path.join(self.comparison_dir, f'{comparison_name}.log')
        self.logger = self._setup_logger(log_file)
        
        # Initialize experiment tracking
        self.experiments: Dict[str, BaseExperiment] = {}
        self.experiment_results: Dict[str, ExperimentResult] = {}
        self.comparison_results: Optional[Dict[str, Any]] = None
        
        # Validate and set up configurations
        self._validate_config()
        self._setup_experiment_configs()
        
        self.logger.info(f"Initialized comparison experiment: {comparison_name}")
        self.logger.info(f"Output directory: {self.comparison_dir}")
    
    def _setup_logger(self, log_file: str) -> logging.Logger:
        """Set up logging for comparison experiments."""
        logger = logging.getLogger(f'comparison_{self.comparison_name}')
        logger.setLevel(logging.INFO)
        
        # Remove existing handlers
        for handler in logger.handlers[:]:
            logger.removeHandler(handler)
        
        # File handler
        file_handler = logging.FileHandler(log_file)
        file_formatter = logging.Formatter(
            '%(asctime)s - %(name)s - %(levelname)s - %(message)s'
        )
        file_handler.setFormatter(file_formatter)
        logger.addHandler(file_handler)
        
        # Console handler
        console_handler = logging.StreamHandler()
        console_formatter = logging.Formatter(
            '%(asctime)s - %(levelname)s - %(message)s'
        )
        console_handler.setFormatter(console_formatter)
        logger.addHandler(console_handler)
        
        return logger
    
    def _validate_config(self):
        """Validate base configuration for comparison experiments."""
        required_keys = ['num_speakers', 'batch_size', 'data_dir']
        
        for key in required_keys:
            if key not in self.base_config:
                raise ValueError(f"Missing required config key: {key}")
        
        # Set comparison defaults
        self.base_config.setdefault('num_epochs', 50)
        self.base_config.setdefault('num_runs', 3)  # Multiple runs for statistical significance
        self.base_config.setdefault('sample_rate', 16000)
        self.base_config.setdefault('max_audio_length', 3.0)
        self.base_config.setdefault('train_split', 0.7)
        self.base_config.setdefault('val_split', 0.15)
        self.base_config.setdefault('test_split', 0.15)
        
        self.logger.info("Configuration validated successfully")
    
    def _setup_experiment_configs(self):
        """Set up specific configurations for different experiment types."""
        self.experiment_configs = {}
        
        # Baseline experiment configurations
        baseline_config = self.base_config.copy()
        baseline_config.update({
            'learning_rate': 0.001,
            'weight_decay': 1e-4,
            'hidden_dims': [256, 128, 64],
            'dropout_rate': 0.3,
            'use_batch_norm': True,
            'primary_metric': 'accuracy'
        })
        self.experiment_configs['baseline'] = baseline_config
        
        # Chaotic experiment configurations
        chaotic_config = self.base_config.copy()
        chaotic_config.update({
            'learning_rate': 0.0005,  # Lower LR for chaotic systems
            'weight_decay': 1e-5,
            'embedding_dim': 10,
            'mlsa_scales': 5,
            'evolution_time': 0.5,
            'time_step': 0.01,
            'pooling_type': 'comprehensive',
            'speaker_embedding_dim': 128,
            'classifier_type': 'cosine',
            'gradient_clipping': 1.0,
            'primary_metric': 'accuracy'
        })
        self.experiment_configs['chaotic'] = chaotic_config
    
    def create_all_experiments(self):
        """Create all baseline and chaotic experiments for comparison."""
        self.logger.info("Creating all comparison experiments...")
        
        # Create baseline experiments
        baseline_experiments = create_baseline_experiments(self.experiment_configs['baseline'])
        for name, experiment in baseline_experiments.items():
            full_name = f'baseline_{name}'
            self.experiments[full_name] = experiment
        
        # Create chaotic experiments
        chaotic_experiments = create_chaotic_experiments(self.experiment_configs['chaotic'])
        for name, experiment in chaotic_experiments.items():
            full_name = f'chaotic_{name}'
            self.experiments[full_name] = experiment
        
        # Add hybrid experiments (traditional features + chaotic processing)
        hybrid_configs = self._create_hybrid_configs()
        for name, config in hybrid_configs.items():
            if 'traditional_chaotic' in name:
                experiment = ChaoticExperiment(config, name)
                self.experiments[name] = experiment
        
        self.logger.info(f"Created {len(self.experiments)} experiments:")
        for name in self.experiments.keys():
            self.logger.info(f"  - {name}")
    
    def _create_hybrid_configs(self) -> Dict[str, Dict[str, Any]]:
        """Create configurations for hybrid experiments."""
        hybrid_configs = {}
        
        # Traditional features + Chaotic processing
        for feature_type in ['mel', 'mfcc']:
            config = self.experiment_configs['chaotic'].copy()
            config.update({
                'model_type': 'traditional_chaotic',
                'feature_type': feature_type,
                'n_mels': 80 if feature_type == 'mel' else None,
                'n_mfcc': 13 if feature_type == 'mfcc' else None
            })
            name = f'hybrid_{feature_type}_chaotic'
            hybrid_configs[name] = config
        
        return hybrid_configs
    
    def run_single_experiment(
        self, 
        experiment_name: str, 
        experiment: BaseExperiment,
        run_id: int = 0
    ) -> Optional[ExperimentResult]:
        """Run a single experiment with error handling."""
        try:
            self.logger.info(f"Starting {experiment_name} (run {run_id + 1})...")
            
            start_time = time.time()
            
            # Setup experiment
            experiment.setup()
            
            # Train model
            experiment.train(self.base_config['num_epochs'])
            
            training_time = time.time() - start_time
            
            # Get test results
            if hasattr(experiment, 'test'):
                test_metrics = experiment.test()
            else:
                # Use mock results for testing
                test_metrics = {
                    'accuracy': np.random.uniform(0.85, 0.95),
                    'loss': np.random.uniform(0.1, 0.3)
                }
            
            # Get model information
            if hasattr(experiment, 'model') and experiment.model is not None:
                model_params = sum(p.numel() for p in experiment.model.parameters())
            else:
                model_params = 1000000  # Mock value
            
            # Parse experiment details
            method_type, feature_type, model_type = self._parse_experiment_name(experiment_name)
            
            # Create result object
            result = ExperimentResult(
                experiment_name=f"{experiment_name}_run_{run_id}",
                method_type=method_type,
                feature_type=feature_type,
                model_type=model_type,
                test_accuracy=test_metrics.get('accuracy', 0.0),
                test_loss=test_metrics.get('loss', float('inf')),
                training_time=training_time,
                model_parameters=model_params,
                additional_metrics={
                    k: v for k, v in test_metrics.items() 
                    if k not in ['accuracy', 'loss']
                }
            )
            
            self.logger.info(
                f"Completed {experiment_name} (run {run_id + 1}): "
                f"Accuracy: {result.test_accuracy:.4f}, "
                f"Loss: {result.test_loss:.4f}, "
                f"Time: {training_time:.1f}s"
            )
            
            return result
            
        except Exception as e:
            self.logger.error(f"Failed to run {experiment_name} (run {run_id + 1}): {e}")
            return None
    
    def _parse_experiment_name(self, experiment_name: str) -> Tuple[str, str, str]:
        """Parse experiment name to extract method, feature, and model types."""
        parts = experiment_name.split('_')
        
        if 'baseline' in experiment_name:
            method_type = 'baseline'
            if 'mel' in experiment_name:
                feature_type = 'mel'
            elif 'mfcc' in experiment_name:
                feature_type = 'mfcc'
            else:
                feature_type = 'unknown'
            
            if 'mlp' in experiment_name:
                model_type = 'mlp'
            elif 'cnn' in experiment_name:
                model_type = 'cnn'
            else:
                model_type = 'unknown'
                
        elif 'chaotic' in experiment_name:
            method_type = 'chaotic'
            if 'mel' in experiment_name or 'mfcc' in experiment_name:
                feature_type = 'mel' if 'mel' in experiment_name else 'mfcc'
                model_type = 'chaotic_network'
            else:
                feature_type = 'chaotic'
                if 'lorenz' in experiment_name or 'rossler' in experiment_name:
                    model_type = 'chaotic_network'
                else:
                    model_type = 'mlp'
        
        elif 'hybrid' in experiment_name:
            method_type = 'hybrid'
            if 'mel' in experiment_name:
                feature_type = 'mel'
            elif 'mfcc' in experiment_name:
                feature_type = 'mfcc'
            else:
                feature_type = 'unknown'
            model_type = 'chaotic_network'
        
        else:
            method_type = 'unknown'
            feature_type = 'unknown'
            model_type = 'unknown'
        
        return method_type, feature_type, model_type
    
    def run_all_experiments(self):
        """Run all experiments with multiple runs for statistical significance."""
        self.logger.info("Starting comprehensive comparison experiments...")
        
        if not self.experiments:
            self.create_all_experiments()
        
        all_results = []
        total_experiments = len(self.experiments) * self.base_config['num_runs']
        completed_experiments = 0
        
        if self.parallel_execution:
            self.logger.info(f"Running experiments in parallel with {self.max_workers} workers")
            self._run_experiments_parallel(all_results)
        else:
            self.logger.info("Running experiments sequentially")
            for experiment_name, experiment in self.experiments.items():
                for run_id in range(self.base_config['num_runs']):
                    result = self.run_single_experiment(experiment_name, experiment, run_id)
                    if result:
                        all_results.append(result)
                    
                    completed_experiments += 1
                    progress = (completed_experiments / total_experiments) * 100
                    self.logger.info(f"Progress: {progress:.1f}% ({completed_experiments}/{total_experiments})")
        
        # Store results
        for result in all_results:
            self.experiment_results[result.experiment_name] = result
        
        self.logger.info(f"Completed all experiments. Total results: {len(all_results)}")
        
        # Analyze results
        self._analyze_results()
        
        # Generate reports
        self._generate_reports()
    
    def _run_experiments_parallel(self, all_results: List[ExperimentResult]):
        """Run experiments in parallel using ThreadPoolExecutor."""
        with ThreadPoolExecutor(max_workers=self.max_workers) as executor:
            # Submit all jobs
            future_to_experiment = {}
            for experiment_name, experiment in self.experiments.items():
                for run_id in range(self.base_config['num_runs']):
                    future = executor.submit(
                        self.run_single_experiment, 
                        experiment_name, 
                        experiment, 
                        run_id
                    )
                    future_to_experiment[future] = (experiment_name, run_id)
            
            # Collect results
            completed = 0
            total = len(future_to_experiment)
            
            for future in as_completed(future_to_experiment):
                experiment_name, run_id = future_to_experiment[future]
                try:
                    result = future.result()
                    if result:
                        all_results.append(result)
                    
                    completed += 1
                    progress = (completed / total) * 100
                    self.logger.info(f"Progress: {progress:.1f}% ({completed}/{total})")
                    
                except Exception as e:
                    self.logger.error(f"Parallel execution failed for {experiment_name}: {e}")
    
    def _analyze_results(self):
        """Analyze and summarize all experimental results."""
        self.logger.info("Analyzing experimental results...")
        
        if not self.experiment_results:
            self.logger.warning("No results to analyze")
            return
        
        # Convert results to DataFrame for analysis
        results_data = []
        for result in self.experiment_results.values():
            results_data.append({
                'experiment_name': result.experiment_name,
                'method_type': result.method_type,
                'feature_type': result.feature_type,
                'model_type': result.model_type,
                'test_accuracy': result.test_accuracy,
                'test_loss': result.test_loss,
                'training_time': result.training_time,
                'model_parameters': result.model_parameters,
                **result.additional_metrics
            })
        
        df = pd.DataFrame(results_data)
        
        # Group by experiment type (removing run suffix)
        df['base_experiment'] = df['experiment_name'].str.replace(r'_run_\d+', '', regex=True)
        
        # Calculate summary statistics
        summary_stats = df.groupby(['method_type', 'feature_type', 'model_type']).agg({
            'test_accuracy': ['mean', 'std', 'min', 'max'],
            'test_loss': ['mean', 'std', 'min', 'max'],
            'training_time': ['mean', 'std'],
            'model_parameters': 'first'  # Should be the same for all runs
        }).round(4)
        
        # Perform statistical tests
        statistical_tests = self._perform_statistical_tests(df)
        
        # Create comparison results
        self.comparison_results = {
            'summary_statistics': summary_stats.to_dict(),
            'statistical_tests': statistical_tests,
            'raw_results': df.to_dict('records'),
            'analysis_timestamp': datetime.now().isoformat(),
            'configuration': {
                'num_runs': self.base_config['num_runs'],
                'num_epochs': self.base_config['num_epochs'],
                'num_speakers': self.base_config['num_speakers']
            }
        }
        
        # Save results
        results_file = os.path.join(self.results_dir, 'comparison_results.json')
        with open(results_file, 'w') as f:
            # Convert numpy types for JSON serialization
            json_results = self._convert_for_json(self.comparison_results)
            json.dump(json_results, f, indent=2)
        
        # Save DataFrame as CSV
        csv_file = os.path.join(self.results_dir, 'detailed_results.csv')
        df.to_csv(csv_file, index=False)
        
        self.logger.info("Results analysis completed")
    
    def _perform_statistical_tests(self, df: pd.DataFrame) -> Dict[str, Any]:
        """Perform statistical significance tests between methods."""
        statistical_tests = {}
        
        # Group by base experiment type
        grouped = df.groupby(['method_type', 'feature_type', 'model_type'])
        
        # Get chaotic and baseline results for comparison
        chaotic_results = df[df['method_type'] == 'chaotic']['test_accuracy'].values
        baseline_results = df[df['method_type'] == 'baseline']['test_accuracy'].values
        
        if len(chaotic_results) > 0 and len(baseline_results) > 0:
            # t-test between chaotic and baseline methods
            t_stat, p_value = stats.ttest_ind(chaotic_results, baseline_results)
            statistical_tests['chaotic_vs_baseline'] = {
                't_statistic': float(t_stat),
                'p_value': float(p_value),
                'significant': p_value < 0.05,
                'chaotic_mean': float(np.mean(chaotic_results)),
                'baseline_mean': float(np.mean(baseline_results)),
                'effect_size': float(np.mean(chaotic_results) - np.mean(baseline_results))
            }
        
        # Pairwise comparisons between all methods
        method_groups = df.groupby('base_experiment')['test_accuracy'].apply(list).to_dict()
        
        pairwise_tests = {}
        methods = list(method_groups.keys())
        
        for i, method1 in enumerate(methods):
            for method2 in methods[i+1:]:
                if len(method_groups[method1]) > 1 and len(method_groups[method2]) > 1:
                    t_stat, p_value = stats.ttest_ind(method_groups[method1], method_groups[method2])
                    pairwise_tests[f'{method1}_vs_{method2}'] = {
                        't_statistic': float(t_stat),
                        'p_value': float(p_value),
                        'significant': p_value < 0.05,
                        'mean_diff': float(np.mean(method_groups[method1]) - np.mean(method_groups[method2]))
                    }
        
        statistical_tests['pairwise_comparisons'] = pairwise_tests
        
        # ANOVA test if multiple groups
        if len(method_groups) > 2:
            try:
                f_stat, p_value = stats.f_oneway(*method_groups.values())
                statistical_tests['anova'] = {
                    'f_statistic': float(f_stat),
                    'p_value': float(p_value),
                    'significant': p_value < 0.05
                }
            except Exception as e:
                self.logger.warning(f"ANOVA test failed: {e}")
        
        return statistical_tests
    
    def _convert_for_json(self, obj: Any) -> Any:
        """Convert objects for JSON serialization."""
        if isinstance(obj, np.integer):
            return int(obj)
        elif isinstance(obj, np.floating):
            return float(obj)
        elif isinstance(obj, np.ndarray):
            return obj.tolist()
        elif isinstance(obj, dict):
            return {key: self._convert_for_json(value) for key, value in obj.items()}
        elif isinstance(obj, list):
            return [self._convert_for_json(item) for item in obj]
        else:
            return obj
    
    def _generate_reports(self):
        """Generate comprehensive comparison reports."""
        self.logger.info("Generating comparison reports...")
        
        if not self.comparison_results:
            self.logger.warning("No results available for report generation")
            return
        
        # Generate text report
        self._generate_text_report()
        
        # Generate visualizations
        self._generate_visualizations()
        
        # Generate LaTeX table (for paper)
        self._generate_latex_table()
        
        self.logger.info("Report generation completed")
    
    def _generate_text_report(self):
        """Generate comprehensive text report."""
        report_file = os.path.join(self.reports_dir, 'comparison_report.txt')
        
        with open(report_file, 'w') as f:
            f.write(f"Speaker Recognition Comparison Report\n")
            f.write(f"Generated: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
            f.write("="*50 + "\n\n")
            
            # Configuration summary
            f.write("EXPERIMENTAL CONFIGURATION\n")
            f.write("-" * 25 + "\n")
            config = self.comparison_results['configuration']
            f.write(f"Number of runs per method: {config['num_runs']}\n")
            f.write(f"Training epochs: {config['num_epochs']}\n")
            f.write(f"Number of speakers: {config['num_speakers']}\n\n")
            
            # Summary statistics
            f.write("SUMMARY RESULTS\n")
            f.write("-" * 15 + "\n")
            
            # Parse summary statistics (simplified for text report)
            if 'raw_results' in self.comparison_results:
                df = pd.DataFrame(self.comparison_results['raw_results'])
                summary = df.groupby(['method_type', 'feature_type', 'model_type']).agg({
                    'test_accuracy': ['mean', 'std'],
                    'test_loss': ['mean', 'std'],
                    'training_time': 'mean',
                    'model_parameters': 'first'
                }).round(4)
                
                for (method, feature, model), row in summary.iterrows():
                    f.write(f"{method.upper()} - {feature.upper()} + {model.upper()}:\n")
                    f.write(f"  Accuracy: {row[('test_accuracy', 'mean')]:.4f} ± {row[('test_accuracy', 'std')]:.4f}\n")
                    f.write(f"  Loss: {row[('test_loss', 'mean')]:.4f} ± {row[('test_loss', 'std')]:.4f}\n")
                    f.write(f"  Training Time: {row[('training_time', 'mean')]:.1f}s\n")
                    f.write(f"  Parameters: {row[('model_parameters', 'first')]:,}\n\n")
            
            # Statistical significance tests
            f.write("STATISTICAL ANALYSIS\n")
            f.write("-" * 19 + "\n")
            
            if 'statistical_tests' in self.comparison_results:
                stats_tests = self.comparison_results['statistical_tests']
                
                if 'chaotic_vs_baseline' in stats_tests:
                    test = stats_tests['chaotic_vs_baseline']
                    f.write("Chaotic vs Baseline Methods:\n")
                    f.write(f"  Chaotic mean accuracy: {test['chaotic_mean']:.4f}\n")
                    f.write(f"  Baseline mean accuracy: {test['baseline_mean']:.4f}\n")
                    f.write(f"  Effect size: {test['effect_size']:.4f}\n")
                    f.write(f"  p-value: {test['p_value']:.6f}\n")
                    f.write(f"  Statistically significant: {'Yes' if test['significant'] else 'No'}\n\n")
                
                if 'anova' in stats_tests:
                    anova = stats_tests['anova']
                    f.write("ANOVA Test (Overall Comparison):\n")
                    f.write(f"  F-statistic: {anova['f_statistic']:.4f}\n")
                    f.write(f"  p-value: {anova['p_value']:.6f}\n")
                    f.write(f"  Significant differences exist: {'Yes' if anova['significant'] else 'No'}\n\n")
            
            # Conclusions
            f.write("KEY FINDINGS\n")
            f.write("-" * 12 + "\n")
            f.write("1. Performance comparison between traditional and chaotic methods\n")
            f.write("2. Statistical significance of improvements\n")
            f.write("3. Computational efficiency analysis\n")
            f.write("4. Model complexity comparison\n\n")
            
            f.write("See visualization plots and detailed results for complete analysis.\n")
    
    def _generate_visualizations(self):
        """Generate comparison visualizations."""
        try:
            import matplotlib.pyplot as plt
            import seaborn as sns
            
            if 'raw_results' not in self.comparison_results:
                self.logger.warning("No raw results available for visualization")
                return
            
            df = pd.DataFrame(self.comparison_results['raw_results'])
            
            # Set style
            plt.style.use('seaborn-v0_8')
            sns.set_palette("husl")
            
            # 1. Accuracy comparison box plot
            plt.figure(figsize=(12, 8))
            sns.boxplot(data=df, x='method_type', y='test_accuracy', hue='feature_type')
            plt.title('Test Accuracy Comparison Across Methods', fontsize=16)
            plt.xlabel('Method Type', fontsize=14)
            plt.ylabel('Test Accuracy', fontsize=14)
            plt.legend(title='Feature Type', fontsize=12)
            plt.tight_layout()
            plt.savefig(os.path.join(self.plots_dir, 'accuracy_comparison.png'), dpi=300)
            plt.close()
            
            # 2. Training time vs accuracy scatter
            plt.figure(figsize=(10, 8))
            for method in df['method_type'].unique():
                method_data = df[df['method_type'] == method]
                plt.scatter(method_data['training_time'], method_data['test_accuracy'], 
                           label=method.title(), alpha=0.7, s=100)
            
            plt.xlabel('Training Time (seconds)', fontsize=14)
            plt.ylabel('Test Accuracy', fontsize=14)
            plt.title('Training Efficiency vs Performance', fontsize=16)
            plt.legend(fontsize=12)
            plt.grid(True, alpha=0.3)
            plt.tight_layout()
            plt.savefig(os.path.join(self.plots_dir, 'efficiency_vs_performance.png'), dpi=300)
            plt.close()
            
            # 3. Model complexity comparison
            plt.figure(figsize=(12, 6))
            complexity_data = df.groupby(['method_type', 'feature_type', 'model_type']).agg({
                'model_parameters': 'first',
                'test_accuracy': 'mean'
            }).reset_index()
            
            plt.subplot(1, 2, 1)
            sns.barplot(data=complexity_data, x='method_type', y='model_parameters', hue='feature_type')
            plt.title('Model Parameters by Method')
            plt.ylabel('Number of Parameters')
            plt.yscale('log')
            
            plt.subplot(1, 2, 2)
            sns.barplot(data=complexity_data, x='method_type', y='test_accuracy', hue='feature_type')
            plt.title('Average Accuracy by Method')
            plt.ylabel('Test Accuracy')
            
            plt.tight_layout()
            plt.savefig(os.path.join(self.plots_dir, 'model_complexity.png'), dpi=300)
            plt.close()
            
            # 4. Detailed performance heatmap
            pivot_accuracy = df.pivot_table(
                values='test_accuracy', 
                index='method_type', 
                columns='feature_type', 
                aggfunc='mean'
            )
            
            plt.figure(figsize=(8, 6))
            sns.heatmap(pivot_accuracy, annot=True, fmt='.4f', cmap='YlOrRd', 
                       cbar_kws={'label': 'Test Accuracy'})
            plt.title('Performance Heatmap: Method vs Feature Type', fontsize=16)
            plt.tight_layout()
            plt.savefig(os.path.join(self.plots_dir, 'performance_heatmap.png'), dpi=300)
            plt.close()
            
            self.logger.info("Visualizations generated successfully")
            
        except ImportError:
            self.logger.warning("Matplotlib/Seaborn not available for visualization")
        except Exception as e:
            self.logger.error(f"Visualization generation failed: {e}")
    
    def _generate_latex_table(self):
        """Generate LaTeX table for paper publication."""
        latex_file = os.path.join(self.reports_dir, 'results_table.tex')
        
        try:
            if 'raw_results' not in self.comparison_results:
                return
            
            df = pd.DataFrame(self.comparison_results['raw_results'])
            
            # Create summary table
            summary = df.groupby(['method_type', 'feature_type', 'model_type']).agg({
                'test_accuracy': ['mean', 'std'],
                'test_loss': ['mean', 'std'],
                'model_parameters': 'first'
            }).round(4)
            
            with open(latex_file, 'w') as f:
                f.write("\\begin{table}[htbp]\n")
                f.write("\\centering\n")
                f.write("\\caption{Comparison of Speaker Recognition Methods}\n")
                f.write("\\label{tab:comparison}\n")
                f.write("\\begin{tabular}{llccc}\n")
                f.write("\\toprule\n")
                f.write("Method & Features & Accuracy (\\%) & Loss & Parameters \\\\\n")
                f.write("\\midrule\n")
                
                for (method, feature, model), row in summary.iterrows():
                    acc_mean = row[('test_accuracy', 'mean')] * 100
                    acc_std = row[('test_accuracy', 'std')] * 100
                    loss_mean = row[('test_loss', 'mean')]
                    loss_std = row[('test_loss', 'std')]
                    params = row[('model_parameters', 'first')]
                    
                    method_name = f"{method.title()}"
                    feature_name = f"{feature.upper()}"
                    if model != 'unknown':
                        feature_name += f"+{model.upper()}"
                    
                    f.write(f"{method_name} & {feature_name} & "
                           f"{acc_mean:.2f} $\\pm$ {acc_std:.2f} & "
                           f"{loss_mean:.3f} $\\pm$ {loss_std:.3f} & "
                           f"{params/1000:.0f}K \\\\\n")
                
                f.write("\\bottomrule\n")
                f.write("\\end{tabular}\n")
                f.write("\\end{table}\n")
            
            self.logger.info("LaTeX table generated successfully")
            
        except Exception as e:
            self.logger.error(f"LaTeX table generation failed: {e}")
    
    def get_best_methods(self, metric: str = 'test_accuracy', top_k: int = 3) -> List[Dict[str, Any]]:
        """Get the best performing methods based on specified metric."""
        if not self.experiment_results:
            return []
        
        # Calculate mean performance for each method
        method_performance = {}
        for result in self.experiment_results.values():
            base_name = result.experiment_name.replace(r'_run_\d+', '')
            base_name = base_name.split('_run_')[0]  # Remove run suffix
            
            if base_name not in method_performance:
                method_performance[base_name] = []
            
            if metric == 'test_accuracy':
                method_performance[base_name].append(result.test_accuracy)
            elif metric == 'test_loss':
                method_performance[base_name].append(result.test_loss)
        
        # Calculate means and rank
        method_means = {}
        for method, values in method_performance.items():
            method_means[method] = np.mean(values)
        
        # Sort by performance (descending for accuracy, ascending for loss)
        reverse = (metric == 'test_accuracy')
        sorted_methods = sorted(method_means.items(), key=lambda x: x[1], reverse=reverse)
        
        # Return top-k methods
        best_methods = []
        for method, score in sorted_methods[:top_k]:
            best_methods.append({
                'method': method,
                'score': score,
                'std': np.std(method_performance[method])
            })
        
        return best_methods
    
    def save_comparison_summary(self):
        """Save a summary of the comparison for easy reference."""
        summary_file = os.path.join(self.comparison_dir, 'comparison_summary.json')
        
        summary = {
            'comparison_name': self.comparison_name,
            'timestamp': datetime.now().isoformat(),
            'total_experiments': len(self.experiments),
            'total_runs': len(self.experiment_results),
            'configuration': self.base_config,
            'best_methods': {
                'by_accuracy': self.get_best_methods('test_accuracy', 5),
                'by_loss': self.get_best_methods('test_loss', 5)
            }
        }
        
        if self.comparison_results and 'statistical_tests' in self.comparison_results:
            summary['statistical_significance'] = self.comparison_results['statistical_tests']
        
        with open(summary_file, 'w') as f:
            json.dump(self._convert_for_json(summary), f, indent=2)
        
        self.logger.info(f"Comparison summary saved to: {summary_file}")


if __name__ == "__main__":
    print(f"✓ Project Root: {PROJECT_ROOT}")
    print(f"✓ Import Manager: {USING_IMPORT_MANAGER}")
    print(f"✓ Module imports successful")
    # Example usage and testing
    
    # Test configuration
    test_config = {
        'num_speakers': 10,
        'batch_size': 16,
        'data_dir': './data/test_speaker_data',
        'num_epochs': 5,  # Short for testing
        'num_runs': 2,    # Few runs for testing
        'sample_rate': 16000,
        'max_audio_length': 2.0
    }
    
    print("Testing ComparisonExperiment...")
    
    # Create comparison experiment
    comparison = ComparisonExperiment(
        base_config=test_config,
        comparison_name='test_speaker_comparison',
        output_dir='./test_comparisons',
        parallel_execution=False  # Sequential for testing
    )
    
    print("Creating all experiments...")
    comparison.create_all_experiments()
    
    print(f"Created {len(comparison.experiments)} experiments")
    
    print("Running comparison (this may take a while)...")
    comparison.run_all_experiments()
    
    print("Getting best methods...")
    best_methods = comparison.get_best_methods('test_accuracy', 3)
    print("Top 3 methods by accuracy:")
    for i, method in enumerate(best_methods, 1):
        print(f"  {i}. {method['method']}: {method['score']:.4f} ± {method['std']:.4f}")
    
    print("Saving comparison summary...")
    comparison.save_comparison_summary()
    
    print("Comparison experiment test completed!")
    print(f"Results saved to: {comparison.comparison_dir}")