Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove src/__init__.py and fix many errors #979

Merged
merged 23 commits into from
Oct 26, 2022
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ repos:
- id: isort

- repo: https://gitlab.com/PyCQA/flake8
rev: 4.0.1
rev: 5.0.4
hooks:
- id: flake8
additional_dependencies:
Expand Down Expand Up @@ -132,7 +132,7 @@ repos:
)$

- repo: https://github.com/pre-commit/mirrors-mypy
rev: v0.812
rev: v0.982
hooks:
- id: mypy
exclude: |
Expand All @@ -144,6 +144,7 @@ repos:
src/gt4py/frontend/nodes.py |
src/gt4py/frontend/node_util.py |
src/gt4py/frontend/gtscript_frontend.py |
src/gt4py/frontend/defir_to_gtir.py |
src/gt4py/utils/meta.py |
tests/definitions.py |
tests/definition_setup.py |
Expand All @@ -158,3 +159,5 @@ repos:
tests/test_unittest/test_gtc/gtir_utils.py |
tests/test_unittest/test_gtc/test_gtir_to_oir.py |
)$
entry: |
mypy --install-types --non-interactive
8 changes: 4 additions & 4 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ check-manifest>=0.40
coverage>=5.0
darglint>=1.6
factory-boy>=3.1
flake8~=4.0
flake8>=5.0.4
flake8-bugbear>=20.11.1
flake8-builtins>=1.5.3
flake8-debugger>=4.0.0
Expand All @@ -12,8 +12,8 @@ flake8-eradicate>=1.0.0
flake8-mutable>=1.2.0
flake8-rst-docstrings>=0.0.14
hypothesis>=4.14
isort~=5.1
mypy>=0.800
isort~=5.10
mypy>=0.980
pre-commit~=2.17
psutil>=5.0
pygments>=2.7
Expand All @@ -26,5 +26,5 @@ setuptools>=40.8.0
seed-isort-config~=2.1
sphinx~=4.4
sphinx_rtd_theme~=1.0
tox~=3.14
tox>=3.14
devtools~=0.6
9 changes: 5 additions & 4 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ include_package_data = True
python_requires = >= 3.8
install_requires =
attrs>=20.3
black>=22.3.0
black>=22.3
cached-property>=1.5
click>=7.1
jinja2>=2.10
Expand All @@ -52,7 +52,7 @@ install_requires =
packaging>=20.0
pybind11>=2.5
tabulate>=0.8
typing-extensions>=3.7
typing-extensions>=4.2
astunparse>=1.6.3;python_version<'3.9'
# ---- eve / gtc ----
boltons>=20.0
Expand Down Expand Up @@ -92,7 +92,7 @@ cuda116 =
cuda117 =
cupy-cuda117
dace =
dace~=0.14
dace>=0.14.1,<0.15
sympy
format =
clang-format>=9.0
Expand Down Expand Up @@ -209,7 +209,8 @@ show_error_codes = True
allow_untyped_defs = False

[mypy-gtc.*]
allow_untyped_defs = False
# TODO: Make this False and fix errors
allow_untyped_defs = True


#-- pytest --
Expand Down
13 changes: 0 additions & 13 deletions src/__init__.py

This file was deleted.

2 changes: 1 addition & 1 deletion src/eve/datamodels.py
Original file line number Diff line number Diff line change
Expand Up @@ -542,7 +542,7 @@ def _make_dataclass_field_from_attr(field_attrib: Attribute) -> dataclasses.Fiel
default=default,
default_factory=default_factory,
init=field_attrib.init,
repr=field_attrib.repr if not callable(field_attrib.repr) else None,
repr=field_attrib.repr if not callable(field_attrib.repr) else False,
hash=field_attrib.hash,
compare=field_attrib.eq,
metadata=field_attrib.metadata,
Expand Down
6 changes: 3 additions & 3 deletions src/eve/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,13 @@
import toolz # noqa: F401 # imported but unused


KeyValue = Tuple[Union[int, str], Any]
TreeIterationItem = Union[Any, Tuple[KeyValue, Any]]
KeyValue = Tuple[Union[int, str], concepts.TreeNode]
TreeIterationItem = Union[concepts.TreeNode, Tuple[KeyValue, concepts.TreeNode]]


