diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 67d27aef..9057b346 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -55,7 +55,6 @@ repos: hooks: - id: mypy files: src - args: ["--enable-incomplete-feature=Unpack"] additional_dependencies: - pytest diff --git a/src/galdynamix/utils/_collections.py b/src/galdynamix/utils/_collections.py index a6c88cf0..2925abb4 100644 --- a/src/galdynamix/utils/_collections.py +++ b/src/galdynamix/utils/_collections.py @@ -4,11 +4,12 @@ __all__ = ["ImmutableDict"] from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, ValuesView -from typing import TypeVar +from typing import Any, TypeVar from jax.tree_util import register_pytree_node_class V = TypeVar("V") +T = TypeVar("T") @register_pytree_node_class @@ -64,6 +65,15 @@ def items(self) -> ItemsView[str, V]: def __repr__(self) -> str: return f"{self.__class__.__name__}({self._data!r})" + def __or__(self, value: Any, /) -> "ImmutableDict[V]": + if not isinstance(value, Mapping): + return NotImplemented + + return type(self)(self._data | dict(value)) + + def __ror__(self, value: Any) -> Any: + return value | self._data + # === PyTree === def tree_flatten(self) -> tuple[tuple[V, ...], tuple[str, ...]]: diff --git a/tests/utils/test_collections.py b/tests/utils/test_collections.py index 36a0ec3d..f2571767 100644 --- a/tests/utils/test_collections.py +++ b/tests/utils/test_collections.py @@ -1,3 +1,6 @@ +from collections import OrderedDict +from types import MappingProxyType + import pytest from galdynamix.utils import ImmutableDict @@ -69,6 +72,20 @@ def test_repr(self): d = ImmutableDict(a=1, b=2) assert repr(d) == "ImmutableDict({'a': 1, 'b': 2})" + def test_or(self): + """Test `__or__`.""" + d = ImmutableDict(a=1, b=2) + assert d | ImmutableDict(c=3) == ImmutableDict(a=1, b=2, c=3) + assert d | {"c": 3} == ImmutableDict(a=1, b=2, c=3) + assert d | OrderedDict([("c", 3)]) == ImmutableDict(a=1, b=2, c=3) + assert d | MappingProxyType({"c": 3}) == ImmutableDict(a=1, b=2, c=3) + + # Reverse order + assert {"c": 3} | d == {"c": 3, "a": 1, "b": 2} + assert OrderedDict([("c", 3)]) | d == OrderedDict( + [("c", 3), ("a", 1), ("b", 2)] + ) + # === Test pytree methods === def test_tree_flatten(self):