-
Notifications
You must be signed in to change notification settings - Fork 89
/
mpi.jl
30 lines (27 loc) · 1.37 KB
/
mpi.jl
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
# Convenience functions for working with MPI
using MPI
"""
Initialize MPI. Must be called before doing any non-trivial MPI work
(even in the single-process case). Unlike the MPI.Init() function,
this can be called multiple times.
"""
function mpi_ensure_initialized()
# MPI Thread level 3 means that the environment is multithreaded, but that only
# one thread will call MPI at once
# see https://www.open-mpi.org/doc/current/man3/MPI_Init_thread.3.php#toc7
# TODO look more closely at interaction between MPI and threads
MPI.Initialized() || MPI.Init_thread(MPI.ThreadLevel(3))
end
"""
Number of processors used in MPI. Can be called without ensuring initialization.
"""
mpi_nprocs(comm=MPI.COMM_WORLD) = (mpi_ensure_initialized(); MPI.Comm_size(comm))
mpi_master(comm=MPI.COMM_WORLD) = (mpi_ensure_initialized(); MPI.Comm_rank(comm) == 0)
mpi_sum( arr, comm::MPI.Comm) = MPI.Allreduce( arr, +, comm)
mpi_sum!( arr, comm::MPI.Comm) = MPI.Allreduce!(arr, +, comm)
mpi_min( arr, comm::MPI.Comm) = MPI.Allreduce( arr, min, comm)
mpi_min!( arr, comm::MPI.Comm) = MPI.Allreduce!(arr, min, comm)
mpi_max( arr, comm::MPI.Comm) = MPI.Allreduce( arr, max, comm)
mpi_max!( arr, comm::MPI.Comm) = MPI.Allreduce!(arr, max, comm)
mpi_mean( arr, comm::MPI.Comm) = mpi_sum(arr, comm) ./ mpi_nprocs(comm)
mpi_mean!(arr, comm::MPI.Comm) = (mpi_sum!(arr, comm); arr ./= mpi_nprocs(comm))