Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions src/art/trajectories.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,43 @@ def __init__(
],
)

def __copy__(self):
"""Support for copy.copy()"""
import copy

# Create a new instance using the constructor
# Pass shallow copies of the lists to avoid shared mutation
new_instance = self.__class__(
trajectories=self.trajectories[:], # Shallow copy of list
exceptions=[], # Will be set below
)
# Manually copy exceptions since they're PydanticException objects
new_instance.exceptions = self.exceptions[:]
return new_instance

def __deepcopy__(self, memo: dict[int, Any] | None = None):
"""Support for copy.deepcopy()"""
import copy

# Initialize memo if not provided
if memo is None:
memo = {}

# Check memo to handle circular references
if id(self) in memo:
return memo[id(self)]

# Create a new instance with deep copies
new_instance = self.__class__(
trajectories=copy.deepcopy(self.trajectories, memo),
exceptions=[], # Will be set below
)
# Register in memo before deep copying attributes to handle circular refs
memo[id(self)] = new_instance
# Deep copy exceptions
new_instance.exceptions = copy.deepcopy(self.exceptions, memo)
return new_instance

def __iter__(self) -> Iterator[Trajectory]: # type: ignore[override]
return iter(self.trajectories)

Expand Down
193 changes: 193 additions & 0 deletions tests/unit/test_trajectory_copy.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,193 @@
"""Tests for TrajectoryGroup copy and deepcopy functionality."""

import copy

import pytest
from openai.types.chat import ChatCompletionMessage
from openai.types.chat.chat_completion import Choice

from art.trajectories import PydanticException, Trajectory, TrajectoryGroup


@pytest.fixture
def sample_trajectory():
"""Create a sample trajectory for testing."""
return Trajectory(
messages_and_choices=[
{"role": "user", "content": "Hello"},
Choice(
finish_reason="stop",
index=0,
logprobs=None,
message=ChatCompletionMessage(
role="assistant",
content="Hi there!",
refusal=None,
),
),
],
tools=None,
reward=1.0,
metrics={"accuracy": 0.95},
metadata={"test": "value"},
)


@pytest.fixture
def sample_trajectory_group(sample_trajectory):
"""Create a sample trajectory group for testing."""
trajectory2 = Trajectory(
messages_and_choices=[
{"role": "user", "content": "How are you?"},
Choice(
finish_reason="stop",
index=0,
logprobs=None,
message=ChatCompletionMessage(
role="assistant",
content="I'm doing well!",
refusal=None,
),
),
],
tools=None,
reward=0.8,
)
return TrajectoryGroup(
trajectories=[sample_trajectory, trajectory2],
exceptions=[],
)


def test_shallow_copy(sample_trajectory_group):
"""Test that shallow copy works correctly."""
copied = copy.copy(sample_trajectory_group)

# Should be a different object
assert copied is not sample_trajectory_group

# Trajectories should be a new list (shallow copy of list)
assert copied.trajectories is not sample_trajectory_group.trajectories

# But the trajectory objects themselves should be the same (shallow copy)
assert copied.trajectories[0] is sample_trajectory_group.trajectories[0]
assert copied.trajectories[1] is sample_trajectory_group.trajectories[1]

# Exceptions should be a new list with same contents
assert copied.exceptions is not sample_trajectory_group.exceptions
assert copied.exceptions == sample_trajectory_group.exceptions


def test_deep_copy(sample_trajectory_group):
"""Test that deep copy works correctly."""
copied = copy.deepcopy(sample_trajectory_group)

# Should be a different object
assert copied is not sample_trajectory_group

# Should have different trajectories list (deep copy)
assert copied.trajectories is not sample_trajectory_group.trajectories

# Trajectories themselves should be different objects
assert copied.trajectories[0] is not sample_trajectory_group.trajectories[0]
assert copied.trajectories[1] is not sample_trajectory_group.trajectories[1]

# But should have same content
assert len(copied.trajectories) == len(sample_trajectory_group.trajectories)
assert (
copied.trajectories[0].reward == sample_trajectory_group.trajectories[0].reward
)
assert (
copied.trajectories[1].reward == sample_trajectory_group.trajectories[1].reward
)

# Exceptions should also be deep copied
assert copied.exceptions is not sample_trajectory_group.exceptions


def test_deep_copy_with_exceptions():
"""Test that deep copy works with exceptions."""
group = TrajectoryGroup(
trajectories=[
Trajectory(
messages_and_choices=[{"role": "user", "content": "test"}],
tools=None,
reward=1.0,
)
],
exceptions=[ValueError("test error")],
)

copied = copy.deepcopy(group)

# Should be different objects
assert copied is not group
assert copied.exceptions is not group.exceptions

# Should have same exception content
assert len(copied.exceptions) == len(group.exceptions)
assert copied.exceptions[0].message == group.exceptions[0].message


def test_deep_copy_circular_reference():
"""Test that deep copy handles circular references correctly."""
group = TrajectoryGroup(
trajectories=[
Trajectory(
messages_and_choices=[{"role": "user", "content": "test"}],
tools=None,
reward=1.0,
)
],
exceptions=[],
)

# Create a memo dict with a circular reference
memo = {}
copied = copy.deepcopy(group, memo)

# Should be in memo
assert id(group) in memo
assert memo[id(group)] is copied

# Copying again with same memo should return the same object
copied2 = copy.deepcopy(group, memo)
assert copied2 is copied


def test_deep_copy_preserves_metadata(sample_trajectory_group):
"""Test that deep copy preserves trajectory metadata."""
copied = copy.deepcopy(sample_trajectory_group)

# Check that metadata is preserved
assert (
copied.trajectories[0].metrics
== sample_trajectory_group.trajectories[0].metrics
)
assert (
copied.trajectories[0].metadata
== sample_trajectory_group.trajectories[0].metadata
)

# But should be different dict objects
assert (
copied.trajectories[0].metrics
is not sample_trajectory_group.trajectories[0].metrics
)
assert (
copied.trajectories[0].metadata
is not sample_trajectory_group.trajectories[0].metadata
)


def test_copy_empty_group():
"""Test copying an empty trajectory group."""
empty_group = TrajectoryGroup(trajectories=[], exceptions=[])

shallow = copy.copy(empty_group)
assert shallow is not empty_group
assert len(shallow.trajectories) == 0

deep = copy.deepcopy(empty_group)
assert deep is not empty_group
assert len(deep.trajectories) == 0