def generic_iter_children(
node: concepts.TreeNode, *, with_keys: bool = False
) -> Iterable[Union[Any, Tuple[KeyValue, Any]]]:
) -> Iterable[Union[concepts.TreeNode, Tuple[KeyValue, concepts.TreeNode]]]:
"""Create an iterator to traverse values as Eve tree nodes.

Args:
Expand Down
2 changes: 1 addition & 1 deletion src/eve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,7 +291,7 @@ def _decorator(base_cls: Type) -> Type:
def noninstantiable(cls: Type) -> Type:
original_init = cls.__init__

def _noninstantiable_init(self, *args, **kwargs) -> None:
def _noninstantiable_init(self, *args: Any, **kwargs: Any) -> None: # type: ignore # mypy somehow thinks there are missing type annotations here
if self.__class__ is cls:
raise TypeError(f"Trying to instantiate `{cls.__name__}` non-instantiable class.")
else:
Expand Down
11 changes: 8 additions & 3 deletions src/gt4py/backend/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import pathlib
import time
import warnings
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Dict, List, Optional, Tuple, Type, Union
from typing import TYPE_CHECKING, Any, ClassVar, Dict, List, Optional, Protocol, Tuple, Type, Union

from gt4py import definitions as gt_definitions
from gt4py import utils as gt_utils
Expand Down Expand Up @@ -305,6 +305,11 @@ def make_module_source(self, *, args_data: Optional[ModuleData] = None, **kwargs
return source


class MakeModuleSourceCallable(Protocol):
def __call__(self, *, args_data: Optional[ModuleData] = None, **kwargs: Any) -> str:
...


class PurePythonBackendCLIMixin(CLIBackendMixin):
"""Mixin for CLI support for backends deriving from BaseBackend."""

Expand All @@ -314,11 +319,11 @@ class PurePythonBackendCLIMixin(CLIBackendMixin):
#: In order to use this mixin, the backend class must implement
#: a :py:meth:`make_module_source` method or derive from
#: :py:meth:`BaseBackend`.
make_module_source: Callable
make_module_source: MakeModuleSourceCallable

def generate_computation(self) -> Dict[str, Union[str, Dict]]:
file_name = self.builder.module_path.name
source = self.make_module_source(implementation_ir=self.builder.implementation_ir)
source = self.make_module_source(ir=self.builder.gtir)
return {str(file_name): source}

def generate_bindings(self, language_name: str) -> Dict[str, Union[str, Dict]]:
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/backend/cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@ def __call__(self, stencil_ir: gtir.Stencil) -> Dict[str, Dict[str, str]]:
cuir_node = extent_analysis.CacheExtents().visit(cuir_node)
format_source = self.backend.builder.options.format_source
implementation = cuir_codegen.CUIRCodegen.apply(cuir_node, format_source=format_source)
bindings = CudaBindingsCodegen.apply(
bindings = CudaBindingsCodegen.apply_codegen(
cuir_node,
module_name=self.module_name,
backend=self.backend,
Expand Down Expand Up @@ -124,7 +124,7 @@ def visit_Program(self, node: cuir.Program, **kwargs):
Program = bindings_main_template()

@classmethod
def apply(cls, root, *, module_name="stencil", backend, **kwargs) -> str:
def apply_codegen(cls, root, *, module_name="stencil", backend, **kwargs) -> str:
generated_code = cls(backend).visit(root, module_name=module_name, **kwargs)
if kwargs.get("format_source", True):
generated_code = codegen.format_source("cpp", generated_code, style="LLVM")
Expand Down
8 changes: 3 additions & 5 deletions src/gt4py/backend/dace_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pathlib
import re
import textwrap
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Type
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type

import dace
import dace.data
Expand Down Expand Up @@ -694,8 +694,6 @@ class BaseDaceBackend(BaseGTBackend, CLIBackendMixin):
GT_BACKEND_T = "dace"
PYEXT_GENERATOR_CLASS = DaCeExtGenerator # type: ignore

options = BaseGTBackend.GT_BACKEND_OPTS

def generate(self) -> Type["StencilObject"]:
self.check_options(self.builder.options)

Expand Down Expand Up @@ -731,7 +729,7 @@ class DaceCPUBackend(BaseDaceBackend):

options = BaseGTBackend.GT_BACKEND_OPTS

def generate_extension(self) -> Tuple[str, str]:
def generate_extension(self, **kwargs: Any) -> Tuple[str, str]:
return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=False)


Expand All @@ -753,5 +751,5 @@ class DaceGPUBackend(BaseDaceBackend):
"device_sync": {"versioning": True, "type": bool},
}

def generate_extension(self) -> Tuple[str, str]:
def generate_extension(self, **kwargs: Any) -> Tuple[str, str]:
return self.make_extension(stencil_ir=self.builder.gtir, uses_cuda=True)
1 change: 1 addition & 0 deletions src/gt4py/backend/dace_lazy_stencil.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ def __sdfg__(self, *args, **kwargs) -> dace.SDFG:
sdfg_manager = SDFGManager(self.builder)
args_data = make_args_data_from_gtir(self.builder.gtir_pipeline)
arg_names = [arg.name for arg in self.builder.gtir.api_signature]
assert args_data.domain_info is not None
norm_kwargs = DaCeStencilObject.normalize_args(
*args,
backend=self.backend.name,
Expand Down
26 changes: 8 additions & 18 deletions src/gt4py/backend/dace_stencil_object.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,24 +23,16 @@
import dace.frontend.python.common
from dace.frontend.python.common import SDFGClosure, SDFGConvertible

from gt4py import backend as gt_backend
import gt4py.backend
from gt4py.backend.dace_backend import freeze_origin_domain_sdfg
from gt4py.definitions import AccessKind, DomainInfo, FieldInfo
from gt4py.stencil_object import FrozenStencil, StencilObject
from gt4py.stencil_object import ArgsInfo, FrozenStencil, StencilObject
from gt4py.utils import shash


@dataclass
class _ArgsInfo:
device: str
array: dace.data.Array
origin: Optional[Tuple[int]] = None
dimensions: Optional[Tuple[str]] = None


def _extract_array_infos(field_args, device) -> Dict[str, _ArgsInfo]:
def _extract_array_infos(field_args, device) -> Dict[str, Optional[ArgsInfo]]:
return {
name: _ArgsInfo(
name: ArgsInfo(
array=arg,
dimensions=getattr(arg, "__gt_dims__", None),
device=device,
Expand All @@ -50,10 +42,6 @@ def _extract_array_infos(field_args, device) -> Dict[str, _ArgsInfo]:
}


def _extract_stencil_arrays(array_infos: Dict[str, _ArgsInfo]):
return {name: info.array for name, info in array_infos.items()}


def add_optional_fields(
sdfg: dace.SDFG, field_info: Dict[str, Any], parameter_info: Dict[str, Any], **kwargs: Any
) -> dace.SDFG:
Expand Down Expand Up @@ -199,17 +187,19 @@ def normalize_args(
arg_names: Iterable[str],
domain_info: DomainInfo,
field_info: Dict[str, FieldInfo],
domain: Optional[Tuple[int, int, int]] = None,
domain: Optional[Tuple[int, ...]] = None,
origin: Optional[Dict[str, Tuple[int, ...]]] = None,
**kwargs,
):
backend_cls = gt4py.backend.from_name(backend)
assert backend_cls is not None
args_iter = iter(args)
args_as_kwargs = {
name: (kwargs[name] if name in kwargs else next(args_iter)) for name in arg_names
}
arg_infos = _extract_array_infos(
field_args=args_as_kwargs,
device=gt_backend.from_name(backend).storage_info["device"],
device=backend_cls.storage_info["device"],
)

origin = DaCeStencilObject._normalize_origins(arg_infos, field_info, origin)
Expand Down
12 changes: 6 additions & 6 deletions src/gt4py/backend/gtc_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
if TYPE_CHECKING:
from gt4py.stencil_builder import StencilBuilder
from gt4py.stencil_object import StencilObject
from gt4py.storage.storage import Storage


def _get_unit_stride_dim(backend, domain_dim_flags, data_ndim):
Expand Down Expand Up @@ -222,7 +221,7 @@ class BackendCodegen:
TEMPLATE_FILES: Dict[str, str]

@abc.abstractmethod
def __init__(self, class_name: str, module_name: str, backend: str):
def __init__(self, class_name: str, module_name: str, backend: Any):
pass

@abc.abstractmethod
Expand All @@ -233,7 +232,7 @@ def __call__(self, ir: gtir.Stencil) -> Dict[str, Dict[str, str]]:

class BaseGTBackend(gt_backend.BasePyExtBackend, gt_backend.CLIBackendMixin):

GT_BACKEND_OPTS = {
GT_BACKEND_OPTS: Dict[str, Dict[str, Any]] = {
"add_profile_info": {"versioning": True, "type": bool},
"clean": {"versioning": False, "type": bool},
"debug_mode": {"versioning": True, "type": bool},
Expand Down Expand Up @@ -287,6 +286,7 @@ def make_extension(
stencil_ir = self.builder.gtir
# Generate source
gt_pyext_files: Dict[str, Any]
gt_pyext_sources: Dict[str, Any]
if not self.builder.options._impl_opts.get("disable-code-generation", False):
gt_pyext_files = self.make_extension_sources(stencil_ir=stencil_ir)
gt_pyext_sources = {**gt_pyext_files["computation"], **gt_pyext_files["bindings"]}
Expand Down Expand Up @@ -407,7 +407,7 @@ def make_x86_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], ...
return _permute_layout_to_dimensions([lt for lt in layout if lt is not None], dimensions)


def x86_is_compatible_layout(field: "Storage", dimensions: Tuple[str, ...]) -> bool:
def x86_is_compatible_layout(field: np.ndarray, dimensions: Tuple[str, ...]) -> bool:
stride = 0
layout_map = make_x86_layout_map(dimensions)
flattened_layout = [index for index in layout_map if index is not None]
Expand Down Expand Up @@ -439,7 +439,7 @@ def make_mc_layout_map(dimensions: Tuple[str, ...]) -> Tuple[int, ...]:
return _permute_layout_to_dimensions([lt for lt in layout if lt is not None], dimensions)


def mc_is_compatible_layout(field: "Storage", dimensions: Tuple[str, ...]) -> bool:
def mc_is_compatible_layout(field: np.ndarray, dimensions: Tuple[str, ...]) -> bool:
stride = 0
layout_map = make_mc_layout_map(dimensions)
flattened_layout = [index for index in layout_map if index is not None]
Expand All @@ -457,7 +457,7 @@ def make_cuda_layout_map(dimensions: Tuple[str, ...]) -> Tuple[Optional[int], ..
return _permute_layout_to_dimensions(layout, dimensions)


def cuda_is_compatible_layout(field: "Storage", dimensions: Tuple[str, ...]) -> bool:
def cuda_is_compatible_layout(field: np.ndarray, dimensions: Tuple[str, ...]) -> bool:
stride = 0
layout_map = make_cuda_layout_map(dimensions)
flattened_layout = [index for index in layout_map if index is not None]
Expand Down
7 changes: 4 additions & 3 deletions src/gt4py/backend/module_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import numbers
import os
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Set
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Set, cast

import jinja2
import numpy
Expand Down Expand Up @@ -102,7 +102,8 @@ def make_args_data_from_gtir(pipeline: GtirPipeline) -> ModuleData:
)

for decl in (param for param in all_params if isinstance(param, gtir.ScalarDecl)):
access = accesses[decl.name]
access = cast(Literal[AccessKind.NONE, AccessKind.READ], accesses[decl.name])
assert access in {AccessKind.NONE, AccessKind.READ}
dtype = numpy.dtype(decl.dtype.name.lower())
data.parameter_info[decl.name] = ParameterInfo(access=access, dtype=dtype)

Expand Down Expand Up @@ -241,7 +242,7 @@ def generate_sources(self) -> Dict[str, str]:
if self.builder.gtir.sources is not None:
return {
key: gt_utils.text.format_source(value, line_length=self.SOURCE_LINE_LENGTH)
for key, value in self.builder.gtir.sources
for key, value in self.builder.gtir.sources.items()
}
return {}

Expand Down
Loading