Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/gt4py/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@
Grid,
ParameterInfo,
CartesianSpace,
StencilObject,
)
from .stencil_object import StencilObject

from . import config
from . import gtscript
Expand Down
4 changes: 2 additions & 2 deletions src/gt4py/backend/concepts.py
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ class BaseGenerator(abc.ABC):
from numpy import dtype
{{ imports }}

from gt4py import AccessKind, Boundary, DomainInfo, FieldInfo, ParameterInfo, StencilObject
from gt4py.stencil_object import AccessKind, Boundary, DomainInfo, FieldInfo, ParameterInfo, StencilObject

{{ module_members }}

Expand Down Expand Up @@ -339,7 +339,7 @@ def __call__(self, {{ stencil_signature }}, domain=None, origin=None, exec_info=
{{synchronization}}
{%- endfilter %}

self.call_run(
self._call_run(
field_args=field_args,
parameter_args=parameter_args,
domain=domain,
Expand Down
7 changes: 7 additions & 0 deletions src/gt4py/backend/debug_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

import numpy as np

from gt4py import backend as gt_backend
from gt4py import ir as gt_ir
from gt4py import definitions as gt_definitions
Expand Down Expand Up @@ -207,6 +209,10 @@ def debug_is_compatible_layout(field):
return sum(field.shape) > 0


def debug_is_compatible_type(field):
return isinstance(field, np.ndarray)


@gt_backend.register
class DebugBackend(gt_backend.BaseBackend):
name = "debug"
Expand All @@ -216,6 +222,7 @@ class DebugBackend(gt_backend.BaseBackend):
"device": "cpu",
"layout_map": debug_layout,
"is_compatible_layout": debug_is_compatible_layout,
"is_compatible_type": debug_is_compatible_type,
}

GENERATOR_CLASS = DebugGenerator
6 changes: 6 additions & 0 deletions src/gt4py/backend/gt_cpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ def x86_is_compatible_layout(field):
return True


def gtcpu_is_compatible_type(field):
return isinstance(field, np.ndarray)


def make_mc_layout_map(mask):
ctr = reversed(range(sum(mask)))
if len(mask) < 3:
Expand Down Expand Up @@ -136,6 +140,7 @@ class GTX86Backend(GTCPUBackend):
"device": "cpu",
"layout_map": make_x86_layout_map,
"is_compatible_layout": x86_is_compatible_layout,
"is_compatible_type": gtcpu_is_compatible_type,
}

_CPU_ARCHITECTURE = "x86"
Expand All @@ -151,6 +156,7 @@ class GTMCBackend(GTCPUBackend):
"device": "cpu",
"layout_map": make_mc_layout_map,
"is_compatible_layout": mc_is_compatible_layout,
"is_compatible_type": gtcpu_is_compatible_type,
}

_CPU_ARCHITECTURE = "mc"
6 changes: 6 additions & 0 deletions src/gt4py/backend/gt_cuda_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import numpy as np

from gt4py import backend as gt_backend
from gt4py import storage as gt_storage
from . import pyext_builder


Expand Down Expand Up @@ -47,6 +48,10 @@ def cuda_is_compatible_layout(field):
return True


def cuda_is_compatible_type(field):
return isinstance(field, gt_storage.storage.GPUStorage)


@gt_backend.register
class GTCUDABackend(gt_backend.BaseGTBackend):
GENERATOR_CLASS = PythonGTCUDAGenerator
Expand All @@ -57,6 +62,7 @@ class GTCUDABackend(gt_backend.BaseGTBackend):
"device": "gpu",
"layout_map": cuda_layout,
"is_compatible_layout": cuda_is_compatible_layout,
"is_compatible_type": cuda_is_compatible_type,
}

@classmethod
Expand Down
7 changes: 7 additions & 0 deletions src/gt4py/backend/numpy_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

import numpy as np

from gt4py import backend as gt_backend
from gt4py import ir as gt_ir
from gt4py import definitions as gt_definitions
Expand Down Expand Up @@ -296,6 +298,10 @@ def numpy_is_compatible_layout(field):
return sum(field.shape) > 0


def numpy_is_compatible_type(field):
return isinstance(field, np.ndarray)


@gt_backend.register
class NumPyBackend(gt_backend.BaseBackend):
name = "numpy"
Expand All @@ -305,6 +311,7 @@ class NumPyBackend(gt_backend.BaseBackend):
"device": "cpu",
"layout_map": numpy_layout,
"is_compatible_layout": numpy_is_compatible_layout,
"is_compatible_type": numpy_is_compatible_type,
}

GENERATOR_CLASS = NumPyGenerator
6 changes: 0 additions & 6 deletions src/gt4py/backend/templates/computation.src.in
Original file line number Diff line number Diff line change
Expand Up @@ -129,12 +129,6 @@ data_store_t<T, Id> make_data_store(const BufferInfo& bi,
" != " + std::to_string(3) + "]");
}

for (int i = 0; i < 2 /*3*/; ++i) {
if (2*origin[i] + compute_domain_shape[i] > bi.shape[i])
throw std::runtime_error(
"Given shape and origin exceed buffer dimension");
}

// ptr, dims and strides are "outer domain" (i.e., compute domain + halo
// region). The halo region is only defined through `make_grid` (and
// currently, in the storage info)
Expand Down
Loading