Skip to content

Commit

Permalink
MPI Reduce for ValLocPair (AMReX-Codes#3003)
Browse files Browse the repository at this point in the history
Add ParallelReduce::Min, ParallelReduce::Max, ParallelAllReduce::Min,
and ParallelAllReduce::Max for ValLocPair<TV,TI>, where TV and TI are
types that have corresponding MPI types (e.g., int, Real, IntVect, Box,
etc.).
  • Loading branch information
WeiqunZhang committed Oct 29, 2022
1 parent 3ec0768 commit 735c351
Show file tree
Hide file tree
Showing 7 changed files with 181 additions and 25 deletions.
73 changes: 73 additions & 0 deletions Src/Base/AMReX_ParallelDescriptor.H
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <AMReX_REAL.H>
#include <AMReX_Array.H>
#include <AMReX_Vector.H>
#include <AMReX_ValLocPair.H>

#ifndef BL_AMRPROF
#include <AMReX_Box.H>
Expand Down Expand Up @@ -211,6 +212,11 @@ while ( false )
extern AMREX_EXPORT MPI_Comm m_comm;
inline MPI_Comm Communicator () noexcept { return m_comm; }

#ifdef AMREX_USE_MPI
extern Vector<MPI_Datatype*> m_mpi_types;
extern Vector<MPI_Op*> m_mpi_ops;
#endif

//! return the number of MPI ranks local to the current Parallel Context
inline int
NProcs () noexcept
Expand Down Expand Up @@ -1479,6 +1485,73 @@ void DoReduce (T* r, MPI_Op op, int cnt, int cpu)
#endif
}

#ifdef AMREX_USE_MPI
namespace ParallelDescriptor {

template<typename TV, typename TI>
struct Mpi_typemap<ValLocPair<TV,TI>>
{
static MPI_Datatype type ()
{
static MPI_Datatype mpi_type = MPI_DATATYPE_NULL;
if (mpi_type == MPI_DATATYPE_NULL) {
using T = ValLocPair<TV,TI>;
static_assert(std::is_trivially_copyable<T>::value,
"To communicate with MPI, ValLocPair must be trivially copyable.");
static_assert(std::is_standard_layout<T>::value,
"To communicate with MPI, ValLocPair must be standard layout");

T vlp[2];
MPI_Datatype types[] = {
Mpi_typemap<TV>::type(),
Mpi_typemap<TI>::type(),
};
int blocklens[] = { 1, 1 };
MPI_Aint disp[2];
BL_MPI_REQUIRE( MPI_Get_address(&vlp[0].value, &disp[0]) );
BL_MPI_REQUIRE( MPI_Get_address(&vlp[0].index, &disp[1]) );
disp[1] -= disp[0];
disp[0] = 0;
BL_MPI_REQUIRE( MPI_Type_create_struct(2, blocklens, disp, types,
&mpi_type) );
MPI_Aint lb, extent;
BL_MPI_REQUIRE( MPI_Type_get_extent(mpi_type, &lb, &extent) );
if (extent != sizeof(T)) {
MPI_Datatype tmp = mpi_type;
BL_MPI_REQUIRE( MPI_Type_create_resized(tmp, 0, sizeof(vlp[0]), &mpi_type) );
BL_MPI_REQUIRE( MPI_Type_free(&tmp) );
}
BL_MPI_REQUIRE( MPI_Type_commit( &mpi_type ) );

m_mpi_types.push_back(&mpi_type);
}
return mpi_type;
}
};

template <typename T, typename F>
MPI_Op Mpi_op ()
{
static MPI_Op mpi_op = MPI_OP_NULL;
if (mpi_op == MPI_OP_NULL) {
static auto user_fn = [] (void *invec, void *inoutvec, int* len,
MPI_Datatype * /*datatype*/)
{
auto in = static_cast<T const*>(invec);
auto out = static_cast<T*>(inoutvec);
for (int i = 0; i < *len; ++i) {
out[i] = F()(in[i],out[i]);
}
};
BL_MPI_REQUIRE( MPI_Op_create(user_fn, 1, &mpi_op) );
m_mpi_ops.push_back(&mpi_op);
}
return mpi_op;
}

}
#endif

}

