Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[wip] feat[next]: AOT toolchain #1545

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 7 additions & 10 deletions src/gt4py/next/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,9 @@ def foast_to_foast_closure(
inp: workflow.InputWithArgs[ffront_stages.FoastOperatorDefinition],
) -> ffront_stages.FoastClosure:
from_fieldop = inp.kwargs.pop("from_fieldop")
debug = inp.kwargs.pop("debug", inp.data.debug)
return ffront_stages.FoastClosure(
foast_op_def=inp.data,
foast_op_def=dataclasses.replace(inp.data, debug=debug),
args=inp.args,
kwargs=inp.kwargs,
closure_vars={inp.data.foast_node.id: from_fieldop},
Expand Down Expand Up @@ -70,10 +71,10 @@ class FieldopTransformWorkflow(workflow.NamedStepSequenceWithArgs):
)
)
past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] = (
dataclasses.field(default=past_process_args.past_process_args)
dataclasses.field(default=past_process_args.PastProcessArgs(aot_off=False))
)
past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = (
dataclasses.field(default_factory=past_to_itir.PastToItirFactory)
dataclasses.field(default_factory=past_to_itir.JITPastToItirFactory)
)

foast_to_itir: workflow.Workflow[ffront_stages.FoastOperatorDefinition, itir.Expr] = (
Expand Down Expand Up @@ -115,21 +116,17 @@ class ProgramTransformWorkflow(workflow.NamedStepSequenceWithArgs):
ffront_stages.PastProgramDefinition, ffront_stages.PastClosure
] = dataclasses.field(
default=lambda inp: ffront_stages.PastClosure(
past_node=inp.data.past_node,
closure_vars=inp.data.closure_vars,
grid_type=inp.data.grid_type,
definition=dataclasses.replace(inp.data, debug=inp.kwargs.pop("debug", inp.data.debug)),
args=inp.args,
kwargs=inp.kwargs,
),
metadata={"takes_args": True},
)
past_transform_args: workflow.Workflow[ffront_stages.PastClosure, ffront_stages.PastClosure] = (
dataclasses.field(
default=past_process_args.past_process_args, metadata={"takes_args": False}
)
dataclasses.field(default=past_process_args.PastProcessArgs(aot_off=False))
)
past_to_itir: workflow.Workflow[ffront_stages.PastClosure, stages.ProgramCall] = (
dataclasses.field(default_factory=past_to_itir.PastToItirFactory)
dataclasses.field(default_factory=past_to_itir.JITPastToItirFactory)
)


Expand Down
12 changes: 7 additions & 5 deletions src/gt4py/next/ffront/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,21 +173,23 @@ def _all_closure_vars(self) -> dict[str, Any]:
@functools.cached_property
def itir(self) -> itir.FencilDefinition:
no_args_past = ffront_stages.PastClosure(
past_node=self.past_stage.past_node,
closure_vars=self.past_stage.closure_vars,
grid_type=self.definition_stage.grid_type,
definition=ffront_stages.PastProgramDefinition(
past_node=self.past_stage.past_node,
closure_vars=self.past_stage.closure_vars,
grid_type=self.definition_stage.grid_type,
),
args=[],
kwargs={},
)
if self.backend is not None and self.backend.transforms_prog is not None:
return self.backend.transforms_prog.past_to_itir(no_args_past).program
return past_to_itir.PastToItirFactory()(no_args_past).program
return past_to_itir.JITPastToItirFactory()(no_args_past).program

def __call__(self, *args, offset_provider: dict[str, Dimension], **kwargs: Any) -> None:
if self.backend is None:
warnings.warn(
UserWarning(
f"Field View Program '{self.itir.id}': Using Python execution, consider selecting a perfomance backend."
f"Field View Program '{self.definition_stage.definition.__name__}': Using Python execution, consider selecting a perfomance backend."
),
stacklevel=2,
)
Expand Down
4 changes: 1 addition & 3 deletions src/gt4py/next/ffront/foast_to_past.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ def __call__(self, inp: ffront_stages.FoastClosure) -> ffront_stages.PastClosure
)

