diff --git a/openevolve/controller.py b/openevolve/controller.py index 00fd311c6..dc03d22d0 100644 --- a/openevolve/controller.py +++ b/openevolve/controller.py @@ -157,6 +157,7 @@ def __init__( self.llm_evaluator_ensemble, self.evaluator_prompt_sampler, database=self.database, + suffix=Path(self.initial_program_path).suffix, ) self.evaluation_file = evaluation_file diff --git a/openevolve/evaluator.py b/openevolve/evaluator.py index 80bcac333..bba5bdfad 100644 --- a/openevolve/evaluator.py +++ b/openevolve/evaluator.py @@ -44,9 +44,11 @@ def __init__( llm_ensemble: Optional[LLMEnsemble] = None, prompt_sampler: Optional[PromptSampler] = None, database: Optional[ProgramDatabase] = None, + suffix: Optional[str]=".py", ): self.config = config self.evaluation_file = evaluation_file + self.program_suffix = suffix self.llm_ensemble = llm_ensemble self.prompt_sampler = prompt_sampler self.database = database @@ -152,7 +154,7 @@ async def evaluate_program( last_exception = None for attempt in range(self.config.max_retries + 1): # Create a temporary file for the program - with tempfile.NamedTemporaryFile(suffix=".py", delete=False) as temp_file: + with tempfile.NamedTemporaryFile(suffix=self.program_suffix, delete=False) as temp_file: temp_file.write(program_code.encode("utf-8")) temp_file_path = temp_file.name