#endif /*BL_PARALLELDESCRIPTOR_H*/
15 changes: 15 additions & 0 deletions Src/Base/AMReX_ParallelDescriptor.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@ namespace amrex { namespace ParallelDescriptor {

MPI_Comm m_comm = MPI_COMM_NULL; // communicator for all ranks, probably MPI_COMM_WORLD

#ifdef AMREX_USE_MPI
Vector<MPI_Datatype*> m_mpi_types;
Vector<MPI_Op*> m_mpi_ops;
#endif

int m_MinTag = 1000, m_MaxTag = -1;

const int ioProcessor = 0;
Expand Down Expand Up @@ -357,10 +362,20 @@ EndParallel ()
BL_MPI_REQUIRE( MPI_Type_free(&mpi_type_indextype) );
BL_MPI_REQUIRE( MPI_Type_free(&mpi_type_box) );
BL_MPI_REQUIRE( MPI_Type_free(&mpi_type_lull_t) );
for (auto t : m_mpi_types) {
BL_MPI_REQUIRE( MPI_Type_free(t) );
*t = MPI_DATATYPE_NULL;
}
for (auto op : m_mpi_ops) {
BL_MPI_REQUIRE( MPI_Op_free(op) );
*op = MPI_OP_NULL;
}
mpi_type_intvect = MPI_DATATYPE_NULL;
mpi_type_indextype = MPI_DATATYPE_NULL;
mpi_type_box = MPI_DATATYPE_NULL;
mpi_type_lull_t = MPI_DATATYPE_NULL;
m_mpi_types.clear();
m_mpi_ops.clear();
}

if (!call_mpi_finalize) {
Expand Down
55 changes: 55 additions & 0 deletions Src/Base/AMReX_ParallelReduce.H
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <AMReX_Config.H>

#include <AMReX.H>
#include <AMReX_Functional.H>
#include <AMReX_ParallelDescriptor.H>
#include <AMReX_Print.H>
#include <AMReX_Vector.H>
Expand Down Expand Up @@ -120,6 +121,32 @@ namespace ParallelGather {

namespace ParallelAllReduce {

template<typename TV, typename TI>
void Max (ValLocPair<TV,TI>& vi, MPI_Comm comm) {
#ifdef AMREX_USE_MPI
auto tmp = vi;
using T = ValLocPair<TV,TI>;
MPI_Allreduce(&tmp, &vi, 1,
ParallelDescriptor::Mpi_typemap<T>::type(),
ParallelDescriptor::Mpi_op<T,amrex::Greater<T>>(), comm);
#else
amrex::ignore_unused(vi, comm);
#endif
}

template<typename TV, typename TI>
void Min (ValLocPair<TV,TI>& vi, MPI_Comm comm) {
#ifdef AMREX_USE_MPI
auto tmp = vi;
using T = ValLocPair<TV,TI>;
MPI_Allreduce(&tmp, &vi, 1,
ParallelDescriptor::Mpi_typemap<T>::type(),
ParallelDescriptor::Mpi_op<T,amrex::Less<T>>(), comm);
#else
amrex::ignore_unused(vi, comm);
#endif
}

template<typename T>
void Max (T& v, MPI_Comm comm) {
detail::Reduce(detail::ReduceOp::max, v, -1, comm);
Expand Down Expand Up @@ -174,6 +201,34 @@ namespace ParallelAllReduce {

namespace ParallelReduce {

template<typename TV, typename TI>
void Max (ValLocPair<TV,TI>& vi, int root, MPI_Comm comm) {
#ifdef AMREX_USE_MPI
auto tmp = vi;
using T = ValLocPair<TV,TI>;
MPI_Reduce(&tmp, &vi, 1,
ParallelDescriptor::Mpi_typemap<T>::type(),
ParallelDescriptor::Mpi_op<T,amrex::Greater<T>>(),
root, comm);
#else
amrex::ignore_unused(vi, root, comm);
#endif
}

template<typename TV, typename TI>
void Min (ValLocPair<TV,TI>& vi, int root, MPI_Comm comm) {
#ifdef AMREX_USE_MPI
auto tmp = vi;
using T = ValLocPair<TV,TI>;
MPI_Reduce(&tmp, &vi, 1,
ParallelDescriptor::Mpi_typemap<T>::type(),
ParallelDescriptor::Mpi_op<T,amrex::Less<T>>(),
root, comm);
#else
amrex::ignore_unused(vi, root, comm);
#endif
}

template<typename T>
void Max (T& v, int root, MPI_Comm comm) {
detail::Reduce(detail::ReduceOp::max, v, root, comm);
Expand Down
26 changes: 1 addition & 25 deletions Src/Base/AMReX_Reduce.H
Original file line number Diff line number Diff line change
Expand Up @@ -6,38 +6,14 @@
#include <AMReX_Arena.H>
#include <AMReX_OpenMP.H>
#include <AMReX_MFIter.H>
#include <AMReX_ValLocPair.H>

#include <algorithm>
#include <functional>
#include <limits>

namespace amrex {

template <typename TV, typename TI>
struct ValLocPair
{
TV value;
TI index;

static constexpr ValLocPair<TV,TI> max () {
return ValLocPair<TV,TI>{std::numeric_limits<TV>::max(), TI()};
}

static constexpr ValLocPair<TV,TI> lowest () {
return ValLocPair<TV,TI>{std::numeric_limits<TV>::lowest(), TI()};
}

friend constexpr bool operator< (ValLocPair<TV,TI> const& a, ValLocPair<TV,TI> const& b)
{
return a.value < b.value;
}

friend constexpr bool operator> (ValLocPair<TV,TI> const& a, ValLocPair<TV,TI> const& b)
{
return a.value > b.value;
}
};

namespace Reduce { namespace detail {

#ifdef AMREX_USE_GPU
Expand Down
35 changes: 35 additions & 0 deletions Src/Base/AMReX_ValLocPair.H
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
#ifndef AMREX_VALLOCPAIR_H_
#define AMREX_VALLOCPAIR_H_

#include <limits>

namespace amrex {

template <typename TV, typename TI>
struct ValLocPair
{
TV value;
TI index;

static constexpr ValLocPair<TV,TI> max () {
return ValLocPair<TV,TI>{std::numeric_limits<TV>::max(), TI()};
}

static constexpr ValLocPair<TV,TI> lowest () {
return ValLocPair<TV,TI>{std::numeric_limits<TV>::lowest(), TI()};
}

friend constexpr bool operator< (ValLocPair<TV,TI> const& a, ValLocPair<TV,TI> const& b)
{
return a.value < b.value;
}

friend constexpr bool operator> (ValLocPair<TV,TI> const& a, ValLocPair<TV,TI> const& b)
{
return a.value > b.value;
}
};

}

#endif
1 change: 1 addition & 0 deletions Src/Base/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ target_sources( amrex
AMReX_Utility.cpp
AMReX_FileSystem.H
AMReX_FileSystem.cpp
AMReX_ValLocPair.H
AMReX_Reduce.H
AMReX_Scan.H
AMReX_Partition.H
Expand Down
1 change: 1 addition & 0 deletions Src/Base/Make.package
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ C$(AMREX_BASE)_sources += AMReX_BlockMutex.cpp
C$(AMREX_BASE)_sources += AMReX_ParmParse.cpp AMReX_parmparse_fi.cpp AMReX_Utility.cpp
C$(AMREX_BASE)_headers += AMReX_ParmParse.H AMReX_Utility.H AMReX_BLassert.H AMReX_ArrayLim.H
C$(AMREX_BASE)_headers += AMReX_Functional.H AMReX_Reduce.H AMReX_Scan.H AMReX_Partition.H
C$(AMREX_BASE)_headers += AMReX_ValLocPair.H

C$(AMREX_BASE)_headers += AMReX_FileSystem.H
C$(AMREX_BASE)_sources += AMReX_FileSystem.cpp
Expand Down

0 comments on commit 735c351

Please sign in to comment.