From 8b82a314e18d0bda1008ec87ae846063990fe8e1 Mon Sep 17 00:00:00 2001 From: jvm Date: Sun, 18 May 2025 15:57:45 +0200 Subject: [PATCH] More resilient regex in utils.code_utils.extract_diffs and removed redundant implementation in apply_diff() function --- openevolve/utils/code_utils.py | 75 ++++++++++++++++--------------- tests/test_basic.py | 80 +++++++++++++++++++--------------- 2 files changed, 81 insertions(+), 74 deletions(-) diff --git a/openevolve/utils/code_utils.py b/openevolve/utils/code_utils.py index 44c8e922a..45d70c498 100644 --- a/openevolve/utils/code_utils.py +++ b/openevolve/utils/code_utils.py @@ -8,20 +8,20 @@ def parse_evolve_blocks(code: str) -> List[Tuple[int, int, str]]: """ Parse evolve blocks from code - + Args: code: Source code with evolve blocks - + Returns: List of tuples (start_line, end_line, block_content) """ lines = code.split("\n") blocks = [] - + in_block = False start_line = -1 block_content = [] - + for i, line in enumerate(lines): if "# EVOLVE-BLOCK-START" in line: in_block = True @@ -32,102 +32,101 @@ def parse_evolve_blocks(code: str) -> List[Tuple[int, int, str]]: blocks.append((start_line, i, "\n".join(block_content))) elif in_block: block_content.append(line) - + return blocks def apply_diff(original_code: str, diff_text: str) -> str: """ Apply a diff to the original code - + Args: original_code: Original source code diff_text: Diff in the SEARCH/REPLACE format - + Returns: Modified code """ # Split into lines for easier processing original_lines = original_code.split("\n") result_lines = original_lines.copy() - + # Extract diff blocks - diff_pattern = r"<<<<<<< SEARCH\n(.*?)\n=======\n(.*?)\n>>>>>>> REPLACE" - diff_blocks = re.findall(diff_pattern, diff_text, re.DOTALL) - + diff_blocks = extract_diffs(diff_text) + # Apply each diff block for search_text, replace_text in diff_blocks: search_lines = search_text.split("\n") replace_lines = replace_text.split("\n") - + # Find where the search pattern starts in the original code for i in range(len(result_lines) - len(search_lines) + 1): if result_lines[i:i+len(search_lines)] == search_lines: # Replace the matched section result_lines[i:i+len(search_lines)] = replace_lines break - + return "\n".join(result_lines) def extract_diffs(diff_text: str) -> List[Tuple[str, str]]: """ Extract diff blocks from the diff text - + Args: diff_text: Diff in the SEARCH/REPLACE format - + Returns: List of tuples (search_text, replace_text) """ - diff_pattern = r"<<<<<<< SEARCH\n(.*?)\n=======\n(.*?)\n>>>>>>> REPLACE" + diff_pattern = r"<<<<<<< SEARCH\n(.*?)=======\n(.*?)>>>>>>> REPLACE" diff_blocks = re.findall(diff_pattern, diff_text, re.DOTALL) - return diff_blocks + return [(match[0].rstrip(), match[1].rstrip()) for match in diff_blocks] def parse_full_rewrite(llm_response: str, language: str = "python") -> Optional[str]: """ Extract a full rewrite from an LLM response - + Args: llm_response: Response from the LLM language: Programming language - + Returns: Extracted code or None if not found """ code_block_pattern = r"```" + language + r"\n(.*?)```" matches = re.findall(code_block_pattern, llm_response, re.DOTALL) - + if matches: return matches[0].strip() - + # Fallback to any code block code_block_pattern = r"```(.*?)```" matches = re.findall(code_block_pattern, llm_response, re.DOTALL) - + if matches: return matches[0].strip() - + return None def format_diff_summary(diff_blocks: List[Tuple[str, str]]) -> str: """ Create a human-readable summary of the diff - + Args: diff_blocks: List of (search_text, replace_text) tuples - + Returns: Summary string """ summary = [] - + for i, (search_text, replace_text) in enumerate(diff_blocks): search_lines = search_text.strip().split("\n") replace_lines = replace_text.strip().split("\n") - + # Create a short summary if len(search_lines) == 1 and len(replace_lines) == 1: summary.append(f"Change {i+1}: '{search_lines[0]}' to '{replace_lines[0]}'") @@ -135,34 +134,34 @@ def format_diff_summary(diff_blocks: List[Tuple[str, str]]) -> str: search_summary = f"{len(search_lines)} lines" if len(search_lines) > 1 else search_lines[0] replace_summary = f"{len(replace_lines)} lines" if len(replace_lines) > 1 else replace_lines[0] summary.append(f"Change {i+1}: Replace {search_summary} with {replace_summary}") - + return "\n".join(summary) def calculate_edit_distance(code1: str, code2: str) -> int: """ Calculate the Levenshtein edit distance between two code snippets - + Args: code1: First code snippet code2: Second code snippet - + Returns: Edit distance (number of operations needed to transform code1 into code2) """ if code1 == code2: return 0 - + # Simple implementation of Levenshtein distance m, n = len(code1), len(code2) dp = [[0 for _ in range(n + 1)] for _ in range(m + 1)] - + for i in range(m + 1): dp[i][0] = i - + for j in range(n + 1): dp[0][j] = j - + for i in range(1, m + 1): for j in range(1, n + 1): cost = 0 if code1[i-1] == code2[j-1] else 1 @@ -171,17 +170,17 @@ def calculate_edit_distance(code1: str, code2: str) -> int: dp[i][j-1] + 1, # insertion dp[i-1][j-1] + cost, # substitution ) - + return dp[m][n] def extract_code_language(code: str) -> str: """ Try to determine the language of a code snippet - + Args: code: Code snippet - + Returns: Detected language or "unknown" """ @@ -198,5 +197,5 @@ def extract_code_language(code: str) -> str: return "rust" elif re.search(r"^(SELECT|CREATE TABLE|INSERT INTO)", code, re.MULTILINE): return "sql" - + return "unknown" diff --git a/tests/test_basic.py b/tests/test_basic.py index 55c55168b..0a2ce9fa7 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -17,12 +17,12 @@ class TestCodeUtils(unittest.TestCase): """Tests for code utilities""" - + def test_extract_diffs(self): """Test extracting diffs from a response""" diff_text = """ Let's improve this code: - + <<<<<<< SEARCH def hello(): print("Hello") @@ -30,33 +30,41 @@ def hello(): def hello(): print("Hello, World!") >>>>>>> REPLACE - + Another change: - + <<<<<<< SEARCH x = 1 ======= x = 2 >>>>>>> REPLACE """ - + diffs = extract_diffs(diff_text) self.assertEqual(len(diffs), 2) - self.assertEqual(diffs[0][0].strip(), "def hello():\n print(\"Hello\")") - self.assertEqual(diffs[0][1].strip(), "def hello():\n print(\"Hello, World!\")") - self.assertEqual(diffs[1][0].strip(), "x = 1") - self.assertEqual(diffs[1][1].strip(), "x = 2") - + self.assertEqual( + diffs[0][0], + """ def hello(): + print("Hello")""", + ) + self.assertEqual( + diffs[0][1], + """ def hello(): + print("Hello, World!")""", + ) + self.assertEqual(diffs[1][0], " x = 1") + self.assertEqual(diffs[1][1], " x = 2") + def test_apply_diff(self): """Test applying diffs to code""" original_code = """ def hello(): print("Hello") - + x = 1 y = 2 """ - + diff_text = """ <<<<<<< SEARCH def hello(): @@ -65,40 +73,40 @@ def hello(): def hello(): print("Hello, World!") >>>>>>> REPLACE - + <<<<<<< SEARCH x = 1 ======= x = 2 >>>>>>> REPLACE """ - + expected_code = """ def hello(): print("Hello, World!") - + x = 2 y = 2 """ - + result = apply_diff(original_code, diff_text) - + # Normalize whitespace for comparison self.assertEqual( - result.replace(" ", "").replace("\n", ""), - expected_code.replace(" ", "").replace("\n", "") + result, + expected_code, ) class TestProgramDatabase(unittest.TestCase): """Tests for program database""" - + def setUp(self): """Set up test database""" config = Config() config.database.in_memory = True self.db = ProgramDatabase(config.database) - + def test_add_and_get(self): """Test adding and retrieving a program""" program = Program( @@ -107,15 +115,15 @@ def test_add_and_get(self): language="python", metrics={"score": 0.5}, ) - + self.db.add(program) - + retrieved = self.db.get("test1") self.assertIsNotNone(retrieved) self.assertEqual(retrieved.id, "test1") self.assertEqual(retrieved.code, "def test(): pass") self.assertEqual(retrieved.metrics["score"], 0.5) - + def test_get_best_program(self): """Test getting the best program""" program1 = Program( @@ -124,21 +132,21 @@ def test_get_best_program(self): language="python", metrics={"score": 0.5}, ) - + program2 = Program( id="test2", code="def test2(): pass", language="python", metrics={"score": 0.7}, ) - + self.db.add(program1) self.db.add(program2) - + best = self.db.get_best_program() self.assertIsNotNone(best) self.assertEqual(best.id, "test2") - + def test_sample(self): """Test sampling from the database""" program1 = Program( @@ -147,31 +155,31 @@ def test_sample(self): language="python", metrics={"score": 0.5}, ) - + program2 = Program( id="test2", code="def test2(): pass", language="python", metrics={"score": 0.7}, ) - + self.db.add(program1) self.db.add(program2) - + parent, inspirations = self.db.sample() - + self.assertIsNotNone(parent) self.assertIn(parent.id, ["test1", "test2"]) class TestPromptSampler(unittest.TestCase): """Tests for prompt sampler""" - + def setUp(self): """Set up test prompt sampler""" config = Config() self.prompt_sampler = PromptSampler(config.prompt) - + def test_build_prompt(self): """Test building a prompt""" current_program = "def test(): pass" @@ -191,7 +199,7 @@ def test_build_prompt(self): "metrics": {"score": 0.6}, } ] - + prompt = self.prompt_sampler.build_prompt( current_program=current_program, parent_program=parent_program, @@ -199,7 +207,7 @@ def test_build_prompt(self): previous_programs=previous_programs, top_programs=top_programs, ) - + self.assertIn("system", prompt) self.assertIn("user", prompt) self.assertIn("def test(): pass", prompt["user"])