# MPI-aware tasks
We would also like to be able to use tasks that are MPI-aware, and in that way we can directly access the memory-space of the tasks (and avoid the file-system). In this case the tasks will execute within the context of an MPI communicator (by default this is equivalent to `MPI.COMM_WORLD`) that includes all the MPI processes of a particular job allocation. 

In [None]:
import sys
from jobqueue_features.clusters import CustomSLURMCluster
from jobqueue_features import (
    on_cluster, 
    mpi_task, 
    get_task_mpi_comm
)
from jobqueue_features.clusters_controller import (
    clusters_controller_singleton as controller,
)

Again, we need to tell the cluster that we need `mpi_mode` and give it enough information so that it can request the required resources.

In [None]:
custom_cluster = CustomSLURMCluster(
    name="mpiCluster", mpi_mode=True, nodes=2,
)

Now let's define a couple of MPI-aware tasks that take different arguments.

Note that we only import `mpi4py` within the context of the task. This is **important** since importing it into our notebook context will mess up our ability to use the remote MPI context.

Let's define one that takes an argument

In [None]:
@mpi_task(cluster=custom_cluster)
def task1(task_name):
    from mpi4py import MPI

    comm = get_task_mpi_comm()
    size = comm.Get_size()
    name = MPI.Get_processor_name()
    all_nodes = comm.gather(name, root=0)
    if all_nodes:
        all_nodes = set(all_nodes)
    else:
        all_nodes = []
    # Since it is a return  value it will only get printed by root
    return_string = "Running %d tasks of type %s on nodes %s." % (
        size,
        task_name,
        all_nodes,
    )
    # The flush is required to ensure that the print statements
    # appear in the job log files
    print(return_string)
    sys.stdout.flush()
    return return_string

and another that takes and argument and a kwarg

In [None]:
@mpi_task(cluster=custom_cluster)
def task2(name, task_name="default"):
    from mpi4py import MPI

    comm = get_task_mpi_comm()
    rank = comm.Get_rank()
    # This only appears in the slurm job output
    return_string = "Hi %s, my rank is %d for task of type %s" % (name, rank, task_name)
    # The flush is required to ensure that the print statements appear in
    # the job log files
    print(return_string)
    sys.stdout.flush()
    return return_string

Not let's run these tasks in the context of our cluster

In [None]:
@on_cluster(cluster=custom_cluster)
def mpi_example():
    t1 = task1("task1")
    t2 = task1("task1, 2nd iteration")
    t3 = task2("Alan", task_name="Task 2")
    print(t1.result())
    print(t2.result())
    print(t3.result())

In [None]:
mpi_example()

What about all the output from the other MPI ranks? Go take a look in the output file from SLURM.

Let's clean up after ourselves

In [None]:
controller._close()

# Scaling up
Let's define a new cluster that can scale and test out our scalability

In [None]:
mpi_multi_cluster = CustomSLURMCluster(
    name="mpiMultiCluster",
    nodes=1,
    maximum_jobs=2,
    mpi_mode=True
)

Let's define a new task to run there.

In [None]:
@on_cluster(cluster=mpi_multi_cluster)
@mpi_task(cluster=mpi_multi_cluster)
def task(task_name):
    import time
    from mpi4py import MPI

    comm = get_task_mpi_comm()
    size = comm.Get_size()
    name = MPI.Get_processor_name()
    all_nodes = comm.gather(name, root=0)
    if all_nodes:
        all_nodes = list(set(all_nodes))
        all_nodes.sort()
    else:
        all_nodes = []
    # Since it is a return  value it will only get printed by root
    return_string = "Running %d tasks of type %s on nodes %s." % (
        size,
        task_name,
        all_nodes,
    )

    # Add a sleep to make the task substantial enough to require scaling
    time.sleep(0.5)
    return return_string

Then check the execution of it:

In [None]:
from distributed import as_completed

tasks = []
for x in range(100):
    tasks.append(
        task("task-{}".format(x))
    )
    
c1_count = 0
c2_count = 0
for job in as_completed(tasks):
    result = job.result()
    job.cancel()
    if 'c1' in result:
        c1_count += 1
    elif 'c2' in result:
        c2_count += 1
print("c1: {} \nc2: {}".format(c1_count, c2_count))

Let's also clean up after ourselves

In [None]:
controller._close()