return ffront_stages.PastClosure(
past_node=past_def.past_node,
closure_vars=past_def.closure_vars,
grid_type=past_def.grid_type,
definition=past_def,
args=inp.args,
kwargs=inp.kwargs,
)
1 change: 1 addition & 0 deletions src/gt4py/next/ffront/func_to_foast.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@ def func_to_foast(
closure_vars=closure_vars,
grid_type=inp.grid_type,
attributes=inp.attributes,
debug=inp.debug,
)


Expand Down
1 change: 1 addition & 0 deletions src/gt4py/next/ffront/func_to_past.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def func_to_past(inp: ffront_stages.ProgramDefinition) -> ffront_stages.PastProg
past_node=ProgramParser.apply(source_def, closure_vars, annotations),
closure_vars=closure_vars,
grid_type=inp.grid_type,
debug=inp.debug,
)


Expand Down
33 changes: 17 additions & 16 deletions src/gt4py/next/ffront/past_process_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#
# SPDX-License-Identifier: GPL-3.0-or-later

import dataclasses
from typing import Any, Iterator, Optional

from gt4py.next import common, errors
Expand All @@ -20,25 +21,25 @@
stages as ffront_stages,
type_specifications as ts_ffront,
)
from gt4py.next.otf import workflow
from gt4py.next.type_system import type_info, type_specifications as ts, type_translation


@workflow.make_step
def past_process_args(inp: ffront_stages.PastClosure) -> ffront_stages.PastClosure:
extra_kwarg_names = ["offset_provider", "column_axis"]
extra_kwargs = {k: v for k, v in inp.kwargs.items() if k in extra_kwarg_names}
kwargs = {k: v for k, v in inp.kwargs.items() if k not in extra_kwarg_names}
rewritten_args, size_args, kwargs = _process_args(
past_node=inp.past_node, args=list(inp.args), kwargs=kwargs
)
return ffront_stages.PastClosure(
past_node=inp.past_node,
closure_vars=inp.closure_vars,
grid_type=inp.grid_type,
args=tuple([*rewritten_args, *size_args]),
kwargs=kwargs | extra_kwargs,
)
@dataclasses.dataclass(frozen=True)
class PastProcessArgs:
aot_off: bool = False

def __call__(self, inp: ffront_stages.PastClosure) -> ffront_stages.PastClosure:
extra_kwarg_names = ["offset_provider", "column_axis"]
extra_kwargs = {k: v for k, v in inp.kwargs.items() if k in extra_kwarg_names}
kwargs = {k: v for k, v in inp.kwargs.items() if k not in extra_kwarg_names}
rewritten_args, size_args, kwargs = _process_args(
past_node=inp.definition.past_node, args=list(inp.args), kwargs=kwargs
)
return ffront_stages.PastClosure(
definition=inp.definition,
args=(*rewritten_args, *(size_args if self.aot_off else tuple())),
kwargs=kwargs | extra_kwargs,
)


def _validate_args(past_node: past.Program, args: list, kwargs: dict[str, Any]) -> None:
Expand Down
50 changes: 42 additions & 8 deletions src/gt4py/next/ffront/past_to_itir.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,13 +39,15 @@

@dataclasses.dataclass(frozen=True)
class PastToItir(workflow.ChainableWorkflowMixin):
def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall:
all_closure_vars = transform_utils._get_closure_vars_recursively(inp.closure_vars)
def __call__(self, inp: ffront_stages.AOTFieldviewProgramAst) -> stages.AOTProgram:
all_closure_vars = transform_utils._get_closure_vars_recursively(
inp.definition.closure_vars
)
offsets_and_dimensions = transform_utils._filter_closure_vars_by_type(
all_closure_vars, fbuiltins.FieldOffset, common.Dimension
)
grid_type = transform_utils._deduce_grid_type(
inp.grid_type, offsets_and_dimensions.values()
inp.definition.grid_type, offsets_and_dimensions.values()
)

