# **NDSL Orchestration Basics**

### **Introduction**

When writing code using NDSL, there will be moments where an algorithm or code pattern does not match the stencil paradigm, and shoehorning the algorithm into the paradigm increases development difficulty.  For these moments, we have a capability called orchestration that enables developers to use regular Python for non-stencil algorithms alongside stencil-based code via [DaCe](https://github.com/spcl/dace).  DaCe also will attempt to find optimizations before output C++ code.

In this example, we will explore how to orchestrate a codebase using NDSL.

### **Orchestration Example**

We'll step through a simple example that will orchestrate a codebase containing stencils and Python code.  First we'll import the necessary packages.

In [1]:
import numpy as np
from gt4py.cartesian.gtscript import (
    PARALLEL,
    computation,
    interval,
)
from ndsl import (
    StencilFactory,
    DaceConfig,
    orchestrate,
    QuantityFactory,
)
from ndsl.constants import X_DIM, Y_DIM, Z_DIM
from ndsl.dsl.typing import FloatField, Float

from orch_boilerplate import get_one_tile_factory_orchestrated

2024-05-29 13:16:33|INFO|rank 0|ndsl.logging:Constant selected: ConstantVersions.GFS


Next we'll define a simple stencil that sums the values around a point and applies a weight factor to that sum.  Note that unlike [previous](./01_gt4py_basics.ipynb#Copy_Stencil_example) examples, we are not using the `@stencil` decorator since this stencil will be referenced within a `StencilFactory` function call.

In [2]:
def localsum_stencil(
    field: FloatField,  # type: ignore
    result: FloatField,  # type: ignore
    weight: Float,  # type: ignore
):
    with computation(PARALLEL), interval(...):
        result = weight * (
            field[1, 0, 0] + field[0, 1, 0] + field[-1, 0, 0] + field[0, -1, 0]
        )

We'll define an object that enables the orchestration and combines both stencils and regular Python codes.  The orchestration occurs with the `orchestrate` call in the `__init__` definition.  Within `__call__`, there's a combination of both stencil and regular python codes.

In [3]:
class LocalSum:
    def __init__(
        self, stencil_factory: StencilFactory, quantity_factory: QuantityFactory
    ) -> None:
        orchestrate(
            obj=self,
            config=stencil_factory.config.dace_config
            or DaceConfig(None, stencil_factory.backend),
        )
        grid_indexing = stencil_factory.grid_indexing
        self._local_sum = stencil_factory.from_origin_domain(
            localsum_stencil,  # <-- gt4py stencil function wrapped into NDSL
            origin=grid_indexing.origin_compute(),
            domain=grid_indexing.domain_compute(),
        )
        self._tmp_field = quantity_factory.zeros(
            [X_DIM, Y_DIM, Z_DIM], "n/a", dtype=dtype
        )
        self._n_halo = quantity_factory.sizer.n_halo

    def __call__(self, in_field: FloatField, out_result: FloatField) -> None:
        self._local_sum(in_field, out_result, 2.0) # GT4Py Stencil
        tmp_field = out_result[:, :, :] + 2        # Regular Python code
        self._local_sum(tmp_field, out_result, 2.0) # GT4Py Stencil

Next, we'll create a simple driver that defines the domain and halo size, specifies the backend (`dace:cpu` in order to use DaCe), and uses the boilerplate code to create a stencil and quantity factory objects.  These objects help define the computational domain used for this particular example.  After defining quantities (`in_field` and `out_field`) to hold the appropriate values and creating an object `local_sum` for our combined stencil/Python calculation, `local_sum` is called to perform the computation.  In the output, we can see DaCe orchestrating the code. 

In [4]:
# ----- Driver ----- #

if __name__ == "__main__":
    # Settings
    backend = "dace:cpu"
    dtype = np.float64
    origin = (0, 0, 0)
    rebuild = True
    tile_size = (3, 3, 3)

    # Setup
    stencil_factory, qty_factory = get_one_tile_factory_orchestrated(
        nx=tile_size[0],
        ny=tile_size[1],
        nz=tile_size[2],
        nhalo=2,
        backend=backend,
    )
    local_sum = LocalSum(stencil_factory, qty_factory)

    in_field = qty_factory.zeros([X_DIM, Y_DIM, Z_DIM], "n/a", dtype=dtype)
    in_field.view[:] = 2.0
    out_field = qty_factory.zeros([X_DIM, Y_DIM, Z_DIM], "n/a", dtype=dtype)

    # Run
    local_sum(in_field, out_field)

2024-05-29 13:16:33|INFO|rank 0|ndsl.logging:[DaCeOrchestration.BuildAndRun] Rank 0 reading/writing cache .gt_cache_FV3_A
2024-05-29 13:16:33|INFO|rank 0|ndsl.logging:Building DaCe orchestration
Inlined 2 SDFGs.
Fused 4 states.
Inferred 2 optional arrays.
SDFG 0: Eliminated 1 arrays: {'out_result_0'}.
Fused 2 states.
Inferred 4 optional arrays.
Inlined 2 SDFGs.
2024-05-29 13:16:34|INFO|rank 0|ndsl.logging:[DaCeOrchestration.BuildAndRun] LocalSum___call__:
StorageType.Default:
  Alloc ref 0.01 mb
  Alloc unref 0.00 mb
  Pooled 0.00 mb
  Top lvl alloc: 0.01mb

[DaCe Config] Rank 0 loading SDFG /home/ckung/Documents/Code/SMT-Nebulae-Tutorial/tutorial/NDSL/.gt_cache_FV3_A/dacecache/LocalSum___call__
