From f51907014891d58934891a94c4a484def7d76c74 Mon Sep 17 00:00:00 2001 From: yhl48 Date: Wed, 24 Apr 2024 16:11:44 +0100 Subject: [PATCH 01/14] copied _pytree.py from PyTorch --- src/litdata/utilities/_pytree.py | 1595 ++++++++++++++++++++++++++++++ 1 file changed, 1595 insertions(+) create mode 100644 src/litdata/utilities/_pytree.py diff --git a/src/litdata/utilities/_pytree.py b/src/litdata/utilities/_pytree.py new file mode 100644 index 000000000..14b6115b8 --- /dev/null +++ b/src/litdata/utilities/_pytree.py @@ -0,0 +1,1595 @@ +""" +Code taken from https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py + +Contains utility functions for working with nested python data structures. + +A *pytree* is Python nested data structure. It is a tree in the sense that +nodes are Python collections (e.g., list, tuple, dict) and the leaves are +Python values. Furthermore, a pytree should not contain reference cycles. + +pytrees are useful for working with nested collections of Tensors. For example, +one can use `tree_map` to map a function over all Tensors inside some nested +collection of Tensors and `tree_leaves` to get a flat list of all Tensors +inside some nested collection. pytrees are helpful for implementing nested +collection support for PyTorch APIs. + +This pytree implementation is not very performant due to Python overhead +To improve the performance we can move parts of the implementation to C++. +""" + +import dataclasses +import functools +import importlib +import json +import sys +import threading +import types +import warnings +from collections import defaultdict, deque, namedtuple, OrderedDict +from typing import ( + Any, + Callable, + cast, + DefaultDict, + Deque, + Dict, + FrozenSet, + Generic, + Hashable, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + OrderedDict as GenericOrderedDict, + overload, + Protocol, + Sequence, + Tuple, + Type, + TypeVar, + Union, +) + + +__all__ = [ + "PyTree", + "Context", + "FlattenFunc", + "UnflattenFunc", + "DumpableContext", + "ToDumpableContextFn", + "FromDumpableContextFn", + "TreeSpec", + "LeafSpec", + "keystr", + "key_get", + "register_pytree_node", + "tree_flatten", + "tree_flatten_with_path", + "tree_unflatten", + "tree_iter", + "tree_leaves", + "tree_leaves_with_path", + "tree_structure", + "tree_map", + "tree_map_with_path", + "tree_map_", + "tree_map_only", + "tree_map_only_", + "tree_all", + "tree_any", + "tree_all_only", + "tree_any_only", + "treespec_dumps", + "treespec_loads", + "treespec_pprint", +] + + +T = TypeVar("T") +S = TypeVar("S") +U = TypeVar("U") +R = TypeVar("R") + + +DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1 +NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND" + + +class KeyEntry(Protocol): + def __hash__(self) -> int: + ... + + def __eq__(self, other: object) -> bool: + ... + + def __str__(self) -> str: + ... + + def get(self, parent: Any) -> Any: + ... + + +Context = Any +PyTree = Any +FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]] +UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] +DumpableContext = Any # Any json dumpable text +ToDumpableContextFn = Callable[[Context], DumpableContext] +FromDumpableContextFn = Callable[[DumpableContext], Context] +ToStrFunc = Callable[["TreeSpec", List[str]], str] +MaybeFromStrFunc = Callable[[str], Optional[Tuple[Any, Context, str]]] +KeyPath = Tuple[KeyEntry, ...] +FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]] + + +# A NodeDef holds two callables: +# - flatten_fn should take the collection and return a flat list of values. +# It can also return some context that is used in reconstructing the +# collection. +# - unflatten_fn should take a flat list of values and some context +# (returned by flatten_fn). It returns the collection by reconstructing +# it from the list and the context. +# - flatten_with_keys_fn, which is a callable that takes a +# pytree and returns a list of (keypath, value) pairs and a context. +class NodeDef(NamedTuple): + type: Type[Any] + flatten_fn: FlattenFunc + unflatten_fn: UnflattenFunc + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] + + +_NODE_REGISTRY_LOCK = threading.Lock() +SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} + + +# _SerializeNodeDef holds the following: +# - typ: the type of the node (e.g., "Dict", "List", etc) +# - serialized_type_name: the fully qualified name of the type, e.g. "collections.OrderedDict" +# - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the +# context, and the version number +# - from_dumpable_context takes in a string representation of the context, and the +# version, and returns the deserialized context +class _SerializeNodeDef(NamedTuple): + typ: Type[Any] + serialized_type_name: str + to_dumpable_context: Optional[ToDumpableContextFn] + from_dumpable_context: Optional[FromDumpableContextFn] + + +SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {} +SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {} + + +def register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """Register a container-like type as pytree node. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in torch.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in torch.export right now. + flatten_with_keys_fn: An optional keyword argument to specify how to + access each pytree leaf's keypath when flattening and tree-mapping. + Like ``flatten_fn``, but in place of a List[leaf], it should return + a List[(keypath, leaf)]. + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + raise ValueError(f"{cls} is already registered as pytree node.") + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + flatten_with_keys_fn=flatten_with_keys_fn, + ) + + try: + from . import _cxx_pytree as cxx + except ImportError: + pass + else: + cxx._private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +def _register_namedtuple( + cls: Type[Any], + *, + serialized_type_name: str, +) -> None: + """ + Registers a namedtuple as a valid pytree node. By default namedtuples are + valid pytree nodes, but they are not serializable. This API provides the + argument `serialized_type_name` which allows these namedtuples to be + serialized. + + Args: + cls: the dataclass type to register + serialized_type_name: The serialized name for the dataclass. This is + required if you want to serialize the pytree TreeSpec containing this + namedtuple. + """ + _private_register_pytree_node( + cls, + _namedtuple_flatten, + _namedtuple_unflatten, + serialized_type_name=serialized_type_name, + to_dumpable_context=_namedtuple_serialize, + from_dumpable_context=_namedtuple_deserialize, + flatten_with_keys_fn=_namedtuple_flatten_with_keys, + ) + + +def _register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + to_str_fn: Optional[ToStrFunc] = None, # deprecated + maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """Register a container-like type as pytree node for the Python pytree only. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in torch.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in torch.export right now. + flatten_with_keys_fn: An optional keyword argument to specify how to + access each pytree leaf's keypath when flattening and tree-mapping. + Like ``flatten_fn``, but in place of a List[leaf], it should return + a List[(keypath, leaf)]. + """ + warnings.warn( + "torch.utils._pytree._register_pytree_node is deprecated. " + "Please use torch.utils._pytree.register_pytree_node instead.", + stacklevel=2, + ) + + if to_str_fn is not None or maybe_from_str_fn is not None: + warnings.warn( + "to_str_fn and maybe_from_str_fn is deprecated. " + "Please use to_dumpable_context and from_dumpable_context instead." + ) + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + flatten_with_keys_fn=flatten_with_keys_fn, + ) + + +def _private_register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """This is an internal function that is used to register a pytree node type + for the Python pytree only. End-users should use :func:`register_pytree_node` + instead. + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + # TODO: change this warning to an error after OSS/internal stabilize + warnings.warn( + f"{cls} is already registered as pytree node. " + "Overwriting the previous registration.", + ) + + node_def = NodeDef(cls, flatten_fn, unflatten_fn, flatten_with_keys_fn) + SUPPORTED_NODES[cls] = node_def + + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError( + f"Both to_dumpable_context and from_dumpable_context for {cls} must " + "be None or registered." + ) + + if serialized_type_name is None: + serialized_type_name = NO_SERIALIZED_TYPE_NAME_FOUND + + serialize_node_def = _SerializeNodeDef( + cls, + serialized_type_name, + to_dumpable_context, + from_dumpable_context, + ) + SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def + SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls + + +@dataclasses.dataclass(frozen=True) +class SequenceKey(Generic[T]): + idx: int + + def __str__(self) -> str: + return f"[{self.idx!r}]" + + def get(self, sequence: Sequence[T]) -> T: + return sequence[self.idx] + + +K = TypeVar("K", bound=Hashable) + + +@dataclasses.dataclass(frozen=True) +class MappingKey(Generic[K, T]): + key: K + + def __str__(self) -> str: + return f"[{self.key!r}]" + + def get(self, mapping: Mapping[K, T]) -> T: + return mapping[self.key] + + +@dataclasses.dataclass(frozen=True) +class GetAttrKey: + name: str + + def __str__(self) -> str: + return f".{self.name}" + + def get(self, obj: Any) -> Any: + return getattr(obj, self.name) + + +def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]: + return list(d), None + + +def _tuple_flatten_with_keys( + d: Tuple[Any, ...] +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _tuple_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _tuple_unflatten(values: Iterable[Any], context: Context) -> Tuple[Any, ...]: + return tuple(values) + + +def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: + return d, None + + +def _list_flatten_with_keys(d: List[Any]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _list_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _list_unflatten(values: Iterable[Any], context: Context) -> List[Any]: + return list(values) + + +def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: + return list(d.values()), list(d.keys()) + + +def _dict_flatten_with_keys( + d: Dict[Any, Any] +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _dict_flatten(d) + return [(MappingKey(k), v) for k, v in zip(context, values)], context + + +def _dict_unflatten(values: Iterable[Any], context: Context) -> Dict[Any, Any]: + return dict(zip(context, values)) + + +def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]: + return list(d), type(d) + + +def _namedtuple_flatten_with_keys( + d: NamedTuple, +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _namedtuple_flatten(d) + return ( + [(GetAttrKey(field), v) for field, v in zip(context._fields, values)], + context, + ) + + +def _namedtuple_unflatten(values: Iterable[Any], context: Context) -> NamedTuple: + return cast(NamedTuple, context(*values)) + + +def _namedtuple_serialize(context: Context) -> DumpableContext: + if context not in SUPPORTED_SERIALIZED_TYPES: + raise NotImplementedError( + f"Can't serialize TreeSpec of namedtuple class {context} because we " + "didn't register a serializated_type_name. Please register using " + "`_register_namedtuple`." + ) + + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[context] + serialized_type_name = serialize_node_def.serialized_type_name + + if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND: + raise NotImplementedError( + f"Can't serialize TreeSpec of namedtuple class {context} because we " + "couldn't find a serializated_type_name. Please register using " + "`_register_namedtuple`." + ) + return serialized_type_name + + +def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context: + if dumpable_context not in SERIALIZED_TYPE_TO_PYTHON_TYPE: + raise NotImplementedError( + f"Can't deserialize TreeSpec of namedtuple class {dumpable_context} " + "because we couldn't find a serializated name." + ) + + typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[dumpable_context] + return typ + + +def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Context]: + return list(d.values()), list(d.keys()) + + +def _ordereddict_flatten_with_keys( + d: GenericOrderedDict[Any, Any] +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _ordereddict_flatten(d) + return [(MappingKey(k), v) for k, v in zip(context, values)], context + + +def _ordereddict_unflatten( + values: Iterable[Any], + context: Context, +) -> GenericOrderedDict[Any, Any]: + return OrderedDict((key, value) for key, value in zip(context, values)) + + +_odict_flatten = _ordereddict_flatten +_odict_unflatten = _ordereddict_unflatten + + +def _defaultdict_flatten(d: DefaultDict[Any, Any]) -> Tuple[List[Any], Context]: + values, dict_context = _dict_flatten(d) + return values, [d.default_factory, dict_context] + + +def _defaultdict_flatten_with_keys( + d: DefaultDict[Any, Any] +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _defaultdict_flatten(d) + _, dict_context = context + return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context + + +def _defaultdict_unflatten( + values: Iterable[Any], + context: Context, +) -> DefaultDict[Any, Any]: + default_factory, dict_context = context + return defaultdict(default_factory, _dict_unflatten(values, dict_context)) + + +def _defaultdict_serialize(context: Context) -> DumpableContext: + default_factory, dict_context = context + json_defaultdict = { + "default_factory_module": default_factory.__module__, + "default_factory_name": default_factory.__qualname__, + "dict_context": dict_context, + } + return json_defaultdict + + +def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context: + assert isinstance(dumpable_context, dict) + assert set(dumpable_context) == { + "default_factory_module", + "default_factory_name", + "dict_context", + } + + default_factory_module = dumpable_context["default_factory_module"] + default_factory_name = dumpable_context["default_factory_name"] + assert isinstance(default_factory_module, str) + assert isinstance(default_factory_name, str) + module = importlib.import_module(default_factory_module) + default_factory = getattr(module, default_factory_name) + + dict_context = dumpable_context["dict_context"] + return [default_factory, dict_context] + + +def _deque_flatten(d: Deque[Any]) -> Tuple[List[Any], Context]: + return list(d), d.maxlen + + +def _deque_flatten_with_keys( + d: Deque[Any], +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _deque_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _deque_unflatten(values: Iterable[Any], context: Context) -> Deque[Any]: + return deque(values, maxlen=context) + + +_private_register_pytree_node( + tuple, + _tuple_flatten, + _tuple_unflatten, + serialized_type_name="builtins.tuple", + flatten_with_keys_fn=_tuple_flatten_with_keys, +) +_private_register_pytree_node( + list, + _list_flatten, + _list_unflatten, + serialized_type_name="builtins.list", + flatten_with_keys_fn=_list_flatten_with_keys, +) +_private_register_pytree_node( + dict, + _dict_flatten, + _dict_unflatten, + serialized_type_name="builtins.dict", + flatten_with_keys_fn=_dict_flatten_with_keys, +) +_private_register_pytree_node( + namedtuple, # type: ignore[arg-type] + _namedtuple_flatten, + _namedtuple_unflatten, + serialized_type_name="collections.namedtuple", + to_dumpable_context=_namedtuple_serialize, + from_dumpable_context=_namedtuple_deserialize, + flatten_with_keys_fn=_namedtuple_flatten_with_keys, +) +_private_register_pytree_node( + OrderedDict, + _ordereddict_flatten, + _ordereddict_unflatten, + serialized_type_name="collections.OrderedDict", + flatten_with_keys_fn=_ordereddict_flatten_with_keys, +) +_private_register_pytree_node( + defaultdict, + _defaultdict_flatten, + _defaultdict_unflatten, + serialized_type_name="collections.defaultdict", + to_dumpable_context=_defaultdict_serialize, + from_dumpable_context=_defaultdict_deserialize, + flatten_with_keys_fn=_defaultdict_flatten_with_keys, +) +_private_register_pytree_node( + deque, + _deque_flatten, + _deque_unflatten, + serialized_type_name="collections.deque", + flatten_with_keys_fn=_deque_flatten_with_keys, +) + + +STANDARD_DICT_TYPES: FrozenSet[type] = frozenset( + {dict, OrderedDict, defaultdict}, +) +BUILTIN_TYPES: FrozenSet[type] = frozenset( + {tuple, list, dict, namedtuple, OrderedDict, defaultdict, deque}, # type: ignore[arg-type] +) + + +# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple +def _is_namedtuple_instance(tree: Any) -> bool: + typ = type(tree) + bases = typ.__bases__ + if len(bases) != 1 or bases[0] != tuple: + return False + fields = getattr(typ, "_fields", None) + if not isinstance(fields, tuple): + return False + return all(type(entry) == str for entry in fields) + + +def _get_node_type(tree: Any) -> Any: + if _is_namedtuple_instance(tree): + return namedtuple + return type(tree) + + +# A leaf is defined as anything that is not a Node. +def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool: + return (is_leaf is not None and is_leaf(tree)) or _get_node_type( + tree + ) not in SUPPORTED_NODES + + +# A TreeSpec represents the structure of a pytree. It holds: +# "type": the type of root Node of the pytree +# context: some context that is useful in unflattening the pytree +# children_specs: specs for each child of the root Node +# num_leaves: the number of leaves +@dataclasses.dataclass +class TreeSpec: + type: Any + context: Context + children_specs: List["TreeSpec"] + + num_nodes: int = dataclasses.field(init=False) + num_leaves: int = dataclasses.field(init=False) + num_children: int = dataclasses.field(init=False) + + def __post_init__(self) -> None: + self.num_nodes = 1 + sum(spec.num_nodes for spec in self.children_specs) + self.num_leaves = sum(spec.num_leaves for spec in self.children_specs) + self.num_children = len(self.children_specs) + + def __repr__(self, indent: int = 0) -> str: + repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" + children_specs_str: str = "" + if self.num_children > 0: + indent += 2 + children_specs_str += self.children_specs[0].__repr__(indent) + children_specs_str += "," if self.num_children > 1 else "" + children_specs_str += ",".join( + [ + "\n" + " " * indent + child.__repr__(indent) + for child in self.children_specs[1:] + ] + ) + repr_suffix: str = f"{children_specs_str}])" + return repr_prefix + repr_suffix + + def is_leaf(self) -> bool: + return self.num_nodes == 1 and self.num_leaves == 1 + + def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None: + if self.is_leaf(): + subtrees.append(tree) + return + + node_type = _get_node_type(tree) + if self.type not in BUILTIN_TYPES: + # Always require custom node types to match exactly + if node_type != self.type: + raise ValueError( + f"Type mismatch; " + f"expected {self.type!r}, but got {node_type!r}.", + ) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + if len(child_pytrees) != self.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {self.num_children}, but got {len(child_pytrees)}.", + ) + if context != self.context: + raise ValueError( + f"Node context mismatch for custom node type {self.type!r}.", + ) + else: + # For builtin dictionary types, we allow some flexibility + # Otherwise, we require exact matches + both_standard_dict = ( + self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES + ) + if node_type != self.type and not both_standard_dict: + raise ValueError( + f"Node type mismatch; " + f"expected {self.type!r}, but got {node_type!r}.", + ) + if len(tree) != self.num_children: + raise ValueError( + f"Node arity mismatch; " + f"expected {self.num_children}, but got {len(tree)}.", + ) + + if both_standard_dict: # dictionary types are compatible with each other + dict_context = ( + self.context + if self.type is not defaultdict + # ignore mismatch of `default_factory` for defaultdict + else self.context[1] + ) + expected_keys = dict_context + got_key_set = set(tree) + expected_key_set = set(expected_keys) + if got_key_set != expected_key_set: + missing_keys = expected_key_set.difference(got_key_set) + extra_keys = got_key_set.difference(expected_key_set) + message = "" + if missing_keys: + message += f"; missing key(s): {missing_keys}" + if extra_keys: + message += f"; extra key(s): {extra_keys}" + raise ValueError(f"Node keys mismatch{message}.") + child_pytrees = [tree[key] for key in expected_keys] + else: + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + if ( + context != self.context + and self.type is not deque # ignore mismatch of `maxlen` for deque + ): + raise ValueError( + f"Node context mismatch for node type {self.type!r}; " + f"expected {self.context!r}, but got {context!r}.", # namedtuple type mismatch + ) + + for child_pytree, child_spec in zip(child_pytrees, self.children_specs): + child_spec._flatten_up_to_helper(child_pytree, subtrees) + + def flatten_up_to(self, tree: PyTree) -> List[PyTree]: + subtrees: List[PyTree] = [] + self._flatten_up_to_helper(tree, subtrees) + return subtrees + + def unflatten(self, leaves: Iterable[Any]) -> PyTree: + if not isinstance(leaves, (list, tuple)): + leaves = list(leaves) + if len(leaves) != self.num_leaves: + raise ValueError( + f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " + f"but the spec refers to a pytree that holds {self.num_leaves} " + f"items ({self}).", + ) + if self.is_leaf(): + return leaves[0] + + unflatten_fn = SUPPORTED_NODES[self.type].unflatten_fn + + # Recursively unflatten the children + start = 0 + end = 0 + child_pytrees = [] + for child_spec in self.children_specs: + end += child_spec.num_leaves + child_pytrees.append(child_spec.unflatten(leaves[start:end])) + start = end + + return unflatten_fn(child_pytrees, self.context) + + +class LeafSpec(TreeSpec): + def __init__(self) -> None: + super().__init__(None, None, []) + + def __post_init__(self) -> None: + self.num_nodes = 1 + self.num_leaves = 1 + self.num_children = 0 + + def __repr__(self, indent: int = 0) -> str: + return "*" + + +# All leaves are equivalent, so represent with a single object to save on +# object construction time +_LEAF_SPEC = LeafSpec() + + +def _tree_flatten_helper( + tree: PyTree, + leaves: List[Any], + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> TreeSpec: + if _is_leaf(tree, is_leaf=is_leaf): + leaves.append(tree) + return _LEAF_SPEC + + node_type = _get_node_type(tree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + + # Recursively flatten the children + children_specs = [ + _tree_flatten_helper(child, leaves, is_leaf=is_leaf) for child in child_pytrees + ] + + return TreeSpec(node_type, context, children_specs) + + +def tree_flatten( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Tuple[List[Any], TreeSpec]: + """Flattens a pytree into a list of values and a TreeSpec that can be used + to reconstruct the pytree. + """ + leaves: List[Any] = [] + spec = _tree_flatten_helper(tree, leaves, is_leaf=is_leaf) + return leaves, spec + + +def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: + """Given a list of values and a TreeSpec, builds a pytree. + This is the inverse operation of `tree_flatten`. + """ + if not isinstance(treespec, TreeSpec): + raise TypeError( + f"tree_unflatten(leaves, treespec): Expected `treespec` to be " + f"instance of TreeSpec but got item of type {type(treespec)}.", + ) + return treespec.unflatten(leaves) + + +def tree_iter( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Iterable[Any]: + """Get an iterator over the leaves of a pytree.""" + if _is_leaf(tree, is_leaf=is_leaf): + yield tree + else: + node_type = _get_node_type(tree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, _ = flatten_fn(tree) + + # Recursively flatten the children + for child in child_pytrees: + yield from tree_iter(child, is_leaf=is_leaf) + + +def tree_leaves( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> List[Any]: + """Get a list of leaves of a pytree.""" + return list(tree_iter(tree, is_leaf=is_leaf)) + + +def tree_structure( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> TreeSpec: + """Get the TreeSpec for a pytree.""" + return tree_flatten(tree, is_leaf=is_leaf)[1] + + +def tree_map( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Map a multi-input function over pytree args to produce a new pytree. + + See also :func:`tree_map_`. + + >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) + {'x': 8, 'y': (43, 65)} + >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) + {'x': False, 'y': (False, False), 'z': True} + + If multiple inputs are given, the structure of the tree is taken from the first input; + subsequent inputs need only have ``tree`` as a prefix: + + >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) + [[5, 7, 9], [6, 1, 2]] + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + A new pytree with the same structure as ``tree`` but with the value at each leaf given by + ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs`` + is the tuple of values at corresponding nodes in ``rests``. + """ + leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + return treespec.unflatten(map(func, *flat_args)) + + +def tree_map_( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree. + + See also :func:`tree_map`. + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + The original ``tree`` with the value at each leaf is given by the side-effect of function + ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf + in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``. + """ + leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + tuple(map(func, *flat_args)) # consume and exhaust the iterable + return tree + + +Type2 = Tuple[Type[T], Type[S]] +Type3 = Tuple[Type[T], Type[S], Type[U]] +if sys.version_info >= (3, 10): + TypeAny = Union[Type[Any], Tuple[Type[Any], ...], types.UnionType] +else: + TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] + +Fn2 = Callable[[Union[T, S]], R] +Fn3 = Callable[[Union[T, S, U]], R] +Fn = Callable[[T], R] +FnAny = Callable[[Any], R] + +MapOnlyFn = Callable[[T], Callable[[Any], Any]] + + +# These specializations help with type inference on the lambda passed to this +# function +@overload +def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]: + ... + + +@overload +def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]: + ... + + +@overload +def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]: + ... + + +# This specialization is needed for the implementations below that call +@overload +def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]: + ... + + +@overload +def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]: + ... + + +def map_only( + __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]] +) -> MapOnlyFn[FnAny[Any]]: + """ + Suppose you are writing a tree_map over tensors, leaving everything + else unchanged. Ordinarily you would have to write: + + def go(t): + if isinstance(t, Tensor): + return ... + else: + return t + + With this function, you only need to write: + + @map_only(Tensor) + def go(t): + return ... + + You can also directly use 'tree_map_only' + """ + if isinstance(__type_or_types_or_pred, (type, tuple)) or ( + sys.version_info >= (3, 10) + and isinstance(__type_or_types_or_pred, types.UnionType) + ): + + def pred(x: Any) -> bool: + return isinstance(x, __type_or_types_or_pred) # type: ignore[arg-type] + + elif callable(__type_or_types_or_pred): + pred = __type_or_types_or_pred # type: ignore[assignment] + else: + raise TypeError("Argument must be a type, a tuple of types, or a callable.") + + def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]: + @functools.wraps(func) + def wrapped(x: T) -> Any: + if pred(x): + return func(x) + return x + + return wrapped + + return wrapper + + +@overload +def tree_map_only( + __type_or_types_or_pred: Type[T], + func: Fn[T, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + __type_or_types_or_pred: Type2[T, S], + func: Fn2[T, S, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + __type_or_types_or_pred: Type3[T, S, U], + func: Fn3[T, S, U, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only( + __type_or_types_or_pred: Callable[[Any], bool], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +def tree_map_only( + __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf) + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Type[T], + func: Fn[T, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Type2[T, S], + func: Fn2[T, S, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Type3[T, S, U], + func: Fn3[T, S, U, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Callable[[Any], bool], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + ... + + +def tree_map_only_( + __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf) + + +def tree_all( + pred: Callable[[Any], bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return all(map(pred, flat_args)) + + +def tree_any( + pred: Callable[[Any], bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return any(map(pred, flat_args)) + + +@overload +def tree_all_only( + __type_or_types: Type[T], + pred: Fn[T, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_all_only( + __type_or_types: Type2[T, S], + pred: Fn2[T, S, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_all_only( + __type_or_types: Type3[T, S, U], + pred: Fn3[T, S, U, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +def tree_all_only( + __type_or_types: TypeAny, + pred: FnAny[bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return all(pred(x) for x in flat_args if isinstance(x, __type_or_types)) + + +@overload +def tree_any_only( + __type_or_types: Type[T], + pred: Fn[T, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_any_only( + __type_or_types: Type2[T, S], + pred: Fn2[T, S, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +@overload +def tree_any_only( + __type_or_types: Type3[T, S, U], + pred: Fn3[T, S, U, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + ... + + +def tree_any_only( + __type_or_types: TypeAny, + pred: FnAny[bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return any(pred(x) for x in flat_args if isinstance(x, __type_or_types)) + + +# Broadcasts a pytree to the provided TreeSpec and returns the flattened +# values. If this is not possible, then this function returns None. +# +# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]), +# would return [0, 0]. This is useful for part of the vmap implementation: +# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be +# broadcastable to the tree structure of `inputs` and we use +# _broadcast_to_and_flatten to check this. +def _broadcast_to_and_flatten( + tree: PyTree, + treespec: TreeSpec, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Optional[List[Any]]: + assert isinstance(treespec, TreeSpec) + + if _is_leaf(tree, is_leaf=is_leaf): + return [tree] * treespec.num_leaves + if treespec.is_leaf(): + return None + node_type = _get_node_type(tree) + if node_type != treespec.type: + return None + + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, ctx = flatten_fn(tree) + + # Check if the Node is different from the spec + if len(child_pytrees) != treespec.num_children or ctx != treespec.context: + return None + + # Recursively flatten the children + result: List[Any] = [] + for child, child_spec in zip(child_pytrees, treespec.children_specs): + flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf) + if flat is not None: + result += flat + else: + return None + + return result + + +@dataclasses.dataclass +class _TreeSpecSchema: + """ + _TreeSpecSchema is the schema used to serialize the TreeSpec + It contains the following fields: + - type: A string name of the type. null for the case of a LeafSpec. + - context: Any format which is json dumpable + - children_spec: A list of children serialized specs. + """ + + type: Optional[str] + context: DumpableContext + children_spec: List["_TreeSpecSchema"] + + +class _ProtocolFn(NamedTuple): + treespec_to_json: Callable[[TreeSpec], DumpableContext] + json_to_treespec: Callable[[DumpableContext], TreeSpec] + + +_SUPPORTED_PROTOCOLS: Dict[int, _ProtocolFn] = {} + + +def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: + if treespec.is_leaf(): + return _TreeSpecSchema(None, None, []) + + if treespec.type not in SUPPORTED_SERIALIZED_TYPES: + raise NotImplementedError( + f"Serializing {treespec.type} in pytree is not registered.", + ) + + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type] + + serialized_type_name = serialize_node_def.serialized_type_name + + if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND: + raise NotImplementedError( + f"No registered serialization name for {treespec.type} found. " + "Please update your _register_pytree_node call with a `serialized_type_name` kwarg." + ) + + if serialize_node_def.to_dumpable_context is None: + try: + serialized_context = json.dumps(treespec.context) + except TypeError as e: + raise TypeError( + "Unable to serialize context. " + "Please make the context json dump-able, or register a " + "custom serializer using _register_pytree_node." + ) from e + else: + serialized_context = serialize_node_def.to_dumpable_context(treespec.context) + + child_schemas = [_treespec_to_json(child) for child in treespec.children_specs] + + return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas) + + +def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: + if ( + json_schema["type"] is None + and json_schema["context"] is None + and len(json_schema["children_spec"]) == 0 + ): + return _LEAF_SPEC + + if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: + raise NotImplementedError( + f'Deserializing {json_schema["type"]} in pytree is not registered.', + ) + + typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]] + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ] + + if serialize_node_def.from_dumpable_context is None: + try: + context = json.loads(json_schema["context"]) + except TypeError as ex: + raise TypeError( + "Unable to deserialize context. " + "Please make the context json load-able, or register a " + "custom serializer using _register_pytree_node.", + ) from ex + else: + context = serialize_node_def.from_dumpable_context(json_schema["context"]) + + children_specs = [] + for child_string in json_schema["children_spec"]: + children_specs.append(_json_to_treespec(child_string)) + + return TreeSpec(typ, context, children_specs) + + +_SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec) + + +def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: + if not isinstance(treespec, TreeSpec): + raise TypeError( + f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of " + f"TreeSpec but got item of type {type(treespec)}.", + ) + + if protocol is None: + protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL + + if protocol in _SUPPORTED_PROTOCOLS: + json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec) + else: + raise ValueError( + f"Unknown protocol {protocol}. " + f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + ) + + str_spec = json.dumps((protocol, dataclasses.asdict(json_spec))) + return str_spec + + +def treespec_loads(serialized: str) -> TreeSpec: + protocol, json_schema = json.loads(serialized) + + if protocol in _SUPPORTED_PROTOCOLS: + return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema) + raise ValueError( + f"Unknown protocol {protocol}. " + f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + ) + + +class _DummyLeaf: + def __repr__(self) -> str: + return "*" + + +def treespec_pprint(treespec: TreeSpec) -> str: + dummy_tree = tree_unflatten( + [_DummyLeaf() for _ in range(treespec.num_leaves)], + treespec, + ) + return repr(dummy_tree) + + +# TODO(angelayi): remove this function after OSS/internal stabilize +def pytree_to_str(treespec: TreeSpec) -> str: + warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps") + return treespec_dumps(treespec) + + +# TODO(angelayi): remove this function after OSS/internal stabilize +def str_to_pytree(json: str) -> TreeSpec: + warnings.warn("str_to_pytree is deprecated. Please use treespec_loads") + return treespec_loads(json) + + +def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> List[Any]: + """Get a flat list of arguments to this function + + A slightly faster version of tree_leaves((args, kwargs)) + """ + leaves: List[Any] = [] + for a in args: + leaves.extend(tree_iter(a)) + for a in kwargs.values(): + leaves.extend(tree_iter(a)) + return leaves + + +def tree_flatten_with_path( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]: + """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path. + + Args: + tree: a pytree to flatten. If it contains a custom type, that type must be + registered with an appropriate `tree_flatten_with_path_fn` when registered + with :func:`register_pytree_node`. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + Returns: + A tuple where the first element is a list of (key path, leaf) pairs, and the + second element is a :class:`TreeSpec` representing the structure of the flattened + tree. + """ + _, treespec = tree_flatten(tree, is_leaf) + return list(_generate_key_paths((), tree, is_leaf)), treespec + + +def tree_leaves_with_path( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> List[Tuple[KeyPath, Any]]: + """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. + + Args: + tree: a pytree. If it contains a custom type, that type must be + registered with an appropriate `tree_flatten_with_path_fn` when registered + with :func:`register_pytree_node`. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + Returns: + A list of (key path, leaf) pairs. + """ + return list(_generate_key_paths((), tree, is_leaf)) + + +def _generate_key_paths( + key_path: KeyPath, + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Iterable[Tuple[KeyPath, Any]]: + if is_leaf and is_leaf(tree): + yield key_path, tree + return + + node_type = _get_node_type(tree) + handler = SUPPORTED_NODES.get(node_type) + if not handler: + # This is a leaf + yield key_path, tree + return + + flatten_with_keys = handler.flatten_with_keys_fn + if flatten_with_keys: + key_children, _ = flatten_with_keys(tree) + for k, c in key_children: + yield from _generate_key_paths((*key_path, k), c, is_leaf) + else: + # We registered this pytree but didn't add a flatten_with_keys_fn, complain. + raise ValueError( + f"Did not find a flatten_with_keys_fn for type: {node_type}. " + "Please pass a flatten_with_keys_fn argument to register_pytree_node." + ) + + +def tree_map_with_path( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Like :func:`tree_map`, but the provided callable takes an additional key path argument. + + Args: + func: A function that takes ``2 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. The first positional argument + to ``func`` is the key path of the leaf in question. The second + positional argument is the value of the leaf. + tree: A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests: A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns + A new pytree with the same structure as ``tree`` but with the value at each leaf given by + ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the + corresponding leaf in ``tree``, ``x`` is the value at that leaf, and + ``xs`` is the tuple of values at corresponding nodes in ``rests``. + """ + keypath_leaves, treespec = tree_flatten_with_path(tree, is_leaf) + keypath_leaves = list(zip(*keypath_leaves)) + all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests] + return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves)) + + +def keystr(kp: KeyPath) -> str: + """Given a key path, return a pretty-printed representation.""" + return "".join([str(k) for k in kp]) + + +def key_get(obj: Any, kp: KeyPath) -> Any: + """Given an object and a key path, return the value at the key path.""" + for k in kp: + obj = k.get(obj) + return obj From 6f6528c27c02393d3f28803e5fdd67ef8e13dfea Mon Sep 17 00:00:00 2001 From: yhl48 Date: Wed, 24 Apr 2024 16:13:08 +0100 Subject: [PATCH 02/14] imported pytree functions from litdata.utilities._pytree --- src/litdata/processing/data_processor.py | 4 +--- src/litdata/processing/functions.py | 4 +--- src/litdata/streaming/config.py | 4 +--- src/litdata/streaming/dataloader.py | 4 +--- src/litdata/streaming/item_loader.py | 4 +--- src/litdata/streaming/writer.py | 4 +--- 6 files changed, 6 insertions(+), 18 deletions(-) diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 26d38ad0a..58f0e359c 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -51,15 +51,13 @@ from litdata.streaming.resolver import _resolve_dir from litdata.utilities.broadcast import broadcast_object from litdata.utilities.packing import _pack_greedily +from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads _TQDM_AVAILABLE = RequirementCache("tqdm") if _TQDM_AVAILABLE: from tqdm.auto import tqdm as _tqdm -if _TORCH_GREATER_EQUAL_2_1_0: - from torch.utils._pytree import tree_flatten, tree_unflatten, treespec_loads - if _LIGHTNING_CLOUD_AVAILABLE: from lightning_cloud.openapi import V1DatasetType diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 391a5ea1e..ea974b627 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -34,9 +34,7 @@ _execute, _resolve_dir, ) - -if _TORCH_GREATER_EQUAL_2_1_0: - from torch.utils._pytree import tree_flatten +from litdata.utilities._pytree import tree_flatten def _get_indexed_paths(data: Any) -> Dict[int, str]: diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index 23402df55..9670abe17 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -21,9 +21,7 @@ from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader from litdata.streaming.sampler import ChunkedIndex from litdata.streaming.serializers import Serializer - -if _TORCH_GREATER_EQUAL_2_1_0: - from torch.utils._pytree import tree_unflatten, treespec_loads +from litdata.utilities._pytree import tree_unflatten, treespec_loads class ChunksConfig: diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 72c360d16..d69f0a042 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -43,9 +43,7 @@ from litdata.streaming.dataset import StreamingDataset from litdata.streaming.sampler import CacheBatchSampler from litdata.utilities.env import _DistributedEnv - -if _TORCH_GREATER_EQUAL_2_1_0: - from torch.utils._pytree import tree_flatten +from litdata.utilities._pytree import tree_unflatten logger = logging.Logger(__name__) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 04b9b2b8c..f2c462fd1 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -25,9 +25,7 @@ _TORCH_GREATER_EQUAL_2_1_0, ) from litdata.streaming.serializers import Serializer - -if _TORCH_GREATER_EQUAL_2_1_0: - from torch.utils._pytree import PyTree, tree_unflatten +from litdata.utilities._pytree import PyTree, tree_unflatten class BaseItemLoader(ABC): diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index 42d381010..92ecbcfb6 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -27,9 +27,7 @@ from litdata.streaming.serializers import Serializer, _get_serializers from litdata.utilities.env import _DistributedEnv, _WorkerEnv from litdata.utilities.format import _convert_bytes_to_int, _human_readable_bytes - -if _TORCH_GREATER_EQUAL_2_1_0: - from torch.utils._pytree import PyTree, tree_flatten, treespec_dumps +from litdata.utilities._pytree import PyTree, tree_flatten, treespec_dumps @dataclass From 58eb24c07adb4cacfcfefe1f9c57eec24dbeacfd Mon Sep 17 00:00:00 2001 From: yhl48 Date: Wed, 24 Apr 2024 16:13:53 +0100 Subject: [PATCH 03/14] removed error message for using torch<2.1.0 --- src/litdata/streaming/cache.py | 3 --- 1 file changed, 3 deletions(-) diff --git a/src/litdata/streaming/cache.py b/src/litdata/streaming/cache.py index 5d00b97e2..2c068ab20 100644 --- a/src/litdata/streaming/cache.py +++ b/src/litdata/streaming/cache.py @@ -56,9 +56,6 @@ def __init__( """ super().__init__() - if not _TORCH_GREATER_EQUAL_2_1_0: - raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.") - input_dir = _resolve_dir(input_dir) self._cache_dir = input_dir.path assert self._cache_dir From aeabb55f9f8faa296cecc131d9323df5d889c96e Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Apr 2024 16:38:08 +0100 Subject: [PATCH 04/14] update --- pyproject.toml | 3 +- src/litdata/processing/data_processor.py | 3 +- src/litdata/processing/functions.py | 2 +- src/litdata/streaming/cache.py | 1 - src/litdata/streaming/config.py | 2 +- src/litdata/streaming/dataloader.py | 4 +- src/litdata/streaming/item_loader.py | 1 - src/litdata/streaming/writer.py | 4 +- src/litdata/utilities/_pytree.py | 193 +++++++++-------------- 9 files changed, 83 insertions(+), 130 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6b07d5d84..63c61d93e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,8 @@ lint.ignore = [ exclude = [ ".git", "docs", - "_notebooks" + "_notebooks", + "src/litdata/utilities/_pytree.py", ] lint.ignore-init-module-imports = true diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 58f0e359c..c8a79ca92 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -39,7 +39,6 @@ _INDEX_FILENAME, _IS_IN_STUDIO, _LIGHTNING_CLOUD_AVAILABLE, - _TORCH_GREATER_EQUAL_2_1_0, ) from litdata.imports import RequirementCache from litdata.processing.readers import BaseReader, StreamingDataLoaderReader @@ -49,9 +48,9 @@ from litdata.streaming.client import S3Client from litdata.streaming.dataloader import StreamingDataLoader from litdata.streaming.resolver import _resolve_dir +from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads from litdata.utilities.broadcast import broadcast_object from litdata.utilities.packing import _pack_greedily -from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads _TQDM_AVAILABLE = RequirementCache("tqdm") diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index ea974b627..143874251 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -22,7 +22,7 @@ import torch -from litdata.constants import _IS_IN_STUDIO, _TORCH_GREATER_EQUAL_2_1_0 +from litdata.constants import _IS_IN_STUDIO from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from litdata.processing.readers import BaseReader from litdata.processing.utilities import optimize_dns_context diff --git a/src/litdata/streaming/cache.py b/src/litdata/streaming/cache.py index 2c068ab20..a72abf8e5 100644 --- a/src/litdata/streaming/cache.py +++ b/src/litdata/streaming/cache.py @@ -17,7 +17,6 @@ from litdata.constants import ( _INDEX_FILENAME, - _TORCH_GREATER_EQUAL_2_1_0, ) from litdata.streaming.item_loader import BaseItemLoader from litdata.streaming.reader import BinaryReader diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index 9670abe17..51a0df21e 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -15,7 +15,7 @@ import os from typing import Any, Dict, List, Optional, Tuple -from litdata.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0 +from litdata.constants import _INDEX_FILENAME from litdata.streaming.compression import _COMPRESSORS, Compressor from litdata.streaming.downloader import get_downloader_cls from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index d69f0a042..9eabdef34 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -33,7 +33,7 @@ ) from torch.utils.data.sampler import BatchSampler, Sampler -from litdata.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE +from litdata.constants import _DEFAULT_CHUNK_BYTES, _VIZ_TRACKER_AVAILABLE from litdata.streaming import Cache from litdata.streaming.combined import ( __NUM_SAMPLES_YIELDED_KEY__, @@ -42,8 +42,8 @@ ) from litdata.streaming.dataset import StreamingDataset from litdata.streaming.sampler import CacheBatchSampler +from litdata.utilities._pytree import tree_flatten from litdata.utilities.env import _DistributedEnv -from litdata.utilities._pytree import tree_unflatten logger = logging.Logger(__name__) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index f2c462fd1..216f76b17 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -22,7 +22,6 @@ from litdata.constants import ( _TORCH_DTYPES_MAPPING, - _TORCH_GREATER_EQUAL_2_1_0, ) from litdata.streaming.serializers import Serializer from litdata.utilities._pytree import PyTree, tree_unflatten diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index 92ecbcfb6..21809a62c 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -21,13 +21,13 @@ import numpy as np import torch -from litdata.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0 +from litdata.constants import _INDEX_FILENAME from litdata.processing.utilities import get_worker_rank from litdata.streaming.compression import _COMPRESSORS, Compressor from litdata.streaming.serializers import Serializer, _get_serializers +from litdata.utilities._pytree import PyTree, tree_flatten, treespec_dumps from litdata.utilities.env import _DistributedEnv, _WorkerEnv from litdata.utilities.format import _convert_bytes_to_int, _human_readable_bytes -from litdata.utilities._pytree import PyTree, tree_flatten, treespec_dumps @dataclass diff --git a/src/litdata/utilities/_pytree.py b/src/litdata/utilities/_pytree.py index 14b6115b8..c5cca32a8 100644 --- a/src/litdata/utilities/_pytree.py +++ b/src/litdata/utilities/_pytree.py @@ -25,11 +25,10 @@ import threading import types import warnings -from collections import defaultdict, deque, namedtuple, OrderedDict +from collections import OrderedDict, defaultdict, deque, namedtuple from typing import ( Any, Callable, - cast, DefaultDict, Deque, Dict, @@ -41,16 +40,18 @@ Mapping, NamedTuple, Optional, - OrderedDict as GenericOrderedDict, - overload, Protocol, Sequence, Tuple, Type, TypeVar, Union, + cast, + overload, +) +from typing import ( + OrderedDict as GenericOrderedDict, ) - __all__ = [ "PyTree", @@ -98,17 +99,13 @@ class KeyEntry(Protocol): - def __hash__(self) -> int: - ... + def __hash__(self) -> int: ... - def __eq__(self, other: object) -> bool: - ... + def __eq__(self, other: object) -> bool: ... - def __str__(self) -> str: - ... + def __str__(self) -> str: ... - def get(self, parent: Any) -> Any: - ... + def get(self, parent: Any) -> Any: ... Context = Any @@ -195,6 +192,7 @@ def register_pytree_node( access each pytree leaf's keypath when flattening and tree-mapping. Like ``flatten_fn``, but in place of a List[leaf], it should return a List[(keypath, leaf)]. + """ with _NODE_REGISTRY_LOCK: if cls in SUPPORTED_NODES: @@ -230,10 +228,8 @@ def _register_namedtuple( *, serialized_type_name: str, ) -> None: - """ - Registers a namedtuple as a valid pytree node. By default namedtuples are - valid pytree nodes, but they are not serializable. This API provides the - argument `serialized_type_name` which allows these namedtuples to be + """Registers a namedtuple as a valid pytree node. By default namedtuples are valid pytree nodes, but they are not + serializable. This API provides the argument `serialized_type_name` which allows these namedtuples to be serialized. Args: @@ -241,6 +237,7 @@ def _register_namedtuple( serialized_type_name: The serialized name for the dataclass. This is required if you want to serialize the pytree TreeSpec containing this namedtuple. + """ _private_register_pytree_node( cls, @@ -288,6 +285,7 @@ def _register_pytree_node( access each pytree leaf's keypath when flattening and tree-mapping. Like ``flatten_fn``, but in place of a List[leaf], it should return a List[(keypath, leaf)]. + """ warnings.warn( "torch.utils._pytree._register_pytree_node is deprecated. " @@ -322,16 +320,17 @@ def _private_register_pytree_node( from_dumpable_context: Optional[FromDumpableContextFn] = None, flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, ) -> None: - """This is an internal function that is used to register a pytree node type - for the Python pytree only. End-users should use :func:`register_pytree_node` + """This is an internal function that is used to register a pytree node type for the Python pytree only. + + End-users should use :func:`register_pytree_node` instead. + """ with _NODE_REGISTRY_LOCK: if cls in SUPPORTED_NODES: # TODO: change this warning to an error after OSS/internal stabilize warnings.warn( - f"{cls} is already registered as pytree node. " - "Overwriting the previous registration.", + f"{cls} is already registered as pytree node. " "Overwriting the previous registration.", ) node_def = NodeDef(cls, flatten_fn, unflatten_fn, flatten_with_keys_fn) @@ -339,8 +338,7 @@ def _private_register_pytree_node( if (to_dumpable_context is None) ^ (from_dumpable_context is None): raise ValueError( - f"Both to_dumpable_context and from_dumpable_context for {cls} must " - "be None or registered." + f"Both to_dumpable_context and from_dumpable_context for {cls} must " "be None or registered." ) if serialized_type_name is None: @@ -396,9 +394,7 @@ def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]: return list(d), None -def _tuple_flatten_with_keys( - d: Tuple[Any, ...] -) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: +def _tuple_flatten_with_keys(d: Tuple[Any, ...]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: values, context = _tuple_flatten(d) return [(SequenceKey(i), v) for i, v in enumerate(values)], context @@ -424,9 +420,7 @@ def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: return list(d.values()), list(d.keys()) -def _dict_flatten_with_keys( - d: Dict[Any, Any] -) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: +def _dict_flatten_with_keys(d: Dict[Any, Any]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: values, context = _dict_flatten(d) return [(MappingKey(k), v) for k, v in zip(context, values)], context @@ -488,9 +482,7 @@ def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Co return list(d.values()), list(d.keys()) -def _ordereddict_flatten_with_keys( - d: GenericOrderedDict[Any, Any] -) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: +def _ordereddict_flatten_with_keys(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: values, context = _ordereddict_flatten(d) return [(MappingKey(k), v) for k, v in zip(context, values)], context @@ -511,9 +503,7 @@ def _defaultdict_flatten(d: DefaultDict[Any, Any]) -> Tuple[List[Any], Context]: return values, [d.default_factory, dict_context] -def _defaultdict_flatten_with_keys( - d: DefaultDict[Any, Any] -) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: +def _defaultdict_flatten_with_keys(d: DefaultDict[Any, Any]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: values, context = _defaultdict_flatten(d) _, dict_context = context return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context @@ -654,9 +644,7 @@ def _get_node_type(tree: Any) -> Any: # A leaf is defined as anything that is not a Node. def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool: - return (is_leaf is not None and is_leaf(tree)) or _get_node_type( - tree - ) not in SUPPORTED_NODES + return (is_leaf is not None and is_leaf(tree)) or _get_node_type(tree) not in SUPPORTED_NODES # A TreeSpec represents the structure of a pytree. It holds: @@ -687,10 +675,7 @@ def __repr__(self, indent: int = 0) -> str: children_specs_str += self.children_specs[0].__repr__(indent) children_specs_str += "," if self.num_children > 1 else "" children_specs_str += ",".join( - [ - "\n" + " " * indent + child.__repr__(indent) - for child in self.children_specs[1:] - ] + ["\n" + " " * indent + child.__repr__(indent) for child in self.children_specs[1:]] ) repr_suffix: str = f"{children_specs_str}])" return repr_prefix + repr_suffix @@ -708,15 +693,13 @@ def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None: # Always require custom node types to match exactly if node_type != self.type: raise ValueError( - f"Type mismatch; " - f"expected {self.type!r}, but got {node_type!r}.", + f"Type mismatch; " f"expected {self.type!r}, but got {node_type!r}.", ) flatten_fn = SUPPORTED_NODES[node_type].flatten_fn child_pytrees, context = flatten_fn(tree) if len(child_pytrees) != self.num_children: raise ValueError( - f"Node arity mismatch; " - f"expected {self.num_children}, but got {len(child_pytrees)}.", + f"Node arity mismatch; " f"expected {self.num_children}, but got {len(child_pytrees)}.", ) if context != self.context: raise ValueError( @@ -725,18 +708,14 @@ def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None: else: # For builtin dictionary types, we allow some flexibility # Otherwise, we require exact matches - both_standard_dict = ( - self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES - ) + both_standard_dict = self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES if node_type != self.type and not both_standard_dict: raise ValueError( - f"Node type mismatch; " - f"expected {self.type!r}, but got {node_type!r}.", + f"Node type mismatch; " f"expected {self.type!r}, but got {node_type!r}.", ) if len(tree) != self.num_children: raise ValueError( - f"Node arity mismatch; " - f"expected {self.num_children}, but got {len(tree)}.", + f"Node arity mismatch; " f"expected {self.num_children}, but got {len(tree)}.", ) if both_standard_dict: # dictionary types are compatible with each other @@ -763,8 +742,7 @@ def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None: flatten_fn = SUPPORTED_NODES[node_type].flatten_fn child_pytrees, context = flatten_fn(tree) if ( - context != self.context - and self.type is not deque # ignore mismatch of `maxlen` for deque + context != self.context and self.type is not deque # ignore mismatch of `maxlen` for deque ): raise ValueError( f"Node context mismatch for node type {self.type!r}; " @@ -837,9 +815,7 @@ def _tree_flatten_helper( child_pytrees, context = flatten_fn(tree) # Recursively flatten the children - children_specs = [ - _tree_flatten_helper(child, leaves, is_leaf=is_leaf) for child in child_pytrees - ] + children_specs = [_tree_flatten_helper(child, leaves, is_leaf=is_leaf) for child in child_pytrees] return TreeSpec(node_type, context, children_specs) @@ -848,9 +824,7 @@ def tree_flatten( tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, ) -> Tuple[List[Any], TreeSpec]: - """Flattens a pytree into a list of values and a TreeSpec that can be used - to reconstruct the pytree. - """ + """Flattens a pytree into a list of values and a TreeSpec that can be used to reconstruct the pytree.""" leaves: List[Any] = [] spec = _tree_flatten_helper(tree, leaves, is_leaf=is_leaf) return leaves, spec @@ -858,7 +832,9 @@ def tree_flatten( def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: """Given a list of values and a TreeSpec, builds a pytree. + This is the inverse operation of `tree_flatten`. + """ if not isinstance(treespec, TreeSpec): raise TypeError( @@ -939,6 +915,7 @@ def tree_map( A new pytree with the same structure as ``tree`` but with the value at each leaf given by ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs`` is the tuple of values at corresponding nodes in ``rests``. + """ leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] @@ -972,6 +949,7 @@ def tree_map_( The original ``tree`` with the value at each leaf is given by the side-effect of function ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``. + """ leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] @@ -997,37 +975,29 @@ def tree_map_( # These specializations help with type inference on the lambda passed to this # function @overload -def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]: - ... +def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]: ... @overload -def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]: - ... +def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]: ... @overload -def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]: - ... +def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]: ... # This specialization is needed for the implementations below that call @overload -def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]: - ... +def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]: ... @overload -def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]: - ... +def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]: ... -def map_only( - __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]] -) -> MapOnlyFn[FnAny[Any]]: - """ - Suppose you are writing a tree_map over tensors, leaving everything - else unchanged. Ordinarily you would have to write: +def map_only(__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]) -> MapOnlyFn[FnAny[Any]]: + """Suppose you are writing a tree_map over tensors, leaving everything else unchanged. Ordinarily you would have + to write: def go(t): if isinstance(t, Tensor): @@ -1042,10 +1012,10 @@ def go(t): return ... You can also directly use 'tree_map_only' + """ if isinstance(__type_or_types_or_pred, (type, tuple)) or ( - sys.version_info >= (3, 10) - and isinstance(__type_or_types_or_pred, types.UnionType) + sys.version_info >= (3, 10) and isinstance(__type_or_types_or_pred, types.UnionType) ): def pred(x: Any) -> bool: @@ -1074,8 +1044,7 @@ def tree_map_only( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1084,8 +1053,7 @@ def tree_map_only( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1094,8 +1062,7 @@ def tree_map_only( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1104,8 +1071,7 @@ def tree_map_only( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... def tree_map_only( @@ -1123,8 +1089,7 @@ def tree_map_only_( func: Fn[T, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1133,8 +1098,7 @@ def tree_map_only_( func: Fn2[T, S, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1143,8 +1107,7 @@ def tree_map_only_( func: Fn3[T, S, U, Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... @overload @@ -1153,8 +1116,7 @@ def tree_map_only_( func: FnAny[Any], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> PyTree: - ... +) -> PyTree: ... def tree_map_only_( @@ -1190,8 +1152,7 @@ def tree_all_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1200,8 +1161,7 @@ def tree_all_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1210,8 +1170,7 @@ def tree_all_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... def tree_all_only( @@ -1230,8 +1189,7 @@ def tree_any_only( pred: Fn[T, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1240,8 +1198,7 @@ def tree_any_only( pred: Fn2[T, S, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... @overload @@ -1250,8 +1207,7 @@ def tree_any_only( pred: Fn3[T, S, U, bool], tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None, -) -> bool: - ... +) -> bool: ... def tree_any_only( @@ -1308,12 +1264,13 @@ def _broadcast_to_and_flatten( @dataclasses.dataclass class _TreeSpecSchema: - """ - _TreeSpecSchema is the schema used to serialize the TreeSpec + """_TreeSpecSchema is the schema used to serialize the TreeSpec. + It contains the following fields: - type: A string name of the type. null for the case of a LeafSpec. - context: Any format which is json dumpable - children_spec: A list of children serialized specs. + """ type: Optional[str] @@ -1366,11 +1323,7 @@ def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: - if ( - json_schema["type"] is None - and json_schema["context"] is None - and len(json_schema["children_spec"]) == 0 - ): + if json_schema["type"] is None and json_schema["context"] is None and len(json_schema["children_spec"]) == 0: return _LEAF_SPEC if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: @@ -1417,8 +1370,7 @@ def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec) else: raise ValueError( - f"Unknown protocol {protocol}. " - f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + f"Unknown protocol {protocol}. " f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", ) str_spec = json.dumps((protocol, dataclasses.asdict(json_spec))) @@ -1431,8 +1383,7 @@ def treespec_loads(serialized: str) -> TreeSpec: if protocol in _SUPPORTED_PROTOCOLS: return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema) raise ValueError( - f"Unknown protocol {protocol}. " - f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + f"Unknown protocol {protocol}. " f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", ) @@ -1462,9 +1413,10 @@ def str_to_pytree(json: str) -> TreeSpec: def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> List[Any]: - """Get a flat list of arguments to this function + """Get a flat list of arguments to this function. A slightly faster version of tree_leaves((args, kwargs)) + """ leaves: List[Any] = [] for a in args: @@ -1493,6 +1445,7 @@ def tree_flatten_with_path( A tuple where the first element is a list of (key path, leaf) pairs, and the second element is a :class:`TreeSpec` representing the structure of the flattened tree. + """ _, treespec = tree_flatten(tree, is_leaf) return list(_generate_key_paths((), tree, is_leaf)), treespec @@ -1515,6 +1468,7 @@ def tree_leaves_with_path( leaf or not. If the function is not specified, the default pytree registry will be used. Returns: A list of (key path, leaf) pairs. + """ return list(_generate_key_paths((), tree, is_leaf)) @@ -1576,6 +1530,7 @@ def tree_map_with_path( ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the corresponding leaf in ``tree``, ``x`` is the value at that leaf, and ``xs`` is the tuple of values at corresponding nodes in ``rests``. + """ keypath_leaves, treespec = tree_flatten_with_path(tree, is_leaf) keypath_leaves = list(zip(*keypath_leaves)) From ab7db4156c92c8c88a5e6a42c1d6390a1e27955c Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Apr 2024 16:40:55 +0100 Subject: [PATCH 05/14] update --- .github/workflows/ci-testing.yml | 2 +- status.json | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) create mode 100644 status.json diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index ff1d9cab9..4383ab3f0 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -17,7 +17,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macOS-latest, windows-latest] - python-version: [3.9] + python-version: [3.10] requires: ["oldest", "latest"] # Timeout: https://stackoverflow.com/a/59076067/4521646 diff --git a/status.json b/status.json new file mode 100644 index 000000000..72aa54955 --- /dev/null +++ b/status.json @@ -0,0 +1 @@ +{ "progress": "20.0%" } From d68ed0dce3d9e9ab45606fdfecb24c82a089b8ba Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Apr 2024 16:42:56 +0100 Subject: [PATCH 06/14] update --- .github/workflows/ci-testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 4383ab3f0..a0843e878 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -17,7 +17,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macOS-latest, windows-latest] - python-version: [3.10] + python-version: [3.11] requires: ["oldest", "latest"] # Timeout: https://stackoverflow.com/a/59076067/4521646 From ddcb442d57d5685fac2f8c5938fb8fcacbc2b46d Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Apr 2024 16:43:22 +0100 Subject: [PATCH 07/14] update --- .github/workflows/ci-testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index a0843e878..d452a0419 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -17,7 +17,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macOS-latest, windows-latest] - python-version: [3.11] + python-version: [3.10.10] requires: ["oldest", "latest"] # Timeout: https://stackoverflow.com/a/59076067/4521646 From 44026b234dab6307f0b745c72551105a05367509 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Apr 2024 16:45:46 +0100 Subject: [PATCH 08/14] update --- .github/workflows/ci-testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index d452a0419..ad080b9b1 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -17,7 +17,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macOS-latest, windows-latest] - python-version: [3.10.10] + python-version: ["3.10.10"] requires: ["oldest", "latest"] # Timeout: https://stackoverflow.com/a/59076067/4521646 From 5f3320e5f065428af82ad3de42a740a266fd8a98 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Apr 2024 16:47:10 +0100 Subject: [PATCH 09/14] update --- .github/workflows/ci-testing.yml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index ad080b9b1..ffd1521a3 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -16,8 +16,8 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macOS-latest, windows-latest] - python-version: ["3.10.10"] + os: [ubuntu-latest, macOS, windows-latest] + python-version: [3.9] requires: ["oldest", "latest"] # Timeout: https://stackoverflow.com/a/59076067/4521646 From e9aa9e41e2be789d3ca3f4e8d58d8238bb073946 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Apr 2024 16:48:33 +0100 Subject: [PATCH 10/14] update --- .github/workflows/ci-testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index ffd1521a3..ff1d9cab9 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macOS, windows-latest] + os: [ubuntu-latest, macOS-latest, windows-latest] python-version: [3.9] requires: ["oldest", "latest"] From 03eab4def4c568e640396d426c9a7ed503362223 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Apr 2024 16:49:40 +0100 Subject: [PATCH 11/14] update --- .github/workflows/ci-testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index ff1d9cab9..873400c9d 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macOS-latest, windows-latest] + os: [ubuntu-latest, windows-latest] python-version: [3.9] requires: ["oldest", "latest"] From 695bb86aac6809ff8e7cb38d542194e7c3f8ea69 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Apr 2024 16:52:16 +0100 Subject: [PATCH 12/14] update --- .github/workflows/ci-checks.yml | 4 ++-- .github/workflows/ci-testing.yml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index c7cfaea25..d0af1db36 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -30,8 +30,8 @@ jobs: artifact-name: dist-packages-${{ github.sha }} testing-matrix: | { - "os": ["ubuntu-latest", "macos-latest", "windows-latest"], - "python-version": ["3.8", "3.10"] + "os": ["ubuntu-latest", "windows-latest"], + "python-version": ["3.10"] } check-docs: diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 873400c9d..45b2cf0fe 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -17,7 +17,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, windows-latest] - python-version: [3.9] + python-version: [3.10] requires: ["oldest", "latest"] # Timeout: https://stackoverflow.com/a/59076067/4521646 From 505e023b33c84ba962ebb3510abd254a8dd68714 Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Apr 2024 16:53:40 +0100 Subject: [PATCH 13/14] udpate --- .github/workflows/ci-checks.yml | 4 ++-- .github/workflows/ci-testing.yml | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index d0af1db36..8dd2315a2 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -30,8 +30,8 @@ jobs: artifact-name: dist-packages-${{ github.sha }} testing-matrix: | { - "os": ["ubuntu-latest", "windows-latest"], - "python-version": ["3.10"] + "os": ["ubuntu-latest", "macos-13", "windows-latest"], + "python-version": ["3.8", "3.10"] } check-docs: diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index 45b2cf0fe..c1f7cd091 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, windows-latest] + os: [ubuntu-latest, macos-13, windows-latest] python-version: [3.10] requires: ["oldest", "latest"] From 5ebdb98ac820d52467e28750536e18f075af7e9a Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 24 Apr 2024 16:54:48 +0100 Subject: [PATCH 14/14] update --- .github/workflows/ci-testing.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index c1f7cd091..b6f0d707b 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -17,7 +17,7 @@ jobs: fail-fast: false matrix: os: [ubuntu-latest, macos-13, windows-latest] - python-version: [3.10] + python-version: [3.9] requires: ["oldest", "latest"] # Timeout: https://stackoverflow.com/a/59076067/4521646