# Grid generation and other things ...

In [None]:
import ipyparallel as ipp

cluster = ipp.Cluster(engines="mpi", n=6).start_and_connect_sync()
%autopx

In [2]:
%%capture
nx = 100
ny = 100
nz = 79
""" (int): number of points per tile in each direction """

nhalo = 3
""" (int): number of halo points in x- and y- directions """

layout = (1, 1)
""" (int, int): how many parts each tile is split into? """

backend = "numpy"
""" (str): """

In [4]:
from mpi4py import MPI

mpi_comm = MPI.COMM_WORLD
mpi_size = mpi_comm.Get_size()
mpi_rank = mpi_comm.Get_rank()

In [5]:
from pace.util import CubedSpherePartitioner, TilePartitioner

partitioner = CubedSpherePartitioner(TilePartitioner(layout))

In [6]:
from pace.util import CubedSphereCommunicator

communicator = CubedSphereCommunicator(mpi_comm, partitioner)

In [7]:
from pace.util import SubtileGridSizer

sizer = SubtileGridSizer.from_tile_params(
    nx_tile=nx,
    ny_tile=ny,
    nz=nz,
    n_halo=nhalo,
    extra_dim_lengths={},
    layout=layout,
    tile_partitioner=partitioner.tile,
    tile_rank=communicator.tile.rank,
)

In [8]:
from pace.util import QuantityFactory

quantity_factory = QuantityFactory.from_backend(
    sizer=sizer, 
    backend=backend
    )



In [9]:
from pace.util.grid import MetricTerms

metric_terms = MetricTerms(
    quantity_factory=quantity_factory, 
    communicator=communicator
    )
    

In [10]:
from pace.util.grid import DampingCoefficients

damping_coefficients = DampingCoefficients.new_from_metric_terms(metric_terms)
    

  np.sum(p * q, axis=-1)


  np.sum(p * q, axis=-1)


In [None]:
from pace.util.grid import GridData

grid_data = GridData.new_from_metric_terms(metric_terms)    

In [14]:
from pace.dsl.dace.dace_config import DaceConfig, DaCeOrchestration

dace_config = DaceConfig(
    communicator=communicator, 
    backend=backend, 
    orchestration=DaCeOrchestration.Python
    ) 

In [16]:
from pace.dsl.stencil_config import CompilationConfig, RunMode

compilation_config = CompilationConfig(backend=backend, 
    rebuild=True, 
    validate_args=True, 
    format_source=False, 
    device_sync=False, 
    run_mode=RunMode.BuildAndRun, 
    use_minimal_caching=False, 
    communicator=communicator,
)

In [20]:
from pace.dsl.stencil import StencilConfig

stencil_config = StencilConfig(
    compare_to_numpy=False,
    compilation_config=compilation_config,
    dace_config = dace_config
)

In [21]:
from pace.dsl.stencil import GridIndexing

grid_indexing = GridIndexing.from_sizer_and_communicator(
        sizer=sizer, 
        cube=communicator
    )

# # set the domain so there is only one level in the vertical -- forced
# domain = grid_indexing.domain
# domain_new = list(domain)
# domain_new[2] = 1
# domain_new = tuple(domain_new)

# grid_indexing.domain = domain_new

In [22]:
from pace.dsl.stencil import StencilFactory

stencil_factory = StencilFactory(
    config=stencil_config, 
    grid_indexing=grid_indexing
    )