From c2f668af4daa09ea284eba170dfce93698bf4cf5 Mon Sep 17 00:00:00 2001 From: Asankhaya Sharma Date: Mon, 21 Jul 2025 22:21:43 +0800 Subject: [PATCH] Switch to process-based parallelism and remove threading Replaces the previous thread-based parallel controller with a new process-based parallel controller for true parallel execution, updating all references and tests accordingly. Removes the threaded_parallel.py module, adds process_parallel.py, updates controller logic and shutdown handling, and introduces new tests for process-based parallelism. Also updates the README to reflect the change and bumps the package version to 0.1.0. --- README.md | 2 +- openevolve/_version.py | 2 +- openevolve/controller.py | 6 +- openevolve/process_parallel.py | 539 ++++++++++++++++++++++++++++++++ openevolve/threaded_parallel.py | 353 --------------------- tests/test_checkpoint_resume.py | 24 +- tests/test_process_parallel.py | 190 +++++++++++ 7 files changed, 746 insertions(+), 370 deletions(-) create mode 100644 openevolve/process_parallel.py delete mode 100644 openevolve/threaded_parallel.py create mode 100644 tests/test_process_parallel.py diff --git a/README.md b/README.md index 663623b78..dd00115f1 100644 --- a/README.md +++ b/README.md @@ -56,7 +56,7 @@ OpenEvolve implements a comprehensive evolutionary coding system with: - **Error Recovery**: Robust checkpoint loading with automatic fix for common serialization issues #### 🚀 **Performance & Scalability** -- **Threaded Parallelism**: High-throughput asynchronous evaluation pipeline +- **Process-Based Parallelism**: True parallel execution bypassing Python's GIL for CPU-bound tasks - **Resource Management**: Memory limits, timeouts, and resource monitoring - **Efficient Storage**: Optimized database with artifact management and cleanup policies diff --git a/openevolve/_version.py b/openevolve/_version.py index 73b505cfd..57fc3ceb2 100644 --- a/openevolve/_version.py +++ b/openevolve/_version.py @@ -1,3 +1,3 @@ """Version information for openevolve package.""" -__version__ = "0.0.20" +__version__ = "0.1.0" diff --git a/openevolve/controller.py b/openevolve/controller.py index 0c0ab4d29..260c3a227 100644 --- a/openevolve/controller.py +++ b/openevolve/controller.py @@ -16,7 +16,7 @@ from openevolve.evaluator import Evaluator from openevolve.llm.ensemble import LLMEnsemble from openevolve.prompt.sampler import PromptSampler -from openevolve.threaded_parallel import ImprovedParallelController +from openevolve.process_parallel import ProcessParallelController from openevolve.utils.code_utils import ( extract_code_language, ) @@ -257,7 +257,7 @@ async def run( # Initialize improved parallel processing try: - self.parallel_controller = ImprovedParallelController( + self.parallel_controller = ProcessParallelController( self.config, self.evaluation_file, self.database ) @@ -439,7 +439,7 @@ async def _run_evolution_with_checkpoints( ) # Check if shutdown was requested - if self.parallel_controller.shutdown_flag.is_set(): + if self.parallel_controller.shutdown_event.is_set(): logger.info("Evolution stopped due to shutdown request") return diff --git a/openevolve/process_parallel.py b/openevolve/process_parallel.py new file mode 100644 index 000000000..285340bdd --- /dev/null +++ b/openevolve/process_parallel.py @@ -0,0 +1,539 @@ +""" +Process-based parallel controller for true parallelism +""" + +import asyncio +import logging +import multiprocessing as mp +import pickle +import signal +import time +from concurrent.futures import ProcessPoolExecutor, Future +from dataclasses import dataclass, asdict +from pathlib import Path +from typing import Any, Dict, List, Optional, Tuple + +from openevolve.config import Config +from openevolve.database import Program, ProgramDatabase + +logger = logging.getLogger(__name__) + + +@dataclass +class SerializableResult: + """Result that can be pickled and sent between processes""" + child_program_dict: Optional[Dict[str, Any]] = None + parent_id: Optional[str] = None + iteration_time: float = 0.0 + prompt: Optional[Dict[str, str]] = None + llm_response: Optional[str] = None + artifacts: Optional[Dict[str, Any]] = None + iteration: int = 0 + error: Optional[str] = None + + +def _worker_init(config_dict: dict, evaluation_file: str) -> None: + """Initialize worker process with necessary components""" + global _worker_config + global _worker_evaluation_file + global _worker_evaluator + global _worker_llm_ensemble + global _worker_prompt_sampler + + # Store config for later use + # Reconstruct Config object from nested dictionaries + from openevolve.config import Config, DatabaseConfig, EvaluatorConfig, LLMConfig, PromptConfig, LLMModelConfig + + # Reconstruct model objects + models = [LLMModelConfig(**m) for m in config_dict['llm']['models']] + evaluator_models = [LLMModelConfig(**m) for m in config_dict['llm']['evaluator_models']] + + # Create LLM config with models + llm_dict = config_dict['llm'].copy() + llm_dict['models'] = models + llm_dict['evaluator_models'] = evaluator_models + llm_config = LLMConfig(**llm_dict) + + # Create other configs + prompt_config = PromptConfig(**config_dict['prompt']) + database_config = DatabaseConfig(**config_dict['database']) + evaluator_config = EvaluatorConfig(**config_dict['evaluator']) + + _worker_config = Config( + llm=llm_config, + prompt=prompt_config, + database=database_config, + evaluator=evaluator_config, + **{k: v for k, v in config_dict.items() + if k not in ['llm', 'prompt', 'database', 'evaluator']} + ) + _worker_evaluation_file = evaluation_file + + # These will be lazily initialized on first use + _worker_evaluator = None + _worker_llm_ensemble = None + _worker_prompt_sampler = None + + +def _lazy_init_worker_components(): + """Lazily initialize expensive components on first use""" + global _worker_evaluator + global _worker_llm_ensemble + global _worker_prompt_sampler + + if _worker_llm_ensemble is None: + from openevolve.llm.ensemble import LLMEnsemble + _worker_llm_ensemble = LLMEnsemble(_worker_config.llm.models) + + if _worker_prompt_sampler is None: + from openevolve.prompt.sampler import PromptSampler + _worker_prompt_sampler = PromptSampler(_worker_config.prompt) + + if _worker_evaluator is None: + from openevolve.evaluator import Evaluator + from openevolve.llm.ensemble import LLMEnsemble + from openevolve.prompt.sampler import PromptSampler + + # Create evaluator-specific components + evaluator_llm = LLMEnsemble(_worker_config.llm.evaluator_models) + evaluator_prompt = PromptSampler(_worker_config.prompt) + evaluator_prompt.set_templates("evaluator_system_message") + + _worker_evaluator = Evaluator( + _worker_config.evaluator, + _worker_evaluation_file, + evaluator_llm, + evaluator_prompt, + database=None # No shared database in worker + ) + + +def _run_iteration_worker( + iteration: int, + db_snapshot: Dict[str, Any], + parent_id: str, + inspiration_ids: List[str] +) -> SerializableResult: + """Run a single iteration in a worker process""" + try: + # Lazy initialization + _lazy_init_worker_components() + + # Reconstruct programs from snapshot + programs = { + pid: Program(**prog_dict) + for pid, prog_dict in db_snapshot["programs"].items() + } + + parent = programs[parent_id] + inspirations = [programs[pid] for pid in inspiration_ids if pid in programs] + + # Get parent artifacts if available + parent_artifacts = db_snapshot["artifacts"].get(parent_id) + + # Get island-specific programs for context + parent_island = parent.metadata.get("island", db_snapshot["current_island"]) + island_programs = [ + programs[pid] for pid in db_snapshot["islands"][parent_island] + if pid in programs + ] + + # Sort by metrics for top programs + from openevolve.utils.metrics_utils import safe_numeric_average + island_programs.sort( + key=lambda p: p.metrics.get("combined_score", safe_numeric_average(p.metrics)), + reverse=True + ) + + island_top_programs = island_programs[:5] + island_previous_programs = island_programs[:3] + + # Build prompt + prompt = _worker_prompt_sampler.build_prompt( + current_program=parent.code, + parent_program=parent.code, + program_metrics=parent.metrics, + previous_programs=[p.to_dict() for p in island_previous_programs], + top_programs=[p.to_dict() for p in island_top_programs], + inspirations=[p.to_dict() for p in inspirations], + language=_worker_config.language, + evolution_round=iteration, + diff_based_evolution=_worker_config.diff_based_evolution, + program_artifacts=parent_artifacts, + ) + + iteration_start = time.time() + + # Generate code modification (sync wrapper for async) + llm_response = asyncio.run( + _worker_llm_ensemble.generate_with_context( + system_message=prompt["system"], + messages=[{"role": "user", "content": prompt["user"]}], + ) + ) + + # Parse response based on evolution mode + if _worker_config.diff_based_evolution: + from openevolve.utils.code_utils import extract_diffs, apply_diff, format_diff_summary + + diff_blocks = extract_diffs(llm_response) + if not diff_blocks: + return SerializableResult( + error=f"No valid diffs found in response", + iteration=iteration + ) + + child_code = apply_diff(parent.code, llm_response) + changes_summary = format_diff_summary(diff_blocks) + else: + from openevolve.utils.code_utils import parse_full_rewrite + + new_code = parse_full_rewrite(llm_response, _worker_config.language) + if not new_code: + return SerializableResult( + error=f"No valid code found in response", + iteration=iteration + ) + + child_code = new_code + changes_summary = "Full rewrite" + + # Check code length + if len(child_code) > _worker_config.max_code_length: + return SerializableResult( + error=f"Generated code exceeds maximum length ({len(child_code)} > {_worker_config.max_code_length})", + iteration=iteration + ) + + # Evaluate the child program + import uuid + child_id = str(uuid.uuid4()) + child_metrics = asyncio.run( + _worker_evaluator.evaluate_program(child_code, child_id) + ) + + # Get artifacts + artifacts = _worker_evaluator.get_pending_artifacts(child_id) + + # Create child program + child_program = Program( + id=child_id, + code=child_code, + language=_worker_config.language, + parent_id=parent.id, + generation=parent.generation + 1, + metrics=child_metrics, + iteration_found=iteration, + metadata={ + "changes": changes_summary, + "parent_metrics": parent.metrics, + "island": parent_island, + } + ) + + iteration_time = time.time() - iteration_start + + return SerializableResult( + child_program_dict=child_program.to_dict(), + parent_id=parent.id, + iteration_time=iteration_time, + prompt=prompt, + llm_response=llm_response, + artifacts=artifacts, + iteration=iteration + ) + + except Exception as e: + logger.exception(f"Error in worker iteration {iteration}") + return SerializableResult( + error=str(e), + iteration=iteration + ) + + +class ProcessParallelController: + """Controller for process-based parallel evolution""" + + def __init__(self, config: Config, evaluation_file: str, database: ProgramDatabase): + self.config = config + self.evaluation_file = evaluation_file + self.database = database + + self.executor: Optional[ProcessPoolExecutor] = None + self.shutdown_event = mp.Event() + + # Number of worker processes + self.num_workers = config.evaluator.parallel_evaluations + + logger.info(f"Initialized process parallel controller with {self.num_workers} workers") + + def _serialize_config(self, config: Config) -> dict: + """Serialize config object to a dictionary that can be pickled""" + # Manual serialization to handle nested objects properly + return { + 'llm': { + 'models': [asdict(m) for m in config.llm.models], + 'evaluator_models': [asdict(m) for m in config.llm.evaluator_models], + 'api_base': config.llm.api_base, + 'api_key': config.llm.api_key, + 'temperature': config.llm.temperature, + 'top_p': config.llm.top_p, + 'max_tokens': config.llm.max_tokens, + 'timeout': config.llm.timeout, + 'retries': config.llm.retries, + 'retry_delay': config.llm.retry_delay, + }, + 'prompt': asdict(config.prompt), + 'database': asdict(config.database), + 'evaluator': asdict(config.evaluator), + 'max_iterations': config.max_iterations, + 'checkpoint_interval': config.checkpoint_interval, + 'log_level': config.log_level, + 'log_dir': config.log_dir, + 'random_seed': config.random_seed, + 'diff_based_evolution': config.diff_based_evolution, + 'max_code_length': config.max_code_length, + 'language': config.language, + } + + def start(self) -> None: + """Start the process pool""" + # Convert config to dict for pickling + # We need to be careful with nested dataclasses + config_dict = self._serialize_config(self.config) + + # Create process pool with initializer + self.executor = ProcessPoolExecutor( + max_workers=self.num_workers, + initializer=_worker_init, + initargs=(config_dict, self.evaluation_file) + ) + + logger.info(f"Started process pool with {self.num_workers} processes") + + def stop(self) -> None: + """Stop the process pool""" + self.shutdown_event.set() + + if self.executor: + self.executor.shutdown(wait=True) + self.executor = None + + logger.info("Stopped process pool") + + def request_shutdown(self) -> None: + """Request graceful shutdown""" + logger.info("Graceful shutdown requested...") + self.shutdown_event.set() + + def _create_database_snapshot(self) -> Dict[str, Any]: + """Create a serializable snapshot of the database state""" + # Only include necessary data for workers + snapshot = { + "programs": { + pid: prog.to_dict() + for pid, prog in self.database.programs.items() + }, + "islands": [ + list(island) for island in self.database.islands + ], + "current_island": self.database.current_island, + "artifacts": {}, # Will be populated selectively + } + + # Include artifacts for programs that might be selected + # (limit to reduce serialization overhead) + for pid in list(self.database.programs.keys())[:100]: + artifacts = self.database.get_artifacts(pid) + if artifacts: + snapshot["artifacts"][pid] = artifacts + + return snapshot + + async def run_evolution( + self, + start_iteration: int, + max_iterations: int, + target_score: Optional[float] = None, + checkpoint_callback=None, + ): + """Run evolution with process-based parallelism""" + if not self.executor: + raise RuntimeError("Process pool not started") + + total_iterations = start_iteration + max_iterations + + logger.info( + f"Starting process-based evolution from iteration {start_iteration} " + f"for {max_iterations} iterations (total: {total_iterations})" + ) + + # Track pending futures + pending_futures: Dict[int, Future] = {} + batch_size = min(self.num_workers * 2, max_iterations) + + # Submit initial batch + for i in range(start_iteration, min(start_iteration + batch_size, total_iterations)): + future = self._submit_iteration(i) + if future: + pending_futures[i] = future + + next_iteration = start_iteration + batch_size + completed_iterations = 0 + + # Island management + programs_per_island = max(1, max_iterations // (self.config.database.num_islands * 10)) + current_island_counter = 0 + + # Process results as they complete + while ( + pending_futures + and completed_iterations < max_iterations + and not self.shutdown_event.is_set() + ): + # Find completed futures + completed_iteration = None + for iteration, future in list(pending_futures.items()): + if future.done(): + completed_iteration = iteration + break + + if completed_iteration is None: + await asyncio.sleep(0.01) + continue + + # Process completed result + future = pending_futures.pop(completed_iteration) + + try: + result = future.result() + + if result.error: + logger.warning(f"Iteration {completed_iteration} error: {result.error}") + elif result.child_program_dict: + # Reconstruct program from dict + child_program = Program(**result.child_program_dict) + + # Add to database + self.database.add(child_program, iteration=completed_iteration) + + # Store artifacts + if result.artifacts: + self.database.store_artifacts(child_program.id, result.artifacts) + + # Log prompts + if result.prompt: + self.database.log_prompt( + template_key=( + "full_rewrite_user" + if not self.config.diff_based_evolution + else "diff_user" + ), + program_id=child_program.id, + prompt=result.prompt, + responses=[result.llm_response] if result.llm_response else [] + ) + + # Island management + if completed_iteration > start_iteration and current_island_counter >= programs_per_island: + self.database.next_island() + current_island_counter = 0 + logger.debug(f"Switched to island {self.database.current_island}") + + current_island_counter += 1 + self.database.increment_island_generation() + + # Check migration + if self.database.should_migrate(): + logger.info(f"Performing migration at iteration {completed_iteration}") + self.database.migrate_programs() + self.database.log_island_status() + + # Log progress + logger.info( + f"Iteration {completed_iteration}: " + f"Program {child_program.id} " + f"(parent: {result.parent_id}) " + f"completed in {result.iteration_time:.2f}s" + ) + + if child_program.metrics: + metrics_str = ", ".join([ + f"{k}={v:.4f}" if isinstance(v, (int, float)) else f"{k}={v}" + for k, v in child_program.metrics.items() + ]) + logger.info(f"Metrics: {metrics_str}") + + # Check for new best + if self.database.best_program_id == child_program.id: + logger.info( + f"🌟 New best solution found at iteration {completed_iteration}: " + f"{child_program.id}" + ) + + # Checkpoint callback + if completed_iteration % self.config.checkpoint_interval == 0: + logger.info(f"Checkpoint interval reached at iteration {completed_iteration}") + self.database.log_island_status() + if checkpoint_callback: + checkpoint_callback(completed_iteration) + + # Check target score + if target_score is not None and child_program.metrics: + numeric_metrics = [ + v for v in child_program.metrics.values() + if isinstance(v, (int, float)) + ] + if numeric_metrics: + avg_score = sum(numeric_metrics) / len(numeric_metrics) + if avg_score >= target_score: + logger.info( + f"Target score {target_score} reached at iteration {completed_iteration}" + ) + break + + except Exception as e: + logger.error(f"Error processing result from iteration {completed_iteration}: {e}") + + completed_iterations += 1 + + # Submit next iteration + if next_iteration < total_iterations and not self.shutdown_event.is_set(): + future = self._submit_iteration(next_iteration) + if future: + pending_futures[next_iteration] = future + next_iteration += 1 + + # Handle shutdown + if self.shutdown_event.is_set(): + logger.info("Shutdown requested, canceling remaining evaluations...") + for future in pending_futures.values(): + future.cancel() + + logger.info("Evolution completed") + + return self.database.get_best_program() + + def _submit_iteration(self, iteration: int) -> Optional[Future]: + """Submit an iteration to the process pool""" + try: + # Sample parent and inspirations + parent, inspirations = self.database.sample() + + # Create database snapshot + db_snapshot = self._create_database_snapshot() + + # Submit to process pool + future = self.executor.submit( + _run_iteration_worker, + iteration, + db_snapshot, + parent.id, + [insp.id for insp in inspirations] + ) + + return future + + except Exception as e: + logger.error(f"Error submitting iteration {iteration}: {e}") + return None \ No newline at end of file diff --git a/openevolve/threaded_parallel.py b/openevolve/threaded_parallel.py deleted file mode 100644 index 815de9689..000000000 --- a/openevolve/threaded_parallel.py +++ /dev/null @@ -1,353 +0,0 @@ -""" -Improved parallel processing using threads with shared memory -""" - -import asyncio -import logging -import signal -import threading -import time -from concurrent.futures import ThreadPoolExecutor, Future -from typing import Any, Dict, List, Optional - -from openevolve.config import Config -from openevolve.database import ProgramDatabase -from openevolve.evaluator import Evaluator -from openevolve.llm.ensemble import LLMEnsemble -from openevolve.prompt.sampler import PromptSampler -from openevolve.iteration import run_iteration_with_shared_db - -logger = logging.getLogger(__name__) - - -class ThreadedEvaluationPool: - """ - Thread-based parallel evaluation pool for improved performance - - Uses threads instead of processes to avoid pickling issues while - still providing parallelism for I/O-bound LLM calls. - """ - - def __init__(self, config: Config, evaluation_file: str, database: ProgramDatabase): - self.config = config - self.evaluation_file = evaluation_file - self.database = database - - self.num_workers = config.evaluator.parallel_evaluations - self.executor = None - - # Pre-initialize components for each thread - self.thread_local = threading.local() - - logger.info(f"Initializing threaded evaluation pool with {self.num_workers} workers") - - def start(self) -> None: - """Start the thread pool""" - self.executor = ThreadPoolExecutor( - max_workers=self.num_workers, thread_name_prefix="EvalWorker" - ) - logger.info(f"Started thread pool with {self.num_workers} threads") - - def stop(self) -> None: - """Stop the thread pool""" - if self.executor: - self.executor.shutdown(wait=True) - self.executor = None - logger.info("Stopped thread pool") - - def submit_evaluation(self, iteration: int) -> Future: - """ - Submit an evaluation task to the thread pool - - Args: - iteration: Iteration number to evaluate - - Returns: - Future that will contain the result - """ - if not self.executor: - raise RuntimeError("Thread pool not started") - - return self.executor.submit(self._run_evaluation, iteration) - - def _run_evaluation(self, iteration: int): - """Run evaluation in a worker thread""" - # Get or create thread-local components - if not hasattr(self.thread_local, "initialized"): - self._initialize_thread_components() - - try: - # Run the iteration - result = asyncio.run( - run_iteration_with_shared_db( - iteration, - self.config, - self.database, # Shared database (thread-safe reads) - self.thread_local.evaluator, - self.thread_local.llm_ensemble, - self.thread_local.prompt_sampler, - ) - ) - - return result - - except Exception as e: - logger.error(f"Error in thread evaluation {iteration}: {e}") - return None - - def _initialize_thread_components(self) -> None: - """Initialize components for this thread""" - thread_id = threading.get_ident() - logger.debug(f"Initializing components for thread {thread_id}") - - try: - # Initialize LLM components - self.thread_local.llm_ensemble = LLMEnsemble(self.config.llm.models) - self.thread_local.llm_evaluator_ensemble = LLMEnsemble(self.config.llm.evaluator_models) - - # Initialize prompt samplers - self.thread_local.prompt_sampler = PromptSampler(self.config.prompt) - self.thread_local.evaluator_prompt_sampler = PromptSampler(self.config.prompt) - self.thread_local.evaluator_prompt_sampler.set_templates("evaluator_system_message") - - # Initialize evaluator - self.thread_local.evaluator = Evaluator( - self.config.evaluator, - self.evaluation_file, - self.thread_local.llm_evaluator_ensemble, - self.thread_local.evaluator_prompt_sampler, - database=self.database, - ) - - self.thread_local.initialized = True - logger.debug(f"Initialized components for thread {thread_id}") - - except Exception as e: - logger.error(f"Failed to initialize thread components: {e}") - raise - - -class ImprovedParallelController: - """ - Controller for improved parallel processing using shared memory and threads - """ - - def __init__(self, config: Config, evaluation_file: str, database: ProgramDatabase): - self.config = config - self.evaluation_file = evaluation_file - self.database = database - - self.thread_pool = None - self.database_lock = threading.RLock() # For database writes - self.shutdown_flag = threading.Event() # For graceful shutdown - - def start(self) -> None: - """Start the improved parallel system""" - self.thread_pool = ThreadedEvaluationPool(self.config, self.evaluation_file, self.database) - self.thread_pool.start() - - logger.info("Started improved parallel controller") - - def stop(self) -> None: - """Stop the improved parallel system""" - self.shutdown_flag.set() # Signal shutdown - - if self.thread_pool: - self.thread_pool.stop() - self.thread_pool = None - - logger.info("Stopped improved parallel controller") - - def request_shutdown(self) -> None: - """Request graceful shutdown (for signal handlers)""" - logger.info("Graceful shutdown requested...") - self.shutdown_flag.set() - - async def run_evolution( - self, - start_iteration: int, - max_iterations: int, - target_score: Optional[float] = None, - checkpoint_callback=None, - ): - """ - Run evolution with improved parallel processing - - Args: - start_iteration: Starting iteration number - max_iterations: Maximum number of iterations - target_score: Target score to achieve - - Returns: - Best program found - """ - total_iterations = start_iteration + max_iterations - - logger.info( - f"Starting improved parallel evolution from iteration {start_iteration} " - f"for {max_iterations} iterations (total: {total_iterations})" - ) - - # Submit initial batch of evaluations - pending_futures = {} - batch_size = min(self.config.evaluator.parallel_evaluations * 2, max_iterations) - - for i in range(start_iteration, min(start_iteration + batch_size, total_iterations)): - future = self.thread_pool.submit_evaluation(i) - pending_futures[i] = future - - next_iteration_to_submit = start_iteration + batch_size - completed_iterations = 0 - - # Island management - programs_per_island = max(1, max_iterations // (self.config.database.num_islands * 10)) - current_island_counter = 0 - - # Process results as they complete - while ( - pending_futures - and completed_iterations < max_iterations - and not self.shutdown_flag.is_set() - ): - # Find completed futures - completed_iteration = None - for iteration, future in list(pending_futures.items()): - if future.done(): - completed_iteration = iteration - break - - if completed_iteration is None: - # No results ready, wait a bit - await asyncio.sleep(0.01) - continue - - # Process completed result - future = pending_futures.pop(completed_iteration) - - try: - result = future.result() - - if result and hasattr(result, "child_program") and result.child_program: - # Thread-safe database update - with self.database_lock: - self.database.add(result.child_program, iteration=completed_iteration) - - # Store artifacts if they exist - if result.artifacts: - self.database.store_artifacts(result.child_program.id, result.artifacts) - - # Log prompts - if hasattr(result, "prompt") and result.prompt: - self.database.log_prompt( - template_key=( - "full_rewrite_user" - if not self.config.diff_based_evolution - else "diff_user" - ), - program_id=result.child_program.id, - prompt=result.prompt, - responses=( - [result.llm_response] if hasattr(result, "llm_response") else [] - ), - ) - - # Manage island evolution - if ( - completed_iteration > start_iteration - and current_island_counter >= programs_per_island - ): - self.database.next_island() - current_island_counter = 0 - logger.debug(f"Switched to island {self.database.current_island}") - - current_island_counter += 1 - - # Increment generation for current island - self.database.increment_island_generation() - - # Check migration - if self.database.should_migrate(): - logger.info(f"Performing migration at iteration {completed_iteration}") - self.database.migrate_programs() - self.database.log_island_status() - - # Log progress (outside lock) - logger.info( - f"Iteration {completed_iteration}: " - f"Program {result.child_program.id} " - f"(parent: {result.parent.id if result.parent else 'None'}) " - f"completed in {result.iteration_time:.2f}s" - ) - - if result.child_program.metrics: - metrics_str = ", ".join( - [ - f"{k}={v:.4f}" if isinstance(v, (int, float)) else f"{k}={v}" - for k, v in result.child_program.metrics.items() - ] - ) - logger.info(f"Metrics: {metrics_str}") - - # Check for new best program - if self.database.best_program_id == result.child_program.id: - logger.info( - f"🌟 New best solution found at iteration {completed_iteration}: " - f"{result.child_program.id}" - ) - - # Save checkpoints at intervals - if completed_iteration % self.config.checkpoint_interval == 0: - logger.info( - f"Checkpoint interval reached at iteration {completed_iteration}" - ) - self.database.log_island_status() - if checkpoint_callback: - checkpoint_callback(completed_iteration) - - # Check target score - if target_score is not None and result.child_program.metrics: - numeric_metrics = [ - v - for v in result.child_program.metrics.values() - if isinstance(v, (int, float)) - ] - if numeric_metrics: - avg_score = sum(numeric_metrics) / len(numeric_metrics) - if avg_score >= target_score: - logger.info( - f"Target score {target_score} reached after {completed_iteration} iterations" - ) - break - else: - logger.warning(f"No valid result from iteration {completed_iteration}") - - except Exception as e: - logger.error(f"Error processing result from iteration {completed_iteration}: {e}") - - completed_iterations += 1 - - # Submit next iteration if available - if next_iteration_to_submit < total_iterations: - future = self.thread_pool.submit_evaluation(next_iteration_to_submit) - pending_futures[next_iteration_to_submit] = future - next_iteration_to_submit += 1 - - # Handle shutdown or completion - if self.shutdown_flag.is_set(): - logger.info("Shutdown requested, canceling remaining evaluations...") - # Cancel remaining futures - for iteration, future in pending_futures.items(): - future.cancel() - logger.debug(f"Canceled iteration {iteration}") - else: - # Wait for any remaining futures if not shutting down - for iteration, future in pending_futures.items(): - try: - future.result(timeout=10.0) - except Exception as e: - logger.warning(f"Error waiting for iteration {iteration}: {e}") - - if self.shutdown_flag.is_set(): - logger.info("Evolution interrupted by shutdown") - else: - logger.info("Evolution completed") diff --git a/tests/test_checkpoint_resume.py b/tests/test_checkpoint_resume.py index f593ad7cf..0320ae289 100644 --- a/tests/test_checkpoint_resume.py +++ b/tests/test_checkpoint_resume.py @@ -98,14 +98,14 @@ async def run_test(): # Mock the parallel controller to avoid API calls with patch( - "openevolve.controller.ImprovedParallelController" + "openevolve.controller.ProcessParallelController" ) as mock_controller_class: mock_controller = Mock() mock_controller.run_evolution = AsyncMock(return_value=None) mock_controller.start = Mock(return_value=None) mock_controller.stop = Mock(return_value=None) - mock_controller.shutdown_flag = Mock() - mock_controller.shutdown_flag.is_set.return_value = False + mock_controller.shutdown_event = Mock() + mock_controller.shutdown_event.is_set.return_value = False mock_controller_class.return_value = mock_controller # Run for 0 iterations (just initialization) @@ -154,14 +154,14 @@ async def run_test(): # Mock the parallel controller to avoid API calls with patch( - "openevolve.controller.ImprovedParallelController" + "openevolve.controller.ProcessParallelController" ) as mock_controller_class: mock_controller = Mock() mock_controller.run_evolution = AsyncMock(return_value=None) mock_controller.start = Mock(return_value=None) mock_controller.stop = Mock(return_value=None) - mock_controller.shutdown_flag = Mock() - mock_controller.shutdown_flag.is_set.return_value = False + mock_controller.shutdown_event = Mock() + mock_controller.shutdown_event.is_set.return_value = False mock_controller_class.return_value = mock_controller # Run for 0 iterations (just initialization) @@ -209,14 +209,14 @@ async def run_test(): # Mock the parallel controller to avoid API calls with patch( - "openevolve.controller.ImprovedParallelController" + "openevolve.controller.ProcessParallelController" ) as mock_controller_class: mock_controller = Mock() mock_controller.run_evolution = AsyncMock(return_value=None) mock_controller.start = Mock(return_value=None) mock_controller.stop = Mock(return_value=None) - mock_controller.shutdown_flag = Mock() - mock_controller.shutdown_flag.is_set.return_value = False + mock_controller.shutdown_event = Mock() + mock_controller.shutdown_event.is_set.return_value = False mock_controller_class.return_value = mock_controller # Run for 0 iterations (just initialization) @@ -267,14 +267,14 @@ async def run_test(): # Mock the parallel controller to avoid API calls with patch( - "openevolve.controller.ImprovedParallelController" + "openevolve.controller.ProcessParallelController" ) as mock_controller_class: mock_controller = Mock() mock_controller.run_evolution = AsyncMock(return_value=None) mock_controller.start = Mock(return_value=None) mock_controller.stop = Mock(return_value=None) - mock_controller.shutdown_flag = Mock() - mock_controller.shutdown_flag.is_set.return_value = False + mock_controller.shutdown_event = Mock() + mock_controller.shutdown_event.is_set.return_value = False mock_controller_class.return_value = mock_controller # Run for 0 iterations (just initialization) diff --git a/tests/test_process_parallel.py b/tests/test_process_parallel.py new file mode 100644 index 000000000..29c8ecbe4 --- /dev/null +++ b/tests/test_process_parallel.py @@ -0,0 +1,190 @@ +""" +Tests for process-based parallel controller +""" + +import asyncio +import os +import tempfile +import unittest +from unittest.mock import Mock, patch, MagicMock +import time + +# Set dummy API key for testing +os.environ["OPENAI_API_KEY"] = "test" + +from openevolve.config import Config, DatabaseConfig, EvaluatorConfig, LLMConfig, PromptConfig +from openevolve.database import Program, ProgramDatabase +from openevolve.process_parallel import ProcessParallelController, SerializableResult + + +class TestProcessParallel(unittest.TestCase): + """Tests for process-based parallel controller""" + + def setUp(self): + """Set up test environment""" + self.test_dir = tempfile.mkdtemp() + + # Create test config + self.config = Config() + self.config.max_iterations = 10 + self.config.evaluator.parallel_evaluations = 2 + self.config.evaluator.timeout = 10 + self.config.database.num_islands = 2 + self.config.database.in_memory = True + self.config.checkpoint_interval = 5 + + # Create test evaluation file + self.eval_content = """ +def evaluate(program_path): + return {"score": 0.5, "performance": 0.6} +""" + self.eval_file = os.path.join(self.test_dir, "evaluator.py") + with open(self.eval_file, "w") as f: + f.write(self.eval_content) + + # Create test database + self.database = ProgramDatabase(self.config.database) + + # Add some test programs + for i in range(3): + program = Program( + id=f"test_{i}", + code=f"def func_{i}(): return {i}", + language="python", + metrics={"score": 0.5 + i * 0.1, "performance": 0.4 + i * 0.1}, + iteration_found=0 + ) + self.database.add(program) + + def tearDown(self): + """Clean up test environment""" + import shutil + shutil.rmtree(self.test_dir, ignore_errors=True) + + def test_controller_initialization(self): + """Test that controller initializes correctly""" + controller = ProcessParallelController(self.config, self.eval_file, self.database) + + self.assertEqual(controller.num_workers, 2) + self.assertIsNone(controller.executor) + self.assertIsNotNone(controller.shutdown_event) + + def test_controller_start_stop(self): + """Test starting and stopping the controller""" + controller = ProcessParallelController(self.config, self.eval_file, self.database) + + # Start controller + controller.start() + self.assertIsNotNone(controller.executor) + + # Stop controller + controller.stop() + self.assertIsNone(controller.executor) + self.assertTrue(controller.shutdown_event.is_set()) + + def test_database_snapshot_creation(self): + """Test creating database snapshot for workers""" + controller = ProcessParallelController(self.config, self.eval_file, self.database) + + snapshot = controller._create_database_snapshot() + + # Verify snapshot structure + self.assertIn("programs", snapshot) + self.assertIn("islands", snapshot) + self.assertIn("current_island", snapshot) + self.assertIn("artifacts", snapshot) + + # Verify programs are serialized + self.assertEqual(len(snapshot["programs"]), 3) + for pid, prog_dict in snapshot["programs"].items(): + self.assertIsInstance(prog_dict, dict) + self.assertIn("id", prog_dict) + self.assertIn("code", prog_dict) + + def test_run_evolution_basic(self): + """Test basic evolution run""" + async def run_test(): + controller = ProcessParallelController(self.config, self.eval_file, self.database) + + # Mock the executor to avoid actually spawning processes + with patch.object(controller, '_submit_iteration') as mock_submit: + # Create mock futures that complete immediately + mock_future1 = asyncio.Future() + mock_result1 = SerializableResult( + child_program_dict={ + "id": "child_1", + "code": "def evolved(): return 1", + "language": "python", + "parent_id": "test_0", + "generation": 1, + "metrics": {"score": 0.7, "performance": 0.8}, + "iteration_found": 1, + "metadata": {"changes": "test", "island": 0} + }, + parent_id="test_0", + iteration_time=0.1, + iteration=1 + ) + mock_future1.set_result(mock_result1) + + mock_submit.return_value = mock_future1 + + # Start controller + controller.start() + + # Run evolution for 1 iteration + result = await controller.run_evolution( + start_iteration=1, + max_iterations=1, + target_score=None + ) + + # Verify iteration was submitted + mock_submit.assert_called_once_with(1) + + # Verify program was added to database + self.assertIn("child_1", self.database.programs) + child = self.database.get("child_1") + self.assertEqual(child.metrics["score"], 0.7) + + # Run the async test + asyncio.run(run_test()) + + def test_request_shutdown(self): + """Test graceful shutdown request""" + controller = ProcessParallelController(self.config, self.eval_file, self.database) + + # Request shutdown + controller.request_shutdown() + + # Verify shutdown event is set + self.assertTrue(controller.shutdown_event.is_set()) + + def test_serializable_result(self): + """Test SerializableResult dataclass""" + result = SerializableResult( + child_program_dict={"id": "test", "code": "pass"}, + parent_id="parent", + iteration_time=1.5, + iteration=10, + error=None + ) + + # Verify attributes + self.assertEqual(result.child_program_dict["id"], "test") + self.assertEqual(result.parent_id, "parent") + self.assertEqual(result.iteration_time, 1.5) + self.assertEqual(result.iteration, 10) + self.assertIsNone(result.error) + + # Test with error + error_result = SerializableResult( + error="Test error", + iteration=5 + ) + self.assertEqual(error_result.error, "Test error") + self.assertIsNone(error_result.child_program_dict) + + +if __name__ == "__main__": + unittest.main() \ No newline at end of file