diff --git a/.github/workflows/ci-checks.yml b/.github/workflows/ci-checks.yml index c7cfaea25..8dd2315a2 100644 --- a/.github/workflows/ci-checks.yml +++ b/.github/workflows/ci-checks.yml @@ -30,7 +30,7 @@ jobs: artifact-name: dist-packages-${{ github.sha }} testing-matrix: | { - "os": ["ubuntu-latest", "macos-latest", "windows-latest"], + "os": ["ubuntu-latest", "macos-13", "windows-latest"], "python-version": ["3.8", "3.10"] } diff --git a/.github/workflows/ci-testing.yml b/.github/workflows/ci-testing.yml index ff1d9cab9..b6f0d707b 100644 --- a/.github/workflows/ci-testing.yml +++ b/.github/workflows/ci-testing.yml @@ -16,7 +16,7 @@ jobs: strategy: fail-fast: false matrix: - os: [ubuntu-latest, macOS-latest, windows-latest] + os: [ubuntu-latest, macos-13, windows-latest] python-version: [3.9] requires: ["oldest", "latest"] diff --git a/pyproject.toml b/pyproject.toml index 6b07d5d84..63c61d93e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,7 +72,8 @@ lint.ignore = [ exclude = [ ".git", "docs", - "_notebooks" + "_notebooks", + "src/litdata/utilities/_pytree.py", ] lint.ignore-init-module-imports = true diff --git a/src/litdata/processing/data_processor.py b/src/litdata/processing/data_processor.py index 26d38ad0a..c8a79ca92 100644 --- a/src/litdata/processing/data_processor.py +++ b/src/litdata/processing/data_processor.py @@ -39,7 +39,6 @@ _INDEX_FILENAME, _IS_IN_STUDIO, _LIGHTNING_CLOUD_AVAILABLE, - _TORCH_GREATER_EQUAL_2_1_0, ) from litdata.imports import RequirementCache from litdata.processing.readers import BaseReader, StreamingDataLoaderReader @@ -49,6 +48,7 @@ from litdata.streaming.client import S3Client from litdata.streaming.dataloader import StreamingDataLoader from litdata.streaming.resolver import _resolve_dir +from litdata.utilities._pytree import tree_flatten, tree_unflatten, treespec_loads from litdata.utilities.broadcast import broadcast_object from litdata.utilities.packing import _pack_greedily @@ -57,9 +57,6 @@ if _TQDM_AVAILABLE: from tqdm.auto import tqdm as _tqdm -if _TORCH_GREATER_EQUAL_2_1_0: - from torch.utils._pytree import tree_flatten, tree_unflatten, treespec_loads - if _LIGHTNING_CLOUD_AVAILABLE: from lightning_cloud.openapi import V1DatasetType diff --git a/src/litdata/processing/functions.py b/src/litdata/processing/functions.py index 391a5ea1e..143874251 100644 --- a/src/litdata/processing/functions.py +++ b/src/litdata/processing/functions.py @@ -22,7 +22,7 @@ import torch -from litdata.constants import _IS_IN_STUDIO, _TORCH_GREATER_EQUAL_2_1_0 +from litdata.constants import _IS_IN_STUDIO from litdata.processing.data_processor import DataChunkRecipe, DataProcessor, DataTransformRecipe from litdata.processing.readers import BaseReader from litdata.processing.utilities import optimize_dns_context @@ -34,9 +34,7 @@ _execute, _resolve_dir, ) - -if _TORCH_GREATER_EQUAL_2_1_0: - from torch.utils._pytree import tree_flatten +from litdata.utilities._pytree import tree_flatten def _get_indexed_paths(data: Any) -> Dict[int, str]: diff --git a/src/litdata/streaming/cache.py b/src/litdata/streaming/cache.py index 5d00b97e2..a72abf8e5 100644 --- a/src/litdata/streaming/cache.py +++ b/src/litdata/streaming/cache.py @@ -17,7 +17,6 @@ from litdata.constants import ( _INDEX_FILENAME, - _TORCH_GREATER_EQUAL_2_1_0, ) from litdata.streaming.item_loader import BaseItemLoader from litdata.streaming.reader import BinaryReader @@ -56,9 +55,6 @@ def __init__( """ super().__init__() - if not _TORCH_GREATER_EQUAL_2_1_0: - raise ModuleNotFoundError("PyTorch version 2.1 or higher is required to use the cache.") - input_dir = _resolve_dir(input_dir) self._cache_dir = input_dir.path assert self._cache_dir diff --git a/src/litdata/streaming/config.py b/src/litdata/streaming/config.py index 23402df55..51a0df21e 100644 --- a/src/litdata/streaming/config.py +++ b/src/litdata/streaming/config.py @@ -15,15 +15,13 @@ import os from typing import Any, Dict, List, Optional, Tuple -from litdata.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0 +from litdata.constants import _INDEX_FILENAME from litdata.streaming.compression import _COMPRESSORS, Compressor from litdata.streaming.downloader import get_downloader_cls from litdata.streaming.item_loader import BaseItemLoader, PyTreeLoader, TokensLoader from litdata.streaming.sampler import ChunkedIndex from litdata.streaming.serializers import Serializer - -if _TORCH_GREATER_EQUAL_2_1_0: - from torch.utils._pytree import tree_unflatten, treespec_loads +from litdata.utilities._pytree import tree_unflatten, treespec_loads class ChunksConfig: diff --git a/src/litdata/streaming/dataloader.py b/src/litdata/streaming/dataloader.py index 72c360d16..9eabdef34 100644 --- a/src/litdata/streaming/dataloader.py +++ b/src/litdata/streaming/dataloader.py @@ -33,7 +33,7 @@ ) from torch.utils.data.sampler import BatchSampler, Sampler -from litdata.constants import _DEFAULT_CHUNK_BYTES, _TORCH_GREATER_EQUAL_2_1_0, _VIZ_TRACKER_AVAILABLE +from litdata.constants import _DEFAULT_CHUNK_BYTES, _VIZ_TRACKER_AVAILABLE from litdata.streaming import Cache from litdata.streaming.combined import ( __NUM_SAMPLES_YIELDED_KEY__, @@ -42,11 +42,9 @@ ) from litdata.streaming.dataset import StreamingDataset from litdata.streaming.sampler import CacheBatchSampler +from litdata.utilities._pytree import tree_flatten from litdata.utilities.env import _DistributedEnv -if _TORCH_GREATER_EQUAL_2_1_0: - from torch.utils._pytree import tree_flatten - logger = logging.Logger(__name__) diff --git a/src/litdata/streaming/item_loader.py b/src/litdata/streaming/item_loader.py index 04b9b2b8c..216f76b17 100644 --- a/src/litdata/streaming/item_loader.py +++ b/src/litdata/streaming/item_loader.py @@ -22,12 +22,9 @@ from litdata.constants import ( _TORCH_DTYPES_MAPPING, - _TORCH_GREATER_EQUAL_2_1_0, ) from litdata.streaming.serializers import Serializer - -if _TORCH_GREATER_EQUAL_2_1_0: - from torch.utils._pytree import PyTree, tree_unflatten +from litdata.utilities._pytree import PyTree, tree_unflatten class BaseItemLoader(ABC): diff --git a/src/litdata/streaming/writer.py b/src/litdata/streaming/writer.py index 42d381010..21809a62c 100644 --- a/src/litdata/streaming/writer.py +++ b/src/litdata/streaming/writer.py @@ -21,16 +21,14 @@ import numpy as np import torch -from litdata.constants import _INDEX_FILENAME, _TORCH_GREATER_EQUAL_2_1_0 +from litdata.constants import _INDEX_FILENAME from litdata.processing.utilities import get_worker_rank from litdata.streaming.compression import _COMPRESSORS, Compressor from litdata.streaming.serializers import Serializer, _get_serializers +from litdata.utilities._pytree import PyTree, tree_flatten, treespec_dumps from litdata.utilities.env import _DistributedEnv, _WorkerEnv from litdata.utilities.format import _convert_bytes_to_int, _human_readable_bytes -if _TORCH_GREATER_EQUAL_2_1_0: - from torch.utils._pytree import PyTree, tree_flatten, treespec_dumps - @dataclass class Item: diff --git a/src/litdata/utilities/_pytree.py b/src/litdata/utilities/_pytree.py new file mode 100644 index 000000000..c5cca32a8 --- /dev/null +++ b/src/litdata/utilities/_pytree.py @@ -0,0 +1,1550 @@ +""" +Code taken from https://github.com/pytorch/pytorch/blob/main/torch/utils/_pytree.py + +Contains utility functions for working with nested python data structures. + +A *pytree* is Python nested data structure. It is a tree in the sense that +nodes are Python collections (e.g., list, tuple, dict) and the leaves are +Python values. Furthermore, a pytree should not contain reference cycles. + +pytrees are useful for working with nested collections of Tensors. For example, +one can use `tree_map` to map a function over all Tensors inside some nested +collection of Tensors and `tree_leaves` to get a flat list of all Tensors +inside some nested collection. pytrees are helpful for implementing nested +collection support for PyTorch APIs. + +This pytree implementation is not very performant due to Python overhead +To improve the performance we can move parts of the implementation to C++. +""" + +import dataclasses +import functools +import importlib +import json +import sys +import threading +import types +import warnings +from collections import OrderedDict, defaultdict, deque, namedtuple +from typing import ( + Any, + Callable, + DefaultDict, + Deque, + Dict, + FrozenSet, + Generic, + Hashable, + Iterable, + List, + Mapping, + NamedTuple, + Optional, + Protocol, + Sequence, + Tuple, + Type, + TypeVar, + Union, + cast, + overload, +) +from typing import ( + OrderedDict as GenericOrderedDict, +) + +__all__ = [ + "PyTree", + "Context", + "FlattenFunc", + "UnflattenFunc", + "DumpableContext", + "ToDumpableContextFn", + "FromDumpableContextFn", + "TreeSpec", + "LeafSpec", + "keystr", + "key_get", + "register_pytree_node", + "tree_flatten", + "tree_flatten_with_path", + "tree_unflatten", + "tree_iter", + "tree_leaves", + "tree_leaves_with_path", + "tree_structure", + "tree_map", + "tree_map_with_path", + "tree_map_", + "tree_map_only", + "tree_map_only_", + "tree_all", + "tree_any", + "tree_all_only", + "tree_any_only", + "treespec_dumps", + "treespec_loads", + "treespec_pprint", +] + + +T = TypeVar("T") +S = TypeVar("S") +U = TypeVar("U") +R = TypeVar("R") + + +DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL = 1 +NO_SERIALIZED_TYPE_NAME_FOUND = "NO_SERIALIZED_TYPE_NAME_FOUND" + + +class KeyEntry(Protocol): + def __hash__(self) -> int: ... + + def __eq__(self, other: object) -> bool: ... + + def __str__(self) -> str: ... + + def get(self, parent: Any) -> Any: ... + + +Context = Any +PyTree = Any +FlattenFunc = Callable[[PyTree], Tuple[List[Any], Context]] +UnflattenFunc = Callable[[Iterable[Any], Context], PyTree] +DumpableContext = Any # Any json dumpable text +ToDumpableContextFn = Callable[[Context], DumpableContext] +FromDumpableContextFn = Callable[[DumpableContext], Context] +ToStrFunc = Callable[["TreeSpec", List[str]], str] +MaybeFromStrFunc = Callable[[str], Optional[Tuple[Any, Context, str]]] +KeyPath = Tuple[KeyEntry, ...] +FlattenWithKeysFunc = Callable[[PyTree], Tuple[List[Tuple[KeyEntry, Any]], Any]] + + +# A NodeDef holds two callables: +# - flatten_fn should take the collection and return a flat list of values. +# It can also return some context that is used in reconstructing the +# collection. +# - unflatten_fn should take a flat list of values and some context +# (returned by flatten_fn). It returns the collection by reconstructing +# it from the list and the context. +# - flatten_with_keys_fn, which is a callable that takes a +# pytree and returns a list of (keypath, value) pairs and a context. +class NodeDef(NamedTuple): + type: Type[Any] + flatten_fn: FlattenFunc + unflatten_fn: UnflattenFunc + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] + + +_NODE_REGISTRY_LOCK = threading.Lock() +SUPPORTED_NODES: Dict[Type[Any], NodeDef] = {} + + +# _SerializeNodeDef holds the following: +# - typ: the type of the node (e.g., "Dict", "List", etc) +# - serialized_type_name: the fully qualified name of the type, e.g. "collections.OrderedDict" +# - to_dumpable_context takes a TreeSpec, and returns a serialized string format of the +# context, and the version number +# - from_dumpable_context takes in a string representation of the context, and the +# version, and returns the deserialized context +class _SerializeNodeDef(NamedTuple): + typ: Type[Any] + serialized_type_name: str + to_dumpable_context: Optional[ToDumpableContextFn] + from_dumpable_context: Optional[FromDumpableContextFn] + + +SUPPORTED_SERIALIZED_TYPES: Dict[Type[Any], _SerializeNodeDef] = {} +SERIALIZED_TYPE_TO_PYTHON_TYPE: Dict[str, Type[Any]] = {} + + +def register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """Register a container-like type as pytree node. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in torch.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in torch.export right now. + flatten_with_keys_fn: An optional keyword argument to specify how to + access each pytree leaf's keypath when flattening and tree-mapping. + Like ``flatten_fn``, but in place of a List[leaf], it should return + a List[(keypath, leaf)]. + + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + raise ValueError(f"{cls} is already registered as pytree node.") + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + flatten_with_keys_fn=flatten_with_keys_fn, + ) + + try: + from . import _cxx_pytree as cxx + except ImportError: + pass + else: + cxx._private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + ) + + +def _register_namedtuple( + cls: Type[Any], + *, + serialized_type_name: str, +) -> None: + """Registers a namedtuple as a valid pytree node. By default namedtuples are valid pytree nodes, but they are not + serializable. This API provides the argument `serialized_type_name` which allows these namedtuples to be + serialized. + + Args: + cls: the dataclass type to register + serialized_type_name: The serialized name for the dataclass. This is + required if you want to serialize the pytree TreeSpec containing this + namedtuple. + + """ + _private_register_pytree_node( + cls, + _namedtuple_flatten, + _namedtuple_unflatten, + serialized_type_name=serialized_type_name, + to_dumpable_context=_namedtuple_serialize, + from_dumpable_context=_namedtuple_deserialize, + flatten_with_keys_fn=_namedtuple_flatten_with_keys, + ) + + +def _register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + to_str_fn: Optional[ToStrFunc] = None, # deprecated + maybe_from_str_fn: Optional[MaybeFromStrFunc] = None, # deprecated + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """Register a container-like type as pytree node for the Python pytree only. + + Args: + cls: the type to register + flatten_fn: A callable that takes a pytree and returns a flattened + representation of the pytree and additional context to represent the + flattened pytree. + unflatten_fn: A callable that takes a flattened version of the pytree, + additional context, and returns an unflattened pytree. + serialized_type_name: A keyword argument used to specify the fully qualified + name used when serializing the tree spec. + to_dumpable_context: An optional keyword argument to custom specify how + to convert the context of the pytree to a custom json dumpable + representation. This is used for json serialization, which is being + used in torch.export right now. + from_dumpable_context: An optional keyword argument to custom specify how + to convert the custom json dumpable representation of the context + back to the original context. This is used for json deserialization, + which is being used in torch.export right now. + flatten_with_keys_fn: An optional keyword argument to specify how to + access each pytree leaf's keypath when flattening and tree-mapping. + Like ``flatten_fn``, but in place of a List[leaf], it should return + a List[(keypath, leaf)]. + + """ + warnings.warn( + "torch.utils._pytree._register_pytree_node is deprecated. " + "Please use torch.utils._pytree.register_pytree_node instead.", + stacklevel=2, + ) + + if to_str_fn is not None or maybe_from_str_fn is not None: + warnings.warn( + "to_str_fn and maybe_from_str_fn is deprecated. " + "Please use to_dumpable_context and from_dumpable_context instead." + ) + + _private_register_pytree_node( + cls, + flatten_fn, + unflatten_fn, + serialized_type_name=serialized_type_name, + to_dumpable_context=to_dumpable_context, + from_dumpable_context=from_dumpable_context, + flatten_with_keys_fn=flatten_with_keys_fn, + ) + + +def _private_register_pytree_node( + cls: Type[Any], + flatten_fn: FlattenFunc, + unflatten_fn: UnflattenFunc, + *, + serialized_type_name: Optional[str] = None, + to_dumpable_context: Optional[ToDumpableContextFn] = None, + from_dumpable_context: Optional[FromDumpableContextFn] = None, + flatten_with_keys_fn: Optional[FlattenWithKeysFunc] = None, +) -> None: + """This is an internal function that is used to register a pytree node type for the Python pytree only. + + End-users should use :func:`register_pytree_node` + instead. + + """ + with _NODE_REGISTRY_LOCK: + if cls in SUPPORTED_NODES: + # TODO: change this warning to an error after OSS/internal stabilize + warnings.warn( + f"{cls} is already registered as pytree node. " "Overwriting the previous registration.", + ) + + node_def = NodeDef(cls, flatten_fn, unflatten_fn, flatten_with_keys_fn) + SUPPORTED_NODES[cls] = node_def + + if (to_dumpable_context is None) ^ (from_dumpable_context is None): + raise ValueError( + f"Both to_dumpable_context and from_dumpable_context for {cls} must " "be None or registered." + ) + + if serialized_type_name is None: + serialized_type_name = NO_SERIALIZED_TYPE_NAME_FOUND + + serialize_node_def = _SerializeNodeDef( + cls, + serialized_type_name, + to_dumpable_context, + from_dumpable_context, + ) + SUPPORTED_SERIALIZED_TYPES[cls] = serialize_node_def + SERIALIZED_TYPE_TO_PYTHON_TYPE[serialized_type_name] = cls + + +@dataclasses.dataclass(frozen=True) +class SequenceKey(Generic[T]): + idx: int + + def __str__(self) -> str: + return f"[{self.idx!r}]" + + def get(self, sequence: Sequence[T]) -> T: + return sequence[self.idx] + + +K = TypeVar("K", bound=Hashable) + + +@dataclasses.dataclass(frozen=True) +class MappingKey(Generic[K, T]): + key: K + + def __str__(self) -> str: + return f"[{self.key!r}]" + + def get(self, mapping: Mapping[K, T]) -> T: + return mapping[self.key] + + +@dataclasses.dataclass(frozen=True) +class GetAttrKey: + name: str + + def __str__(self) -> str: + return f".{self.name}" + + def get(self, obj: Any) -> Any: + return getattr(obj, self.name) + + +def _tuple_flatten(d: Tuple[Any, ...]) -> Tuple[List[Any], Context]: + return list(d), None + + +def _tuple_flatten_with_keys(d: Tuple[Any, ...]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _tuple_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _tuple_unflatten(values: Iterable[Any], context: Context) -> Tuple[Any, ...]: + return tuple(values) + + +def _list_flatten(d: List[Any]) -> Tuple[List[Any], Context]: + return d, None + + +def _list_flatten_with_keys(d: List[Any]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _list_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _list_unflatten(values: Iterable[Any], context: Context) -> List[Any]: + return list(values) + + +def _dict_flatten(d: Dict[Any, Any]) -> Tuple[List[Any], Context]: + return list(d.values()), list(d.keys()) + + +def _dict_flatten_with_keys(d: Dict[Any, Any]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _dict_flatten(d) + return [(MappingKey(k), v) for k, v in zip(context, values)], context + + +def _dict_unflatten(values: Iterable[Any], context: Context) -> Dict[Any, Any]: + return dict(zip(context, values)) + + +def _namedtuple_flatten(d: NamedTuple) -> Tuple[List[Any], Context]: + return list(d), type(d) + + +def _namedtuple_flatten_with_keys( + d: NamedTuple, +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _namedtuple_flatten(d) + return ( + [(GetAttrKey(field), v) for field, v in zip(context._fields, values)], + context, + ) + + +def _namedtuple_unflatten(values: Iterable[Any], context: Context) -> NamedTuple: + return cast(NamedTuple, context(*values)) + + +def _namedtuple_serialize(context: Context) -> DumpableContext: + if context not in SUPPORTED_SERIALIZED_TYPES: + raise NotImplementedError( + f"Can't serialize TreeSpec of namedtuple class {context} because we " + "didn't register a serializated_type_name. Please register using " + "`_register_namedtuple`." + ) + + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[context] + serialized_type_name = serialize_node_def.serialized_type_name + + if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND: + raise NotImplementedError( + f"Can't serialize TreeSpec of namedtuple class {context} because we " + "couldn't find a serializated_type_name. Please register using " + "`_register_namedtuple`." + ) + return serialized_type_name + + +def _namedtuple_deserialize(dumpable_context: DumpableContext) -> Context: + if dumpable_context not in SERIALIZED_TYPE_TO_PYTHON_TYPE: + raise NotImplementedError( + f"Can't deserialize TreeSpec of namedtuple class {dumpable_context} " + "because we couldn't find a serializated name." + ) + + typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[dumpable_context] + return typ + + +def _ordereddict_flatten(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Any], Context]: + return list(d.values()), list(d.keys()) + + +def _ordereddict_flatten_with_keys(d: GenericOrderedDict[Any, Any]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _ordereddict_flatten(d) + return [(MappingKey(k), v) for k, v in zip(context, values)], context + + +def _ordereddict_unflatten( + values: Iterable[Any], + context: Context, +) -> GenericOrderedDict[Any, Any]: + return OrderedDict((key, value) for key, value in zip(context, values)) + + +_odict_flatten = _ordereddict_flatten +_odict_unflatten = _ordereddict_unflatten + + +def _defaultdict_flatten(d: DefaultDict[Any, Any]) -> Tuple[List[Any], Context]: + values, dict_context = _dict_flatten(d) + return values, [d.default_factory, dict_context] + + +def _defaultdict_flatten_with_keys(d: DefaultDict[Any, Any]) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _defaultdict_flatten(d) + _, dict_context = context + return [(MappingKey(k), v) for k, v in zip(dict_context, values)], context + + +def _defaultdict_unflatten( + values: Iterable[Any], + context: Context, +) -> DefaultDict[Any, Any]: + default_factory, dict_context = context + return defaultdict(default_factory, _dict_unflatten(values, dict_context)) + + +def _defaultdict_serialize(context: Context) -> DumpableContext: + default_factory, dict_context = context + json_defaultdict = { + "default_factory_module": default_factory.__module__, + "default_factory_name": default_factory.__qualname__, + "dict_context": dict_context, + } + return json_defaultdict + + +def _defaultdict_deserialize(dumpable_context: DumpableContext) -> Context: + assert isinstance(dumpable_context, dict) + assert set(dumpable_context) == { + "default_factory_module", + "default_factory_name", + "dict_context", + } + + default_factory_module = dumpable_context["default_factory_module"] + default_factory_name = dumpable_context["default_factory_name"] + assert isinstance(default_factory_module, str) + assert isinstance(default_factory_name, str) + module = importlib.import_module(default_factory_module) + default_factory = getattr(module, default_factory_name) + + dict_context = dumpable_context["dict_context"] + return [default_factory, dict_context] + + +def _deque_flatten(d: Deque[Any]) -> Tuple[List[Any], Context]: + return list(d), d.maxlen + + +def _deque_flatten_with_keys( + d: Deque[Any], +) -> Tuple[List[Tuple[KeyEntry, Any]], Context]: + values, context = _deque_flatten(d) + return [(SequenceKey(i), v) for i, v in enumerate(values)], context + + +def _deque_unflatten(values: Iterable[Any], context: Context) -> Deque[Any]: + return deque(values, maxlen=context) + + +_private_register_pytree_node( + tuple, + _tuple_flatten, + _tuple_unflatten, + serialized_type_name="builtins.tuple", + flatten_with_keys_fn=_tuple_flatten_with_keys, +) +_private_register_pytree_node( + list, + _list_flatten, + _list_unflatten, + serialized_type_name="builtins.list", + flatten_with_keys_fn=_list_flatten_with_keys, +) +_private_register_pytree_node( + dict, + _dict_flatten, + _dict_unflatten, + serialized_type_name="builtins.dict", + flatten_with_keys_fn=_dict_flatten_with_keys, +) +_private_register_pytree_node( + namedtuple, # type: ignore[arg-type] + _namedtuple_flatten, + _namedtuple_unflatten, + serialized_type_name="collections.namedtuple", + to_dumpable_context=_namedtuple_serialize, + from_dumpable_context=_namedtuple_deserialize, + flatten_with_keys_fn=_namedtuple_flatten_with_keys, +) +_private_register_pytree_node( + OrderedDict, + _ordereddict_flatten, + _ordereddict_unflatten, + serialized_type_name="collections.OrderedDict", + flatten_with_keys_fn=_ordereddict_flatten_with_keys, +) +_private_register_pytree_node( + defaultdict, + _defaultdict_flatten, + _defaultdict_unflatten, + serialized_type_name="collections.defaultdict", + to_dumpable_context=_defaultdict_serialize, + from_dumpable_context=_defaultdict_deserialize, + flatten_with_keys_fn=_defaultdict_flatten_with_keys, +) +_private_register_pytree_node( + deque, + _deque_flatten, + _deque_unflatten, + serialized_type_name="collections.deque", + flatten_with_keys_fn=_deque_flatten_with_keys, +) + + +STANDARD_DICT_TYPES: FrozenSet[type] = frozenset( + {dict, OrderedDict, defaultdict}, +) +BUILTIN_TYPES: FrozenSet[type] = frozenset( + {tuple, list, dict, namedtuple, OrderedDict, defaultdict, deque}, # type: ignore[arg-type] +) + + +# h/t https://stackoverflow.com/questions/2166818/how-to-check-if-an-object-is-an-instance-of-a-namedtuple +def _is_namedtuple_instance(tree: Any) -> bool: + typ = type(tree) + bases = typ.__bases__ + if len(bases) != 1 or bases[0] != tuple: + return False + fields = getattr(typ, "_fields", None) + if not isinstance(fields, tuple): + return False + return all(type(entry) == str for entry in fields) + + +def _get_node_type(tree: Any) -> Any: + if _is_namedtuple_instance(tree): + return namedtuple + return type(tree) + + +# A leaf is defined as anything that is not a Node. +def _is_leaf(tree: PyTree, is_leaf: Optional[Callable[[PyTree], bool]] = None) -> bool: + return (is_leaf is not None and is_leaf(tree)) or _get_node_type(tree) not in SUPPORTED_NODES + + +# A TreeSpec represents the structure of a pytree. It holds: +# "type": the type of root Node of the pytree +# context: some context that is useful in unflattening the pytree +# children_specs: specs for each child of the root Node +# num_leaves: the number of leaves +@dataclasses.dataclass +class TreeSpec: + type: Any + context: Context + children_specs: List["TreeSpec"] + + num_nodes: int = dataclasses.field(init=False) + num_leaves: int = dataclasses.field(init=False) + num_children: int = dataclasses.field(init=False) + + def __post_init__(self) -> None: + self.num_nodes = 1 + sum(spec.num_nodes for spec in self.children_specs) + self.num_leaves = sum(spec.num_leaves for spec in self.children_specs) + self.num_children = len(self.children_specs) + + def __repr__(self, indent: int = 0) -> str: + repr_prefix: str = f"TreeSpec({self.type.__name__}, {self.context}, [" + children_specs_str: str = "" + if self.num_children > 0: + indent += 2 + children_specs_str += self.children_specs[0].__repr__(indent) + children_specs_str += "," if self.num_children > 1 else "" + children_specs_str += ",".join( + ["\n" + " " * indent + child.__repr__(indent) for child in self.children_specs[1:]] + ) + repr_suffix: str = f"{children_specs_str}])" + return repr_prefix + repr_suffix + + def is_leaf(self) -> bool: + return self.num_nodes == 1 and self.num_leaves == 1 + + def _flatten_up_to_helper(self, tree: PyTree, subtrees: List[PyTree]) -> None: + if self.is_leaf(): + subtrees.append(tree) + return + + node_type = _get_node_type(tree) + if self.type not in BUILTIN_TYPES: + # Always require custom node types to match exactly + if node_type != self.type: + raise ValueError( + f"Type mismatch; " f"expected {self.type!r}, but got {node_type!r}.", + ) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + if len(child_pytrees) != self.num_children: + raise ValueError( + f"Node arity mismatch; " f"expected {self.num_children}, but got {len(child_pytrees)}.", + ) + if context != self.context: + raise ValueError( + f"Node context mismatch for custom node type {self.type!r}.", + ) + else: + # For builtin dictionary types, we allow some flexibility + # Otherwise, we require exact matches + both_standard_dict = self.type in STANDARD_DICT_TYPES and node_type in STANDARD_DICT_TYPES + if node_type != self.type and not both_standard_dict: + raise ValueError( + f"Node type mismatch; " f"expected {self.type!r}, but got {node_type!r}.", + ) + if len(tree) != self.num_children: + raise ValueError( + f"Node arity mismatch; " f"expected {self.num_children}, but got {len(tree)}.", + ) + + if both_standard_dict: # dictionary types are compatible with each other + dict_context = ( + self.context + if self.type is not defaultdict + # ignore mismatch of `default_factory` for defaultdict + else self.context[1] + ) + expected_keys = dict_context + got_key_set = set(tree) + expected_key_set = set(expected_keys) + if got_key_set != expected_key_set: + missing_keys = expected_key_set.difference(got_key_set) + extra_keys = got_key_set.difference(expected_key_set) + message = "" + if missing_keys: + message += f"; missing key(s): {missing_keys}" + if extra_keys: + message += f"; extra key(s): {extra_keys}" + raise ValueError(f"Node keys mismatch{message}.") + child_pytrees = [tree[key] for key in expected_keys] + else: + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + if ( + context != self.context and self.type is not deque # ignore mismatch of `maxlen` for deque + ): + raise ValueError( + f"Node context mismatch for node type {self.type!r}; " + f"expected {self.context!r}, but got {context!r}.", # namedtuple type mismatch + ) + + for child_pytree, child_spec in zip(child_pytrees, self.children_specs): + child_spec._flatten_up_to_helper(child_pytree, subtrees) + + def flatten_up_to(self, tree: PyTree) -> List[PyTree]: + subtrees: List[PyTree] = [] + self._flatten_up_to_helper(tree, subtrees) + return subtrees + + def unflatten(self, leaves: Iterable[Any]) -> PyTree: + if not isinstance(leaves, (list, tuple)): + leaves = list(leaves) + if len(leaves) != self.num_leaves: + raise ValueError( + f"treespec.unflatten(leaves): `leaves` has length {len(leaves)} " + f"but the spec refers to a pytree that holds {self.num_leaves} " + f"items ({self}).", + ) + if self.is_leaf(): + return leaves[0] + + unflatten_fn = SUPPORTED_NODES[self.type].unflatten_fn + + # Recursively unflatten the children + start = 0 + end = 0 + child_pytrees = [] + for child_spec in self.children_specs: + end += child_spec.num_leaves + child_pytrees.append(child_spec.unflatten(leaves[start:end])) + start = end + + return unflatten_fn(child_pytrees, self.context) + + +class LeafSpec(TreeSpec): + def __init__(self) -> None: + super().__init__(None, None, []) + + def __post_init__(self) -> None: + self.num_nodes = 1 + self.num_leaves = 1 + self.num_children = 0 + + def __repr__(self, indent: int = 0) -> str: + return "*" + + +# All leaves are equivalent, so represent with a single object to save on +# object construction time +_LEAF_SPEC = LeafSpec() + + +def _tree_flatten_helper( + tree: PyTree, + leaves: List[Any], + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> TreeSpec: + if _is_leaf(tree, is_leaf=is_leaf): + leaves.append(tree) + return _LEAF_SPEC + + node_type = _get_node_type(tree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, context = flatten_fn(tree) + + # Recursively flatten the children + children_specs = [_tree_flatten_helper(child, leaves, is_leaf=is_leaf) for child in child_pytrees] + + return TreeSpec(node_type, context, children_specs) + + +def tree_flatten( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Tuple[List[Any], TreeSpec]: + """Flattens a pytree into a list of values and a TreeSpec that can be used to reconstruct the pytree.""" + leaves: List[Any] = [] + spec = _tree_flatten_helper(tree, leaves, is_leaf=is_leaf) + return leaves, spec + + +def tree_unflatten(leaves: Iterable[Any], treespec: TreeSpec) -> PyTree: + """Given a list of values and a TreeSpec, builds a pytree. + + This is the inverse operation of `tree_flatten`. + + """ + if not isinstance(treespec, TreeSpec): + raise TypeError( + f"tree_unflatten(leaves, treespec): Expected `treespec` to be " + f"instance of TreeSpec but got item of type {type(treespec)}.", + ) + return treespec.unflatten(leaves) + + +def tree_iter( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Iterable[Any]: + """Get an iterator over the leaves of a pytree.""" + if _is_leaf(tree, is_leaf=is_leaf): + yield tree + else: + node_type = _get_node_type(tree) + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, _ = flatten_fn(tree) + + # Recursively flatten the children + for child in child_pytrees: + yield from tree_iter(child, is_leaf=is_leaf) + + +def tree_leaves( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> List[Any]: + """Get a list of leaves of a pytree.""" + return list(tree_iter(tree, is_leaf=is_leaf)) + + +def tree_structure( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> TreeSpec: + """Get the TreeSpec for a pytree.""" + return tree_flatten(tree, is_leaf=is_leaf)[1] + + +def tree_map( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Map a multi-input function over pytree args to produce a new pytree. + + See also :func:`tree_map_`. + + >>> tree_map(lambda x: x + 1, {'x': 7, 'y': (42, 64)}) + {'x': 8, 'y': (43, 65)} + >>> tree_map(lambda x: x is None, {'x': 7, 'y': (42, 64), 'z': None}) + {'x': False, 'y': (False, False), 'z': True} + + If multiple inputs are given, the structure of the tree is taken from the first input; + subsequent inputs need only have ``tree`` as a prefix: + + >>> tree_map(lambda x, y: [x] + y, [5, 6], [[7, 9], [1, 2]]) + [[5, 7, 9], [6, 1, 2]] + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + A new pytree with the same structure as ``tree`` but with the value at each leaf given by + ``func(x, *xs)`` where ``x`` is the value at the corresponding leaf in ``tree`` and ``xs`` + is the tuple of values at corresponding nodes in ``rests``. + + """ + leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + return treespec.unflatten(map(func, *flat_args)) + + +def tree_map_( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Like :func:`tree_map`, but do an inplace call on each leaf and return the original tree. + + See also :func:`tree_map`. + + Args: + func (callable): A function that takes ``1 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. + tree (pytree): A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests (tuple of pytree): A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf (callable, optional): An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns: + The original ``tree`` with the value at each leaf is given by the side-effect of function + ``func(x, *xs)`` (not the return value) where ``x`` is the value at the corresponding leaf + in ``tree`` and ``xs`` is the tuple of values at values at corresponding nodes in ``rests``. + + """ + leaves, treespec = tree_flatten(tree, is_leaf=is_leaf) + flat_args = [leaves] + [treespec.flatten_up_to(r) for r in rests] + tuple(map(func, *flat_args)) # consume and exhaust the iterable + return tree + + +Type2 = Tuple[Type[T], Type[S]] +Type3 = Tuple[Type[T], Type[S], Type[U]] +if sys.version_info >= (3, 10): + TypeAny = Union[Type[Any], Tuple[Type[Any], ...], types.UnionType] +else: + TypeAny = Union[Type[Any], Tuple[Type[Any], ...]] + +Fn2 = Callable[[Union[T, S]], R] +Fn3 = Callable[[Union[T, S, U]], R] +Fn = Callable[[T], R] +FnAny = Callable[[Any], R] + +MapOnlyFn = Callable[[T], Callable[[Any], Any]] + + +# These specializations help with type inference on the lambda passed to this +# function +@overload +def map_only(__type_or_types_or_pred: Type2[T, S]) -> MapOnlyFn[Fn2[T, S, Any]]: ... + + +@overload +def map_only(__type_or_types_or_pred: Type3[T, S, U]) -> MapOnlyFn[Fn3[T, S, U, Any]]: ... + + +@overload +def map_only(__type_or_types_or_pred: Type[T]) -> MapOnlyFn[Fn[T, Any]]: ... + + +# This specialization is needed for the implementations below that call +@overload +def map_only(__type_or_types_or_pred: TypeAny) -> MapOnlyFn[FnAny[Any]]: ... + + +@overload +def map_only(__type_or_types_or_pred: Callable[[Any], bool]) -> MapOnlyFn[FnAny[Any]]: ... + + +def map_only(__type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]]) -> MapOnlyFn[FnAny[Any]]: + """Suppose you are writing a tree_map over tensors, leaving everything else unchanged. Ordinarily you would have + to write: + + def go(t): + if isinstance(t, Tensor): + return ... + else: + return t + + With this function, you only need to write: + + @map_only(Tensor) + def go(t): + return ... + + You can also directly use 'tree_map_only' + + """ + if isinstance(__type_or_types_or_pred, (type, tuple)) or ( + sys.version_info >= (3, 10) and isinstance(__type_or_types_or_pred, types.UnionType) + ): + + def pred(x: Any) -> bool: + return isinstance(x, __type_or_types_or_pred) # type: ignore[arg-type] + + elif callable(__type_or_types_or_pred): + pred = __type_or_types_or_pred # type: ignore[assignment] + else: + raise TypeError("Argument must be a type, a tuple of types, or a callable.") + + def wrapper(func: Callable[[T], Any]) -> Callable[[Any], Any]: + @functools.wraps(func) + def wrapped(x: T) -> Any: + if pred(x): + return func(x) + return x + + return wrapped + + return wrapper + + +@overload +def tree_map_only( + __type_or_types_or_pred: Type[T], + func: Fn[T, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: ... + + +@overload +def tree_map_only( + __type_or_types_or_pred: Type2[T, S], + func: Fn2[T, S, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: ... + + +@overload +def tree_map_only( + __type_or_types_or_pred: Type3[T, S, U], + func: Fn3[T, S, U, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: ... + + +@overload +def tree_map_only( + __type_or_types_or_pred: Callable[[Any], bool], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: ... + + +def tree_map_only( + __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + return tree_map(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf) + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Type[T], + func: Fn[T, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: ... + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Type2[T, S], + func: Fn2[T, S, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: ... + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Type3[T, S, U], + func: Fn3[T, S, U, Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: ... + + +@overload +def tree_map_only_( + __type_or_types_or_pred: Callable[[Any], bool], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: ... + + +def tree_map_only_( + __type_or_types_or_pred: Union[TypeAny, Callable[[Any], bool]], + func: FnAny[Any], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + return tree_map_(map_only(__type_or_types_or_pred)(func), tree, is_leaf=is_leaf) + + +def tree_all( + pred: Callable[[Any], bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return all(map(pred, flat_args)) + + +def tree_any( + pred: Callable[[Any], bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return any(map(pred, flat_args)) + + +@overload +def tree_all_only( + __type_or_types: Type[T], + pred: Fn[T, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: ... + + +@overload +def tree_all_only( + __type_or_types: Type2[T, S], + pred: Fn2[T, S, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: ... + + +@overload +def tree_all_only( + __type_or_types: Type3[T, S, U], + pred: Fn3[T, S, U, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: ... + + +def tree_all_only( + __type_or_types: TypeAny, + pred: FnAny[bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return all(pred(x) for x in flat_args if isinstance(x, __type_or_types)) + + +@overload +def tree_any_only( + __type_or_types: Type[T], + pred: Fn[T, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: ... + + +@overload +def tree_any_only( + __type_or_types: Type2[T, S], + pred: Fn2[T, S, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: ... + + +@overload +def tree_any_only( + __type_or_types: Type3[T, S, U], + pred: Fn3[T, S, U, bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: ... + + +def tree_any_only( + __type_or_types: TypeAny, + pred: FnAny[bool], + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> bool: + flat_args = tree_iter(tree, is_leaf=is_leaf) + return any(pred(x) for x in flat_args if isinstance(x, __type_or_types)) + + +# Broadcasts a pytree to the provided TreeSpec and returns the flattened +# values. If this is not possible, then this function returns None. +# +# For example, given pytree=0 and spec=TreeSpec(list, None, [LeafSpec(), LeafSpec()]), +# would return [0, 0]. This is useful for part of the vmap implementation: +# a user can pass in vmap(fn, in_dims)(*inputs). `in_dims` should be +# broadcastable to the tree structure of `inputs` and we use +# _broadcast_to_and_flatten to check this. +def _broadcast_to_and_flatten( + tree: PyTree, + treespec: TreeSpec, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Optional[List[Any]]: + assert isinstance(treespec, TreeSpec) + + if _is_leaf(tree, is_leaf=is_leaf): + return [tree] * treespec.num_leaves + if treespec.is_leaf(): + return None + node_type = _get_node_type(tree) + if node_type != treespec.type: + return None + + flatten_fn = SUPPORTED_NODES[node_type].flatten_fn + child_pytrees, ctx = flatten_fn(tree) + + # Check if the Node is different from the spec + if len(child_pytrees) != treespec.num_children or ctx != treespec.context: + return None + + # Recursively flatten the children + result: List[Any] = [] + for child, child_spec in zip(child_pytrees, treespec.children_specs): + flat = _broadcast_to_and_flatten(child, child_spec, is_leaf=is_leaf) + if flat is not None: + result += flat + else: + return None + + return result + + +@dataclasses.dataclass +class _TreeSpecSchema: + """_TreeSpecSchema is the schema used to serialize the TreeSpec. + + It contains the following fields: + - type: A string name of the type. null for the case of a LeafSpec. + - context: Any format which is json dumpable + - children_spec: A list of children serialized specs. + + """ + + type: Optional[str] + context: DumpableContext + children_spec: List["_TreeSpecSchema"] + + +class _ProtocolFn(NamedTuple): + treespec_to_json: Callable[[TreeSpec], DumpableContext] + json_to_treespec: Callable[[DumpableContext], TreeSpec] + + +_SUPPORTED_PROTOCOLS: Dict[int, _ProtocolFn] = {} + + +def _treespec_to_json(treespec: TreeSpec) -> _TreeSpecSchema: + if treespec.is_leaf(): + return _TreeSpecSchema(None, None, []) + + if treespec.type not in SUPPORTED_SERIALIZED_TYPES: + raise NotImplementedError( + f"Serializing {treespec.type} in pytree is not registered.", + ) + + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[treespec.type] + + serialized_type_name = serialize_node_def.serialized_type_name + + if serialized_type_name == NO_SERIALIZED_TYPE_NAME_FOUND: + raise NotImplementedError( + f"No registered serialization name for {treespec.type} found. " + "Please update your _register_pytree_node call with a `serialized_type_name` kwarg." + ) + + if serialize_node_def.to_dumpable_context is None: + try: + serialized_context = json.dumps(treespec.context) + except TypeError as e: + raise TypeError( + "Unable to serialize context. " + "Please make the context json dump-able, or register a " + "custom serializer using _register_pytree_node." + ) from e + else: + serialized_context = serialize_node_def.to_dumpable_context(treespec.context) + + child_schemas = [_treespec_to_json(child) for child in treespec.children_specs] + + return _TreeSpecSchema(serialized_type_name, serialized_context, child_schemas) + + +def _json_to_treespec(json_schema: DumpableContext) -> TreeSpec: + if json_schema["type"] is None and json_schema["context"] is None and len(json_schema["children_spec"]) == 0: + return _LEAF_SPEC + + if json_schema["type"] not in SERIALIZED_TYPE_TO_PYTHON_TYPE: + raise NotImplementedError( + f'Deserializing {json_schema["type"]} in pytree is not registered.', + ) + + typ = SERIALIZED_TYPE_TO_PYTHON_TYPE[json_schema["type"]] + serialize_node_def = SUPPORTED_SERIALIZED_TYPES[typ] + + if serialize_node_def.from_dumpable_context is None: + try: + context = json.loads(json_schema["context"]) + except TypeError as ex: + raise TypeError( + "Unable to deserialize context. " + "Please make the context json load-able, or register a " + "custom serializer using _register_pytree_node.", + ) from ex + else: + context = serialize_node_def.from_dumpable_context(json_schema["context"]) + + children_specs = [] + for child_string in json_schema["children_spec"]: + children_specs.append(_json_to_treespec(child_string)) + + return TreeSpec(typ, context, children_specs) + + +_SUPPORTED_PROTOCOLS[1] = _ProtocolFn(_treespec_to_json, _json_to_treespec) + + +def treespec_dumps(treespec: TreeSpec, protocol: Optional[int] = None) -> str: + if not isinstance(treespec, TreeSpec): + raise TypeError( + f"treespec_dumps(treespec, protocol): Expected `treespec` to be instance of " + f"TreeSpec but got item of type {type(treespec)}.", + ) + + if protocol is None: + protocol = DEFAULT_TREESPEC_SERIALIZATION_PROTOCOL + + if protocol in _SUPPORTED_PROTOCOLS: + json_spec = _SUPPORTED_PROTOCOLS[protocol].treespec_to_json(treespec) + else: + raise ValueError( + f"Unknown protocol {protocol}. " f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + ) + + str_spec = json.dumps((protocol, dataclasses.asdict(json_spec))) + return str_spec + + +def treespec_loads(serialized: str) -> TreeSpec: + protocol, json_schema = json.loads(serialized) + + if protocol in _SUPPORTED_PROTOCOLS: + return _SUPPORTED_PROTOCOLS[protocol].json_to_treespec(json_schema) + raise ValueError( + f"Unknown protocol {protocol}. " f"Available protocols: {list(_SUPPORTED_PROTOCOLS.keys())}", + ) + + +class _DummyLeaf: + def __repr__(self) -> str: + return "*" + + +def treespec_pprint(treespec: TreeSpec) -> str: + dummy_tree = tree_unflatten( + [_DummyLeaf() for _ in range(treespec.num_leaves)], + treespec, + ) + return repr(dummy_tree) + + +# TODO(angelayi): remove this function after OSS/internal stabilize +def pytree_to_str(treespec: TreeSpec) -> str: + warnings.warn("pytree_to_str is deprecated. Please use treespec_dumps") + return treespec_dumps(treespec) + + +# TODO(angelayi): remove this function after OSS/internal stabilize +def str_to_pytree(json: str) -> TreeSpec: + warnings.warn("str_to_pytree is deprecated. Please use treespec_loads") + return treespec_loads(json) + + +def arg_tree_leaves(*args: PyTree, **kwargs: PyTree) -> List[Any]: + """Get a flat list of arguments to this function. + + A slightly faster version of tree_leaves((args, kwargs)) + + """ + leaves: List[Any] = [] + for a in args: + leaves.extend(tree_iter(a)) + for a in kwargs.values(): + leaves.extend(tree_iter(a)) + return leaves + + +def tree_flatten_with_path( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Tuple[List[Tuple[KeyPath, Any]], TreeSpec]: + """Flattens a pytree like :func:`tree_flatten`, but also returns each leaf's key path. + + Args: + tree: a pytree to flatten. If it contains a custom type, that type must be + registered with an appropriate `tree_flatten_with_path_fn` when registered + with :func:`register_pytree_node`. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + Returns: + A tuple where the first element is a list of (key path, leaf) pairs, and the + second element is a :class:`TreeSpec` representing the structure of the flattened + tree. + + """ + _, treespec = tree_flatten(tree, is_leaf) + return list(_generate_key_paths((), tree, is_leaf)), treespec + + +def tree_leaves_with_path( + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> List[Tuple[KeyPath, Any]]: + """Gets the leaves of a pytree like ``tree_leaves`` and returns each leaf's key path. + + Args: + tree: a pytree. If it contains a custom type, that type must be + registered with an appropriate `tree_flatten_with_path_fn` when registered + with :func:`register_pytree_node`. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + Returns: + A list of (key path, leaf) pairs. + + """ + return list(_generate_key_paths((), tree, is_leaf)) + + +def _generate_key_paths( + key_path: KeyPath, + tree: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> Iterable[Tuple[KeyPath, Any]]: + if is_leaf and is_leaf(tree): + yield key_path, tree + return + + node_type = _get_node_type(tree) + handler = SUPPORTED_NODES.get(node_type) + if not handler: + # This is a leaf + yield key_path, tree + return + + flatten_with_keys = handler.flatten_with_keys_fn + if flatten_with_keys: + key_children, _ = flatten_with_keys(tree) + for k, c in key_children: + yield from _generate_key_paths((*key_path, k), c, is_leaf) + else: + # We registered this pytree but didn't add a flatten_with_keys_fn, complain. + raise ValueError( + f"Did not find a flatten_with_keys_fn for type: {node_type}. " + "Please pass a flatten_with_keys_fn argument to register_pytree_node." + ) + + +def tree_map_with_path( + func: Callable[..., Any], + tree: PyTree, + *rests: PyTree, + is_leaf: Optional[Callable[[PyTree], bool]] = None, +) -> PyTree: + """Like :func:`tree_map`, but the provided callable takes an additional key path argument. + + Args: + func: A function that takes ``2 + len(rests)`` arguments, to be applied at the + corresponding leaves of the pytrees. The first positional argument + to ``func`` is the key path of the leaf in question. The second + positional argument is the value of the leaf. + tree: A pytree to be mapped over, with each leaf providing the first positional + argument to function ``func``. + rests: A tuple of pytrees, each of which has the same structure as + ``tree`` or has ``tree`` as a prefix. + is_leaf: An extra leaf predicate function that will be called at each + flattening step. The function should have a single argument with signature + ``is_leaf(node) -> bool``. If it returns :data:`True`, the whole subtree being treated + as a leaf. Otherwise, the default pytree registry will be used to determine a node is a + leaf or not. If the function is not specified, the default pytree registry will be used. + + Returns + A new pytree with the same structure as ``tree`` but with the value at each leaf given by + ``func(keypath, x, *xs)`` where ``keypath`` is the key path at the + corresponding leaf in ``tree``, ``x`` is the value at that leaf, and + ``xs`` is the tuple of values at corresponding nodes in ``rests``. + + """ + keypath_leaves, treespec = tree_flatten_with_path(tree, is_leaf) + keypath_leaves = list(zip(*keypath_leaves)) + all_keypath_leaves = keypath_leaves + [treespec.flatten_up_to(r) for r in rests] + return treespec.unflatten(func(*xs) for xs in zip(*all_keypath_leaves)) + + +def keystr(kp: KeyPath) -> str: + """Given a key path, return a pretty-printed representation.""" + return "".join([str(k) for k in kp]) + + +def key_get(obj: Any, kp: KeyPath) -> Any: + """Given an object and a key path, return the value at the key path.""" + for k in kp: + obj = k.get(obj) + return obj diff --git a/status.json b/status.json new file mode 100644 index 000000000..72aa54955 --- /dev/null +++ b/status.json @@ -0,0 +1 @@ +{ "progress": "20.0%" }