# A notebook to test the physics implementation and infrastructure

## First set up parameters and our MPI environment:

In [None]:
nx = 20
ny = 20
nz = 10
nhalo = 3
backend = "numpy"

import ipyparallel as ipp

layout = (1, 1)
ntiles = 6
# spinup cluster of MPI-workers
num_ranks = ntiles * layout[0] * layout[1]

cluster = ipp.Cluster(engines="mpi", n=num_ranks).start_and_connect_sync()

# broadcast configuration to all workers
ar = cluster[:].push(
    {
        "ntiles": ntiles,
        "nx": nx,
        "ny": ny,
        "nz": nz,
        "nhalo": nhalo,
        "layout": layout,
        "backend": backend,
    }
)

# start executing cells on the workers in parallel from here on
%autopx

In [None]:
from mpi4py import MPI

mpi_comm = MPI.COMM_WORLD
mpi_rank = mpi_comm.Get_rank()
print(f"Hello from rank {mpi_rank}")

## Next set up the NDSL structures we'll use:

In [None]:
import gt4py.cartesian.gtscript as gtscript
from gt4py.cartesian.gtscript import PARALLEL, computation, interval

from ndsl.dsl.typing import Float, FloatField, FloatFieldIJ
from ndsl import (
    CompilationConfig,
    CubedSphereCommunicator,
    CubedSpherePartitioner,
    GridIndexing,
    Quantity,
    QuantityFactory,
    StencilConfig,
    StencilFactory,
    SubtileGridSizer,
    TilePartitioner,
    WrappedHaloUpdater,
)
from ndsl.constants import X_DIM, Y_DIM, Z_DIM 
from ndsl.typing import Communicator 

In [None]:
partitioner = CubedSpherePartitioner(TilePartitioner(layout))
cs_communicator = CubedSphereCommunicator(mpi_comm, partitioner)

sizer = SubtileGridSizer.from_tile_params(
    nx_tile=nx,
    ny_tile=ny,
    nz=nz,
    n_halo=nhalo,
    extra_dim_lengths={},
    layout=layout,
    tile_partitioner=partitioner.tile,
    tile_rank=cs_communicator.tile.rank,
)

# useful for easily allocating distributed data storages (fields)
quantity_factory = QuantityFactory.from_backend(sizer=sizer, backend=backend)

compilation_config = CompilationConfig(backend=backend, communicator=cs_communicator)

stencil_config = StencilConfig(compare_to_numpy=False, compilation_config=compilation_config)

grid_indexing = GridIndexing.from_sizer_and_communicator(sizer=sizer, comm=cs_communicator)

stencil_factory = StencilFactory(config=stencil_config, grid_indexing=grid_indexing)

# Set up the fields and the stencil:

In [None]:
from pySHiELD.stencils.physics import forward_euler

def euler_stencil(q: FloatField, qt: FloatField, dt: Float):
    with computation(PARALLEL), interval(...):
        q = forward_euler(q, qt, dt)

In [None]:
qq = quantity_factory.ones(dims=(X_DIM, Y_DIM, Z_DIM), units="none", dtype="float")
qt = quantity_factory.ones(dims=(X_DIM, Y_DIM, Z_DIM), units="none", dtype="float")
qt.view[:] *= 0.2
dt = 0.5

test_stencil = stencil_factory.from_origin_domain(
    func = euler_stencil,
    origin=grid_indexing.origin_compute(),
    domain=grid_indexing.domain_compute(),
)

## And run it:

In [None]:
test_stencil(qq, qt, dt)

if mpi_rank == 0:
    print(qq.data[:,:,0])