diff --git a/README.md b/README.md index db8db4612..3b7fc73d9 100644 --- a/README.md +++ b/README.md @@ -157,6 +157,67 @@ Sample configuration files are available in the `configs/` directory: See the [Configuration Guide](configs/default_config.yaml) for a full list of options. +## Artifacts Channel + +OpenEvolve includes a **artifacts side-channel** that allows evaluators to capture build errors, profiling results, etc. to provide better feedback to the LLM in subsequent generations. This feature enhances the evolution process by giving the LLM context about what went wrong and how to fix it. + +The artifacts channel operates alongside the traditional fitness metrics. + +### Example: Compilation Failure Feedback + +```python +from openevolve.evaluation_result import EvaluationResult + +return EvaluationResult( + metrics={"compile_ok": 0.0, "score": 0.0}, + artifacts={ + "stderr": "SyntaxError: invalid syntax (line 15)", + "traceback": "...", + "failure_stage": "compilation" + } +) +``` + +The next generation prompt will include: +``` +## Last Execution Output +### Stderr +``` +SyntaxError: invalid syntax (line 15) +``` +### Traceback +``` +... +``` +``` + +### Configuration + +Artifacts can be controlled via configuration and environment variables: + +```yaml +# config.yaml +evaluator: + enable_artifacts: true + +prompt: + include_artifacts: true + max_artifact_bytes: 4096 # 4KB limit in prompts + artifact_security_filter: true +``` + +```bash +# Environment variable to disable artifacts +export ENABLE_ARTIFACTS=false +``` + +### Benefits + +- **Faster convergence** - LLMs can see what went wrong and fix it directly +- **Better error handling** - Compilation and runtime failures become learning opportunities +- **Rich debugging context** - Full stack traces and error messages guide improvements +- **Zero overhead** - When disabled, no performance impact on evaluation + ## Examples See the `examples/` directory for complete examples of using OpenEvolve on various problems: diff --git a/examples/circle_packing_with_artifacts/README.md b/examples/circle_packing_with_artifacts/README.md new file mode 100644 index 000000000..673f854e1 --- /dev/null +++ b/examples/circle_packing_with_artifacts/README.md @@ -0,0 +1,365 @@ +# Circle Packing Example with Artifacts + +This example demonstrates how OpenEvolve can be used to tackle the challenging mathematical problem of circle packing, a classic problem in computational geometry. Specifically, we focus on packing 26 circles of varying sizes into a unit square to maximize the sum of their radii, replicating one of the tasks from the AlphaEvolve paper. + +**Enhanced with Artifacts**: This version showcases OpenEvolve's artifacts feature, which provides detailed execution feedback to help the LLM understand what went wrong and how to fix it. + +## Artifacts Enhancement + +This example has been enhanced to demonstrate OpenEvolve's **artifacts side-channel**, which captures detailed execution information beyond just numeric metrics. When a program fails or succeeds, the evaluator now provides rich context that gets included in the next generation's prompt. + +### What Artifacts Capture + +The enhanced evaluator captures different types of information: + +#### 🚨 **Failure Artifacts** +When compilation or runtime errors occur: +``` +## Last Execution Output +### Stderr +``` +Invalid shapes: centers=(25, 2), radii=(26,), expected (26, 2) and (26,) +``` +### Failure_stage +``` +shape_validation +``` +### Suggestion +``` +Check for syntax errors, import issues, or runtime exceptions +``` +``` + +#### ⚠️ **Validation Artifacts** +When geometric constraints are violated: +``` +## Last Execution Output +### Boundary_violations +``` +Circle 23 at (0.950000, 0.950000) with radius 0.055000 is outside unit square +Circle 24 at (0.980000, 0.980000) with radius 0.030000 is outside unit square +``` +### Overlap_violations +``` +Circles 5 and 12 overlap: dist=0.180000, r1+r2=0.190000 +``` +### Validation_report +``` +Valid: False, Violations: 2 boundary, 1 overlaps +``` +``` + +#### ✅ **Success Artifacts** +For excellent solutions (>95% of target): +``` +## Last Execution Output +### Stdout +``` +Excellent packing! Achieved 99.7% of target value +``` +### Radius_stats +``` +Min: 0.045123, Max: 0.167500, Avg: 0.101319 +``` +### Packing_summary +``` +Sum of radii: 2.634292/2.635 = 0.9997 +``` +``` + +#### ⏱️ **Performance Artifacts** +Always included: +``` +### Execution_time +``` +12.45s +``` +``` + +### How Artifacts Help Evolution + +The artifacts provide crucial context that helps the LLM make better decisions: + +1. **Specific Error Fixing**: Instead of just seeing `validity: 0.0`, the LLM sees exactly which circles are problematic and why +2. **Performance Guidance**: Execution time helps the LLM understand if algorithms are too slow +3. **Success Recognition**: When a solution works well, artifacts explain why it's good +4. **Debugging Context**: Full stack traces help fix syntax and runtime errors + +### Example Evolution with Artifacts + +Here's how artifacts help in a typical evolution scenario: + +**Generation N**: Program fails with overlapping circles +```python +# Faulty code +centers[5] = [0.3, 0.3] +centers[6] = [0.3, 0.3] # Same position! +radii[5] = radii[6] = 0.1 +``` + +**Artifacts captured**: +``` +Circles 5 and 6 overlap: dist=0.000000, r1+r2=0.200000 +``` + +**Generation N+1**: LLM sees the overlap details and fixes it +```python +# Fixed code +centers[5] = [0.3, 0.3] +centers[6] = [0.5, 0.3] # Different position! +radii[5] = radii[6] = 0.1 +``` + +This leads to faster convergence because the LLM gets specific, actionable feedback instead of just numeric scores. + +### Backward Compatibility + +The artifacts enhancement is fully backward compatible: +- **Existing evaluators** continue to work unchanged +- **Enhanced evaluators** return `EvaluationResult` with both metrics and artifacts +- **Disable artifacts** by setting `export ENABLE_ARTIFACTS=false` if needed + +To run with artifacts disabled: +```bash +export ENABLE_ARTIFACTS=false +python openevolve-run.py examples/circle_packing_with_artifacts/initial_program.py \ + examples/circle_packing_with_artifacts/evaluator.py \ + --config examples/circle_packing_with_artifacts/config_phase_1.yaml +``` + +## Problem Overview + +The circle packing problem involves placing n non-overlapping circles inside a container (in this case, a unit square) to optimize a specific metric. For this example: + +- We pack exactly 26 circles +- Each circle must lie entirely within the unit square +- No circles may overlap +- We aim to maximize the sum of all circle radii + +According to the AlphaEvolve paper, a solution with a sum of radii of approximately 2.635 is achievable for n=26. Our goal was to match or exceed this result. + +## Our Approach + +We structured our evolution in two phases, each with a different configuration to encourage exploration and exploitation at different stages: + +### Phase 1: Initial Exploration + +In the first phase, we focused on exploring different fundamental approaches to the packing problem: + +- Used a constructor-based approach that places circles in strategic positions +- Explored various geometric patterns (concentric rings, grid-based arrangements, etc.) +- Developed simple optimization routines to maximize circle sizes without overlaps + +Configuration highlights: +```yaml +max_iterations: 100 +population_size: 60 +num_islands: 4 +exploitation_ratio: 0.7 +``` + +### Phase 2: Breaking Through the Plateau + +After the initial exploration phase, we observed our solutions plateauing around 2.377. For the second phase, we reconfigured OpenEvolve to encourage more radical innovations: + +- Increased the population size to promote diversity +- Lowered the exploitation ratio to favor exploration +- Updated the system prompt to suggest different optimization techniques +- Allowed for longer and more complex code solutions + +Configuration highlights: +```yaml +max_iterations: 100 +population_size: 70 +num_islands: 5 +exploitation_ratio: 0.6 +``` + +## Evolution Progress + +We tracked the evolution over 470 generations, capturing visualizations at each checkpoint. The progression shows dramatic improvements in the packing strategy: + +### Initial Solution (Generation 0) + +The initial program used a simple constructive approach with a central circle and two concentric rings: + +```python +# Initial attempt +# Place a large circle in the center +centers[0] = [0.5, 0.5] + +# Place 8 circles around it in a ring +for i in range(8): + angle = 2 * np.pi * i / 8 + centers[i + 1] = [0.5 + 0.3 * np.cos(angle), 0.5 + 0.3 * np.sin(angle)] + +# Place 16 more circles in an outer ring +for i in range(16): + angle = 2 * np.pi * i / 16 + centers[i + 9] = [0.5 + 0.7 * np.cos(angle), 0.5 + 0.7 * np.sin(angle)] +``` + +This approach yielded a sum of radii of approximately 0.959. + +![Initial Circle Packing](circle_packing_1.png) + +### Generation 10 Breakthrough + +By generation 10, OpenEvolve had already discovered a more sophisticated approach: + +```python +# Generation 10 +# Parameters for the arrangement (fine-tuned) +r_center = 0.1675 # Central circle radius + +# 1. Place central circle +centers[0] = [0.5, 0.5] +radii[0] = r_center + +# 2. First ring: 6 circles in hexagonal arrangement +r_ring1 = 0.1035 +ring1_distance = r_center + r_ring1 + 0.0005 # Small gap for stability +for i in range(6): + angle = 2 * np.pi * i / 6 + centers[i+1] = [ + 0.5 + ring1_distance * np.cos(angle), + 0.5 + ring1_distance * np.sin(angle) + ] + radii[i+1] = r_ring1 +``` + +The key innovations at this stage included: +- A carefully tuned hexagonal arrangement for the first ring +- Strategic placement of corner circles +- An additional optimization step to maximize each circle's radius + +This approach achieved a sum of radii of approximately 1.795. + +![Generation 10 Packing](circle_packing_10.png) + +### Generation 100: Grid-Based Approach + +By generation 100, OpenEvolve had pivoted to a grid-based approach with variable sized circles: + +```python +# Generation 100 +# Row 1: 5 circles +centers[0] = [0.166, 0.166] +centers[1] = [0.333, 0.166] +centers[2] = [0.500, 0.166] +centers[3] = [0.667, 0.166] +centers[4] = [0.834, 0.166] + +# Row 2: 6 circles (staggered) +centers[5] = [0.100, 0.333] +centers[6] = [0.266, 0.333] +# ... additional circles +``` + +Key innovations: +- Grid-like pattern with staggered rows +- Variable circle sizes based on position (larger in the center) +- More aggressive optimization routine with 50 iterations + +This approach achieved a sum of radii of approximately 2.201. + +![Generation 100 Packing](circle_packing_190.png) + +### Final Solution: Mathematical Optimization + +The breakthrough came when OpenEvolve discovered the power of mathematical optimization techniques. The final solution uses: + +```python +# Final solution with scipy.optimize +def construct_packing(): + # ... initialization code ... + + # Objective function: Negative sum of radii (to maximize) + def objective(x): + centers = x[:2*n].reshape(n, 2) + radii = x[2*n:] + return -np.sum(radii) + + # Constraint: No overlaps and circles stay within the unit square + def constraint(x): + centers = x[:2*n].reshape(n, 2) + radii = x[2*n:] + + # Overlap constraint + overlap_constraints = [] + for i in range(n): + for j in range(i + 1, n): + dist = np.sqrt(np.sum((centers[i] - centers[j])**2)) + overlap_constraints.append(dist - (radii[i] + radii[j])) + # ... boundary constraints ... + + # Optimization using SLSQP + result = minimize(objective, x0, method='SLSQP', bounds=bounds, constraints=constraints) +``` + +The key innovation in the final solution: +- Using `scipy.optimize.minimize` with SLSQP method to find the optimal configuration +- Formulating circle packing as a constrained optimization problem +- Representing both circle positions and radii as optimization variables +- Carefully crafted constraints to enforce non-overlap and boundary conditions + +This approach achieved a sum of radii of 2.634, matching the AlphaEvolve paper's result of 2.635 to within 0.04%! + +![Final Packing Solution](circle_packing_460.png) + +## Results + +Our final solution achieves: + +``` +Sum of radii: 2.634292402141039 +Target ratio: 0.9997314619131079 (99.97% of AlphaEvolve's result) +``` + +This demonstrates that OpenEvolve can successfully reproduce the results from the AlphaEvolve paper on this mathematical optimization problem. + +## Key Observations + +The evolution process demonstrated several interesting patterns: + +1. **Algorithm Transition**: OpenEvolve discovered increasingly sophisticated algorithms, from basic geometric constructions to advanced mathematical optimization techniques. + +2. **Exploration-Exploitation Balance**: The two-phase approach was crucial - initial exploration of different patterns followed by exploitation and refinement of the most promising approaches. + +3. **Breakthrough Discoveries**: The most significant improvements came from fundamental changes in approach (e.g., switching from manual construction to mathematical optimization), not just parameter tuning. + +4. **Code Complexity Evolution**: As the solutions improved, the code grew in complexity, adopting more sophisticated mathematical techniques. + +## Running the Example + +To reproduce our results: + +```bash +# Phase 1: Initial exploration +python openevolve-run.py examples/circle_packing/initial_program.py \ + examples/circle_packing/evaluator.py \ + --config examples/circle_packing/config_phase_1.yaml \ + --iterations 100 + +# Phase 2: Breaking through the plateau +python openevolve-run.py examples/circle_packing/openevolve_output/checkpoints/checkpoint_100/best_program.py \ + examples/circle_packing/evaluator.py \ + --config examples/circle_packing/config_phase_2.yaml \ + --iterations 100 +``` + +To visualize the best solution: + +```python +from examples.circle_packing.openevolve_output.best.best_program import run_packing, visualize + +centers, radii, sum_radii = run_packing() +print(f"Sum of radii: {sum_radii}") +visualize(centers, radii) +``` + +## Conclusion + +This example demonstrates OpenEvolve's ability to discover sophisticated algorithms for mathematical optimization problems. By evolving from simple constructive approaches to advanced numerical optimization techniques, OpenEvolve was able to match the results reported in the AlphaEvolve paper. + +The circle packing problem shows how OpenEvolve can discover not just improvements to existing algorithms, but entirely new algorithmic approaches, transitioning from manual geometric construction to principled mathematical optimization. \ No newline at end of file diff --git a/examples/circle_packing_with_artifacts/config_phase_1.yaml b/examples/circle_packing_with_artifacts/config_phase_1.yaml new file mode 100644 index 000000000..96f1b75e5 --- /dev/null +++ b/examples/circle_packing_with_artifacts/config_phase_1.yaml @@ -0,0 +1,56 @@ +# Configuration for circle packing constructor evolution (n=26) +max_iterations: 100 # Increased iterations +checkpoint_interval: 10 +log_level: "INFO" + +# LLM configuration +llm: + primary_model: "google/gemini-2.0-flash-001" + # primary_model: "llama3.1-8b" + primary_model_weight: 0.8 + secondary_model: "anthropic/claude-3.7-sonnet" + # secondary_model: "llama-4-scout-17b-16e-instruct" + secondary_model_weight: 0.2 + api_base: "https://openrouter.ai/api/v1" + # api_base: "https://api.cerebras.ai/v1" + temperature: 0.7 + top_p: 0.95 + max_tokens: 8192 + timeout: 600 + +# Prompt configuration +prompt: + system_message: | + You are an expert mathematician specializing in circle packing problems and computational geometry. Your task is to improve a constructor function that directly produces a specific arrangement of 26 circles in a unit square, maximizing the sum of their radii. The AlphaEvolve paper achieved a sum of 2.635 for n=26. + + Key geometric insights: + - Circle packings often follow hexagonal patterns in the densest regions + - Maximum density for infinite circle packing is pi/(2*sqrt(3)) ≈ 0.9069 + - Edge effects make square container packing harder than infinite packing + - Circles can be placed in layers or shells when confined to a square + - Similar radius circles often form regular patterns, while varied radii allow better space utilization + - Perfect symmetry may not yield the optimal packing due to edge effects + + Focus on designing an explicit constructor that places each circle in a specific position, rather than an iterative search algorithm. + num_top_programs: 3 + use_template_stochasticity: true + +# Database configuration +database: + population_size: 60 # Increased population for more diversity + archive_size: 25 + num_islands: 4 + elite_selection_ratio: 0.3 + exploitation_ratio: 0.7 + +# Evaluator configuration +evaluator: + timeout: 60 + cascade_evaluation: true + cascade_thresholds: [0.5, 0.75] + parallel_evaluations: 4 + use_llm_feedback: false + +# Evolution settings +diff_based_evolution: false # Use full rewrites instead of diffs +allow_full_rewrites: true # Allow full rewrites for constructor functions diff --git a/examples/circle_packing_with_artifacts/config_phase_2.yaml b/examples/circle_packing_with_artifacts/config_phase_2.yaml new file mode 100644 index 000000000..195c7d399 --- /dev/null +++ b/examples/circle_packing_with_artifacts/config_phase_2.yaml @@ -0,0 +1,58 @@ +# Configuration for breaking through the circle packing plateau +max_iterations: 100 +checkpoint_interval: 10 +log_level: "INFO" + +# LLM configuration +llm: + primary_model: "google/gemini-2.0-flash-001" + # primary_model: "llama3.1-8b" + primary_model_weight: 0.8 + secondary_model: "anthropic/claude-3.7-sonnet" + # secondary_model: "llama-4-scout-17b-16e-instruct" + secondary_model_weight: 0.2 + api_base: "https://openrouter.ai/api/v1" + # api_base: "https://api.cerebras.ai/v1" + temperature: 0.7 + top_p: 0.95 + max_tokens: 8192 + timeout: 600 + +# Prompt configuration +prompt: + system_message: | + You are an expert mathematician specializing in circle packing problems and computational geometry. We're trying to reach the AlphaEvolve target of 2.635 for the sum of radii when packing 26 circles in a unit square. The current implementation has plateaued at 2.377, so we need significant improvements. + + Key insights to explore: + 1. The optimal arrangement likely involves variable-sized circles + 2. A pure hexagonal arrangement may not be optimal due to edge effects + 3. The densest known circle packings often use a hybrid approach + 4. The optimization routine is critically important - simple physics-based models with carefully tuned parameters + 5. Consider strategic placement of circles at square corners and edges + 6. Adjusting the pattern to place larger circles at the center and smaller at the edges + 7. The math literature suggests special arrangements for specific values of n + + Focus on breaking through the plateau by trying fundamentally different approaches - don't just tweak parameters. + num_top_programs: 4 + use_template_stochasticity: true + +# Database configuration +database: + population_size: 70 # Larger population for more diversity + archive_size: 30 + num_islands: 5 + elite_selection_ratio: 0.3 + exploitation_ratio: 0.6 # Slightly lower to encourage exploration + +# Evaluator configuration +evaluator: + timeout: 90 # Extended timeout to allow for more complex optimization + cascade_evaluation: true + cascade_thresholds: [0.5, 0.8] + parallel_evaluations: 4 + use_llm_feedback: false + +# Evolution settings +diff_based_evolution: false +allow_full_rewrites: true # Definitely allow full rewrites +max_code_length: 100000 \ No newline at end of file diff --git a/examples/circle_packing_with_artifacts/evaluator.py b/examples/circle_packing_with_artifacts/evaluator.py new file mode 100644 index 000000000..ea3202546 --- /dev/null +++ b/examples/circle_packing_with_artifacts/evaluator.py @@ -0,0 +1,477 @@ +""" +Evaluator for circle packing example (n=26) with improved timeout handling +Enhanced with artifacts to demonstrate execution feedback +""" + +import importlib.util +import numpy as np +import time +import os +import signal +import subprocess +import tempfile +import traceback +import sys +import pickle + +# Import EvaluationResult for artifacts support +from openevolve.evaluation_result import EvaluationResult + + +class TimeoutError(Exception): + pass + + +def timeout_handler(signum, frame): + """Handle timeout signal""" + raise TimeoutError("Function execution timed out") + + +def validate_packing(centers, radii): + """ + Validate that circles don't overlap and are inside the unit square + + Args: + centers: np.array of shape (n, 2) with (x, y) coordinates + radii: np.array of shape (n) with radius of each circle + + Returns: + Tuple of (is_valid: bool, validation_details: dict) + """ + n = centers.shape[0] + validation_details = { + "total_circles": n, + "boundary_violations": [], + "overlaps": [], + "min_radius": float(np.min(radii)), + "max_radius": float(np.max(radii)), + "avg_radius": float(np.mean(radii)), + } + + # Check if circles are inside the unit square + for i in range(n): + x, y = centers[i] + r = radii[i] + if x - r < -1e-6 or x + r > 1 + 1e-6 or y - r < -1e-6 or y + r > 1 + 1e-6: + violation = ( + f"Circle {i} at ({x:.6f}, {y:.6f}) with radius {r:.6f} is outside unit square" + ) + validation_details["boundary_violations"].append(violation) + print(violation) + + # Check for overlaps + for i in range(n): + for j in range(i + 1, n): + dist = np.sqrt(np.sum((centers[i] - centers[j]) ** 2)) + if dist < radii[i] + radii[j] - 1e-6: # Allow for tiny numerical errors + overlap = ( + f"Circles {i} and {j} overlap: dist={dist:.6f}, r1+r2={radii[i]+radii[j]:.6f}" + ) + validation_details["overlaps"].append(overlap) + print(overlap) + + is_valid = ( + len(validation_details["boundary_violations"]) == 0 + and len(validation_details["overlaps"]) == 0 + ) + validation_details["is_valid"] = is_valid + + return is_valid, validation_details + + +def run_with_timeout(program_path, timeout_seconds=20): + """ + Run the program in a separate process with timeout + using a simple subprocess approach + + Args: + program_path: Path to the program file + timeout_seconds: Maximum execution time in seconds + + Returns: + centers, radii, sum_radii tuple from the program + """ + # Create a temporary file to execute + with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_file: + # Write a script that executes the program and saves results + script = f""" +import sys +import numpy as np +import os +import pickle +import traceback + +# Add the directory to sys.path +sys.path.insert(0, os.path.dirname('{program_path}')) + +# Debugging info +print(f"Running in subprocess, Python version: {{sys.version}}") +print(f"Program path: {program_path}") + +try: + # Import the program + spec = __import__('importlib.util').util.spec_from_file_location("program", '{program_path}') + program = __import__('importlib.util').util.module_from_spec(spec) + spec.loader.exec_module(program) + + # Run the packing function + print("Calling run_packing()...") + centers, radii, sum_radii = program.run_packing() + print(f"run_packing() returned successfully: sum_radii = {{sum_radii}}") + + # Save results to a file + results = {{ + 'centers': centers, + 'radii': radii, + 'sum_radii': sum_radii + }} + + with open('{temp_file.name}.results', 'wb') as f: + pickle.dump(results, f) + print(f"Results saved to {temp_file.name}.results") + +except Exception as e: + # If an error occurs, save the error instead + print(f"Error in subprocess: {{str(e)}}") + traceback.print_exc() + with open('{temp_file.name}.results', 'wb') as f: + pickle.dump({{'error': str(e)}}, f) + print(f"Error saved to {temp_file.name}.results") +""" + temp_file.write(script.encode()) + temp_file_path = temp_file.name + + results_path = f"{temp_file_path}.results" + + try: + # Run the script with timeout + process = subprocess.Popen( + [sys.executable, temp_file_path], stdout=subprocess.PIPE, stderr=subprocess.PIPE + ) + + try: + stdout, stderr = process.communicate(timeout=timeout_seconds) + exit_code = process.returncode + + # Always print output for debugging purposes + print(f"Subprocess stdout: {stdout.decode()}") + if stderr: + print(f"Subprocess stderr: {stderr.decode()}") + + # Still raise an error for non-zero exit codes, but only after printing the output + if exit_code != 0: + raise RuntimeError(f"Process exited with code {exit_code}") + + # Load the results + if os.path.exists(results_path): + with open(results_path, "rb") as f: + results = pickle.load(f) + + # Check if an error was returned + if "error" in results: + raise RuntimeError(f"Program execution failed: {results['error']}") + + return results["centers"], results["radii"], results["sum_radii"] + else: + raise RuntimeError("Results file not found") + + except subprocess.TimeoutExpired: + # Kill the process if it times out + process.kill() + process.wait() + raise TimeoutError(f"Process timed out after {timeout_seconds} seconds") + + finally: + # Clean up temporary files + if os.path.exists(temp_file_path): + os.unlink(temp_file_path) + if os.path.exists(results_path): + os.unlink(results_path) + + +def evaluate(program_path): + """ + Evaluate the program by running it once and checking the sum of radii + + Args: + program_path: Path to the program file + + Returns: + EvaluationResult with metrics and artifacts + """ + # Target value from the paper + TARGET_VALUE = 2.635 # AlphaEvolve result for n=26 + + try: + # For constructor-based approaches, a single evaluation is sufficient + # since the result is deterministic + start_time = time.time() + + # Use subprocess to run with timeout + centers, radii, reported_sum = run_with_timeout( + program_path, timeout_seconds=600 # Single timeout + ) + + end_time = time.time() + eval_time = end_time - start_time + + # Ensure centers and radii are numpy arrays + if not isinstance(centers, np.ndarray): + centers = np.array(centers) + if not isinstance(radii, np.ndarray): + radii = np.array(radii) + + # Validate solution + valid, validation_details = validate_packing(centers, radii) + + # Check shape and size + shape_valid = centers.shape == (26, 2) and radii.shape == (26,) + if not shape_valid: + shape_error = f"Invalid shapes: centers={centers.shape}, radii={radii.shape}, expected (26, 2) and (26,)" + print(shape_error) + + return EvaluationResult( + metrics={ + "sum_radii": 0.0, + "target_ratio": 0.0, + "validity": 0.0, + "eval_time": float(eval_time), + "combined_score": 0.0, + }, + artifacts={ + "stderr": shape_error, + "failure_stage": "shape_validation", + "expected_shapes": "centers: (26, 2), radii: (26,)", + "actual_shapes": f"centers: {centers.shape}, radii: {radii.shape}", + "execution_time": f"{eval_time:.2f}s", + }, + ) + + # Calculate sum + sum_radii = np.sum(radii) if valid else 0.0 + + # Make sure reported_sum matches the calculated sum + sum_mismatch = abs(sum_radii - reported_sum) > 1e-6 + if sum_mismatch: + mismatch_warning = ( + f"Warning: Reported sum {reported_sum} doesn't match calculated sum {sum_radii}" + ) + print(mismatch_warning) + + # Target ratio (how close we are to the target) + target_ratio = sum_radii / TARGET_VALUE if valid else 0.0 + + # Validity score + validity = 1.0 if valid else 0.0 + + # Combined score - higher is better + combined_score = target_ratio * validity + + print( + f"Evaluation: valid={valid}, sum_radii={sum_radii:.6f}, target={TARGET_VALUE}, ratio={target_ratio:.6f}, time={eval_time:.2f}s" + ) + + # Prepare artifacts with packing details + artifacts = { + "execution_time": f"{eval_time:.2f}s", + "packing_summary": f"Sum of radii: {sum_radii:.6f}/{TARGET_VALUE} = {target_ratio:.4f}", + "validation_report": f"Valid: {valid}, Violations: {len(validation_details.get('boundary_violations', []))} boundary, {len(validation_details.get('overlaps', []))} overlaps", + } + + # Add validation details if there are issues + if not valid: + if validation_details.get("boundary_violations"): + artifacts["boundary_violations"] = "\n".join( + validation_details["boundary_violations"] + ) + if validation_details.get("overlaps"): + artifacts["overlap_violations"] = "\n".join(validation_details["overlaps"]) + artifacts["failure_stage"] = "geometric_validation" + + # Add sum mismatch warning if present + if sum_mismatch: + artifacts["sum_mismatch"] = f"Reported: {reported_sum:.6f}, Calculated: {sum_radii:.6f}" + + # Add successful packing stats for good solutions + if valid and target_ratio > 0.95: # Near-optimal solutions + artifacts["stdout"] = f"Excellent packing! Achieved {target_ratio:.1%} of target value" + artifacts["radius_stats"] = ( + f"Min: {validation_details['min_radius']:.6f}, Max: {validation_details['max_radius']:.6f}, Avg: {validation_details['avg_radius']:.6f}" + ) + + return EvaluationResult( + metrics={ + "sum_radii": float(sum_radii), + "target_ratio": float(target_ratio), + "validity": float(validity), + "eval_time": float(eval_time), + "combined_score": float(combined_score), + }, + artifacts=artifacts, + ) + + except TimeoutError as e: + error_msg = f"Evaluation timed out: {str(e)}" + print(error_msg) + return EvaluationResult( + metrics={ + "sum_radii": 0.0, + "target_ratio": 0.0, + "validity": 0.0, + "eval_time": 600.0, # Timeout duration + "combined_score": 0.0, + }, + artifacts={ + "stderr": error_msg, + "failure_stage": "execution_timeout", + "timeout_duration": "600s", + "suggestion": "Consider optimizing the packing algorithm for faster convergence", + }, + ) + except Exception as e: + error_msg = f"Evaluation failed completely: {str(e)}" + print(error_msg) + traceback.print_exc() + return EvaluationResult( + metrics={ + "sum_radii": 0.0, + "target_ratio": 0.0, + "validity": 0.0, + "eval_time": 0.0, + "combined_score": 0.0, + }, + artifacts={ + "stderr": error_msg, + "traceback": traceback.format_exc(), + "failure_stage": "program_execution", + "suggestion": "Check for syntax errors, import issues, or runtime exceptions", + }, + ) + + +# Stage-based evaluation for cascade evaluation +def evaluate_stage1(program_path): + """ + First stage evaluation - quick validation check + Enhanced with artifacts for debugging + """ + try: + # Use the simplified subprocess approach + try: + start_time = time.time() + centers, radii, sum_radii = run_with_timeout(program_path, timeout_seconds=600) + eval_time = time.time() - start_time + + # Ensure centers and radii are numpy arrays + if not isinstance(centers, np.ndarray): + centers = np.array(centers) + if not isinstance(radii, np.ndarray): + radii = np.array(radii) + + # Validate solution (shapes and constraints) + shape_valid = centers.shape == (26, 2) and radii.shape == (26,) + if not shape_valid: + shape_error = f"Invalid shapes: centers={centers.shape}, radii={radii.shape}" + print(shape_error) + return EvaluationResult( + metrics={"validity": 0.0, "combined_score": 0.0}, + artifacts={ + "stderr": shape_error, + "failure_stage": "stage1_shape_validation", + "expected_shapes": "centers: (26, 2), radii: (26,)", + "actual_shapes": f"centers: {centers.shape}, radii: {radii.shape}", + "execution_time": f"{eval_time:.2f}s", + }, + ) + + valid, validation_details = validate_packing(centers, radii) + + # Calculate sum + actual_sum = np.sum(radii) if valid else 0.0 + + # Target from paper + target = 2.635 + + # Simple combined score for stage 1 + combined_score = (actual_sum / target) if valid else 0.0 + + # Prepare artifacts for stage 1 + artifacts = { + "execution_time": f"{eval_time:.2f}s", + "stage": "quick_validation", + "packing_summary": f"Sum: {actual_sum:.6f}, Ratio: {actual_sum/target:.4f}", + } + + # Add validation issues if any + if not valid: + artifacts["stderr"] = ( + f"Validation failed: {len(validation_details.get('boundary_violations', []))} boundary violations, {len(validation_details.get('overlaps', []))} overlaps" + ) + artifacts["failure_stage"] = "stage1_geometric_validation" + if validation_details.get("boundary_violations"): + artifacts["boundary_issues"] = validation_details["boundary_violations"][ + 0 + ] # Just first issue + if validation_details.get("overlaps"): + artifacts["overlap_issues"] = validation_details["overlaps"][ + 0 + ] # Just first issue + + # Return evaluation metrics + return EvaluationResult( + metrics={ + "validity": 1.0 if valid else 0.0, + "sum_radii": float(actual_sum), + "target_ratio": float(actual_sum / target if valid else 0.0), + "combined_score": float(combined_score), + }, + artifacts=artifacts, + ) + + except TimeoutError as e: + error_msg = f"Stage 1 evaluation timed out: {e}" + print(error_msg) + return EvaluationResult( + metrics={"validity": 0.0, "combined_score": 0.0}, + artifacts={ + "stderr": error_msg, + "failure_stage": "stage1_timeout", + "timeout_duration": "600s", + "suggestion": "Algorithm may be too slow for stage 1 - consider simpler heuristics", + }, + ) + except Exception as e: + error_msg = f"Stage 1 evaluation failed: {e}" + print(error_msg) + print(traceback.format_exc()) + return EvaluationResult( + metrics={"validity": 0.0, "combined_score": 0.0}, + artifacts={ + "stderr": error_msg, + "traceback": traceback.format_exc(), + "failure_stage": "stage1_execution", + "suggestion": "Check basic syntax and imports before attempting full evaluation", + }, + ) + + except Exception as e: + error_msg = f"Stage 1 evaluation failed completely: {e}" + print(error_msg) + print(traceback.format_exc()) + return EvaluationResult( + metrics={"validity": 0.0, "combined_score": 0.0}, + artifacts={ + "stderr": error_msg, + "traceback": traceback.format_exc(), + "failure_stage": "stage1_critical_failure", + "suggestion": "Major issues detected - check program structure and dependencies", + }, + ) + + +def evaluate_stage2(program_path): + """ + Second stage evaluation - full evaluation + """ + # Full evaluation as in the main evaluate function + return evaluate(program_path) diff --git a/examples/circle_packing_with_artifacts/initial_program.py b/examples/circle_packing_with_artifacts/initial_program.py new file mode 100644 index 000000000..cb4ea397e --- /dev/null +++ b/examples/circle_packing_with_artifacts/initial_program.py @@ -0,0 +1,133 @@ +# EVOLVE-BLOCK-START +"""Constructor-based circle packing for n=26 circles""" +import numpy as np + + +def construct_packing(): + """ + Construct a specific arrangement of 26 circles in a unit square + that attempts to maximize the sum of their radii. + + Returns: + Tuple of (centers, radii, sum_of_radii) + centers: np.array of shape (26, 2) with (x, y) coordinates + radii: np.array of shape (26) with radius of each circle + sum_of_radii: Sum of all radii + """ + # Initialize arrays for 26 circles + n = 26 + centers = np.zeros((n, 2)) + + # Place circles in a structured pattern + # This is a simple pattern - evolution will improve this + + # First, place a large circle in the center + centers[0] = [0.5, 0.5] + + # Place 8 circles around it in a ring + for i in range(8): + angle = 2 * np.pi * i / 8 + centers[i + 1] = [0.5 + 0.3 * np.cos(angle), 0.5 + 0.3 * np.sin(angle)] + + # Place 16 more circles in an outer ring + for i in range(16): + angle = 2 * np.pi * i / 16 + centers[i + 9] = [0.5 + 0.7 * np.cos(angle), 0.5 + 0.7 * np.sin(angle)] + + # Additional positioning adjustment to make sure all circles + # are inside the square and don't overlap + # Clip to ensure everything is inside the unit square + centers = np.clip(centers, 0.01, 0.99) + + # Compute maximum valid radii for this configuration + radii = compute_max_radii(centers) + + # Calculate the sum of radii + sum_radii = np.sum(radii) + + return centers, radii, sum_radii + + +def compute_max_radii(centers): + """ + Compute the maximum possible radii for each circle position + such that they don't overlap and stay within the unit square. + + Args: + centers: np.array of shape (n, 2) with (x, y) coordinates + + Returns: + np.array of shape (n) with radius of each circle + """ + n = centers.shape[0] + radii = np.ones(n) + + # First, limit by distance to square borders + for i in range(n): + x, y = centers[i] + # Distance to borders + radii[i] = min(x, y, 1 - x, 1 - y) + + # Then, limit by distance to other circles + # Each pair of circles with centers at distance d can have + # sum of radii at most d to avoid overlap + for i in range(n): + for j in range(i + 1, n): + dist = np.sqrt(np.sum((centers[i] - centers[j]) ** 2)) + + # If current radii would cause overlap + if radii[i] + radii[j] > dist: + # Scale both radii proportionally + scale = dist / (radii[i] + radii[j]) + radii[i] *= scale + radii[j] *= scale + + return radii + + +# EVOLVE-BLOCK-END + + +# This part remains fixed (not evolved) +def run_packing(): + """Run the circle packing constructor for n=26""" + centers, radii, sum_radii = construct_packing() + return centers, radii, sum_radii + + +def visualize(centers, radii): + """ + Visualize the circle packing + + Args: + centers: np.array of shape (n, 2) with (x, y) coordinates + radii: np.array of shape (n) with radius of each circle + """ + import matplotlib.pyplot as plt + from matplotlib.patches import Circle + + fig, ax = plt.subplots(figsize=(8, 8)) + + # Draw unit square + ax.set_xlim(0, 1) + ax.set_ylim(0, 1) + ax.set_aspect("equal") + ax.grid(True) + + # Draw circles + for i, (center, radius) in enumerate(zip(centers, radii)): + circle = Circle(center, radius, alpha=0.5) + ax.add_patch(circle) + ax.text(center[0], center[1], str(i), ha="center", va="center") + + plt.title(f"Circle Packing (n={len(centers)}, sum={sum(radii):.6f})") + plt.show() + + +if __name__ == "__main__": + centers, radii, sum_radii = run_packing() + print(f"Sum of radii: {sum_radii}") + # AlphaEvolve improved this to 2.635 + + # Uncomment to visualize: + visualize(centers, radii) diff --git a/examples/circle_packing_with_artifacts/requirements.txt b/examples/circle_packing_with_artifacts/requirements.txt new file mode 100644 index 000000000..067d4a6ea --- /dev/null +++ b/examples/circle_packing_with_artifacts/requirements.txt @@ -0,0 +1,2 @@ +matplotlib +scipy \ No newline at end of file diff --git a/openevolve/config.py b/openevolve/config.py index 1a1a19338..b4f5f55f5 100644 --- a/openevolve/config.py +++ b/openevolve/config.py @@ -129,6 +129,11 @@ class PromptConfig: use_meta_prompting: bool = False meta_prompt_weight: float = 0.1 + # Artifact rendering + include_artifacts: bool = True + max_artifact_bytes: int = 20 * 1024 # 20KB in prompt + artifact_security_filter: bool = True + @dataclass class DatabaseConfig: @@ -160,6 +165,12 @@ class DatabaseConfig: # Random seed for reproducible sampling random_seed: Optional[int] = None + # Artifact storage + artifacts_base_path: Optional[str] = None # Defaults to db_path/artifacts + artifact_size_threshold: int = 32 * 1024 # 32KB threshold + cleanup_old_artifacts: bool = True + artifact_retention_days: int = 30 + @dataclass class EvaluatorConfig: @@ -185,6 +196,10 @@ class EvaluatorConfig: use_llm_feedback: bool = False llm_feedback_weight: float = 0.1 + # Artifact handling + enable_artifacts: bool = True + max_artifact_storage: int = 100 * 1024 * 1024 # 100MB per program + @dataclass class Config: diff --git a/openevolve/controller.py b/openevolve/controller.py index 466b6d779..84e1683d0 100644 --- a/openevolve/controller.py +++ b/openevolve/controller.py @@ -230,6 +230,9 @@ async def run( # Sample parent and inspirations from current island parent, inspirations = self.database.sample() + # Get artifacts for the parent program if available + parent_artifacts = self.database.get_artifacts(parent.id) + # Build prompt prompt = self.prompt_sampler.build_prompt( current_program=parent.code, @@ -240,6 +243,7 @@ async def run( language=self.language, evolution_round=i, allow_full_rewrite=self.config.allow_full_rewrites, + program_artifacts=parent_artifacts if parent_artifacts else None, ) # Generate code modification @@ -283,6 +287,9 @@ async def run( child_id = str(uuid.uuid4()) child_metrics = await self.evaluator.evaluate_program(child_code, child_id) + # Handle artifacts if they exist + artifacts = self.evaluator.get_pending_artifacts(child_id) + # Create a child program child_program = Program( id=child_id, @@ -300,6 +307,10 @@ async def run( # Add to database (will be added to current island) self.database.add(child_program, iteration=i + 1) + # Store artifacts if they exist + if artifacts: + self.database.store_artifacts(child_id, artifacts) + # Increment generation for current island self.database.increment_island_generation() diff --git a/openevolve/database.py b/openevolve/database.py index 983138d94..48527c384 100644 --- a/openevolve/database.py +++ b/openevolve/database.py @@ -2,6 +2,7 @@ Program database for OpenEvolve """ +import base64 import json import logging import os @@ -45,6 +46,10 @@ class Program: # Metadata metadata: Dict[str, Any] = field(default_factory=dict) + # Artifact storage + artifacts_json: Optional[str] = None # JSON-serialized small artifacts + artifact_dir: Optional[str] = None # Path to large artifact files + def to_dict(self) -> Dict[str, Any]: """Convert to dictionary representation""" return asdict(self) @@ -911,3 +916,166 @@ def log_island_status(self) -> None: f"best={stat['best_score']:.4f}, avg={stat['average_score']:.4f}, " f"diversity={stat['diversity']:.2f}, gen={stat['generation']}" ) + + # Artifact storage and retrieval methods + + def store_artifacts(self, program_id: str, artifacts: Dict[str, Union[str, bytes]]) -> None: + """ + Store artifacts for a program + + Args: + program_id: ID of the program + artifacts: Dictionary of artifact name to content + """ + if not artifacts: + return + + program = self.get(program_id) + if not program: + logger.warning(f"Cannot store artifacts: program {program_id} not found") + return + + # Check if artifacts are enabled + artifacts_enabled = os.environ.get("ENABLE_ARTIFACTS", "true").lower() == "true" + if not artifacts_enabled: + logger.debug("Artifacts disabled, skipping storage") + return + + # Split artifacts by size + small_artifacts = {} + large_artifacts = {} + size_threshold = getattr(self.config, "artifact_size_threshold", 32 * 1024) # 32KB default + + for key, value in artifacts.items(): + size = self._get_artifact_size(value) + if size <= size_threshold: + small_artifacts[key] = value + else: + large_artifacts[key] = value + + # Store small artifacts as JSON + if small_artifacts: + program.artifacts_json = json.dumps(small_artifacts, default=self._artifact_serializer) + logger.debug(f"Stored {len(small_artifacts)} small artifacts for program {program_id}") + + # Store large artifacts to disk + if large_artifacts: + artifact_dir = self._create_artifact_dir(program_id) + program.artifact_dir = artifact_dir + for key, value in large_artifacts.items(): + self._write_artifact_file(artifact_dir, key, value) + logger.debug(f"Stored {len(large_artifacts)} large artifacts for program {program_id}") + + def get_artifacts(self, program_id: str) -> Dict[str, Union[str, bytes]]: + """ + Retrieve all artifacts for a program + + Args: + program_id: ID of the program + + Returns: + Dictionary of artifact name to content + """ + program = self.get(program_id) + if not program: + return {} + + artifacts = {} + + # Load small artifacts from JSON + if program.artifacts_json: + try: + small_artifacts = json.loads(program.artifacts_json) + artifacts.update(small_artifacts) + except json.JSONDecodeError as e: + logger.warning(f"Failed to decode artifacts JSON for program {program_id}: {e}") + + # Load large artifacts from disk + if program.artifact_dir and os.path.exists(program.artifact_dir): + disk_artifacts = self._load_artifact_dir(program.artifact_dir) + artifacts.update(disk_artifacts) + + return artifacts + + def _get_artifact_size(self, value: Union[str, bytes]) -> int: + """Get size of an artifact value in bytes""" + if isinstance(value, str): + return len(value.encode("utf-8")) + elif isinstance(value, bytes): + return len(value) + else: + return len(str(value).encode("utf-8")) + + def _artifact_serializer(self, obj): + """JSON serializer for artifacts that handles bytes""" + if isinstance(obj, bytes): + return {"__bytes__": base64.b64encode(obj).decode("utf-8")} + raise TypeError(f"Object of type {type(obj)} is not JSON serializable") + + def _artifact_deserializer(self, dct): + """JSON deserializer for artifacts that handles bytes""" + if "__bytes__" in dct: + return base64.b64decode(dct["__bytes__"]) + return dct + + def _create_artifact_dir(self, program_id: str) -> str: + """Create artifact directory for a program""" + base_path = getattr(self.config, "artifacts_base_path", None) + if not base_path: + base_path = ( + os.path.join(self.config.db_path or ".", "artifacts") + if self.config.db_path + else "./artifacts" + ) + + artifact_dir = os.path.join(base_path, program_id) + os.makedirs(artifact_dir, exist_ok=True) + return artifact_dir + + def _write_artifact_file(self, artifact_dir: str, key: str, value: Union[str, bytes]) -> None: + """Write an artifact to a file""" + # Sanitize filename + safe_key = "".join(c for c in key if c.isalnum() or c in "._-") + if not safe_key: + safe_key = "artifact" + + file_path = os.path.join(artifact_dir, safe_key) + + try: + if isinstance(value, str): + with open(file_path, "w", encoding="utf-8") as f: + f.write(value) + elif isinstance(value, bytes): + with open(file_path, "wb") as f: + f.write(value) + else: + # Convert to string and write + with open(file_path, "w", encoding="utf-8") as f: + f.write(str(value)) + except Exception as e: + logger.warning(f"Failed to write artifact {key} to {file_path}: {e}") + + def _load_artifact_dir(self, artifact_dir: str) -> Dict[str, Union[str, bytes]]: + """Load artifacts from a directory""" + artifacts = {} + + try: + for filename in os.listdir(artifact_dir): + file_path = os.path.join(artifact_dir, filename) + if os.path.isfile(file_path): + try: + # Try to read as text first + with open(file_path, "r", encoding="utf-8") as f: + content = f.read() + artifacts[filename] = content + except UnicodeDecodeError: + # If text fails, read as binary + with open(file_path, "rb") as f: + content = f.read() + artifacts[filename] = content + except Exception as e: + logger.warning(f"Failed to read artifact file {file_path}: {e}") + except Exception as e: + logger.warning(f"Failed to list artifact directory {artifact_dir}: {e}") + + return artifacts diff --git a/openevolve/evaluation_result.py b/openevolve/evaluation_result.py new file mode 100644 index 000000000..06b22dd2b --- /dev/null +++ b/openevolve/evaluation_result.py @@ -0,0 +1,54 @@ +""" +Evaluation result structures for OpenEvolve +""" + +import json +from dataclasses import dataclass, field +from typing import Dict, Union + + +@dataclass +class EvaluationResult: + """ + Result of program evaluation containing both metrics and optional artifacts + + This maintains backward compatibility with the existing dict[str, float] contract + while adding a side-channel for arbitrary artifacts (text or binary data). + """ + + metrics: Dict[str, float] # mandatory - existing contract + artifacts: Dict[str, Union[str, bytes]] = field(default_factory=dict) # optional side-channel + + @classmethod + def from_dict(cls, metrics: Dict[str, float]) -> "EvaluationResult": + """Auto-wrap dict returns for backward compatibility""" + return cls(metrics=metrics) + + def to_dict(self) -> Dict[str, float]: + """Backward compatibility - return just metrics""" + return self.metrics + + def has_artifacts(self) -> bool: + """Check if this result contains any artifacts""" + return bool(self.artifacts) + + def get_artifact_keys(self) -> list: + """Get list of artifact keys""" + return list(self.artifacts.keys()) + + def get_artifact_size(self, key: str) -> int: + """Get size of a specific artifact in bytes""" + if key not in self.artifacts: + return 0 + + value = self.artifacts[key] + if isinstance(value, str): + return len(value.encode("utf-8")) + elif isinstance(value, bytes): + return len(value) + else: + return len(str(value).encode("utf-8")) + + def get_total_artifact_size(self) -> int: + """Get total size of all artifacts in bytes""" + return sum(self.get_artifact_size(key) for key in self.artifacts.keys()) diff --git a/openevolve/evaluator.py b/openevolve/evaluator.py index a66207544..e57b01224 100644 --- a/openevolve/evaluator.py +++ b/openevolve/evaluator.py @@ -11,12 +11,14 @@ import sys import tempfile import time +import traceback import uuid from pathlib import Path from typing import Any, Callable, Dict, List, Optional, Tuple, Union import traceback from openevolve.config import EvaluatorConfig +from openevolve.evaluation_result import EvaluationResult from openevolve.llm.ensemble import LLMEnsemble from openevolve.utils.async_utils import TaskPool, run_in_executor from openevolve.prompt.sampler import PromptSampler @@ -51,6 +53,9 @@ def __init__( # Set up evaluation function if file exists self._load_evaluation_function() + # Pending artifacts storage for programs + self._pending_artifacts: Dict[str, Dict[str, Union[str, bytes]]] = {} + logger.info(f"Initialized evaluator with {evaluation_file}") def _load_evaluation_function(self) -> None: @@ -96,6 +101,9 @@ async def evaluate_program( start_time = time.time() program_id_str = f" {program_id}" if program_id else "" + # Check if artifacts are enabled + artifacts_enabled = os.environ.get("ENABLE_ARTIFACTS", "true").lower() == "true" + # Retry logic for evaluation last_exception = None for attempt in range(self.config.max_retries + 1): @@ -108,10 +116,13 @@ async def evaluate_program( # Run evaluation if self.config.cascade_evaluation: # Run cascade evaluation - metrics = await self._cascade_evaluate(temp_file_path) + result = await self._cascade_evaluate(temp_file_path) else: # Run direct evaluation - metrics = await self._direct_evaluate(temp_file_path) + result = await self._direct_evaluate(temp_file_path) + + # Process the result based on type + eval_result = self._process_evaluation_result(result) # Add LLM feedback if configured if self.config.use_llm_feedback and self.llm_ensemble: @@ -119,15 +130,20 @@ async def evaluate_program( # Combine metrics for name, value in feedback_metrics.items(): - metrics[f"llm_{name}"] = value * self.config.llm_feedback_weight + eval_result.metrics[f"llm_{name}"] = value * self.config.llm_feedback_weight + + # Store artifacts if enabled and present + if artifacts_enabled and eval_result.has_artifacts() and program_id: + self._pending_artifacts[program_id] = eval_result.artifacts elapsed = time.time() - start_time logger.info( f"Evaluated program{program_id_str} in {elapsed:.2f}s: " - f"{format_metrics_safe(metrics)}" + f"{format_metrics_safe(eval_result.metrics)}" ) - return metrics + # Return just metrics for backward compatibility + return eval_result.metrics except Exception as e: last_exception = e @@ -135,6 +151,14 @@ async def evaluate_program( f"Evaluation attempt {attempt + 1}/{self.config.max_retries + 1} failed for program{program_id_str}: {str(e)}" ) + # Capture failure artifacts if enabled + if artifacts_enabled and program_id: + self._pending_artifacts[program_id] = { + "stderr": str(e), + "traceback": traceback.format_exc(), + "failure_stage": "evaluation", + } + # If this is not the last attempt, wait a bit before retrying if attempt < self.config.max_retries: await asyncio.sleep(1.0) # Wait 1 second before retry @@ -150,6 +174,39 @@ async def evaluate_program( ) return {"error": 0.0} + def _process_evaluation_result(self, result: Any) -> EvaluationResult: + """ + Process evaluation result to handle both dict and EvaluationResult returns + + Args: + result: Raw result from evaluation function + + Returns: + EvaluationResult instance + """ + if isinstance(result, dict): + # Backward compatibility - wrap dict in EvaluationResult + return EvaluationResult.from_dict(result) + elif isinstance(result, EvaluationResult): + # New format - use directly + return result + else: + # Error case - return error metrics + logger.warning(f"Unexpected evaluation result type: {type(result)}") + return EvaluationResult(metrics={"error": 0.0}) + + def get_pending_artifacts(self, program_id: str) -> Optional[Dict[str, Union[str, bytes]]]: + """ + Get and clear pending artifacts for a program + + Args: + program_id: Program ID + + Returns: + Artifacts dictionary or None if not found + """ + return self._pending_artifacts.pop(program_id, None) + @run_in_executor def _direct_evaluate(self, program_path: str) -> Dict[str, float]: """ @@ -176,7 +233,9 @@ def _direct_evaluate(self, program_path: str) -> Dict[str, float]: logger.error(f"Error in direct evaluation: {str(e)}") return {"error": 0.0} - async def _cascade_evaluate(self, program_path: str) -> Dict[str, float]: + async def _cascade_evaluate( + self, program_path: str + ) -> Union[Dict[str, float], EvaluationResult]: """ Run cascade evaluation with increasingly challenging test cases @@ -184,7 +243,7 @@ async def _cascade_evaluate(self, program_path: str) -> Dict[str, float]: program_path: Path to the program file Returns: - Dictionary of metric name to score + Dictionary of metrics or EvaluationResult with metrics and artifacts """ # Import the evaluation module to get cascade functions if they exist try: @@ -202,78 +261,110 @@ async def _cascade_evaluate(self, program_path: str) -> Dict[str, float]: # Run first stage try: stage1_result = await run_in_executor(module.evaluate_stage1)(program_path) - if not isinstance(stage1_result, dict): - logger.warning( - f"Stage 1 evaluation returned non-dictionary result: {stage1_result}" - ) - return {"error": 0.0} + stage1_eval_result = self._process_evaluation_result(stage1_result) except Exception as e: logger.error(f"Error in stage 1 evaluation: {str(e)}") - return {"error": 0.0} + # Capture stage 1 failure as artifacts + return EvaluationResult( + metrics={"stage1_passed": 0.0, "error": 0.0}, + artifacts={ + "stderr": str(e), + "traceback": traceback.format_exc(), + "failure_stage": "stage1", + }, + ) # Check threshold - if not self._passes_threshold(stage1_result, self.config.cascade_thresholds[0]): - return stage1_result + if not self._passes_threshold( + stage1_eval_result.metrics, self.config.cascade_thresholds[0] + ): + return stage1_eval_result # Check if second stage exists if not hasattr(module, "evaluate_stage2"): - return stage1_result + return stage1_eval_result # Run second stage try: stage2_result = await run_in_executor(module.evaluate_stage2)(program_path) - if not isinstance(stage2_result, dict): - logger.warning( - f"Stage 2 evaluation returned non-dictionary result: {stage2_result}" - ) - return stage1_result + stage2_eval_result = self._process_evaluation_result(stage2_result) except Exception as e: logger.error(f"Error in stage 2 evaluation: {str(e)}") - return stage1_result + # Capture stage 2 failure, but keep stage 1 results + stage1_eval_result.artifacts.update( + { + "stage2_stderr": str(e), + "stage2_traceback": traceback.format_exc(), + "failure_stage": "stage2", + } + ) + stage1_eval_result.metrics["stage2_passed"] = 0.0 + return stage1_eval_result - # Merge results - result = {} + # Merge results from stage 1 and 2 + merged_metrics = {} # Convert all values to float to avoid type errors - for name, value in stage1_result.items(): + for name, value in stage1_eval_result.metrics.items(): if isinstance(value, (int, float)) and name != "error": - result[name] = float(value) + merged_metrics[name] = float(value) - for name, value in stage2_result.items(): + for name, value in stage2_eval_result.metrics.items(): if isinstance(value, (int, float)) and name != "error": - result[name] = float(value) + merged_metrics[name] = float(value) - # Check threshold + # Merge artifacts + merged_artifacts = {} + merged_artifacts.update(stage1_eval_result.artifacts) + merged_artifacts.update(stage2_eval_result.artifacts) + + merged_result = EvaluationResult(metrics=merged_metrics, artifacts=merged_artifacts) + + # Check threshold for stage 3 if len(self.config.cascade_thresholds) < 2 or not self._passes_threshold( - result, self.config.cascade_thresholds[1] + merged_result.metrics, self.config.cascade_thresholds[1] ): - return result + return merged_result # Check if third stage exists if not hasattr(module, "evaluate_stage3"): - return result + return merged_result # Run third stage try: stage3_result = await run_in_executor(module.evaluate_stage3)(program_path) - if not isinstance(stage3_result, dict): - logger.warning( - f"Stage 3 evaluation returned non-dictionary result: {stage3_result}" - ) - return result + stage3_eval_result = self._process_evaluation_result(stage3_result) except Exception as e: logger.error(f"Error in stage 3 evaluation: {str(e)}") - return result + # Capture stage 3 failure, but keep previous results + merged_result.artifacts.update( + { + "stage3_stderr": str(e), + "stage3_traceback": traceback.format_exc(), + "failure_stage": "stage3", + } + ) + merged_result.metrics["stage3_passed"] = 0.0 + return merged_result - # Merge results - for name, value in stage3_result.items(): + # Merge stage 3 results + for name, value in stage3_eval_result.metrics.items(): if isinstance(value, (int, float)) and name != "error": - result[name] = float(value) + merged_result.metrics[name] = float(value) - return result + merged_result.artifacts.update(stage3_eval_result.artifacts) + + return merged_result except Exception as e: logger.error(f"Error in cascade evaluation: {str(e)}") - return {"error": 0.0} + return EvaluationResult( + metrics={"error": 0.0}, + artifacts={ + "stderr": str(e), + "traceback": traceback.format_exc(), + "failure_stage": "cascade_setup", + }, + ) async def _llm_evaluate(self, program_code: str) -> Dict[str, float]: """ diff --git a/openevolve/prompt/sampler.py b/openevolve/prompt/sampler.py index 6d543424b..9a5310b8d 100644 --- a/openevolve/prompt/sampler.py +++ b/openevolve/prompt/sampler.py @@ -55,6 +55,7 @@ def build_prompt( evolution_round: int = 0, allow_full_rewrite: bool = False, template_key: Optional[str] = None, + program_artifacts: Optional[Dict[str, Union[str, bytes]]] = None, **kwargs: Any, ) -> Dict[str, str]: """ @@ -70,6 +71,7 @@ def build_prompt( evolution_round: Current evolution round allow_full_rewrite: Whether to allow a full rewrite template_key: Optional override for template key + program_artifacts: Optional artifacts from program evaluation **kwargs: Additional keys to replace in the user prompt Returns: @@ -111,6 +113,11 @@ def build_prompt( previous_programs, top_programs, language ) + # Format artifacts section if enabled and available + artifacts_section = "" + if self.config.include_artifacts and program_artifacts: + artifacts_section = self._render_artifacts(program_artifacts) + # Apply stochastic template variations if enabled if self.config.use_template_stochasticity: user_template = self._apply_template_variations(user_template) @@ -122,6 +129,7 @@ def build_prompt( evolution_history=evolution_history, current_program=current_program, language=language, + artifacts=artifacts_section, **kwargs, ) @@ -396,3 +404,87 @@ def _apply_template_variations(self, template: str) -> str: result = result.replace(f"{{{key}}}", chosen_variation) return result + + def _render_artifacts(self, artifacts: Dict[str, Union[str, bytes]]) -> str: + """ + Render artifacts for prompt inclusion + + Args: + artifacts: Dictionary of artifact name to content + + Returns: + Formatted string for prompt inclusion (empty string if no artifacts) + """ + if not artifacts: + return "" + + sections = [] + + # Process all artifacts using .items() + for key, value in artifacts.items(): + content = self._safe_decode_artifact(value) + # Truncate if too long + if len(content) > self.config.max_artifact_bytes: + content = content[: self.config.max_artifact_bytes] + "\n... (truncated)" + + sections.append(f"### {key}\n```\n{content}\n```") + + if sections: + return "## Last Execution Output\n\n" + "\n\n".join(sections) + else: + return "" + + def _safe_decode_artifact(self, value: Union[str, bytes]) -> str: + """ + Safely decode an artifact value to string + + Args: + value: Artifact value (string or bytes) + + Returns: + String representation of the value + """ + if isinstance(value, str): + # Apply security filter if enabled + if self.config.artifact_security_filter: + return self._apply_security_filter(value) + return value + elif isinstance(value, bytes): + try: + decoded = value.decode("utf-8", errors="replace") + if self.config.artifact_security_filter: + return self._apply_security_filter(decoded) + return decoded + except Exception: + return f"" + else: + return str(value) + + def _apply_security_filter(self, text: str) -> str: + """ + Apply security filtering to artifact text + + Args: + text: Input text + + Returns: + Filtered text with potential secrets/sensitive info removed + """ + import re + + # Remove ANSI escape sequences + ansi_escape = re.compile(r"\x1B(?:[@-Z\\-_]|\[[0-?]*[ -/]*[@-~])") + filtered = ansi_escape.sub("", text) + + # Basic patterns for common secrets (can be expanded) + secret_patterns = [ + (r"[A-Za-z0-9]{32,}", ""), # Long alphanumeric tokens + (r"sk-[A-Za-z0-9]{48}", ""), # OpenAI-style API keys + (r"password[=:]\s*[^\s]+", "password="), # Password assignments + (r"token[=:]\s*[^\s]+", "token="), # Token assignments + ] + + for pattern, replacement in secret_patterns: + filtered = re.sub(pattern, replacement, filtered, flags=re.IGNORECASE) + + return filtered diff --git a/openevolve/prompt/templates.py b/openevolve/prompt/templates.py index 82d5a6b03..f5095080a 100644 --- a/openevolve/prompt/templates.py +++ b/openevolve/prompt/templates.py @@ -20,6 +20,8 @@ - Current performance metrics: {metrics} - Areas identified for improvement: {improvement_areas} +{artifacts} + # Program Evolution History {evolution_history} @@ -64,6 +66,8 @@ - Current performance metrics: {metrics} - Areas identified for improvement: {improvement_areas} +{artifacts} + # Program Evolution History {evolution_history} diff --git a/tests/test_artifacts.py b/tests/test_artifacts.py new file mode 100644 index 000000000..5ccbce55f --- /dev/null +++ b/tests/test_artifacts.py @@ -0,0 +1,325 @@ +""" +Test suite for artifacts functionality +""" + +import asyncio +import os +import tempfile +import unittest +from unittest.mock import Mock, patch + +from openevolve.config import DatabaseConfig, EvaluatorConfig, PromptConfig +from openevolve.database import Program, ProgramDatabase +from openevolve.evaluation_result import EvaluationResult +from openevolve.evaluator import Evaluator +from openevolve.prompt.sampler import PromptSampler + + +class TestEvaluationResult(unittest.TestCase): + """Test the EvaluationResult dataclass""" + + def test_from_dict_compatibility(self): + """Test that dict -> EvaluationResult -> dict roundtrip works""" + original_dict = {"accuracy": 0.95, "speed": 0.8} + + # Convert to EvaluationResult + eval_result = EvaluationResult.from_dict(original_dict) + + # Check structure + self.assertEqual(eval_result.metrics, original_dict) + self.assertEqual(eval_result.artifacts, {}) + + # Convert back to dict + result_dict = eval_result.to_dict() + self.assertEqual(result_dict, original_dict) + + def test_evaluation_result_with_artifacts(self): + """Test EvaluationResult with artifacts""" + metrics = {"accuracy": 0.95} + artifacts = {"stderr": "compilation error", "stdout": "test output"} + + eval_result = EvaluationResult(metrics=metrics, artifacts=artifacts) + + self.assertEqual(eval_result.metrics, metrics) + self.assertEqual(eval_result.artifacts, artifacts) + self.assertTrue(eval_result.has_artifacts()) + self.assertEqual(eval_result.get_artifact_keys(), ["stderr", "stdout"]) + + def test_artifact_size_calculation(self): + """Test artifact size calculation""" + eval_result = EvaluationResult( + metrics={"score": 1.0}, artifacts={"text": "hello world", "binary": b"binary data"} + ) + + # Text should be encoded to bytes for size calculation + text_size = eval_result.get_artifact_size("text") + self.assertEqual(text_size, len("hello world".encode("utf-8"))) + + # Binary should return length directly + binary_size = eval_result.get_artifact_size("binary") + self.assertEqual(binary_size, len(b"binary data")) + + # Total size + total_size = eval_result.get_total_artifact_size() + self.assertEqual(total_size, text_size + binary_size) + + +class TestDatabaseArtifacts(unittest.TestCase): + """Test artifact storage in the database""" + + def setUp(self): + self.temp_dir = tempfile.mkdtemp() + config = DatabaseConfig(db_path=self.temp_dir) + self.database = ProgramDatabase(config) + + # Create a test program + self.program = Program(id="test_program_1", code="print('hello')", metrics={"score": 0.5}) + self.database.add(self.program) + + def test_store_small_artifacts(self): + """Test storing small artifacts in JSON""" + artifacts = {"stderr": "small error message", "stdout": "small output"} + + self.database.store_artifacts(self.program.id, artifacts) + + # Retrieve artifacts + retrieved = self.database.get_artifacts(self.program.id) + self.assertEqual(retrieved, artifacts) + + # Check that program has artifacts_json set + program = self.database.get(self.program.id) + self.assertIsNotNone(program.artifacts_json) + + def test_store_large_artifacts(self): + """Test storing large artifacts to disk""" + large_content = "x" * (50 * 1024) # 50KB + artifacts = {"large_output": large_content} + + self.database.store_artifacts(self.program.id, artifacts) + + # Retrieve artifacts + retrieved = self.database.get_artifacts(self.program.id) + self.assertEqual(retrieved["large_output"], large_content) + + # Check that program has artifact_dir set + program = self.database.get(self.program.id) + self.assertIsNotNone(program.artifact_dir) + + def test_store_mixed_artifacts(self): + """Test storing both small and large artifacts""" + small_content = "small message" + large_content = "y" * (50 * 1024) # 50KB + + artifacts = {"stderr": small_content, "large_log": large_content} + + self.database.store_artifacts(self.program.id, artifacts) + + # Retrieve all artifacts + retrieved = self.database.get_artifacts(self.program.id) + self.assertEqual(retrieved["stderr"], small_content) + self.assertEqual(retrieved["large_log"], large_content) + + def test_artifacts_disabled(self): + """Test that artifacts are skipped when disabled""" + with patch.dict(os.environ, {"ENABLE_ARTIFACTS": "false"}): + artifacts = {"stderr": "error message"} + + # Should not store artifacts when disabled + self.database.store_artifacts(self.program.id, artifacts) + + # Should return empty dict + retrieved = self.database.get_artifacts(self.program.id) + self.assertEqual(retrieved, {}) + + +class TestEvaluatorArtifacts(unittest.TestCase): + """Test artifact handling in the evaluator""" + + def setUp(self): + # Set up event loop for async operations in tests + try: + self.loop = asyncio.get_event_loop() + except RuntimeError: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + # Create a mock evaluation file + self.temp_eval_file = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) + self.temp_eval_file.write( + """ +def evaluate(program_path): + return {"score": 0.5} +""" + ) + self.temp_eval_file.close() + + config = EvaluatorConfig() + self.evaluator = Evaluator(config, self.temp_eval_file.name) + + def tearDown(self): + os.unlink(self.temp_eval_file.name) + # Clean up event loop if we created one + if hasattr(self, "loop") and self.loop and not self.loop.is_closed(): + # Cancel any pending tasks + pending = asyncio.all_tasks(self.loop) + for task in pending: + task.cancel() + # Run the loop briefly to let cancellations process + if pending: + self.loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + + def test_evaluate_program_backward_compatibility(self): + """Test that old evaluators still work unchanged""" + + async def run_test(): + result = await self.evaluator.evaluate_program("print('test')", "test_id") + self.assertIsInstance(result, dict) + self.assertIn("score", result) + + asyncio.run(run_test()) + + def test_process_evaluation_result_dict(self): + """Test processing dict results""" + dict_result = {"accuracy": 0.9, "speed": 0.7} + eval_result = self.evaluator._process_evaluation_result(dict_result) + + self.assertIsInstance(eval_result, EvaluationResult) + self.assertEqual(eval_result.metrics, dict_result) + self.assertEqual(eval_result.artifacts, {}) + + def test_process_evaluation_result_evaluation_result(self): + """Test processing EvaluationResult objects""" + original = EvaluationResult(metrics={"score": 0.8}, artifacts={"stderr": "warning message"}) + + result = self.evaluator._process_evaluation_result(original) + self.assertEqual(result, original) + + def test_pending_artifacts(self): + """Test pending artifacts storage and retrieval""" + program_id = "test_program" + artifacts = {"stderr": "error", "stdout": "output"} + + # Store artifacts + self.evaluator._pending_artifacts[program_id] = artifacts + + # Retrieve and check that they're cleared + retrieved = self.evaluator.get_pending_artifacts(program_id) + self.assertEqual(retrieved, artifacts) + + # Should be None after retrieval + second_retrieval = self.evaluator.get_pending_artifacts(program_id) + self.assertIsNone(second_retrieval) + + +class TestPromptArtifacts(unittest.TestCase): + """Test artifact rendering in prompts""" + + def setUp(self): + config = PromptConfig() + self.sampler = PromptSampler(config) + + def test_render_artifacts_all_items(self): + """Test that all artifacts are included using .items() without prioritization""" + artifacts = { + "stderr": "error message", + "stdout": "output message", + "traceback": "stack trace", + "other": "other data", + } + + rendered = self.sampler._render_artifacts(artifacts) + + # All artifacts should be present (no prioritization) + for key in artifacts.keys(): + self.assertIn(key, rendered) + + # Check that all content is included + for value in artifacts.values(): + self.assertIn(value, rendered) + + def test_render_artifacts_generic(self): + """Test that all artifacts are included using .items()""" + artifacts = {"log1": "first log", "log2": "second log", "config": "configuration data"} + + rendered = self.sampler._render_artifacts(artifacts) + + # All artifacts should be present + for key in artifacts.keys(): + self.assertIn(key, rendered) + + def test_render_artifacts_truncation(self): + """Test artifact truncation for large content""" + # Create content larger than 20KB to trigger truncation + large_content = "This is a very long log message. " * 700 # Creates ~23KB of text + artifacts = {"large_log": large_content} + + rendered = self.sampler._render_artifacts(artifacts) + + # Should contain truncation indicator + self.assertIn("(truncated)", rendered) + + def test_render_artifacts_security_filter(self): + """Test that security filter redacts potential tokens""" + # Create content that looks like a token + token_like_content = "x" * 40 # 40 character string that looks like a token + artifacts = {"suspicious_log": token_like_content} + + rendered = self.sampler._render_artifacts(artifacts) + + # Should be redacted by security filter + self.assertIn("", rendered) + + def test_safe_decode_artifact_string(self): + """Test safe decoding of string artifacts""" + text = "hello world" + decoded = self.sampler._safe_decode_artifact(text) + self.assertEqual(decoded, text) + + def test_safe_decode_artifact_bytes(self): + """Test safe decoding of bytes artifacts""" + text = "hello world" + binary = text.encode("utf-8") + decoded = self.sampler._safe_decode_artifact(binary) + self.assertEqual(decoded, text) + + def test_safe_decode_artifact_invalid_bytes(self): + """Test safe decoding of invalid bytes""" + invalid_bytes = b"\xff\xfe\xfd" + decoded = self.sampler._safe_decode_artifact(invalid_bytes) + # Should not raise exception and should contain some indication of binary data + self.assertIsInstance(decoded, str) + + def test_build_prompt_with_artifacts(self): + """Test building prompt with artifacts""" + artifacts = {"stderr": "compilation error"} + + prompt = self.sampler.build_prompt( + current_program="print('test')", + parent_program="print('test')", + program_metrics={"score": 0.5}, + previous_programs=[], + top_programs=[], + program_artifacts=artifacts, + ) + + # Artifacts should be included in the user message (case-insensitive check) + self.assertIn("stderr", prompt["user"].lower()) + self.assertIn("compilation error", prompt["user"]) + + def test_build_prompt_without_artifacts(self): + """Test building prompt without artifacts""" + prompt = self.sampler.build_prompt( + current_program="print('test')", + parent_program="print('test')", + program_metrics={"score": 0.5}, + previous_programs=[], + top_programs=[], + ) + + # Should work normally without artifacts + self.assertIn("system", prompt) + self.assertIn("user", prompt) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/test_artifacts_integration.py b/tests/test_artifacts_integration.py new file mode 100644 index 000000000..852b32c66 --- /dev/null +++ b/tests/test_artifacts_integration.py @@ -0,0 +1,317 @@ +""" +Integration tests for artifacts functionality +""" + +import asyncio +import os +import tempfile +import unittest +from unittest.mock import Mock, patch + +from openevolve.config import Config, DatabaseConfig, EvaluatorConfig, PromptConfig +from openevolve.database import Program, ProgramDatabase +from openevolve.evaluation_result import EvaluationResult +from openevolve.evaluator import Evaluator +from openevolve.prompt.sampler import PromptSampler + + +class TestArtifactsIntegration(unittest.TestCase): + """Test full integration of artifacts feature""" + + def setUp(self): + # Set up event loop for async operations in tests + try: + self.loop = asyncio.get_event_loop() + except RuntimeError: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + # Create temporary directory for database + self.temp_dir = tempfile.mkdtemp() + + # Create evaluation file that can return EvaluationResult + self.eval_file = tempfile.NamedTemporaryFile(mode="w", suffix=".py", delete=False) + self.eval_file.write( + """ +import traceback +from openevolve.evaluation_result import EvaluationResult + +def evaluate(program_path): + try: + # Try to compile the program + with open(program_path, 'r') as f: + code = f.read() + + compile(code, program_path, 'exec') + + # If compilation succeeds, return good metrics + return EvaluationResult( + metrics={"compile_ok": 1.0, "score": 0.8}, + artifacts={"stdout": "Compilation successful"} + ) + except Exception as e: + # If compilation fails, capture the error + return EvaluationResult( + metrics={"compile_ok": 0.0, "score": 0.0}, + artifacts={ + "stderr": str(e), + "traceback": traceback.format_exc(), + "failure_stage": "compilation" + } + ) + +def evaluate_stage1(program_path): + # Basic compilation check + try: + with open(program_path, 'r') as f: + code = f.read() + + compile(code, program_path, 'exec') + return {"stage1_passed": 1.0, "compile_ok": 1.0} + except Exception as e: + return EvaluationResult( + metrics={"stage1_passed": 0.0, "compile_ok": 0.0}, + artifacts={ + "stderr": str(e), + "failure_stage": "stage1_compilation" + } + ) +""" + ) + self.eval_file.close() + + # Set up config + self.config = Config() + self.config.database.db_path = self.temp_dir + self.config.evaluator.cascade_evaluation = True + self.config.prompt.include_artifacts = True + + # Initialize components + self.database = ProgramDatabase(self.config.database) + self.evaluator = Evaluator(self.config.evaluator, self.eval_file.name) + self.prompt_sampler = PromptSampler(self.config.prompt) + + def tearDown(self): + os.unlink(self.eval_file.name) + # Clean up event loop if we created one + if hasattr(self, "loop") and self.loop and not self.loop.is_closed(): + # Cancel any pending tasks + pending = asyncio.all_tasks(self.loop) + for task in pending: + task.cancel() + # Run the loop briefly to let cancellations process + if pending: + self.loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + + def test_compile_failure_artifact_capture(self): + """Test that compilation failures are captured as artifacts""" + + async def run_test(): + # Program with syntax error + bad_code = "print('hello'\n # Missing closing parenthesis" + program_id = "bad_program_1" + + # Evaluate the program + metrics = await self.evaluator.evaluate_program(bad_code, program_id) + + # Should get failure metrics + self.assertEqual(metrics.get("compile_ok"), 0.0) + + # Should have pending artifacts + artifacts = self.evaluator.get_pending_artifacts(program_id) + self.assertIsNotNone(artifacts) + self.assertIn("stderr", artifacts) + # Note: stage1 evaluation doesn't include traceback, only stderr and failure_stage + self.assertIn("failure_stage", artifacts) + + return artifacts + + artifacts = asyncio.run(run_test()) + + # Verify artifact content - should have captured the compilation error + self.assertIn("stderr", artifacts) + self.assertTrue(len(artifacts["stderr"]) > 0, "stderr should not be empty") + self.assertIn("failure_stage", artifacts) + self.assertEqual(artifacts["failure_stage"], "stage1_compilation") + + def test_end_to_end_artifact_flow(self): + """Test full flow: eval failure -> artifact -> prompt -> next gen""" + + async def run_test(): + # 1. Create a program with compilation error + bad_code = "def broken_function(\n return 'incomplete'" + program_id = "flow_test_1" + + # 2. Evaluate and get artifacts + metrics = await self.evaluator.evaluate_program(bad_code, program_id) + artifacts = self.evaluator.get_pending_artifacts(program_id) + + # 3. Create program and store in database + program = Program(id=program_id, code=bad_code, language="python", metrics=metrics) + self.database.add(program) + + # 4. Store artifacts + if artifacts: + self.database.store_artifacts(program_id, artifacts) + + # 5. Retrieve artifacts and build prompt + stored_artifacts = self.database.get_artifacts(program_id) + + prompt = self.prompt_sampler.build_prompt( + current_program=bad_code, + parent_program=bad_code, + program_metrics=metrics, + previous_programs=[], + top_programs=[], + program_artifacts=stored_artifacts, + ) + + return prompt, stored_artifacts + + prompt, artifacts = asyncio.run(run_test()) + + # Verify artifacts appear in prompt + self.assertIn("stderr", prompt["user"].lower()) + self.assertIn("Last Execution Output", prompt["user"]) + + # Verify artifacts were stored and retrieved correctly + self.assertIn("stderr", artifacts) + self.assertTrue(len(artifacts["stderr"]) > 0, "stderr should not be empty") + + def test_cascade_evaluation_with_artifacts(self): + """Test cascade evaluation captures artifacts at each stage""" + + async def run_test(): + # Program that will fail at stage 1 + invalid_code = "invalid syntax here" + program_id = "cascade_test_1" + + # Run cascade evaluation + result = await self.evaluator._cascade_evaluate(f"/tmp/test_program.py") + + # Should be an EvaluationResult with artifacts + if isinstance(result, EvaluationResult): + return result + else: + # If it returns a dict, wrap it + return EvaluationResult.from_dict(result) + + # Mock the actual file operations since we're testing the cascade logic + with patch("openevolve.evaluator.run_in_executor") as mock_executor: + # Mock stage1 to return an error with artifacts + mock_executor.return_value = EvaluationResult( + metrics={"stage1_passed": 0.0}, artifacts={"stderr": "Stage 1 compilation error"} + ) + + result = asyncio.run(run_test()) + + # Should have failure metrics and artifacts + self.assertEqual(result.metrics.get("stage1_passed"), 0.0) + self.assertIn("stderr", result.artifacts) + + def test_artifacts_disabled_integration(self): + """Test that the full system works with artifacts disabled""" + + with patch.dict(os.environ, {"ENABLE_ARTIFACTS": "false"}): + + async def run_test(): + # Program with error + bad_code = "invalid syntax" + program_id = "disabled_test_1" + + # Evaluate + metrics = await self.evaluator.evaluate_program(bad_code, program_id) + + # Should not have pending artifacts when disabled + artifacts = self.evaluator.get_pending_artifacts(program_id) + return metrics, artifacts + + metrics, artifacts = asyncio.run(run_test()) + + # Should still get metrics but no artifacts + self.assertIsInstance(metrics, dict) + self.assertIsNone(artifacts) + + def test_successful_evaluation_with_artifacts(self): + """Test that successful evaluations can also have artifacts""" + + async def run_test(): + # Valid Python code + good_code = "print('Hello, world!')" + program_id = "success_test_1" + + # Evaluate + metrics = await self.evaluator.evaluate_program(good_code, program_id) + artifacts = self.evaluator.get_pending_artifacts(program_id) + + return metrics, artifacts + + metrics, artifacts = asyncio.run(run_test()) + + # Should get successful metrics + self.assertEqual(metrics.get("compile_ok"), 1.0) + + # Should have artifacts from successful compilation + if artifacts: + self.assertIn("stdout", artifacts) + self.assertIn("successful", artifacts["stdout"].lower()) + + +class TestArtifactsPersistence(unittest.TestCase): + """Test that artifacts persist correctly across save/load cycles""" + + def setUp(self): + # Set up event loop for async operations in tests + try: + self.loop = asyncio.get_event_loop() + except RuntimeError: + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + self.temp_dir = tempfile.mkdtemp() + config = DatabaseConfig(db_path=self.temp_dir) + self.database = ProgramDatabase(config) + + def tearDown(self): + # Clean up event loop if we created one + if hasattr(self, "loop") and self.loop and not self.loop.is_closed(): + # Cancel any pending tasks + pending = asyncio.all_tasks(self.loop) + for task in pending: + task.cancel() + # Run the loop briefly to let cancellations process + if pending: + self.loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) + + def test_save_load_artifacts(self): + """Test that artifacts survive database save/load cycle""" + # Create program with artifacts + program = Program(id="persist_test_1", code="print('test')", metrics={"score": 0.8}) + + artifacts = { + "stderr": "error message", + "stdout": "output message", + "large_log": "x" * (50 * 1024), # Large artifact + } + + # Add program and artifacts + self.database.add(program) + self.database.store_artifacts(program.id, artifacts) + + # Save database + self.database.save() + + # Create new database instance and load + new_database = ProgramDatabase(DatabaseConfig(db_path=self.temp_dir)) + new_database.load(self.temp_dir) + + # Check that artifacts are preserved + loaded_artifacts = new_database.get_artifacts(program.id) + + self.assertEqual(loaded_artifacts["stderr"], artifacts["stderr"]) + self.assertEqual(loaded_artifacts["stdout"], artifacts["stdout"]) + self.assertEqual(loaded_artifacts["large_log"], artifacts["large_log"]) + + +if __name__ == "__main__": + unittest.main()