Skip to content

Commit

Permalink
Try #979:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] committed Oct 24, 2022
2 parents 6127440 + c0e0026 commit b206825
Show file tree
Hide file tree
Showing 58 changed files with 514 additions and 384 deletions.
1 change: 1 addition & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -140,6 +140,7 @@ repos:
src/gt4py/frontend/nodes.py |
src/gt4py/frontend/node_util.py |
src/gt4py/frontend/gtscript_frontend.py |
src/gt4py/frontend/defir_to_gtir.py |
src/gt4py/utils/meta.py |
tests/definitions.py |
tests/definition_setup.py |
Expand Down
5 changes: 3 additions & 2 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ cuda116 =
cuda117 =
cupy-cuda117
dace =
dace~=0.14
dace>=0.14.1,<0.15
sympy
format =
clang-format>=9.0
Expand Down Expand Up @@ -209,7 +209,8 @@ show_error_codes = True
allow_untyped_defs = False

[mypy-gtc.*]
allow_untyped_defs = False
# TODO: Make this False and fix errors
allow_untyped_defs = True


#-- pytest --
Expand Down
13 changes: 0 additions & 13 deletions src/__init__.py

This file was deleted.

160 changes: 154 additions & 6 deletions src/eve/concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,26 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

"""Definitions of basic Eve concepts."""
"""Definitions of basic Eve."""


from __future__ import annotations

import collections.abc
import functools

import pydantic
import pydantic.generics

from . import iterators, utils
from .type_definitions import NOTHING, IntEnum, Str, StrEnum
from . import utils
from .type_definitions import NOTHING, Enum, IntEnum, Str, StrEnum
from .typingx import (
Any,
AnyNoArgCallable,
ClassVar,
Dict,
Generator,
Iterable,
List,
Optional,
Set,
Expand All @@ -41,6 +43,14 @@
)


try:
# For perfomance reasons, try to use cytoolz when possible (using cython)
import cytoolz as toolz
except ModuleNotFoundError:
# Fall back to pure Python toolz
import toolz # noqa: F401 # imported but unused


# -- Fields --
class ImplFieldMetadataDict(TypedDict, total=False):
info: pydantic.fields.FieldInfo
Expand Down Expand Up @@ -153,6 +163,144 @@ def __new__(mcls, name, bases, namespace, **kwargs):
return cls


KeyValue = Tuple[Union[int, str], TreeNode]
TreeIterationItem = Union[TreeNode, Tuple[KeyValue, TreeNode]]


def generic_iter_children(
node: TreeNode, *, with_keys: bool = False
) -> Iterable[Union[TreeNode, Tuple[KeyValue, TreeNode]]]:
"""Create an iterator to traverse values as Eve tree nodes.
Args:
with_keys: Return tuples of (key, object) values where keys are
the reference to the object node in the parent.
Defaults to `False`.
"""
if isinstance(node, BaseNode):
return node.iter_children() if with_keys else node.iter_children_values()
elif isinstance(node, (list, tuple)) or (
isinstance(node, collections.abc.Sequence) and not isinstance(node, (str, bytes))
):
return enumerate(node) if with_keys else iter(node)
elif isinstance(node, (set, collections.abc.Set)):
return zip(node, node) if with_keys else iter(node) # type: ignore # problems with iter(Set)
elif isinstance(node, (dict, collections.abc.Mapping)):
return node.items() if with_keys else node.values()

return iter(())


class TraversalOrder(Enum):
PRE_ORDER = "pre"
POST_ORDER = "post"
LEVELS_ORDER = "levels"


def _iter_tree_pre(
node: TreeNode, *, with_keys: bool = False, __key__: Optional[Any] = None
) -> Generator[TreeIterationItem, None, None]:
"""Create a pre-order tree traversal iterator (Depth-First Search).
Args:
with_keys: Return tuples of (key, object) values where keys are
the reference to the object node in the parent.
Defaults to `False`.
"""
if with_keys:
yield __key__, node
for key, child in generic_iter_children(node, with_keys=True):
yield from _iter_tree_pre(child, with_keys=True, __key__=key)
else:
yield node
for child in generic_iter_children(node, with_keys=False):
yield from _iter_tree_pre(child, with_keys=False)


def _iter_tree_post(
node: TreeNode, *, with_keys: bool = False, __key__: Optional[Any] = None
) -> Generator[TreeIterationItem, None, None]:
"""Create a post-order tree traversal iterator (Depth-First Search).
Args:
with_keys: Return tuples of (key, object) values where keys are
the reference to the object node in the parent.
Defaults to `False`.
"""
if with_keys:
for key, child in generic_iter_children(node, with_keys=True):
yield from _iter_tree_post(child, with_keys=True, __key__=key)
yield __key__, node
else:
for child in generic_iter_children(node, with_keys=False):
yield from _iter_tree_post(child, with_keys=False)
yield node


def _iter_tree_levels(
node: TreeNode,
*,
with_keys: bool = False,
__key__: Optional[Any] = None,
__queue__: Optional[List] = None,
) -> Generator[TreeIterationItem, None, None]:
"""Create a tree traversal iterator by levels (Breadth-First Search).
Args:
with_keys: Return tuples of (key, object) values where keys are
the reference to the object node in the parent.
Defaults to `False`.
"""
__queue__ = __queue__ or []
if with_keys:
yield __key__, node
__queue__.extend(generic_iter_children(node, with_keys=True))
if __queue__:
key, child = __queue__.pop(0)
yield from _iter_tree_levels(child, with_keys=True, __key__=key, __queue__=__queue__)
else:
yield node
__queue__.extend(generic_iter_children(node, with_keys=False))
if __queue__:
child = __queue__.pop(0)
yield from _iter_tree_levels(child, with_keys=False, __queue__=__queue__)


iter_tree_pre = utils.as_xiter(_iter_tree_pre)
iter_tree_post = utils.as_xiter(_iter_tree_post)
iter_tree_levels = utils.as_xiter(_iter_tree_levels)


def iter_tree(
node: TreeNode,
traversal_order: TraversalOrder = TraversalOrder.PRE_ORDER,
*,
with_keys: bool = False,
) -> utils.XIterable[TreeIterationItem]:
"""Create a tree traversal iterator.
Args:
traversal_order: Tree nodes traversal order.
with_keys: Return tuples of (key, object) values where keys are
the reference to the object node in the parent.
Defaults to `False`.
"""
if traversal_order is traversal_order.PRE_ORDER:
return iter_tree_pre(node=node, with_keys=with_keys)
elif traversal_order is traversal_order.POST_ORDER:
return iter_tree_post(node=node, with_keys=with_keys)
elif traversal_order is traversal_order.LEVELS_ORDER:
return iter_tree_levels(node=node, with_keys=with_keys)
else:
raise ValueError(f"Invalid '{traversal_order}' traversal order.")


class BaseNode(pydantic.BaseModel, metaclass=NodeMetaclass):
"""Base class representing an IR node.
Expand Down Expand Up @@ -197,13 +345,13 @@ def iter_children_values(self) -> Generator[Any, None, None]:
yield getattr(self, name)

def iter_tree_pre(self) -> utils.XIterable:
return iterators.iter_tree_pre(self)
return iter_tree_pre(self)

def iter_tree_post(self) -> utils.XIterable:
return iterators.iter_tree_post(self)
return iter_tree_post(self)

def iter_tree_levels(self) -> utils.XIterable:
return iterators.iter_tree_levels(self)
return iter_tree_levels(self)

iter_tree = iter_tree_pre

Expand Down
Loading

0 comments on commit b206825

Please sign in to comment.