From 77bdf046ba6b2d93e278ee7bab6a77575e27e870 Mon Sep 17 00:00:00 2001 From: arcticfly Date: Mon, 10 Nov 2025 18:30:05 -0800 Subject: [PATCH 1/4] Make copy.copy work for trajectories --- src/art/trajectories.py | 33 +++++++++++++++++++++++++++++++++ 1 file changed, 33 insertions(+) diff --git a/src/art/trajectories.py b/src/art/trajectories.py index 00c53865..7a8a5b77 100644 --- a/src/art/trajectories.py +++ b/src/art/trajectories.py @@ -173,6 +173,39 @@ def __init__( ], ) + def __copy__(self): + """Support for copy.copy()""" + import copy + + # Create a new instance using the constructor + # Pass shallow copies of the lists to avoid shared mutation + new_instance = self.__class__( + trajectories=self.trajectories[:], # Shallow copy of list + exceptions=[], # Will be set below + ) + # Manually copy exceptions since they're PydanticException objects + new_instance.exceptions = self.exceptions[:] + return new_instance + + def __deepcopy__(self, memo): + """Support for copy.deepcopy()""" + import copy + + # Check memo to handle circular references + if id(self) in memo: + return memo[id(self)] + + # Create a new instance with deep copies + new_instance = self.__class__( + trajectories=copy.deepcopy(self.trajectories, memo), + exceptions=[], # Will be set below + ) + # Register in memo before deep copying attributes to handle circular refs + memo[id(self)] = new_instance + # Deep copy exceptions + new_instance.exceptions = copy.deepcopy(self.exceptions, memo) + return new_instance + def __iter__(self) -> Iterator[Trajectory]: # type: ignore[override] return iter(self.trajectories) From 8529eb56ed98927aa659ebf76760a0de1da3470b Mon Sep 17 00:00:00 2001 From: arcticfly Date: Tue, 11 Nov 2025 12:40:22 -0800 Subject: [PATCH 2/4] Add space --- src/art/trajectories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/art/trajectories.py b/src/art/trajectories.py index 7a8a5b77..7c8c808a 100644 --- a/src/art/trajectories.py +++ b/src/art/trajectories.py @@ -183,7 +183,7 @@ def __copy__(self): trajectories=self.trajectories[:], # Shallow copy of list exceptions=[], # Will be set below ) - # Manually copy exceptions since they're PydanticException objects + # Manually copy exceptions since they're PydanticException objects new_instance.exceptions = self.exceptions[:] return new_instance From 5ca7815f71d94890a7f22d4d9031ccb36a8aaed1 Mon Sep 17 00:00:00 2001 From: arcticfly Date: Tue, 11 Nov 2025 12:40:32 -0800 Subject: [PATCH 3/4] Remove space --- src/art/trajectories.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/art/trajectories.py b/src/art/trajectories.py index 7c8c808a..7a8a5b77 100644 --- a/src/art/trajectories.py +++ b/src/art/trajectories.py @@ -183,7 +183,7 @@ def __copy__(self): trajectories=self.trajectories[:], # Shallow copy of list exceptions=[], # Will be set below ) - # Manually copy exceptions since they're PydanticException objects + # Manually copy exceptions since they're PydanticException objects new_instance.exceptions = self.exceptions[:] return new_instance From f73986681a41673cd15f58fed9005a47c4b96ecd Mon Sep 17 00:00:00 2001 From: arcticfly Date: Tue, 11 Nov 2025 21:14:48 +0000 Subject: [PATCH 4/4] fix types --- src/art/trajectories.py | 6 +- tests/unit/test_trajectory_copy.py | 193 +++++++++++++++++++++++++++++ 2 files changed, 198 insertions(+), 1 deletion(-) create mode 100644 tests/unit/test_trajectory_copy.py diff --git a/src/art/trajectories.py b/src/art/trajectories.py index 7a8a5b77..403419ea 100644 --- a/src/art/trajectories.py +++ b/src/art/trajectories.py @@ -187,10 +187,14 @@ def __copy__(self): new_instance.exceptions = self.exceptions[:] return new_instance - def __deepcopy__(self, memo): + def __deepcopy__(self, memo: dict[int, Any] | None = None): """Support for copy.deepcopy()""" import copy + # Initialize memo if not provided + if memo is None: + memo = {} + # Check memo to handle circular references if id(self) in memo: return memo[id(self)] diff --git a/tests/unit/test_trajectory_copy.py b/tests/unit/test_trajectory_copy.py new file mode 100644 index 00000000..92f5b97a --- /dev/null +++ b/tests/unit/test_trajectory_copy.py @@ -0,0 +1,193 @@ +"""Tests for TrajectoryGroup copy and deepcopy functionality.""" + +import copy + +import pytest +from openai.types.chat import ChatCompletionMessage +from openai.types.chat.chat_completion import Choice + +from art.trajectories import PydanticException, Trajectory, TrajectoryGroup + + +@pytest.fixture +def sample_trajectory(): + """Create a sample trajectory for testing.""" + return Trajectory( + messages_and_choices=[ + {"role": "user", "content": "Hello"}, + Choice( + finish_reason="stop", + index=0, + logprobs=None, + message=ChatCompletionMessage( + role="assistant", + content="Hi there!", + refusal=None, + ), + ), + ], + tools=None, + reward=1.0, + metrics={"accuracy": 0.95}, + metadata={"test": "value"}, + ) + + +@pytest.fixture +def sample_trajectory_group(sample_trajectory): + """Create a sample trajectory group for testing.""" + trajectory2 = Trajectory( + messages_and_choices=[ + {"role": "user", "content": "How are you?"}, + Choice( + finish_reason="stop", + index=0, + logprobs=None, + message=ChatCompletionMessage( + role="assistant", + content="I'm doing well!", + refusal=None, + ), + ), + ], + tools=None, + reward=0.8, + ) + return TrajectoryGroup( + trajectories=[sample_trajectory, trajectory2], + exceptions=[], + ) + + +def test_shallow_copy(sample_trajectory_group): + """Test that shallow copy works correctly.""" + copied = copy.copy(sample_trajectory_group) + + # Should be a different object + assert copied is not sample_trajectory_group + + # Trajectories should be a new list (shallow copy of list) + assert copied.trajectories is not sample_trajectory_group.trajectories + + # But the trajectory objects themselves should be the same (shallow copy) + assert copied.trajectories[0] is sample_trajectory_group.trajectories[0] + assert copied.trajectories[1] is sample_trajectory_group.trajectories[1] + + # Exceptions should be a new list with same contents + assert copied.exceptions is not sample_trajectory_group.exceptions + assert copied.exceptions == sample_trajectory_group.exceptions + + +def test_deep_copy(sample_trajectory_group): + """Test that deep copy works correctly.""" + copied = copy.deepcopy(sample_trajectory_group) + + # Should be a different object + assert copied is not sample_trajectory_group + + # Should have different trajectories list (deep copy) + assert copied.trajectories is not sample_trajectory_group.trajectories + + # Trajectories themselves should be different objects + assert copied.trajectories[0] is not sample_trajectory_group.trajectories[0] + assert copied.trajectories[1] is not sample_trajectory_group.trajectories[1] + + # But should have same content + assert len(copied.trajectories) == len(sample_trajectory_group.trajectories) + assert ( + copied.trajectories[0].reward == sample_trajectory_group.trajectories[0].reward + ) + assert ( + copied.trajectories[1].reward == sample_trajectory_group.trajectories[1].reward + ) + + # Exceptions should also be deep copied + assert copied.exceptions is not sample_trajectory_group.exceptions + + +def test_deep_copy_with_exceptions(): + """Test that deep copy works with exceptions.""" + group = TrajectoryGroup( + trajectories=[ + Trajectory( + messages_and_choices=[{"role": "user", "content": "test"}], + tools=None, + reward=1.0, + ) + ], + exceptions=[ValueError("test error")], + ) + + copied = copy.deepcopy(group) + + # Should be different objects + assert copied is not group + assert copied.exceptions is not group.exceptions + + # Should have same exception content + assert len(copied.exceptions) == len(group.exceptions) + assert copied.exceptions[0].message == group.exceptions[0].message + + +def test_deep_copy_circular_reference(): + """Test that deep copy handles circular references correctly.""" + group = TrajectoryGroup( + trajectories=[ + Trajectory( + messages_and_choices=[{"role": "user", "content": "test"}], + tools=None, + reward=1.0, + ) + ], + exceptions=[], + ) + + # Create a memo dict with a circular reference + memo = {} + copied = copy.deepcopy(group, memo) + + # Should be in memo + assert id(group) in memo + assert memo[id(group)] is copied + + # Copying again with same memo should return the same object + copied2 = copy.deepcopy(group, memo) + assert copied2 is copied + + +def test_deep_copy_preserves_metadata(sample_trajectory_group): + """Test that deep copy preserves trajectory metadata.""" + copied = copy.deepcopy(sample_trajectory_group) + + # Check that metadata is preserved + assert ( + copied.trajectories[0].metrics + == sample_trajectory_group.trajectories[0].metrics + ) + assert ( + copied.trajectories[0].metadata + == sample_trajectory_group.trajectories[0].metadata + ) + + # But should be different dict objects + assert ( + copied.trajectories[0].metrics + is not sample_trajectory_group.trajectories[0].metrics + ) + assert ( + copied.trajectories[0].metadata + is not sample_trajectory_group.trajectories[0].metadata + ) + + +def test_copy_empty_group(): + """Test copying an empty trajectory group.""" + empty_group = TrajectoryGroup(trajectories=[], exceptions=[]) + + shallow = copy.copy(empty_group) + assert shallow is not empty_group + assert len(shallow.trajectories) == 0 + + deep = copy.deepcopy(empty_group) + assert deep is not empty_group + assert len(deep.trajectories) == 0