gt_callables = transform_utils._filter_closure_vars_by_type(
Expand All @@ -54,14 +56,17 @@ def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall:
lowered_funcs = [gt_callable.__gt_itir__() for gt_callable in gt_callables]

itir_program = ProgramLowering.apply(
inp.past_node, function_definitions=lowered_funcs, grid_type=grid_type
inp.definition.past_node,
function_definitions=lowered_funcs,
grid_type=grid_type,
)

if config.DEBUG or "debug" in inp.kwargs:
if config.DEBUG or inp.definition.debug:
devtools.debug(itir_program)

return stages.ProgramCall(
itir_program, inp.args, inp.kwargs | {"column_axis": _column_axis(all_closure_vars)}
return stages.AOTProgram(
program=itir_program,
argspec=dataclasses.replace(inp.argspec, column_axis=_column_axis(all_closure_vars)),
)


Expand All @@ -70,6 +75,29 @@ class Meta:
model = PastToItir


@dataclasses.dataclass(frozen=True)
class JITPastToItir(workflow.ChainableWorkflowMixin):
inner: PastToItir = dataclasses.field(default_factory=PastToItir)

def __call__(self, inp: ffront_stages.PastClosure) -> stages.ProgramCall:
aot_program = self.inner(
ffront_stages.AOTFieldviewProgramAst(
definition=inp.definition,
argspec=stages.CompileArgSpec.from_concrete(*inp.args, **inp.kwargs),
)
)
return stages.ProgramCall(
program=aot_program.program,
args=inp.args,
kwargs=inp.kwargs | {"column_axis": aot_program.argspec.column_axis},
)


class JITPastToItirFactory(factory.Factory):
class Meta:
model = JITPastToItir


def _column_axis(all_closure_vars: dict[str, Any]) -> Optional[common.Dimension]:
# construct mapping from column axis to scan operators defined on
# that dimension. only one column axis is allowed, but we can use
Expand Down Expand Up @@ -193,15 +221,21 @@ def visit_Program(

params = self.visit(node.params)

implicit_domain = False
if any("domain" not in body_entry.kwargs for body_entry in node.body):
params = params + self._gen_size_params_from_program(node)
implicit_domain = True

closures: list[itir.StencilClosure] = []
for stmt in node.body:
closures.append(self._visit_stencil_call(stmt, **kwargs))

return itir.FencilDefinition(
id=node.id, function_definitions=function_definitions, params=params, closures=closures
id=node.id,
function_definitions=function_definitions,
params=params,
closures=closures,
implicit_domain=implicit_domain,
)

def _visit_stencil_call(self, node: past.Call, **kwargs: Any) -> itir.StencilClosure:
Expand Down
15 changes: 12 additions & 3 deletions src/gt4py/next/ffront/stages.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from gt4py.eve import extended_typing as xtyping
from gt4py.next import common
from gt4py.next.ffront import field_operator_ast as foast, program_ast as past, source_utils
from gt4py.next.otf import stages as otf_stages
from gt4py.next.type_system import type_specifications as ts


Expand All @@ -43,6 +44,7 @@ class FieldOperatorDefinition(Generic[OperatorNodeT]):
grid_type: Optional[common.GridType] = None
node_class: type[OperatorNodeT] = dataclasses.field(default=foast.FieldOperator) # type: ignore[assignment] # TODO(ricoh): understand why mypy complains
attributes: dict[str, Any] = dataclasses.field(default_factory=dict)
debug: bool = False


@dataclasses.dataclass(frozen=True)
Expand All @@ -51,6 +53,7 @@ class FoastOperatorDefinition(Generic[OperatorNodeT]):
closure_vars: dict[str, Any]
grid_type: Optional[common.GridType] = None
attributes: dict[str, Any] = dataclasses.field(default_factory=dict)
debug: bool = False


@dataclasses.dataclass(frozen=True)
Expand All @@ -73,24 +76,30 @@ class FoastClosure(Generic[OperatorNodeT]):
class ProgramDefinition:
definition: types.FunctionType
grid_type: Optional[common.GridType] = None
debug: bool = False


@dataclasses.dataclass(frozen=True)
class PastProgramDefinition:
past_node: past.Program
closure_vars: dict[str, Any]
grid_type: Optional[common.GridType] = None
debug: bool = False


@dataclasses.dataclass(frozen=True)
class PastClosure:
closure_vars: dict[str, Any]
past_node: past.Program
grid_type: Optional[common.GridType]
definition: PastProgramDefinition
args: tuple[Any, ...]
kwargs: dict[str, Any]


@dataclasses.dataclass(frozen=True)
class AOTFieldviewProgramAst:
definition: PastProgramDefinition
argspec: otf_stages.CompileArgSpec


def fingerprint_stage(obj: Any, algorithm: Optional[str | xtyping.HashlibAlgorithm] = None) -> str:
hasher: xtyping.HashlibAlgorithm
if not algorithm:
Expand Down
7 changes: 7 additions & 0 deletions src/gt4py/next/iterator/embedded.py
Original file line number Diff line number Diff line change
Expand Up @@ -1566,6 +1566,13 @@ def fendef_embedded(fun: Callable[..., None], *args: Any, **kwargs: Any):
common.UnitRange(0, 0), # empty: indicates column operation, will update later
)

import inspect

if len(args) < len(inspect.getfullargspec(fun).args):
from gt4py.next.program_processors.runners import gtfn

args = (*args, *gtfn.iter_size_args(args))

with embedded_context.new_context(**context_vars) as ctx:
ctx.run(fun, *args)

Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/iterator/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,6 +216,7 @@ class FencilDefinition(Node, ValidatedSymbolTableTrait):
function_definitions: List[FunctionDefinition]
params: List[Sym]
closures: List[StencilClosure]
implicit_domain: bool = False

_NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in BUILTINS]

Expand All @@ -241,6 +242,7 @@ class Program(Node, ValidatedSymbolTableTrait):
params: List[Sym]
declarations: List[Temporary]
body: List[Stmt]
implicit_domain: bool = False

_NODE_SYMBOLS_: ClassVar[List[Sym]] = [Sym(id=name) for name in GTIR_BUILTINS]

Expand Down
2 changes: 2 additions & 0 deletions src/gt4py/next/iterator/transforms/fencil_to_program.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ def visit_FencilDefinition(self, node: itir.FencilDefinition) -> itir.Program:
params=node.params,
declarations=[],
body=self.visit(node.closures),
implicit_domain=node.implicit_domain,
)

