diff --git a/openevolve/utils/code_utils.py b/openevolve/utils/code_utils.py index 397465fb6..2d97ef4ea 100644 --- a/openevolve/utils/code_utils.py +++ b/openevolve/utils/code_utils.py @@ -53,8 +53,7 @@ def apply_diff(original_code: str, diff_text: str) -> str: result_lines = original_lines.copy() # Extract diff blocks - diff_pattern = r"<<<<<<< SEARCH\n(.*?)\n=======\n(.*?)\n>>>>>>> REPLACE" - diff_blocks = re.findall(diff_pattern, diff_text, re.DOTALL) + diff_blocks = extract_diffs(diff_text) # Apply each diff block for search_text, replace_text in diff_blocks: @@ -81,9 +80,9 @@ def extract_diffs(diff_text: str) -> List[Tuple[str, str]]: Returns: List of tuples (search_text, replace_text) """ - diff_pattern = r"<<<<<<< SEARCH\n(.*?)\n=======\n(.*?)\n>>>>>>> REPLACE" + diff_pattern = r"<<<<<<< SEARCH\n(.*?)=======\n(.*?)>>>>>>> REPLACE" diff_blocks = re.findall(diff_pattern, diff_text, re.DOTALL) - return diff_blocks + return [(match[0].rstrip(), match[1].rstrip()) for match in diff_blocks] def parse_full_rewrite(llm_response: str, language: str = "python") -> Optional[str]: diff --git a/tests/test_basic.py b/tests/test_basic.py index 7a746e337..f8a8710b6 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -23,7 +23,7 @@ def test_extract_diffs(self): """Test extracting diffs from a response""" diff_text = """ Let's improve this code: - + <<<<<<< SEARCH def hello(): print("Hello") @@ -31,9 +31,9 @@ def hello(): def hello(): print("Hello, World!") >>>>>>> REPLACE - + Another change: - + <<<<<<< SEARCH x = 1 ======= @@ -43,17 +43,25 @@ def hello(): diffs = extract_diffs(diff_text) self.assertEqual(len(diffs), 2) - self.assertEqual(diffs[0][0].strip(), 'def hello():\n print("Hello")') - self.assertEqual(diffs[0][1].strip(), 'def hello():\n print("Hello, World!")') - self.assertEqual(diffs[1][0].strip(), "x = 1") - self.assertEqual(diffs[1][1].strip(), "x = 2") + self.assertEqual( + diffs[0][0], + """ def hello(): + print("Hello")""", + ) + self.assertEqual( + diffs[0][1], + """ def hello(): + print("Hello, World!")""", + ) + self.assertEqual(diffs[1][0], " x = 1") + self.assertEqual(diffs[1][1], " x = 2") def test_apply_diff(self): """Test applying diffs to code""" original_code = """ def hello(): print("Hello") - + x = 1 y = 2 """ @@ -66,7 +74,7 @@ def hello(): def hello(): print("Hello, World!") >>>>>>> REPLACE - + <<<<<<< SEARCH x = 1 ======= @@ -77,7 +85,7 @@ def hello(): expected_code = """ def hello(): print("Hello, World!") - + x = 2 y = 2 """ @@ -86,8 +94,8 @@ def hello(): # Normalize whitespace for comparison self.assertEqual( - result.replace(" ", "").replace("\n", ""), - expected_code.replace(" ", "").replace("\n", ""), + result, + expected_code, )