Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apply func non roundtrippable seq #250

Closed
Closed
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/lightning_utilities/core/apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,14 @@ def is_dataclass_instance(obj: object) -> bool:
return dataclasses.is_dataclass(obj) and not isinstance(obj, type)


def can_roundtrip_sequence(obj: Sequence) -> bool:
"""Check if sequence can be roundtripped."""
try:
return obj == type(obj)(list(obj))
except (TypeError, ValueError):
return False


def apply_to_collection(
data: Any,
dtype: Union[type, Any, Tuple[Union[type, Any]]],
Expand Down Expand Up @@ -118,7 +126,7 @@ def _apply_to_collection_slow(
return elem_type(OrderedDict(out))

is_namedtuple_ = is_namedtuple(data)
is_sequence = isinstance(data, Sequence) and not isinstance(data, str)
is_sequence = isinstance(data, Sequence) and not isinstance(data, str) and can_roundtrip_sequence(data)
if is_namedtuple_ or is_sequence:
out = []
for d in data:
Expand Down
10 changes: 10 additions & 0 deletions tests/unittests/core/test_apply_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -359,3 +359,13 @@ class Foo:
foo = Foo(0)
result = apply_to_collection(foo, int, lambda x: x + 1, allow_frozen=True)
assert foo == result


def test_apply_to_collection_non_roundtrippable_sequence():
class NonRoundtrippableSequence(list):
def __init__(self, x: int):
super().__init__(range(int(x)))

val = NonRoundtrippableSequence(3)
result = apply_to_collection(val, int, lambda x: x + 1)
assert val == result
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't you expect result == [1, 2, 3]? What's a real example where this should become a no-op?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This came up when using: https://github.com/Bayer-Group/pado/blob/635d7b8b57e527254d6302730100a6dab5a2095f/pado/images/ids.py#L126-L351

Where ImageId instances are tuple subclasses but they don't roundtrip, i.e.:

iid = ImageId("a", "b", "c", site="site-1")

ImageId(list(iid)) <- is not allowed

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @ap-- ^^

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry, it seems my reply was in pending state for a month and I had to still click on submit review 🤷

Loading