From 2caa5fa46887ba3dbf81229f4bdbdf99dddd8789 Mon Sep 17 00:00:00 2001 From: Linus Groner Date: Mon, 9 Dec 2019 11:24:09 +0100 Subject: [PATCH 1/6] check in call_run if fields are of compatible type and layout, which requires moving the StencilObject class from definitions.py to a new separate file (stencil_object.py) --- src/gt4py/__init__.py | 1 - src/gt4py/backend/concepts.py | 4 +- src/gt4py/backend/debug_backend.py | 7 + src/gt4py/backend/gt_cpu_backend.py | 6 + src/gt4py/backend/gt_cuda_backend.py | 6 + src/gt4py/backend/numpy_backend.py | 7 + src/gt4py/definitions.py | 264 +------------------------ src/gt4py/loader.py | 3 +- src/gt4py/stencil_object.py | 281 +++++++++++++++++++++++++++ 9 files changed, 312 insertions(+), 267 deletions(-) create mode 100644 src/gt4py/stencil_object.py diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py index 492062a932..c660f31a34 100644 --- a/src/gt4py/__init__.py +++ b/src/gt4py/__init__.py @@ -39,7 +39,6 @@ Grid, ParameterInfo, CartesianSpace, - StencilObject, ) from . import config diff --git a/src/gt4py/backend/concepts.py b/src/gt4py/backend/concepts.py index 61161746ca..8bb137eb2e 100644 --- a/src/gt4py/backend/concepts.py +++ b/src/gt4py/backend/concepts.py @@ -276,8 +276,8 @@ class BaseGenerator(abc.ABC): from numpy import dtype {{ imports }} -from gt4py import AccessKind, Boundary, DomainInfo, FieldInfo, ParameterInfo, StencilObject - +from gt4py import AccessKind, Boundary, DomainInfo, FieldInfo, ParameterInfo +from gt4py.stencil_object import StencilObject {{ module_members }} class {{ class_name }}(StencilObject): diff --git a/src/gt4py/backend/debug_backend.py b/src/gt4py/backend/debug_backend.py index a26499fa06..638af2ea8d 100644 --- a/src/gt4py/backend/debug_backend.py +++ b/src/gt4py/backend/debug_backend.py @@ -14,6 +14,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import numpy as np + from gt4py import backend as gt_backend from gt4py import ir as gt_ir from gt4py import definitions as gt_definitions @@ -207,6 +209,10 @@ def debug_is_compatible_layout(field): return sum(field.shape) > 0 +def debug_is_compatible_type(field): + return isinstance(field, np.ndarray) + + @gt_backend.register class DebugBackend(gt_backend.BaseBackend): name = "debug" @@ -216,6 +222,7 @@ class DebugBackend(gt_backend.BaseBackend): "device": "cpu", "layout_map": debug_layout, "is_compatible_layout": debug_is_compatible_layout, + "is_compatible_type": debug_is_compatible_type, } GENERATOR_CLASS = DebugGenerator diff --git a/src/gt4py/backend/gt_cpu_backend.py b/src/gt4py/backend/gt_cpu_backend.py index e5854b13b5..464dacb313 100644 --- a/src/gt4py/backend/gt_cpu_backend.py +++ b/src/gt4py/backend/gt_cpu_backend.py @@ -47,6 +47,10 @@ def x86_is_compatible_layout(field): return True +def gtcpu_is_compatible_type(field): + return isinstance(field, np.ndarray) + + def make_mc_layout_map(mask): ctr = reversed(range(sum(mask))) if len(mask) < 3: @@ -136,6 +140,7 @@ class GTX86Backend(GTCPUBackend): "device": "cpu", "layout_map": make_x86_layout_map, "is_compatible_layout": x86_is_compatible_layout, + "is_compatible_type": gtcpu_is_compatible_type, } _CPU_ARCHITECTURE = "x86" @@ -151,6 +156,7 @@ class GTMCBackend(GTCPUBackend): "device": "cpu", "layout_map": make_mc_layout_map, "is_compatible_layout": mc_is_compatible_layout, + "is_compatible_type": gtcpu_is_compatible_type, } _CPU_ARCHITECTURE = "mc" diff --git a/src/gt4py/backend/gt_cuda_backend.py b/src/gt4py/backend/gt_cuda_backend.py index 8d21b04b20..4567faad34 100644 --- a/src/gt4py/backend/gt_cuda_backend.py +++ b/src/gt4py/backend/gt_cuda_backend.py @@ -19,6 +19,7 @@ import numpy as np from gt4py import backend as gt_backend +from gt4py import storage as gt_storage from . import pyext_builder @@ -47,6 +48,10 @@ def cuda_is_compatible_layout(field): return True +def cuda_is_compatible_type(field): + return isinstance(field, gt_storage.storage.GPUStorage) + + @gt_backend.register class GTCUDABackend(gt_backend.BaseGTBackend): GENERATOR_CLASS = PythonGTCUDAGenerator @@ -57,6 +62,7 @@ class GTCUDABackend(gt_backend.BaseGTBackend): "device": "gpu", "layout_map": cuda_layout, "is_compatible_layout": cuda_is_compatible_layout, + "is_compatible_type": cuda_is_compatible_type, } @classmethod diff --git a/src/gt4py/backend/numpy_backend.py b/src/gt4py/backend/numpy_backend.py index 27aa0fd327..81c54f5117 100644 --- a/src/gt4py/backend/numpy_backend.py +++ b/src/gt4py/backend/numpy_backend.py @@ -14,6 +14,8 @@ # # SPDX-License-Identifier: GPL-3.0-or-later +import numpy as np + from gt4py import backend as gt_backend from gt4py import ir as gt_ir from gt4py import definitions as gt_definitions @@ -296,6 +298,10 @@ def numpy_is_compatible_layout(field): return sum(field.shape) > 0 +def numpy_is_compatible_type(field): + return isinstance(field, np.ndarray) + + @gt_backend.register class NumPyBackend(gt_backend.BaseBackend): name = "numpy" @@ -305,6 +311,7 @@ class NumPyBackend(gt_backend.BaseBackend): "device": "cpu", "layout_map": numpy_layout, "is_compatible_layout": numpy_is_compatible_layout, + "is_compatible_type": numpy_is_compatible_type, } GENERATOR_CLASS = NumPyGenerator diff --git a/src/gt4py/definitions.py b/src/gt4py/definitions.py index 1b30394d55..82c539ca5c 100644 --- a/src/gt4py/definitions.py +++ b/src/gt4py/definitions.py @@ -24,9 +24,8 @@ import time import warnings - -from gt4py import utils as gt_utils import gt4py as gt +from gt4py import utils as gt_utils from gt4py.utils.attrib import ( attribute, attribclass, @@ -704,267 +703,6 @@ def shashed_id(self): return result -class StencilObject(abc.ABC): - """Generic singleton implementation of a stencil function. - - This class is used as base class for the specific subclass generated - at run-time for any stencil definition and a unique set of external symbols. - Instances of this class do not contain any information and thus it is - implemented as a singleton: only one instance per subclass is actually - allocated (and it is immutable). - """ - - def __new__(cls, *args, **kwargs): - if getattr(cls, "_instance", None) is None: - cls._instance = object.__new__(cls) - return cls._instance - - def __setattr__(self, key, value): - raise AttributeError("Attempting a modification of an attribute in a frozen class") - - def __delattr__(self, item): - raise AttributeError("Attempting a deletion of an attribute in a frozen class") - - def __eq__(self, other): - return type(self) == type(other) - - def __str__(self): - result = """ - [backend="{backend}"] - - I/O fields: {fields} - - Parameters: {params} - - Constants: {constants} - - Definition ({func}): -{source} - """.format( - name=self.options["module"] + "." + self.options["name"], - version=self._gt_id_, - backend=self.backend, - fields=self.field_info, - params=self.parameter_info, - constants=self.constants, - func=self.definition_func, - source=self.source, - ) - - return result - - def __hash__(self): - return int.from_bytes(type(self)._gt_id_.encode(), byteorder="little") - - # Those attributes are added to the class at loading time: - # - # _gt_id_ (stencil_id.version) - # definition_func - - @property - @abc.abstractmethod - def backend(self) -> str: - pass - - @property - @abc.abstractmethod - def source(self): - pass - - @property - @abc.abstractmethod - def domain_info(self): - pass - - @property - @abc.abstractmethod - def field_info(self) -> dict: - pass - - @property - @abc.abstractmethod - def parameter_info(self) -> dict: - pass - - @property - @abc.abstractmethod - def constants(self) -> dict: - pass - - @property - @abc.abstractmethod - def options(self) -> dict: - pass - - @abc.abstractmethod - def run(self, *args, **kwargs): - pass - - @abc.abstractmethod - def __call__(self, *args, **kwargs): - pass - - def call_run(self, field_args, parameter_args, domain, origin, exec_info=None): - """Check and preprocess the provided arguments (called by :class:`StencilObject` subclasses). - - Note that this function will always try to expand simple parameter values to - complete data structures by repeating the same value as many times as needed. - - Parameters - ---------- - field_args: `dict` - Mapping from field names to actually passed data arrays. - This parameter encapsulates `*args` in the actual stencil subclass - by doing: `{input_name[i]: arg for i, arg in enumerate(args)}` - - parameter_args: `dict` - Mapping from parameter names to actually passed parameter values. - This parameter encapsulates `**kwargs` in the actual stencil subclass - by doing: `{name: value for name, value in kwargs.items()}` - - domain : `Sequence` of `int`, optional - Shape of the computation domain. If `None`, it will be used the - largest feasible domain according to the provided input fields - and origin values (`None` by default). - - origin : `[int * ndims]` or `{'field_name': [int * ndims]}`, optional - If a single offset is passed, it will be used for all fields. - If a `dict` is passed, there could be an entry for each field. - A special key *'_all_'* will represent the value to be used for all - the fields not explicitly defined. If `None` is passed or it is - not possible to assign a value to some field according to the - previous rule, the value will be inferred from the global boundaries - of the field. Note that the function checks if the origin values - are at least equal to the `global_border` attribute of that field, - so a 0-based origin will only be acceptable for fields with - a 0-area support region. - - exec_info : `dict`, optional - Dictionary used to store information about the stencil execution. - (`None` by default). - - Returns - ------- - `None` - - Raises - ------- - ValueError - If invalid data or inconsistent options are specified. - """ - - if exec_info is not None: - exec_info["call_run_start_time"] = time.perf_counter() - used_arg_fields = { - name: field for name, field in field_args.items() if self.field_info[name] is not None - } - for name, field_info in self.field_info.items(): - if field_info is not None and field_args[name] is None: - raise ValueError("Field '{field_name}' is None.".format(field_name=name)) - for name, parameter_info in self.parameter_info.items(): - if parameter_info is not None and parameter_args[name] is None: - raise ValueError( - "Parameter '{parameter_name}' is None.".format(parameter_name=name) - ) - # assert compatibility of fields with stencil - for name, field in used_arg_fields.items(): - if not gt.backend.from_name(self.backend).storage_info["is_compatible_layout"](field): - raise ValueError( - "The layout of the field {} is not compatible with the backend.".format(name) - ) - # ToDo: check if mask is correct: need mask info in stencil object. - - if not field.is_stencil_view: - raise ValueError( - "An incompatible view was passed for field " + name + " to the stencil. " - ) - for name_other, field_other in used_arg_fields.items(): - if field_other.mask == field.mask: - if not field_other.shape == field.shape: - raise ValueError( - "The fields {} and {} have the same mask but different shapes.".format( - name, name_other - ) - ) - - assert isinstance(field_args, dict) and isinstance(parameter_args, dict) - - # Shapes - shapes = {} - - for name, field in used_arg_fields.items(): - # if hasattr(field, "grid_group"): - # # Extract ndarray from gt.storage object - # field = field.data - shapes[name] = Shape(field.shape) - # Origins - if origin is None: - origin = {} - else: - origin = normalize_origin_mapping(origin) - for name, field in used_arg_fields.items(): - origin.setdefault(name, origin["_all_"] if "_all_" in origin else field.default_origin) - - # all_origin = Shape(origin["_all_"]) if "_all_" in origin else None - # for name in field_args.keys(): - # min_origin = Shape(self.field_info[name].boundary.lower_indices) - # if name not in origin: - # origin[name] = all_origin if all_origin else min_origin - # else: - # origin[name] = Shape(origin[name]) - # if not origin[name] >= min_origin: - # raise ValueError( - # "Origin value smaller than global boundary for field '{}'".format(name) - # ) - - # Domain - max_domain = Shape([sys.maxsize] * self.domain_info.ndims) - for name, shape in shapes.items(): - upper_boundary = Index(self.field_info[name].boundary.upper_indices) - max_domain &= shape - (Index(origin[name]) + upper_boundary) - - if domain is None: - domain = max_domain - else: - domain = normalize_domain(domain) - if len(domain) != self.domain_info.ndims: - raise ValueError("Invalid 'domain' value ({})".format(domain)) - - # check domain+halo vs field size - if not domain > Shape.zeros(self.domain_info.ndims): - raise ValueError("Compute domain contains zero sizes ({})".format(domain)) - - if not domain <= max_domain: - raise ValueError( - "Compute domain too large (provided: {}, maximum: {})".format(domain, max_domain) - ) - for name, field in used_arg_fields.items(): - min_origin = self.field_info[name].boundary.lower_indices - if origin[name] < min_origin: - raise ValueError( - "Origin for field {} too small. Must be at least {}, is {}".format( - name, min_origin, origin[name] - ) - ) - min_shape = tuple( - o + d + h - for o, d, h in zip( - origin[name], domain, self.field_info[name].boundary.upper_indices - ) - ) - if min_shape > field.shape: - raise ValueError( - "Shape of field {} is {} but must be at least {} for given domain and origin.".format( - name, field.shape, min_shape - ) - ) - - # if domain != max_domain: - # warnings.warn("Input fields do not match default domain size!", UserWarning) - - # field_args = {k: v.view(np.ndarray) for k, v in field_args.items()} - - self.run( - **field_args, **parameter_args, _domain_=domain, _origin_=origin, exec_info=exec_info - ) - - class GTError(Exception): pass diff --git a/src/gt4py/loader.py b/src/gt4py/loader.py index 5fb9dddc7e..7151cb12a3 100644 --- a/src/gt4py/loader.py +++ b/src/gt4py/loader.py @@ -26,6 +26,7 @@ from gt4py import backend as gt_backend from gt4py import definitions as gt_definitions from gt4py import frontend as gt_frontend +from gt4py.stencil_object import StencilObject def load_stencil(frontend_name, backend_name, definition_func, externals, options): @@ -58,7 +59,7 @@ def load_stencil(frontend_name, backend_name, definition_func, externals, option def gtscript_loader(definition_func, backend, build_options, externals): - if isinstance(definition_func, gt_definitions.StencilObject): + if isinstance(definition_func, StencilObject): definition_func = definition_func.definition_func if not isinstance(definition_func, types.FunctionType): raise ValueError("Invalid stencil definition object ({obj})".format(obj=definition_func)) diff --git a/src/gt4py/stencil_object.py b/src/gt4py/stencil_object.py new file mode 100644 index 0000000000..c738dd555d --- /dev/null +++ b/src/gt4py/stencil_object.py @@ -0,0 +1,281 @@ +import abc +import gt4py.backend as gt_backend +from gt4py.definitions import * + + +class StencilObject(abc.ABC): + """Generic singleton implementation of a stencil function. + + This class is used as base class for the specific subclass generated + at run-time for any stencil definition and a unique set of external symbols. + Instances of this class do not contain any information and thus it is + implemented as a singleton: only one instance per subclass is actually + allocated (and it is immutable). + """ + + def __new__(cls, *args, **kwargs): + if getattr(cls, "_instance", None) is None: + cls._instance = object.__new__(cls) + return cls._instance + + def __setattr__(self, key, value): + raise AttributeError("Attempting a modification of an attribute in a frozen class") + + def __delattr__(self, item): + raise AttributeError("Attempting a deletion of an attribute in a frozen class") + + def __eq__(self, other): + return type(self) == type(other) + + def __str__(self): + result = """ + [backend="{backend}"] + - I/O fields: {fields} + - Parameters: {params} + - Constants: {constants} + - Definition ({func}): +{source} + """.format( + name=self.options["module"] + "." + self.options["name"], + version=self._gt_id_, + backend=self.backend, + fields=self.field_info, + params=self.parameter_info, + constants=self.constants, + func=self.definition_func, + source=self.source, + ) + + return result + + def __hash__(self): + return int.from_bytes(type(self)._gt_id_.encode(), byteorder="little") + + # Those attributes are added to the class at loading time: + # + # _gt_id_ (stencil_id.version) + # definition_func + + @property + @abc.abstractmethod + def backend(self) -> str: + pass + + @property + @abc.abstractmethod + def source(self): + pass + + @property + @abc.abstractmethod + def domain_info(self): + pass + + @property + @abc.abstractmethod + def field_info(self) -> dict: + pass + + @property + @abc.abstractmethod + def parameter_info(self) -> dict: + pass + + @property + @abc.abstractmethod + def constants(self) -> dict: + pass + + @property + @abc.abstractmethod + def options(self) -> dict: + pass + + @abc.abstractmethod + def run(self, *args, **kwargs): + pass + + @abc.abstractmethod + def __call__(self, *args, **kwargs): + pass + + def call_run(self, field_args, parameter_args, domain, origin, exec_info=None): + """Check and preprocess the provided arguments (called by :class:`StencilObject` subclasses). + + Note that this function will always try to expand simple parameter values to + complete data structures by repeating the same value as many times as needed. + + Parameters + ---------- + field_args: `dict` + Mapping from field names to actually passed data arrays. + This parameter encapsulates `*args` in the actual stencil subclass + by doing: `{input_name[i]: arg for i, arg in enumerate(args)}` + + parameter_args: `dict` + Mapping from parameter names to actually passed parameter values. + This parameter encapsulates `**kwargs` in the actual stencil subclass + by doing: `{name: value for name, value in kwargs.items()}` + + domain : `Sequence` of `int`, optional + Shape of the computation domain. If `None`, it will be used the + largest feasible domain according to the provided input fields + and origin values (`None` by default). + + origin : `[int * ndims]` or `{'field_name': [int * ndims]}`, optional + If a single offset is passed, it will be used for all fields. + If a `dict` is passed, there could be an entry for each field. + A special key *'_all_'* will represent the value to be used for all + the fields not explicitly defined. If `None` is passed or it is + not possible to assign a value to some field according to the + previous rule, the value will be inferred from the global boundaries + of the field. Note that the function checks if the origin values + are at least equal to the `global_border` attribute of that field, + so a 0-based origin will only be acceptable for fields with + a 0-area support region. + + exec_info : `dict`, optional + Dictionary used to store information about the stencil execution. + (`None` by default). + + Returns + ------- + `None` + + Raises + ------- + ValueError + If invalid data or inconsistent options are specified. + """ + + if exec_info is not None: + exec_info["call_run_start_time"] = time.perf_counter() + used_arg_fields = { + name: field for name, field in field_args.items() if self.field_info[name] is not None + } + for name, field_info in self.field_info.items(): + if field_info is not None and field_args[name] is None: + raise ValueError("Field '{field_name}' is None.".format(field_name=name)) + for name, parameter_info in self.parameter_info.items(): + if parameter_info is not None and parameter_args[name] is None: + raise ValueError( + "Parameter '{parameter_name}' is None.".format(parameter_name=name) + ) + # assert compatibility of fields with stencil + for name, field in used_arg_fields.items(): + if not gt_backend.from_name(self.backend).storage_info["is_compatible_layout"](field): + raise ValueError( + "The layout of the field {} is not compatible with the backend.".format(name) + ) + + if not gt_backend.from_name(self.backend).storage_info["is_compatible_type"](field): + raise ValueError( + "Field '{field}' has type '{type}', which is not compatible with the '{backend}' backend.".format( + field=name, type=type(field), backend=self.backend + ) + ) + # ToDo: check if mask is correct: need mask info in stencil object. + + if not field.is_stencil_view: + raise ValueError( + "An incompatible view was passed for field " + name + " to the stencil. " + ) + for name_other, field_other in used_arg_fields.items(): + if field_other.mask == field.mask: + if not field_other.shape == field.shape: + raise ValueError( + "The fields {} and {} have the same mask but different shapes.".format( + name, name_other + ) + ) + + assert isinstance(field_args, dict) and isinstance(parameter_args, dict) + + # Shapes + shapes = {} + + for name, field in used_arg_fields.items(): + # if hasattr(field, "grid_group"): + # # Extract ndarray from gt.storage object + # field = field.data + shapes[name] = Shape(field.shape) + # Origins + if origin is None: + origin = {} + else: + origin = normalize_origin_mapping(origin) + for name, field in used_arg_fields.items(): + origin.setdefault(name, origin["_all_"] if "_all_" in origin else field.default_origin) + + for name, field in field_args.items(): + if not gt_backend.from_name(self.backend).storage_info["is_compatible_type"](field): + raise TypeError( + "Field '{name}' is of incompatible type {type}.".format( + name=name, type=type(field) + ) + ) + if not gt_backend.from_name(self.backend).storage_info["is_compatible_layout"](field): + raise ValueError("Field '{name}' has an incompatible layout.".format(name=name)) + + # all_origin = Shape(origin["_all_"]) if "_all_" in origin else None + # for name in field_args.keys(): + # min_origin = Shape(self.field_info[name].boundary.lower_indices) + # if name not in origin: + # origin[name] = all_origin if all_origin else min_origin + # else: + # origin[name] = Shape(origin[name]) + # if not origin[name] >= min_origin: + # raise ValueError( + # "Origin value smaller than global boundary for field '{}'".format(name) + # ) + + # Domain + max_domain = Shape([sys.maxsize] * self.domain_info.ndims) + for name, shape in shapes.items(): + upper_boundary = Index(self.field_info[name].boundary.upper_indices) + max_domain &= shape - (Index(origin[name]) + upper_boundary) + + if domain is None: + domain = max_domain + else: + domain = normalize_domain(domain) + if len(domain) != self.domain_info.ndims: + raise ValueError("Invalid 'domain' value ({})".format(domain)) + + # check domain+halo vs field size + if not domain > Shape.zeros(self.domain_info.ndims): + raise ValueError("Compute domain contains zero sizes ({})".format(domain)) + + if not domain <= max_domain: + raise ValueError( + "Compute domain too large (provided: {}, maximum: {})".format(domain, max_domain) + ) + for name, field in used_arg_fields.items(): + min_origin = self.field_info[name].boundary.lower_indices + if origin[name] < min_origin: + raise ValueError( + "Origin for field {} too small. Must be at least {}, is {}".format( + name, min_origin, origin[name] + ) + ) + min_shape = tuple( + o + d + h + for o, d, h in zip( + origin[name], domain, self.field_info[name].boundary.upper_indices + ) + ) + if min_shape > field.shape: + raise ValueError( + "Shape of field {} is {} but must be at least {} for given domain and origin.".format( + name, field.shape, min_shape + ) + ) + + # if domain != max_domain: + # warnings.warn("Input fields do not match default domain size!", UserWarning) + + # field_args = {k: v.view(np.ndarray) for k, v in field_args.items()} + + self.run( + **field_args, **parameter_args, _domain_=domain, _origin_=origin, exec_info=exec_info + ) From 1867dd8f5f252c05c625619ea36e505ab5ea92c7 Mon Sep 17 00:00:00 2001 From: Linus Groner Date: Mon, 9 Dec 2019 14:32:52 +0100 Subject: [PATCH 2/6] fixed parameter checks, cleanup of old, commented-out code --- src/gt4py/stencil_object.py | 56 +++++++++---------- src/gt4py/testing/suites.py | 6 +- .../test_default_arguments.py | 8 +-- 3 files changed, 32 insertions(+), 38 deletions(-) diff --git a/src/gt4py/stencil_object.py b/src/gt4py/stencil_object.py index c738dd555d..6398a6e4eb 100644 --- a/src/gt4py/stencil_object.py +++ b/src/gt4py/stencil_object.py @@ -151,7 +151,14 @@ def call_run(self, field_args, parameter_args, domain, origin, exec_info=None): if exec_info is not None: exec_info["call_run_start_time"] = time.perf_counter() used_arg_fields = { - name: field for name, field in field_args.items() if self.field_info[name] is not None + name: field + for name, field in field_args.items() + if name in self.field_info and self.field_info[name] is not None + } + used_arg_params = { + name: param + for name, param in parameter_args.items() + if name in self.parameter_info and self.parameter_info[name] is not None } for name, field_info in self.field_info.items(): if field_info is not None and field_args[name] is None: @@ -174,6 +181,12 @@ def call_run(self, field_args, parameter_args, domain, origin, exec_info=None): field=name, type=type(field), backend=self.backend ) ) + if not field.dtype == self.field_info[name].dtype: + raise TypeError( + "The dtype of field '{field}' is '{is_dtype}' instead of '{should_dtype}'".format( + field=name, is_dtype=field.dtype, should_dtype=self.field_info[name].dtype + ) + ) # ToDo: check if mask is correct: need mask info in stencil object. if not field.is_stencil_view: @@ -189,15 +202,23 @@ def call_run(self, field_args, parameter_args, domain, origin, exec_info=None): ) ) + # assert compatibility of parameters with stencil + for name, parameter in used_arg_params.items(): + if not type(parameter) == self.parameter_info[name].dtype: + raise TypeError( + "The type of parameter '{field}' is '{is_dtype}' instead of '{should_dtype}'".format( + field=name, + is_dtype=type(parameter), + should_dtype=self.parameter_info[name].dtype, + ) + ) + assert isinstance(field_args, dict) and isinstance(parameter_args, dict) # Shapes shapes = {} for name, field in used_arg_fields.items(): - # if hasattr(field, "grid_group"): - # # Extract ndarray from gt.storage object - # field = field.data shapes[name] = Shape(field.shape) # Origins if origin is None: @@ -207,28 +228,6 @@ def call_run(self, field_args, parameter_args, domain, origin, exec_info=None): for name, field in used_arg_fields.items(): origin.setdefault(name, origin["_all_"] if "_all_" in origin else field.default_origin) - for name, field in field_args.items(): - if not gt_backend.from_name(self.backend).storage_info["is_compatible_type"](field): - raise TypeError( - "Field '{name}' is of incompatible type {type}.".format( - name=name, type=type(field) - ) - ) - if not gt_backend.from_name(self.backend).storage_info["is_compatible_layout"](field): - raise ValueError("Field '{name}' has an incompatible layout.".format(name=name)) - - # all_origin = Shape(origin["_all_"]) if "_all_" in origin else None - # for name in field_args.keys(): - # min_origin = Shape(self.field_info[name].boundary.lower_indices) - # if name not in origin: - # origin[name] = all_origin if all_origin else min_origin - # else: - # origin[name] = Shape(origin[name]) - # if not origin[name] >= min_origin: - # raise ValueError( - # "Origin value smaller than global boundary for field '{}'".format(name) - # ) - # Domain max_domain = Shape([sys.maxsize] * self.domain_info.ndims) for name, shape in shapes.items(): @@ -271,11 +270,6 @@ def call_run(self, field_args, parameter_args, domain, origin, exec_info=None): ) ) - # if domain != max_domain: - # warnings.warn("Input fields do not match default domain size!", UserWarning) - - # field_args = {k: v.view(np.ndarray) for k, v in field_args.items()} - self.run( **field_args, **parameter_args, _domain_=domain, _origin_=origin, exec_info=exec_info ) diff --git a/src/gt4py/testing/suites.py b/src/gt4py/testing/suites.py index 8cc4c68010..b9a81c308d 100644 --- a/src/gt4py/testing/suites.py +++ b/src/gt4py/testing/suites.py @@ -22,7 +22,7 @@ import gt4py as gt from gt4py import gtscript from gt4py import storage as gt_storage - +from gt4py.stencil_object import StencilObject from .input_strategies import * from .utils import * @@ -387,7 +387,7 @@ def test_generation(self, test, externals_dict): for k, v in externals_dict.items(): implementation._gt_constants_[k] = v - assert isinstance(implementation, gt.StencilObject) + assert isinstance(implementation, StencilObject) assert implementation.backend == test["backend"] assert all( cls.global_boundaries[name] == field_info.boundary @@ -411,7 +411,7 @@ def test_implementation(self, test, parameters_dict): "Cannot perform validation tests, since there are no valid implementations." ) for implementation in implementation_list: - if not isinstance(implementation, gt.StencilObject): + if not isinstance(implementation, StencilObject): raise RuntimeError("Wrong function got from implementations_db cache!") fields, exec_info = parameters_dict diff --git a/tests/test_integration/test_default_arguments.py b/tests/test_integration/test_default_arguments.py index e9c7fcb2a2..4e06a90a6b 100644 --- a/tests/test_integration/test_default_arguments.py +++ b/tests/test_integration/test_default_arguments.py @@ -96,14 +96,14 @@ def test_default_arguments(backend): tmp = np.asarray(arg3) tmp *= 2 - branch_true(arg1, arg2=None, par1=2.0, par2=5, par3=3.0) + branch_true(arg1, arg2=None, par1=2.0, par2=5.0, par3=3.0) np.testing.assert_equal(arg1, 10 * np.ones((3, 3, 3))) - branch_true(arg1, arg2=None, par1=2.0, par2=5) + branch_true(arg1, arg2=None, par1=2.0, par2=5.0) np.testing.assert_equal(arg1, 100 * np.ones((3, 3, 3))) - branch_false(arg1, arg2, arg3, par1=2.0, par2=5, par3=3.0) + branch_false(arg1, arg2, arg3, par1=2.0, par2=5.0, par3=3.0) np.testing.assert_equal(arg1, 60 * np.ones((3, 3, 3))) try: - branch_false(arg1, arg2, arg3, par1=2.0, par2=5) + branch_false(arg1, arg2, arg3, par1=2.0, par2=5.0) except ValueError: pass else: From 9fb75e2dfd6b72128b93cf5df5d3e31772fda969 Mon Sep 17 00:00:00 2001 From: Linus Groner Date: Tue, 10 Dec 2019 11:49:44 +0100 Subject: [PATCH 3/6] removed unnecessary, wrong check from c++, added test to verify correctness --- .../backend/templates/computation.src.in | 6 -- ...lt_arguments.py => test_call_interface.py} | 58 +++++++++++++++++++ 2 files changed, 58 insertions(+), 6 deletions(-) rename tests/test_integration/{test_default_arguments.py => test_call_interface.py} (60%) diff --git a/src/gt4py/backend/templates/computation.src.in b/src/gt4py/backend/templates/computation.src.in index f1a570b96c..b74975912b 100644 --- a/src/gt4py/backend/templates/computation.src.in +++ b/src/gt4py/backend/templates/computation.src.in @@ -129,12 +129,6 @@ data_store_t make_data_store(const BufferInfo& bi, " != " + std::to_string(3) + "]"); } - for (int i = 0; i < 2 /*3*/; ++i) { - if (2*origin[i] + compute_domain_shape[i] > bi.shape[i]) - throw std::runtime_error( - "Given shape and origin exceed buffer dimension"); - } - // ptr, dims and strides are "outer domain" (i.e., compute domain + halo // region). The halo region is only defined through `make_grid` (and // currently, in the storage info) diff --git a/tests/test_integration/test_default_arguments.py b/tests/test_integration/test_call_interface.py similarity index 60% rename from tests/test_integration/test_default_arguments.py rename to tests/test_integration/test_call_interface.py index 4e06a90a6b..ba5cd5f3e7 100644 --- a/tests/test_integration/test_default_arguments.py +++ b/tests/test_integration/test_call_interface.py @@ -43,6 +43,13 @@ def a_stencil( arg1 = arg2 + arg3 * par1 * par2 * par3 +def avg_stencil(in_field: Field[np.float64], out_field: Field[np.float64]): + with computation(PARALLEL), interval(...): + out_field = 0.25 * ( + +in_field[0, 1, 0] + in_field[0, -1, 0] + in_field[1, 0, 0] + in_field[-1, 0, 0] + ) + + @pytest.mark.parametrize( "backend", [ @@ -108,3 +115,54 @@ def test_default_arguments(backend): pass else: assert False + + +@pytest.mark.parametrize( + "backend", + [ + name + for name in gt_backend.REGISTRY.names + if gt_backend.from_name(name).storage_info["device"] != "gpu" + ], +) +def test_halo_checks(backend): + stencil = gtscript.stencil(definition=avg_stencil, backend=backend) + + # test default works + in_field = gt_storage.ones( + backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + out_field = gt_storage.zeros( + backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + stencil(in_field=in_field, out_field=out_field) + assert (out_field[1:-1, 1:-1, :] == 1).all() + + # test setting arbitrary, small domain works + in_field = gt_storage.ones( + backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + out_field = gt_storage.zeros( + backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + stencil(in_field=in_field, out_field=out_field, origin=(2, 2, 0), domain=(10, 10, 10)) + assert (out_field[2:12, 2:12, :] == 1).all() + + # test setting domain+origin too large raises + in_field = gt_storage.ones( + backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + out_field = gt_storage.zeros( + backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + with pytest.raises(ValueError): + stencil(in_field=in_field, out_field=out_field, origin=(2, 2, 0), domain=(20, 20, 10)) + + # test 2*origin+domain does not raise if still fits (c.f. previous bug in c++ check.) + in_field = gt_storage.ones( + backend=backend, shape=(23, 23, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + out_field = gt_storage.zeros( + backend=backend, shape=(23, 23, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + stencil(in_field=in_field, out_field=out_field, origin=(2, 2, 0), domain=(20, 20, 10)) From e300ccbec052444df882305055af8ffff3c747e2 Mon Sep 17 00:00:00 2001 From: Linus Groner Date: Tue, 10 Dec 2019 16:09:44 +0100 Subject: [PATCH 4/6] * renamed call_run to _call_run to make clear that it is not part of the gt4py user interface. * accepting numpy integer types for domain, origin in call and storage interfaces --- src/gt4py/__init__.py | 1 + src/gt4py/backend/concepts.py | 6 +- src/gt4py/definitions.py | 10 +- src/gt4py/stencil_object.py | 16 +- src/gt4py/storage/utils.py | 9 +- tests/test_integration/test_call_interface.py | 168 ----------------- tests/test_unittest/test_call_interface.py | 170 ++++++++++++++++++ tests/test_unittest/test_gtscript_frontend.py | 5 +- 8 files changed, 202 insertions(+), 183 deletions(-) delete mode 100644 tests/test_integration/test_call_interface.py diff --git a/src/gt4py/__init__.py b/src/gt4py/__init__.py index c660f31a34..6414b8c324 100644 --- a/src/gt4py/__init__.py +++ b/src/gt4py/__init__.py @@ -40,6 +40,7 @@ ParameterInfo, CartesianSpace, ) +from .stencil_object import StencilObject from . import config from . import gtscript diff --git a/src/gt4py/backend/concepts.py b/src/gt4py/backend/concepts.py index 8bb137eb2e..cbefc0e04d 100644 --- a/src/gt4py/backend/concepts.py +++ b/src/gt4py/backend/concepts.py @@ -276,8 +276,8 @@ class BaseGenerator(abc.ABC): from numpy import dtype {{ imports }} -from gt4py import AccessKind, Boundary, DomainInfo, FieldInfo, ParameterInfo -from gt4py.stencil_object import StencilObject +from gt4py.stencil_object import StencilObject, AccessKind, Boundary, DomainInfo, FieldInfo, ParameterInfo + {{ module_members }} class {{ class_name }}(StencilObject): @@ -339,7 +339,7 @@ def __call__(self, {{ stencil_signature }}, domain=None, origin=None, exec_info= {{synchronization}} {%- endfilter %} - self.call_run( + self._call_run( field_args=field_args, parameter_args=parameter_args, domain=domain, diff --git a/src/gt4py/definitions.py b/src/gt4py/definitions.py index 82c539ca5c..a1fe0023bb 100644 --- a/src/gt4py/definitions.py +++ b/src/gt4py/definitions.py @@ -64,7 +64,7 @@ def _check_value(cls, value, ndims): @classmethod def is_valid(cls, value, *, ndims=(1, CartesianSpace.ndims)): - if isinstance(ndims, int): + if isinstance(ndims, numbers.Integral): ndims = tuple([ndims] * 2) elif not isinstance(ndims, collections.abc.Sequence) or len(ndims) != 2: raise ValueError("Invalid 'ndims' definition ({})".format(ndims)) @@ -241,7 +241,7 @@ class Index(NumericTuple): @classmethod def _check_value(cls, value, ndims): assert isinstance(value, collections.abc.Sequence), "Invalid sequence" - assert all(isinstance(d, int) for d in value) + assert all(isinstance(d, numbers.Integral) for d in value) assert ndims[0] <= len(value) <= ndims[1] @@ -253,7 +253,7 @@ class Shape(NumericTuple): @classmethod def _check_value(cls, value, ndims): assert isinstance(value, collections.abc.Sequence), "Invalid sequence" - assert all(isinstance(d, int) and d >= 0 for d in value) + assert all(isinstance(d, numbers.Integral) and d >= 0 for d in value) assert ndims[0] <= len(value) <= ndims[1] @@ -734,8 +734,8 @@ def normalize_domain(domain): def normalize_origin(origin): if origin is not None: - if isinstance(origin, int): - origin = Shape.from_k(origin) + if isinstance(origin, numbers.Integral): + origin = Shape.from_k(int(origin)) elif isinstance(origin, collections.abc.Sequence) and Index.is_valid(origin): origin = Shape.from_value(origin) else: diff --git a/src/gt4py/stencil_object.py b/src/gt4py/stencil_object.py index 6398a6e4eb..6ddc2b1284 100644 --- a/src/gt4py/stencil_object.py +++ b/src/gt4py/stencil_object.py @@ -1,6 +1,18 @@ +import sys import abc +import time import gt4py.backend as gt_backend -from gt4py.definitions import * +from gt4py.definitions import ( + AccessKind, + Boundary, + DomainInfo, + FieldInfo, + ParameterInfo, + normalize_domain, + normalize_origin_mapping, + Shape, + Index, +) class StencilObject(abc.ABC): @@ -99,7 +111,7 @@ def run(self, *args, **kwargs): def __call__(self, *args, **kwargs): pass - def call_run(self, field_args, parameter_args, domain, origin, exec_info=None): + def _call_run(self, field_args, parameter_args, domain, origin, exec_info=None): """Check and preprocess the provided arguments (called by :class:`StencilObject` subclasses). Note that this function will always try to expand simple parameter values to diff --git a/src/gt4py/storage/utils.py b/src/gt4py/storage/utils.py index 65d2c47534..ae7d90b4c0 100644 --- a/src/gt4py/storage/utils.py +++ b/src/gt4py/storage/utils.py @@ -14,8 +14,9 @@ # # SPDX-License-Identifier: GPL-3.0-or-later -import numpy as np import math +import numbers +import numpy as np import gt4py.utils as gt_util import collections @@ -49,14 +50,14 @@ def normalize_shape(shape, mask=None): "len(shape) must be equal to len(mask) or the number of 'True' entries in mask." ) - if not gt_util.is_iterable_of(shape, int): + if not gt_util.is_iterable_of(shape, numbers.Integral): raise TypeError("shape must be a tuple of ints or pairs of ints.") if any(o <= 0 for o in shape): raise ValueError("shape ({}) contains non-positive value.".format(shape)) new_shape = list(shape) if sum(mask) < len(shape): - new_shape = [h for i, h in enumerate(new_shape) if mask[i]] + new_shape = [int(h) for i, h in enumerate(new_shape) if mask[i]] return tuple(new_shape) @@ -75,7 +76,7 @@ def normalize_default_origin(default_origin, mask=None): "len(default_origin) must be equal to len(mask) or the number of 'True' entries in mask." ) - if not gt_util.is_iterable_of(default_origin, int): + if not gt_util.is_iterable_of(default_origin, numbers.Integral): raise TypeError("default_origin must be a tuple of ints or pairs of ints.") if any(o < 0 for o in default_origin): raise ValueError("default_origin ({}) contains negative value.".format(default_origin)) diff --git a/tests/test_integration/test_call_interface.py b/tests/test_integration/test_call_interface.py deleted file mode 100644 index ba5cd5f3e7..0000000000 --- a/tests/test_integration/test_call_interface.py +++ /dev/null @@ -1,168 +0,0 @@ -# -*- coding: utf-8 -*- -# -# GT4Py - GridTools4Py - GridTools for Python -# -# Copyright (c) 2014-2019, ETH Zurich -# All rights reserved. -# -# This file is part the GT4Py project and the GridTools framework. -# GT4Py is free software: you can redistribute it and/or modify it under -# the terms of the GNU General Public License as published by the -# Free Software Foundation, either version 3 of the License, or any later -# version. See the LICENSE.txt file at the top-level directory of this -# distribution for a copy of the license or check . -# -# SPDX-License-Identifier: GPL-3.0-or-later - -import numpy as np - -import gt4py as gt -import gt4py.gtscript as gtscript -import gt4py.backend as gt_backend -import gt4py.storage as gt_storage -from gt4py.gtscript import Field -import pytest - - -def a_stencil( - arg1: Field[np.float64], - arg2: Field[np.float64], - arg3: Field[np.float64] = None, - *, - par1: np.float64, - par2: np.float64 = 7.0, - par3: np.float64 = None, -): - from __externals__ import BRANCH - - with computation(PARALLEL), interval(...): - - if __INLINED(BRANCH): - arg1 = arg1 * par1 * par2 - else: - arg1 = arg2 + arg3 * par1 * par2 * par3 - - -def avg_stencil(in_field: Field[np.float64], out_field: Field[np.float64]): - with computation(PARALLEL), interval(...): - out_field = 0.25 * ( - +in_field[0, 1, 0] + in_field[0, -1, 0] + in_field[1, 0, 0] + in_field[-1, 0, 0] - ) - - -@pytest.mark.parametrize( - "backend", - [ - name - for name in gt_backend.REGISTRY.names - if gt_backend.from_name(name).storage_info["device"] == "cpu" - ], -) -def test_default_arguments(backend): - branch_true = gtscript.stencil( - backend=backend, definition=a_stencil, externals={"BRANCH": True}, rebuild=True - ) - branch_false = gtscript.stencil( - backend=backend, definition=a_stencil, externals={"BRANCH": False}, rebuild=True - ) - - arg1 = gt_storage.ones( - backend=backend, dtype=np.float64, shape=(3, 3, 3), default_origin=(0, 0, 0) - ) - arg2 = gt_storage.zeros( - backend=backend, dtype=np.float64, shape=(3, 3, 3), default_origin=(0, 0, 0) - ) - arg3 = gt_storage.ones( - backend=backend, dtype=np.float64, shape=(3, 3, 3), default_origin=(0, 0, 0) - ) - tmp = np.asarray(arg3) - tmp *= 2 - - branch_true(arg1, None, arg3, par1=2.0) - np.testing.assert_equal(arg1, 14 * np.ones((3, 3, 3))) - branch_true(arg1, None, par1=2.0) - np.testing.assert_equal(arg1, 196 * np.ones((3, 3, 3))) - branch_false(arg1, arg2, arg3, par1=2.0, par3=2.0) - np.testing.assert_equal(arg1, 56 * np.ones((3, 3, 3))) - try: - branch_false(arg1, arg2, par1=2.0, par3=2.0) - except ValueError: - pass - else: - assert False - - arg1 = gt_storage.ones( - backend=backend, dtype=np.float64, shape=(3, 3, 3), default_origin=(0, 0, 0) - ) - arg2 = gt_storage.zeros( - backend=backend, dtype=np.float64, shape=(3, 3, 3), default_origin=(0, 0, 0) - ) - arg3 = gt_storage.ones( - backend=backend, dtype=np.float64, shape=(3, 3, 3), default_origin=(0, 0, 0) - ) - tmp = np.asarray(arg3) - tmp *= 2 - - branch_true(arg1, arg2=None, par1=2.0, par2=5.0, par3=3.0) - np.testing.assert_equal(arg1, 10 * np.ones((3, 3, 3))) - branch_true(arg1, arg2=None, par1=2.0, par2=5.0) - np.testing.assert_equal(arg1, 100 * np.ones((3, 3, 3))) - branch_false(arg1, arg2, arg3, par1=2.0, par2=5.0, par3=3.0) - np.testing.assert_equal(arg1, 60 * np.ones((3, 3, 3))) - try: - branch_false(arg1, arg2, arg3, par1=2.0, par2=5.0) - except ValueError: - pass - else: - assert False - - -@pytest.mark.parametrize( - "backend", - [ - name - for name in gt_backend.REGISTRY.names - if gt_backend.from_name(name).storage_info["device"] != "gpu" - ], -) -def test_halo_checks(backend): - stencil = gtscript.stencil(definition=avg_stencil, backend=backend) - - # test default works - in_field = gt_storage.ones( - backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 - ) - out_field = gt_storage.zeros( - backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 - ) - stencil(in_field=in_field, out_field=out_field) - assert (out_field[1:-1, 1:-1, :] == 1).all() - - # test setting arbitrary, small domain works - in_field = gt_storage.ones( - backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 - ) - out_field = gt_storage.zeros( - backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 - ) - stencil(in_field=in_field, out_field=out_field, origin=(2, 2, 0), domain=(10, 10, 10)) - assert (out_field[2:12, 2:12, :] == 1).all() - - # test setting domain+origin too large raises - in_field = gt_storage.ones( - backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 - ) - out_field = gt_storage.zeros( - backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 - ) - with pytest.raises(ValueError): - stencil(in_field=in_field, out_field=out_field, origin=(2, 2, 0), domain=(20, 20, 10)) - - # test 2*origin+domain does not raise if still fits (c.f. previous bug in c++ check.) - in_field = gt_storage.ones( - backend=backend, shape=(23, 23, 10), default_origin=(1, 1, 0), dtype=np.float64 - ) - out_field = gt_storage.zeros( - backend=backend, shape=(23, 23, 10), default_origin=(1, 1, 0), dtype=np.float64 - ) - stencil(in_field=in_field, out_field=out_field, origin=(2, 2, 0), domain=(20, 20, 10)) diff --git a/tests/test_unittest/test_call_interface.py b/tests/test_unittest/test_call_interface.py index 21263377fc..15b2dcc78f 100644 --- a/tests/test_unittest/test_call_interface.py +++ b/tests/test_unittest/test_call_interface.py @@ -15,6 +15,7 @@ # SPDX-License-Identifier: GPL-3.0-or-later import numpy as np +import pytest import gt4py.gtscript as gtscript from gt4py.gtscript import Field @@ -194,3 +195,172 @@ def test_domain_selection(): assert np.all(A == 4) assert np.all(B == 7) assert np.all(C == 21) + + +def a_stencil( + arg1: Field[np.float64], + arg2: Field[np.float64], + arg3: Field[np.float64] = None, + *, + par1: np.float64, + par2: np.float64 = 7.0, + par3: np.float64 = None, +): + from __externals__ import BRANCH + + with computation(PARALLEL), interval(...): + + if __INLINED(BRANCH): + arg1 = arg1 * par1 * par2 + else: + arg1 = arg2 + arg3 * par1 * par2 * par3 + + +def avg_stencil(in_field: Field[np.float64], out_field: Field[np.float64]): + with computation(PARALLEL), interval(...): + out_field = 0.25 * ( + +in_field[0, 1, 0] + in_field[0, -1, 0] + in_field[1, 0, 0] + in_field[-1, 0, 0] + ) + + +@pytest.mark.parametrize( + "backend", + [ + name + for name in gt_backend.REGISTRY.names + if gt_backend.from_name(name).storage_info["device"] == "cpu" + ], +) +def test_default_arguments(backend): + branch_true = gtscript.stencil( + backend=backend, definition=a_stencil, externals={"BRANCH": True}, rebuild=True + ) + branch_false = gtscript.stencil( + backend=backend, definition=a_stencil, externals={"BRANCH": False}, rebuild=True + ) + + arg1 = gt_storage.ones( + backend=backend, dtype=np.float64, shape=(3, 3, 3), default_origin=(0, 0, 0) + ) + arg2 = gt_storage.zeros( + backend=backend, dtype=np.float64, shape=(3, 3, 3), default_origin=(0, 0, 0) + ) + arg3 = gt_storage.ones( + backend=backend, dtype=np.float64, shape=(3, 3, 3), default_origin=(0, 0, 0) + ) + tmp = np.asarray(arg3) + tmp *= 2 + + branch_true(arg1, None, arg3, par1=2.0) + np.testing.assert_equal(arg1, 14 * np.ones((3, 3, 3))) + branch_true(arg1, None, par1=2.0) + np.testing.assert_equal(arg1, 196 * np.ones((3, 3, 3))) + branch_false(arg1, arg2, arg3, par1=2.0, par3=2.0) + np.testing.assert_equal(arg1, 56 * np.ones((3, 3, 3))) + try: + branch_false(arg1, arg2, par1=2.0, par3=2.0) + except ValueError: + pass + else: + assert False + + arg1 = gt_storage.ones( + backend=backend, dtype=np.float64, shape=(3, 3, 3), default_origin=(0, 0, 0) + ) + arg2 = gt_storage.zeros( + backend=backend, dtype=np.float64, shape=(3, 3, 3), default_origin=(0, 0, 0) + ) + arg3 = gt_storage.ones( + backend=backend, dtype=np.float64, shape=(3, 3, 3), default_origin=(0, 0, 0) + ) + tmp = np.asarray(arg3) + tmp *= 2 + + branch_true(arg1, arg2=None, par1=2.0, par2=5.0, par3=3.0) + np.testing.assert_equal(arg1, 10 * np.ones((3, 3, 3))) + branch_true(arg1, arg2=None, par1=2.0, par2=5.0) + np.testing.assert_equal(arg1, 100 * np.ones((3, 3, 3))) + branch_false(arg1, arg2, arg3, par1=2.0, par2=5.0, par3=3.0) + np.testing.assert_equal(arg1, 60 * np.ones((3, 3, 3))) + try: + branch_false(arg1, arg2, arg3, par1=2.0, par2=5.0) + except ValueError: + pass + else: + assert False + + +@pytest.mark.parametrize( + "backend", + [ + name + for name in gt_backend.REGISTRY.names + if gt_backend.from_name(name).storage_info["device"] != "gpu" + ], +) +def test_halo_checks(backend): + stencil = gtscript.stencil(definition=avg_stencil, backend=backend) + + # test default works + in_field = gt_storage.ones( + backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + out_field = gt_storage.zeros( + backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + stencil(in_field=in_field, out_field=out_field) + assert (out_field[1:-1, 1:-1, :] == 1).all() + + # test setting arbitrary, small domain works + in_field = gt_storage.ones( + backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + out_field = gt_storage.zeros( + backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + stencil(in_field=in_field, out_field=out_field, origin=(2, 2, 0), domain=(10, 10, 10)) + assert (out_field[2:12, 2:12, :] == 1).all() + + # test setting domain+origin too large raises + in_field = gt_storage.ones( + backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + out_field = gt_storage.zeros( + backend=backend, shape=(22, 22, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + with pytest.raises(ValueError): + stencil(in_field=in_field, out_field=out_field, origin=(2, 2, 0), domain=(20, 20, 10)) + + # test 2*origin+domain does not raise if still fits (c.f. previous bug in c++ check.) + in_field = gt_storage.ones( + backend=backend, shape=(23, 23, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + out_field = gt_storage.zeros( + backend=backend, shape=(23, 23, 10), default_origin=(1, 1, 0), dtype=np.float64 + ) + stencil(in_field=in_field, out_field=out_field, origin=(2, 2, 0), domain=(20, 20, 10)) + + +def test_np_int_types(): + backend = "numpy" + stencil = gtscript.stencil(definition=avg_stencil, backend=backend) + + # test numpy int types are accepted + in_field = gt_storage.ones( + backend=backend, + shape=(np.int8(23), np.int16(23), np.int32(10)), + default_origin=(np.int64(1), int(1), 0), + dtype=np.float64, + ) + out_field = gt_storage.zeros( + backend=backend, + shape=(np.int8(23), np.int16(23), np.int32(10)), + default_origin=(np.int64(1), int(1), 0), + dtype=np.float64, + ) + stencil( + in_field=in_field, + out_field=out_field, + origin=(np.int8(2), np.int16(2), np.int32(0)), + domain=(np.int64(20), int(20), 10), + ) diff --git a/tests/test_unittest/test_gtscript_frontend.py b/tests/test_unittest/test_gtscript_frontend.py index e05e0a960b..7f9fc32278 100644 --- a/tests/test_unittest/test_gtscript_frontend.py +++ b/tests/test_unittest/test_gtscript_frontend.py @@ -396,5 +396,8 @@ def definition_func( with pytest.raises(ValueError, match=r".*data type descriptor.*"): compile_definition( - definition_func, "test_invalid_external_dtypes", module, dtypes={"dtype": test_dtype} + definition_func, + "test_invalid_external_dtypes", + module, + dtypes={"dtype": test_dtype}, ) From cae085bae8a2937d3949531590294314e0b7ab98 Mon Sep 17 00:00:00 2001 From: Linus Groner Date: Tue, 10 Dec 2019 16:57:31 +0100 Subject: [PATCH 5/6] Raise warning if numpy arrays are passed as fields. --- src/gt4py/stencil_object.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/gt4py/stencil_object.py b/src/gt4py/stencil_object.py index 6ddc2b1284..20bff27c53 100644 --- a/src/gt4py/stencil_object.py +++ b/src/gt4py/stencil_object.py @@ -1,6 +1,8 @@ import sys import abc import time +import warnings +import numpy as np import gt4py.backend as gt_backend from gt4py.definitions import ( AccessKind, @@ -193,6 +195,11 @@ def _call_run(self, field_args, parameter_args, domain, origin, exec_info=None): field=name, type=type(field), backend=self.backend ) ) + elif isinstance(field, np.ndarray): + warnings.warn( + "NumPy ndarray passed as field. This is discouraged and only works with constraints and only for certain backends.", + RuntimeWarning, + ) if not field.dtype == self.field_info[name].dtype: raise TypeError( "The dtype of field '{field}' is '{is_dtype}' instead of '{should_dtype}'".format( From 0b350b22c47b8f7af06395c57ffbe11958b0efe2 Mon Sep 17 00:00:00 2001 From: Linus Groner Date: Wed, 11 Dec 2019 09:32:26 +0100 Subject: [PATCH 6/6] changed some .format to f-strings, support for NumPy ndarrays instead of tuples for domain and origin parameters --- src/gt4py/backend/concepts.py | 2 +- src/gt4py/definitions.py | 5 ++- src/gt4py/stencil_object.py | 42 +++++++--------------- tests/test_unittest/test_call_interface.py | 25 +++++++++++++ 4 files changed, 43 insertions(+), 31 deletions(-) diff --git a/src/gt4py/backend/concepts.py b/src/gt4py/backend/concepts.py index cbefc0e04d..6c88ea81f9 100644 --- a/src/gt4py/backend/concepts.py +++ b/src/gt4py/backend/concepts.py @@ -276,7 +276,7 @@ class BaseGenerator(abc.ABC): from numpy import dtype {{ imports }} -from gt4py.stencil_object import StencilObject, AccessKind, Boundary, DomainInfo, FieldInfo, ParameterInfo +from gt4py.stencil_object import AccessKind, Boundary, DomainInfo, FieldInfo, ParameterInfo, StencilObject {{ module_members }} diff --git a/src/gt4py/definitions.py b/src/gt4py/definitions.py index a1fe0023bb..7668477f1d 100644 --- a/src/gt4py/definitions.py +++ b/src/gt4py/definitions.py @@ -724,6 +724,8 @@ def __init__(self, message): def normalize_domain(domain): + if domain is not None: + domain = tuple(domain) if not isinstance(domain, Shape): if not Shape.is_valid(domain): raise ValueError("Invalid 'domain' value ({})".format(domain)) @@ -734,6 +736,7 @@ def normalize_domain(domain): def normalize_origin(origin): if origin is not None: + origin = tuple(origin) if isinstance(origin, numbers.Integral): origin = Shape.from_k(int(origin)) elif isinstance(origin, collections.abc.Sequence) and Index.is_valid(origin): @@ -745,7 +748,7 @@ def normalize_origin(origin): def normalize_origin_mapping(origin_mapping): - origin_mapping = origin_mapping or {} + origin_mapping = origin_mapping if origin_mapping is not None else {} if isinstance(origin_mapping, collections.abc.Mapping): origin_mapping = { key: normalize_origin(value) diff --git a/src/gt4py/stencil_object.py b/src/gt4py/stencil_object.py index 20bff27c53..0d811a4773 100644 --- a/src/gt4py/stencil_object.py +++ b/src/gt4py/stencil_object.py @@ -176,24 +176,20 @@ def _call_run(self, field_args, parameter_args, domain, origin, exec_info=None): } for name, field_info in self.field_info.items(): if field_info is not None and field_args[name] is None: - raise ValueError("Field '{field_name}' is None.".format(field_name=name)) + raise ValueError(f"Field '{name}' is None.") for name, parameter_info in self.parameter_info.items(): if parameter_info is not None and parameter_args[name] is None: - raise ValueError( - "Parameter '{parameter_name}' is None.".format(parameter_name=name) - ) + raise ValueError(f"Parameter '{name}' is None.") # assert compatibility of fields with stencil for name, field in used_arg_fields.items(): if not gt_backend.from_name(self.backend).storage_info["is_compatible_layout"](field): raise ValueError( - "The layout of the field {} is not compatible with the backend.".format(name) + f"The layout of the field {name} is not compatible with the backend." ) if not gt_backend.from_name(self.backend).storage_info["is_compatible_type"](field): raise ValueError( - "Field '{field}' has type '{type}', which is not compatible with the '{backend}' backend.".format( - field=name, type=type(field), backend=self.backend - ) + f"Field '{name}' has type '{type(field)}', which is not compatible with the '{self.backend}' backend." ) elif isinstance(field, np.ndarray): warnings.warn( @@ -202,34 +198,26 @@ def _call_run(self, field_args, parameter_args, domain, origin, exec_info=None): ) if not field.dtype == self.field_info[name].dtype: raise TypeError( - "The dtype of field '{field}' is '{is_dtype}' instead of '{should_dtype}'".format( - field=name, is_dtype=field.dtype, should_dtype=self.field_info[name].dtype - ) + f"The dtype of field '{name}' is '{field.dtype}' instead of '{self.field_info[name].dtype}'" ) # ToDo: check if mask is correct: need mask info in stencil object. if not field.is_stencil_view: raise ValueError( - "An incompatible view was passed for field " + name + " to the stencil. " + f"An incompatible view was passed for field {name} to the stencil. " ) for name_other, field_other in used_arg_fields.items(): if field_other.mask == field.mask: if not field_other.shape == field.shape: raise ValueError( - "The fields {} and {} have the same mask but different shapes.".format( - name, name_other - ) + f"The fields {name} and {name_other} have the same mask but different shapes." ) # assert compatibility of parameters with stencil for name, parameter in used_arg_params.items(): if not type(parameter) == self.parameter_info[name].dtype: raise TypeError( - "The type of parameter '{field}' is '{is_dtype}' instead of '{should_dtype}'".format( - field=name, - is_dtype=type(parameter), - should_dtype=self.parameter_info[name].dtype, - ) + f"The type of parameter '{name}' is '{type(parameter)}' instead of '{self.parameter_info[name].dtype}'" ) assert isinstance(field_args, dict) and isinstance(parameter_args, dict) @@ -258,23 +246,21 @@ def _call_run(self, field_args, parameter_args, domain, origin, exec_info=None): else: domain = normalize_domain(domain) if len(domain) != self.domain_info.ndims: - raise ValueError("Invalid 'domain' value ({})".format(domain)) + raise ValueError(f"Invalid 'domain' value '{domain}'") # check domain+halo vs field size if not domain > Shape.zeros(self.domain_info.ndims): - raise ValueError("Compute domain contains zero sizes ({})".format(domain)) + raise ValueError(f"Compute domain contains zero sizes '{domain}')") if not domain <= max_domain: raise ValueError( - "Compute domain too large (provided: {}, maximum: {})".format(domain, max_domain) + f"Compute domain too large (provided: {domain}, maximum: {max_domain})" ) for name, field in used_arg_fields.items(): min_origin = self.field_info[name].boundary.lower_indices if origin[name] < min_origin: raise ValueError( - "Origin for field {} too small. Must be at least {}, is {}".format( - name, min_origin, origin[name] - ) + f"Origin for field {name} too small. Must be at least {min_origin}, is {origin[name]}" ) min_shape = tuple( o + d + h @@ -284,9 +270,7 @@ def _call_run(self, field_args, parameter_args, domain, origin, exec_info=None): ) if min_shape > field.shape: raise ValueError( - "Shape of field {} is {} but must be at least {} for given domain and origin.".format( - name, field.shape, min_shape - ) + f"Shape of field {name} is {field.shape} but must be at least {min_shape} for given domain and origin." ) self.run( diff --git a/tests/test_unittest/test_call_interface.py b/tests/test_unittest/test_call_interface.py index 15b2dcc78f..7741cc4e4b 100644 --- a/tests/test_unittest/test_call_interface.py +++ b/tests/test_unittest/test_call_interface.py @@ -364,3 +364,28 @@ def test_np_int_types(): origin=(np.int8(2), np.int16(2), np.int32(0)), domain=(np.int64(20), int(20), 10), ) + + +def test_np_array_int_types(): + backend = "numpy" + stencil = gtscript.stencil(definition=avg_stencil, backend=backend) + + # test numpy int types are accepted + in_field = gt_storage.ones( + backend=backend, + shape=np.asarray((23, 23, 10), dtype=np.int64), + default_origin=np.asarray((1, 1, 0), dtype=np.int64), + dtype=np.float64, + ) + out_field = gt_storage.zeros( + backend=backend, + shape=np.asarray((23, 23, 10), dtype=np.int64), + default_origin=np.asarray((1, 1, 0), dtype=np.int64), + dtype=np.float64, + ) + stencil( + in_field=in_field, + out_field=out_field, + origin=np.asarray((2, 2, 0), dtype=np.int64), + domain=np.asarray((20, 20, 10), dtype=np.int64), + )