def visit_FencilWithTemporaries(self, node: global_tmps.FencilWithTemporaries) -> itir.Program:
Expand All @@ -43,4 +44,5 @@ def visit_FencilWithTemporaries(self, node: global_tmps.FencilWithTemporaries) -
params=node.params,
declarations=node.tmps,
body=self.visit(node.fencil.closures),
implicit_domain=node.fencil.implicit_domain,
)
2 changes: 2 additions & 0 deletions src/gt4py/next/iterator/transforms/global_tmps.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,7 @@ def always_extract_heuristics(_):
+ [ir.Sym(id=AUTO_DOMAIN.fun.id)], # type: ignore[attr-defined] # value is a global constant
closures=list(reversed(closures)),
location=node.location,
implicit_domain=node.implicit_domain,
),
params=node.params,
tmps=[ir.Temporary(id=tmp.id) for tmp in tmps],
Expand Down Expand Up @@ -564,6 +565,7 @@ def update_domains(
params=node.fencil.params[:-1], # remove `_gtmp_auto_domain` param again
closures=list(reversed(closures)),
location=node.fencil.location,
implicit_domain=node.fencil.implicit_domain,
),
params=node.params,
tmps=node.tmps,
Expand Down
3 changes: 3 additions & 0 deletions src/gt4py/next/otf/compilation/build_systems/compiledb.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def __call__(
cmake_flags=self.cmake_extra_flags or [],
language=source.program_source.language,
language_settings=source.program_source.language_settings,
implicit_domain=source.program_source.implicit_domain,
)

if self.renew_compiledb or not (
Expand Down Expand Up @@ -219,6 +220,7 @@ def _cc_prototype_program_source(
cmake_flags: list[str],
language: type[SrcL],
language_settings: languages.LanguageWithHeaderFilesSettings,
implicit_domain: bool,
) -> stages.ProgramSource:
name = _cc_prototype_program_name(deps, build_type.value, cmake_flags)
return stages.ProgramSource(
Expand All @@ -227,6 +229,7 @@ def _cc_prototype_program_source(
library_deps=deps,
language=language,
language_settings=language_settings,
implicit_domain=implicit_domain,
)


Expand Down
Loading
Loading