Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 0 additions & 219 deletions tests/test_basic.py

This file was deleted.

93 changes: 93 additions & 0 deletions tests/test_code_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""
Tests for code utilities in openevolve.utils.code_utils
"""

import unittest
from openevolve.utils.code_utils import apply_diff, extract_diffs


class TestCodeUtils(unittest.TestCase):
"""Tests for code utilities"""

def test_extract_diffs(self):
"""Test extracting diffs from a response"""
diff_text = """
Let's improve this code:

<<<<<<< SEARCH
def hello():
print("Hello")
=======
def hello():
print("Hello, World!")
>>>>>>> REPLACE

Another change:

<<<<<<< SEARCH
x = 1
=======
x = 2
>>>>>>> REPLACE
"""

diffs = extract_diffs(diff_text)
self.assertEqual(len(diffs), 2)
self.assertEqual(
diffs[0][0],
""" def hello():
print(\"Hello\")""",
)
self.assertEqual(
diffs[0][1],
""" def hello():
print(\"Hello, World!\")""",
)
self.assertEqual(diffs[1][0], " x = 1")
self.assertEqual(diffs[1][1], " x = 2")

def test_apply_diff(self):
"""Test applying diffs to code"""
original_code = """
def hello():
print("Hello")

x = 1
y = 2
"""

diff_text = """
<<<<<<< SEARCH
def hello():
print("Hello")
=======
def hello():
print("Hello, World!")
>>>>>>> REPLACE

<<<<<<< SEARCH
x = 1
=======
x = 2
>>>>>>> REPLACE
"""

expected_code = """
def hello():
print("Hello, World!")

x = 2
y = 2
"""

result = apply_diff(original_code, diff_text)

# Normalize whitespace for comparison
self.assertEqual(
result,
expected_code,
)


if __name__ == "__main__":
unittest.main()
85 changes: 85 additions & 0 deletions tests/test_database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,85 @@
"""
Tests for ProgramDatabase in openevolve.database
"""

import unittest
from openevolve.config import Config
from openevolve.database import Program, ProgramDatabase


class TestProgramDatabase(unittest.TestCase):
"""Tests for program database"""

def setUp(self):
"""Set up test database"""
config = Config()
config.database.in_memory = True
self.db = ProgramDatabase(config.database)

def test_add_and_get(self):
"""Test adding and retrieving a program"""
program = Program(
id="test1",
code="def test(): pass",
language="python",
metrics={"score": 0.5},
)

self.db.add(program)

retrieved = self.db.get("test1")
self.assertIsNotNone(retrieved)
self.assertEqual(retrieved.id, "test1")
self.assertEqual(retrieved.code, "def test(): pass")
self.assertEqual(retrieved.metrics["score"], 0.5)

def test_get_best_program(self):
"""Test getting the best program"""
program1 = Program(
id="test1",
code="def test1(): pass",
language="python",
metrics={"score": 0.5},
)

program2 = Program(
id="test2",
code="def test2(): pass",
language="python",
metrics={"score": 0.7},
)

self.db.add(program1)
self.db.add(program2)

best = self.db.get_best_program()
self.assertIsNotNone(best)
self.assertEqual(best.id, "test2")

def test_sample(self):
"""Test sampling from the database"""
program1 = Program(
id="test1",
code="def test1(): pass",
language="python",
metrics={"score": 0.5},
)

program2 = Program(
id="test2",
code="def test2(): pass",
language="python",
metrics={"score": 0.7},
)

self.db.add(program1)
self.db.add(program2)

parent, inspirations = self.db.sample()

self.assertIsNotNone(parent)
self.assertIn(parent.id, ["test1", "test2"])


if __name__ == "__main__":
unittest.main()
Loading