# MPI-aware tasks
Hopefully we are able to use tasks that act like MPI-aware libraries so that we can directly access the memory-space of the tasks (and avoid the file-system). In this case the tasks will execute in an `MPI.COMM_WORLD` that includes all the MPI processes of a particular job allocation. 

In [None]:
import sys
from jobqueue_features.clusters import CustomSLURMCluster
from jobqueue_features.decorators import on_cluster, mpi_task
from jobqueue_features.mpi_wrapper import SRUN

When using this mode, we need to tell the cluster what type of MPI launcher to use for the Dask worker task via the `mpi_launcher` kwarg

In [None]:
custom_cluster = CustomSLURMCluster(
    name="mpiCluster", walltime="00:03:00", queue='devel', nodes=2, mpi_mode=True, mpi_launcher=SRUN
)

In [None]:
@mpi_task(cluster_id=custom_cluster.name)
def task1(task_name):
    from mpi4py import MPI

    comm = MPI.COMM_WORLD
    size = comm.Get_size()
    name = MPI.Get_processor_name()
    all_nodes = comm.gather(name, root=0)
    if all_nodes:
        all_nodes = set(all_nodes)
    else:
        all_nodes = []
    # Since it is a return  value it will only get printed by root
    return_string = "Running %d tasks of type %s on nodes %s." % (
        size,
        task_name,
        all_nodes,
    )
    # The flush is required to ensure that the print statements appear in the job log
    # files
    print(return_string)
    sys.stdout.flush()
    return return_string

In [None]:
@mpi_task(cluster_id=custom_cluster.name)
def task2(name, task_name="default"):
    from mpi4py import MPI

    comm = MPI.COMM_WORLD
    rank = comm.Get_rank()
    # This only appears in the slurm job output
    return_string = "Hi %s, my rank is %d for task of type %s" % (name, rank, task_name)
    # The flush is required to ensure that the print statements appear in the job log
    # files
    print(return_string)
    sys.stdout.flush()
    return return_string

In [None]:
@on_cluster(cluster=custom_cluster, cluster_id=custom_cluster.name)
def mpi_example():
    t1 = task1("task1")
    t2 = task1("task1, 2nd iteration")
    t3 = task2("Alan", task_name="Task 2")
    print(t1.result())
    print(t2.result())
    print(t3.result())

In [None]:
mpi_example()