From ace54c64350a6185f41e068ca19f622478b0df1d Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 21 Mar 2024 10:56:51 +0100 Subject: [PATCH 1/9] Make mutability check in JaxsimDataclass.mutable_context stricter --- src/jaxsim/utils/jaxsim_dataclass.py | 56 +++++++++++++++++++++++++++- 1 file changed, 54 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/utils/jaxsim_dataclass.py b/src/jaxsim/utils/jaxsim_dataclass.py index c719dfb3c..f4b8a8b01 100644 --- a/src/jaxsim/utils/jaxsim_dataclass.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -46,22 +46,74 @@ def mutable_context( original_mutability = self._mutability() - def restore_self(): - self._set_mutability(mutability=Mutability.MUTABLE) + original_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self) + original_shapes = JaxsimDataclass.get_leaf_shapes(tree=self) + original_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self) + original_structure = jax.tree_util.tree_structure(tree=self) + + def restore_self() -> None: + self._set_mutability(mutability=Mutability.MUTABLE_NO_VALIDATION) for f in dataclasses.fields(self_copy): setattr(self, f.name, getattr(self_copy, f.name)) try: self._set_mutability(mutability) yield self + + if mutability is not Mutability.MUTABLE_NO_VALIDATION: + new_structure = jax.tree_util.tree_structure(tree=self) + if original_structure != new_structure: + msg = "Pytree structure has changed from {} to {}" + raise ValueError(msg.format(original_structure, new_structure)) + + new_shapes = JaxsimDataclass.get_leaf_shapes(tree=self) + if original_shapes != new_shapes: + msg = "Leaves shapes have changed from {} to {}" + raise ValueError(msg.format(original_shapes, new_shapes)) + + new_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self) + if original_dtypes != new_dtypes: + msg = "Leaves dtypes have changed from {} to {}" + raise ValueError(msg.format(original_dtypes, new_dtypes)) + + new_weak_types = JaxsimDataclass.get_leaf_weak_types(tree=self) + if original_weak_types != new_weak_types: + msg = "Leaves weak types have changed from {} to {}" + raise ValueError(msg.format(original_weak_types, new_weak_types)) + except Exception as e: if restore_after_exception: restore_self() self._set_mutability(original_mutability) raise e + finally: self._set_mutability(original_mutability) + @staticmethod + def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...]]: + return tuple( # noqa + leaf.shape + for leaf in jax.tree_util.tree_leaves(tree) + if hasattr(leaf, "shape") + ) + + @staticmethod + def get_leaf_dtypes(tree: jtp.PyTree) -> tuple: + return tuple( + leaf.dtype + for leaf in jax.tree_util.tree_leaves(tree) + if hasattr(leaf, "dtype") + ) + + @staticmethod + def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]: + return tuple( + leaf.weak_type + for leaf in jax.tree_util.tree_leaves(tree) + if hasattr(leaf, "weak_type") + ) + def is_mutable(self, validate: bool = False) -> bool: """""" From 0fcaede63a9c7d8c3e5f7cea071791de15eaca25 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 21 Mar 2024 11:18:36 +0100 Subject: [PATCH 2/9] Update typing of context managers --- src/jaxsim/utils/jaxsim_dataclass.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/src/jaxsim/utils/jaxsim_dataclass.py b/src/jaxsim/utils/jaxsim_dataclass.py index f4b8a8b01..212ddff95 100644 --- a/src/jaxsim/utils/jaxsim_dataclass.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -2,7 +2,8 @@ import contextlib import copy import dataclasses -from typing import ClassVar, Generator +from collections.abc import Iterator +from typing import ClassVar import jax.flatten_util import jax_dataclasses @@ -25,7 +26,7 @@ class JaxsimDataclass(abc.ABC): __mutability__: ClassVar[Mutability] = Mutability.FROZEN @contextlib.contextmanager - def editable(self: Self, validate: bool = True) -> Generator[Self, None, None]: + def editable(self: Self, validate: bool = True) -> Iterator[Self]: """""" mutability = ( @@ -38,7 +39,7 @@ def editable(self: Self, validate: bool = True) -> Generator[Self, None, None]: @contextlib.contextmanager def mutable_context( self: Self, mutability: Mutability, restore_after_exception: bool = True - ) -> Generator[Self, None, None]: + ) -> Iterator[Self]: """""" if restore_after_exception: From dc3119ac246af63b8a0b9753772220cd97518663 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 21 Mar 2024 11:19:59 +0100 Subject: [PATCH 3/9] Do not automatically copy replaced fields The user is responsible of using copy when needed --- src/jaxsim/utils/jaxsim_dataclass.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/jaxsim/utils/jaxsim_dataclass.py b/src/jaxsim/utils/jaxsim_dataclass.py index 212ddff95..8cea0098a 100644 --- a/src/jaxsim/utils/jaxsim_dataclass.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -1,6 +1,5 @@ import abc import contextlib -import copy import dataclasses from collections.abc import Iterator from typing import ClassVar @@ -153,7 +152,7 @@ def copy(self: Self) -> Self: def replace(self: Self, validate: bool = True, **kwargs) -> Self: with self.editable(validate=validate) as obj: - _ = [obj.__setattr__(k, copy.copy(v)) for k, v in kwargs.items()] + _ = [obj.__setattr__(k, v) for k, v in kwargs.items()] obj._set_mutability(mutability=self._mutability()) return obj From 7f9b33a65b7e64f5377c381455500ac7de00dcb6 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 21 Mar 2024 11:21:55 +0100 Subject: [PATCH 4/9] Add explicit flatten_fn and unflatten_fn --- src/jaxsim/utils/jaxsim_dataclass.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/src/jaxsim/utils/jaxsim_dataclass.py b/src/jaxsim/utils/jaxsim_dataclass.py index 8cea0098a..b5004fb5c 100644 --- a/src/jaxsim/utils/jaxsim_dataclass.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -2,7 +2,7 @@ import contextlib import dataclasses from collections.abc import Iterator -from typing import ClassVar +from typing import Callable, ClassVar, Type import jax.flatten_util import jax_dataclasses @@ -158,4 +158,11 @@ def replace(self: Self, validate: bool = True, **kwargs) -> Self: return obj def flatten(self) -> jtp.VectorJax: - return jax.flatten_util.ravel_pytree(self)[0] + return self.flatten_fn()(self) + + @classmethod + def flatten_fn(cls: Type[Self]) -> Callable[[Self], jtp.VectorJax]: + return lambda pytree: jax.flatten_util.ravel_pytree(pytree)[0] + + def unflatten_fn(self: Self) -> Callable[[jtp.VectorJax], Self]: + return jax.flatten_util.ravel_pytree(self)[1] From 41166b84cb410f164c43c2cbcd3f31b22fb6253b Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 21 Mar 2024 11:23:16 +0100 Subject: [PATCH 5/9] Minor enhancement in editable method --- src/jaxsim/utils/jaxsim_dataclass.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/jaxsim/utils/jaxsim_dataclass.py b/src/jaxsim/utils/jaxsim_dataclass.py index b5004fb5c..815e97c69 100644 --- a/src/jaxsim/utils/jaxsim_dataclass.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -32,7 +32,7 @@ def editable(self: Self, validate: bool = True) -> Iterator[Self]: Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION ) - with JaxsimDataclass.mutable_context(self.copy(), mutability=mutability) as obj: + with self.copy().mutable_context(mutability=mutability) as obj: yield obj @contextlib.contextmanager From adf63b0fbb67f368aa735d44247d5e9d5d4f64ff Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 21 Mar 2024 12:13:49 +0100 Subject: [PATCH 6/9] Update mutability handling This change should also optimize memory usage by preventing unnecessary copies. --- src/jaxsim/api/common.py | 2 +- src/jaxsim/utils/jaxsim_dataclass.py | 61 ++++++++++++++++------------ 2 files changed, 37 insertions(+), 26 deletions(-) diff --git a/src/jaxsim/api/common.py b/src/jaxsim/api/common.py index f014906d7..2c1c12bc6 100644 --- a/src/jaxsim/api/common.py +++ b/src/jaxsim/api/common.py @@ -70,7 +70,7 @@ def switch_velocity_representation( # We run this in a mutable context with restoration so that any exception # occurring, we restore the original object in case it was modified. with self.mutable_context( - mutability=self._mutability(), restore_after_exception=True + mutability=self.mutability(), restore_after_exception=True ): yield self diff --git a/src/jaxsim/utils/jaxsim_dataclass.py b/src/jaxsim/utils/jaxsim_dataclass.py index 815e97c69..a2304be61 100644 --- a/src/jaxsim/utils/jaxsim_dataclass.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -44,7 +44,7 @@ def mutable_context( if restore_after_exception: self_copy = self.copy() - original_mutability = self._mutability() + original_mutability = self.mutability() original_dtypes = JaxsimDataclass.get_leaf_dtypes(tree=self) original_shapes = JaxsimDataclass.get_leaf_shapes(tree=self) @@ -52,12 +52,12 @@ def mutable_context( original_structure = jax.tree_util.tree_structure(tree=self) def restore_self() -> None: - self._set_mutability(mutability=Mutability.MUTABLE_NO_VALIDATION) + self.set_mutability(mutability=Mutability.MUTABLE_NO_VALIDATION) for f in dataclasses.fields(self_copy): setattr(self, f.name, getattr(self_copy, f.name)) try: - self._set_mutability(mutability) + self.set_mutability(mutability) yield self if mutability is not Mutability.MUTABLE_NO_VALIDATION: @@ -84,16 +84,16 @@ def restore_self() -> None: except Exception as e: if restore_after_exception: restore_self() - self._set_mutability(original_mutability) + self.set_mutability(original_mutability) raise e finally: - self._set_mutability(original_mutability) + self.set_mutability(original_mutability) @staticmethod - def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...]]: + def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]: return tuple( # noqa - leaf.shape + leaf.shape if hasattr(leaf, "shape") else None for leaf in jax.tree_util.tree_leaves(tree) if hasattr(leaf, "shape") ) @@ -101,7 +101,7 @@ def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...]]: @staticmethod def get_leaf_dtypes(tree: jtp.PyTree) -> tuple: return tuple( - leaf.dtype + leaf.dtype if hasattr(leaf, "dtype") else None for leaf in jax.tree_util.tree_leaves(tree) if hasattr(leaf, "dtype") ) @@ -109,7 +109,7 @@ def get_leaf_dtypes(tree: jtp.PyTree) -> tuple: @staticmethod def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]: return tuple( - leaf.weak_type + leaf.weak_type if hasattr(leaf, "weak_type") else False for leaf in jax.tree_util.tree_leaves(tree) if hasattr(leaf, "weak_type") ) @@ -123,38 +123,49 @@ def is_mutable(self, validate: bool = False) -> bool: else self.__mutability__ is Mutability.MUTABLE_NO_VALIDATION ) - def set_mutability(self, mutable: bool = True, validate: bool = False) -> None: - if not mutable: - mutability = Mutability.FROZEN - else: - mutability = ( - Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION - ) - - self._set_mutability(mutability=mutability) - - def _mutability(self) -> Mutability: + def mutability(self) -> Mutability: return self.__mutability__ - def _set_mutability(self, mutability: Mutability) -> None: + def set_mutability(self, mutability: Mutability) -> None: jax_dataclasses._copy_and_mutate._mark_mutable( self, mutable=mutability, visited=set() ) def mutable(self: Self, mutable: bool = True, validate: bool = False) -> Self: - self.set_mutability(mutable=mutable, validate=validate) + + if mutable: + mutability = ( + Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION + ) + else: + mutability = Mutability.FROZEN + + self.set_mutability(mutability=mutability) return self def copy(self: Self) -> Self: + + # Make a copy calling tree_map. obj = jax.tree_util.tree_map(lambda leaf: leaf, self) - obj._set_mutability(mutability=self._mutability()) + + # Make sure that the copied object and all the copied leaves have the same + # mutability of the original object. + obj.set_mutability(mutability=self.mutability()) + return obj def replace(self: Self, validate: bool = True, **kwargs) -> Self: - with self.editable(validate=validate) as obj: + + mutability = ( + Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION + ) + + with self.mutable_context(mutability=mutability) as obj: _ = [obj.__setattr__(k, v) for k, v in kwargs.items()] - obj._set_mutability(mutability=self._mutability()) + # Make sure that all the new leaves have the same mutability of the object. + obj.set_mutability(mutability=self.mutability()) + return obj def flatten(self) -> jtp.VectorJax: From 6f0343c5733c73cc246f73b294e99bdaf6a91fd2 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 21 Mar 2024 12:14:57 +0100 Subject: [PATCH 7/9] Add docstrings --- src/jaxsim/utils/jaxsim_dataclass.py | 141 ++++++++++++++++++++++++++- 1 file changed, 137 insertions(+), 4 deletions(-) diff --git a/src/jaxsim/utils/jaxsim_dataclass.py b/src/jaxsim/utils/jaxsim_dataclass.py index a2304be61..80db98465 100644 --- a/src/jaxsim/utils/jaxsim_dataclass.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -19,14 +19,26 @@ @jax_dataclasses.pytree_dataclass class JaxsimDataclass(abc.ABC): - """""" + """Class extending `jax_dataclasses.pytree_dataclass` instances with utilities.""" # This attribute is set by jax_dataclasses __mutability__: ClassVar[Mutability] = Mutability.FROZEN @contextlib.contextmanager def editable(self: Self, validate: bool = True) -> Iterator[Self]: - """""" + """ + Context manager to operate on a mutable copy of the object. + + Args: + validate: Whether to validate the output PyTree upon exiting the context. + + Yields: + A mutable copy of the object. + + Note: + This context manager is useful to operate on an r/w copy of a PyTree making + sure that the output object does not trigger JIT recompilations. + """ mutability = ( Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION @@ -39,7 +51,23 @@ def editable(self: Self, validate: bool = True) -> Iterator[Self]: def mutable_context( self: Self, mutability: Mutability, restore_after_exception: bool = True ) -> Iterator[Self]: - """""" + """ + Context manager to temporarily change the mutability of the object. + + Args: + mutability: The mutability to set. + restore_after_exception: + Whether to restore the original object in case of an exception + occurring within the context. + + Yields: + The object with the new mutability. + + Note: + This context manager is useful to operate in place on a PyTree without + the need to make a copy while optionally keeping active the checks on + the PyTree structure, shapes, and dtypes. + """ if restore_after_exception: self_copy = self.copy() @@ -92,6 +120,17 @@ def restore_self() -> None: @staticmethod def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]: + """ + Helper method to get the leaf shapes of a PyTree. + + Args: + tree: The PyTree to consider. + + Returns: + A tuple containing the leaf shapes of the PyTree or `None` is the leaf is + not a numpy-like array. + """ + return tuple( # noqa leaf.shape if hasattr(leaf, "shape") else None for leaf in jax.tree_util.tree_leaves(tree) @@ -100,6 +139,17 @@ def get_leaf_shapes(tree: jtp.PyTree) -> tuple[tuple[int, ...] | None]: @staticmethod def get_leaf_dtypes(tree: jtp.PyTree) -> tuple: + """ + Helper method to get the leaf dtypes of a PyTree. + + Args: + tree: The PyTree to consider. + + Returns: + A tuple containing the leaf dtypes of the PyTree or `None` is the leaf is + not a numpy-like array. + """ + return tuple( leaf.dtype if hasattr(leaf, "dtype") else None for leaf in jax.tree_util.tree_leaves(tree) @@ -108,6 +158,16 @@ def get_leaf_dtypes(tree: jtp.PyTree) -> tuple: @staticmethod def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]: + """ + Helper method to get the leaf weak types of a PyTree. + + Args: + tree: The PyTree to consider. + + Returns: + A tuple marking whether the leaf contains a JAX array with weak type. + """ + return tuple( leaf.weak_type if hasattr(leaf, "weak_type") else False for leaf in jax.tree_util.tree_leaves(tree) @@ -115,7 +175,15 @@ def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]: ) def is_mutable(self, validate: bool = False) -> bool: - """""" + """ + Check whether the object is mutable. + + Args: + validate: Additionally checks if the object also has validation enabled. + + Returns: + True if the object is mutable, False otherwise. + """ return ( self.__mutability__ is Mutability.MUTABLE @@ -124,14 +192,38 @@ def is_mutable(self, validate: bool = False) -> bool: ) def mutability(self) -> Mutability: + """ + Get the mutability type of the object. + + Returns: + The mutability type of the object. + """ + return self.__mutability__ def set_mutability(self, mutability: Mutability) -> None: + """ + Set the mutability of the object in-place. + + Args: + mutability: The desired mutability type. + """ + jax_dataclasses._copy_and_mutate._mark_mutable( self, mutable=mutability, visited=set() ) def mutable(self: Self, mutable: bool = True, validate: bool = False) -> Self: + """ + Return a mutable reference of the object. + + Args: + mutable: Whether to make the object mutable. + validate: Whether to enable validation on the object. + + Returns: + A mutable reference of the object. + """ if mutable: mutability = ( @@ -144,6 +236,12 @@ def mutable(self: Self, mutable: bool = True, validate: bool = False) -> Self: return self def copy(self: Self) -> Self: + """ + Return a copy of the object. + + Returns: + A copy of the object. + """ # Make a copy calling tree_map. obj = jax.tree_util.tree_map(lambda leaf: leaf, self) @@ -155,6 +253,17 @@ def copy(self: Self) -> Self: return obj def replace(self: Self, validate: bool = True, **kwargs) -> Self: + """ + Return a new object replacing in-place the specified fields with new values. + + Args: + validate: + Whether to validate that the new fields do not alter the PyTree. + **kwargs: The fields to replace. + + Returns: + A reference of the object with the specified fields replaced. + """ mutability = ( Mutability.MUTABLE if validate else Mutability.MUTABLE_NO_VALIDATION @@ -169,11 +278,35 @@ def replace(self: Self, validate: bool = True, **kwargs) -> Self: return obj def flatten(self) -> jtp.VectorJax: + """ + Flatten the object into a 1D vector. + + Returns: + A 1D vector containing the flattened object. + """ + return self.flatten_fn()(self) @classmethod def flatten_fn(cls: Type[Self]) -> Callable[[Self], jtp.VectorJax]: + """ + Return a function to flatten the object into a 1D vector. + + Returns: + A function to flatten the object into a 1D vector. + """ + return lambda pytree: jax.flatten_util.ravel_pytree(pytree)[0] def unflatten_fn(self: Self) -> Callable[[jtp.VectorJax], Self]: + """ + Return a function to unflatten a 1D vector into the object. + + Returns: + A function to unflatten a 1D vector into the object. + + Notes: + Due to JAX internals, the function to unflatten a PyTree needs to be + created from an existing instance of the PyTree. + """ return jax.flatten_util.ravel_pytree(self)[1] From 2ca7bd82302ce3ee3fb6891643b864f418c12cb8 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 21 Mar 2024 12:17:56 +0100 Subject: [PATCH 8/9] Add JaxsimDataclass.check_compatibility helper --- src/jaxsim/utils/jaxsim_dataclass.py | 51 +++++++++++++++++++++++++++- 1 file changed, 50 insertions(+), 1 deletion(-) diff --git a/src/jaxsim/utils/jaxsim_dataclass.py b/src/jaxsim/utils/jaxsim_dataclass.py index 80db98465..deb6dca77 100644 --- a/src/jaxsim/utils/jaxsim_dataclass.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -1,8 +1,9 @@ import abc import contextlib import dataclasses +import functools from collections.abc import Iterator -from typing import Callable, ClassVar, Type +from typing import Any, Callable, ClassVar, Sequence, Type import jax.flatten_util import jax_dataclasses @@ -174,6 +175,54 @@ def get_leaf_weak_types(tree: jtp.PyTree) -> tuple[bool, ...]: if hasattr(leaf, "weak_type") ) + @staticmethod + def check_compatibility(*trees: Sequence[Any]) -> None: + """ + Check whether the PyTrees are compatible in structure, shape, and dtype. + + Args: + *trees: The PyTrees to compare. + + Raises: + ValueError: If the PyTrees have incompatible structures, shapes, or dtypes. + """ + + target_structure = jax.tree_util.tree_structure(trees[0]) + + compatible_structure = functools.reduce( + lambda compatible, tree: compatible + and jax.tree_util.tree_structure(tree) == target_structure, + trees[1:], + True, + ) + + if not compatible_structure: + raise ValueError("Pytrees have incompatible structures.") + + target_shapes = JaxsimDataclass.get_leaf_shapes(trees[0]) + + compatible_shapes = functools.reduce( + lambda compatible, tree: compatible + and JaxsimDataclass.get_leaf_shapes(tree) == target_shapes, + trees[1:], + True, + ) + + if not compatible_shapes: + raise ValueError("Pytrees have incompatible shapes.") + + target_dtypes = JaxsimDataclass.get_leaf_dtypes(trees[0]) + + compatible_dtypes = functools.reduce( + lambda compatible, tree: compatible + and JaxsimDataclass.get_leaf_dtypes(tree) == target_dtypes, + trees[1:], + True, + ) + + if not compatible_dtypes: + raise ValueError("Pytrees have incompatible dtypes.") + def is_mutable(self, validate: bool = False) -> bool: """ Check whether the object is mutable. From 83fa95cefc6b86bfeab8eb38d38eac846c126566 Mon Sep 17 00:00:00 2001 From: diegoferigo Date: Thu, 21 Mar 2024 12:49:50 +0100 Subject: [PATCH 9/9] JaxsimDataclass.replace returns a shallow copy of the original object --- src/jaxsim/utils/jaxsim_dataclass.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/jaxsim/utils/jaxsim_dataclass.py b/src/jaxsim/utils/jaxsim_dataclass.py index deb6dca77..d791074ef 100644 --- a/src/jaxsim/utils/jaxsim_dataclass.py +++ b/src/jaxsim/utils/jaxsim_dataclass.py @@ -1,5 +1,6 @@ import abc import contextlib +import copy import dataclasses import functools from collections.abc import Iterator @@ -324,7 +325,10 @@ def replace(self: Self, validate: bool = True, **kwargs) -> Self: # Make sure that all the new leaves have the same mutability of the object. obj.set_mutability(mutability=self.mutability()) - return obj + # Return a shallow copy of the object with the new fields replaced. + # Note that the shallow copy of the original object contains exactly the same + # attributes of the original object (in other words, with the same id). + return copy.copy(obj) def flatten(self) -> jtp.VectorJax: """