Skip to content

Commit

Permalink
ImmutableDict union (#18)
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Dec 8, 2023
1 parent 283e530 commit be77d5a
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 2 deletions.
1 change: 0 additions & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ repos:
hooks:
- id: mypy
files: src
args: ["--enable-incomplete-feature=Unpack"]
additional_dependencies:
- pytest

Expand Down
12 changes: 11 additions & 1 deletion src/galdynamix/utils/_collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,12 @@
__all__ = ["ImmutableDict"]

from collections.abc import ItemsView, Iterable, Iterator, KeysView, Mapping, ValuesView
from typing import TypeVar
from typing import Any, TypeVar

from jax.tree_util import register_pytree_node_class

V = TypeVar("V")
T = TypeVar("T")


@register_pytree_node_class
Expand Down Expand Up @@ -64,6 +65,15 @@ def items(self) -> ItemsView[str, V]:
def __repr__(self) -> str:
return f"{self.__class__.__name__}({self._data!r})"

def __or__(self, value: Any, /) -> "ImmutableDict[V]":
if not isinstance(value, Mapping):
return NotImplemented

return type(self)(self._data | dict(value))

def __ror__(self, value: Any) -> Any:
return value | self._data

# === PyTree ===

def tree_flatten(self) -> tuple[tuple[V, ...], tuple[str, ...]]:
Expand Down
17 changes: 17 additions & 0 deletions tests/utils/test_collections.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
from collections import OrderedDict
from types import MappingProxyType

import pytest

from galdynamix.utils import ImmutableDict
Expand Down Expand Up @@ -69,6 +72,20 @@ def test_repr(self):
d = ImmutableDict(a=1, b=2)
assert repr(d) == "ImmutableDict({'a': 1, 'b': 2})"

def test_or(self):
"""Test `__or__`."""
d = ImmutableDict(a=1, b=2)
assert d | ImmutableDict(c=3) == ImmutableDict(a=1, b=2, c=3)
assert d | {"c": 3} == ImmutableDict(a=1, b=2, c=3)
assert d | OrderedDict([("c", 3)]) == ImmutableDict(a=1, b=2, c=3)
assert d | MappingProxyType({"c": 3}) == ImmutableDict(a=1, b=2, c=3)

# Reverse order
assert {"c": 3} | d == {"c": 3, "a": 1, "b": 2}
assert OrderedDict([("c", 3)]) | d == OrderedDict(
[("c", 3), ("a", 1), ("b", 2)]
)

# === Test pytree methods ===

def test_tree_flatten(self):
Expand Down

0 comments on commit be77d5a

Please sign in to comment.