Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pass custom mpi4py communicator #94

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 22 additions & 4 deletions pypolychord/_pypolychord.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,19 +5,29 @@
#define NPY_NO_DEPRECATED_API NPY_1_7_API_VERSION
#include <numpy/arrayobject.h>

#ifdef USE_MPI
#include <mpi.h>
#include <mpi4py/mpi4py.h>
#endif

/* Initialize the module */
#ifdef PYTHON3
PyMODINIT_FUNC PyInit__pypolychord(void)
{
import_array();
#ifdef USE_MPI
import_mpi4py();
#endif
return PyModule_Create(&_pypolychordmodule);
}
#else
PyMODINIT_FUNC init_pypolychord(void)
{
Py_InitModule3("_pypolychord", module_methods, module_docstring);
import_array();
#ifdef USE_MPI
import_mpi4py();
#endif
}
#endif

Expand Down Expand Up @@ -121,12 +131,12 @@ static PyObject *run_pypolychord(PyObject *, PyObject *args)
Settings S;

PyObject *temp_logl, *temp_prior, *temp_dumper;
PyObject* py_grade_dims, *py_grade_frac, *py_nlives;
PyObject* py_grade_dims, *py_grade_frac, *py_nlives, *py_comm;
char* base_dir, *file_root;


if (!PyArg_ParseTuple(args,
"OOOiiiiiiiiddidiiiiiiiiiiidissO!O!O!i:run",
"OOOiiiiiiiiddidiiiiiiiiiiidissO!O!O!iO:run",
&temp_logl,
&temp_prior,
&temp_dumper,
Expand Down Expand Up @@ -163,7 +173,8 @@ static PyObject *run_pypolychord(PyObject *, PyObject *args)
&py_grade_dims,
&PyDict_Type,
&py_nlives,
&S.seed
&S.seed,
&py_comm
)
)
return NULL;
Expand Down Expand Up @@ -216,7 +227,14 @@ static PyObject *run_pypolychord(PyObject *, PyObject *args)
python_dumper = temp_dumper;

/* Run PolyChord */
try{ run_polychord(loglikelihood, prior, dumper, S); }
try{
#ifdef USE_MPI
MPI_Comm *comm = PyMPIComm_Get(py_comm);
run_polychord(loglikelihood, prior, dumper, S, *comm);
#else
run_polychord(loglikelihood, prior, dumper, S);
#endif
}
catch (PythonException& e)
{
Py_DECREF(py_grade_frac);Py_DECREF(py_grade_dims);Py_DECREF(python_loglikelihood);Py_DECREF(python_prior);
Expand Down
34 changes: 21 additions & 13 deletions pypolychord/polychord.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def default_dumper(live, dead, logweights, logZ, logZerr):


def run_polychord(loglikelihood, nDims, nDerived, settings,
prior=default_prior, dumper=default_dumper):
prior=default_prior, dumper=default_dumper, comm=None):
"""
Runs PolyChord.

Expand Down Expand Up @@ -143,13 +143,15 @@ def run_polychord(loglikelihood, nDims, nDerived, settings,
Final output evidence statistics

"""

try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
except ImportError:
rank = 0
rank = 0
if comm is None:
try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
except ImportError:
pass
if comm is not None:
rank = comm.rank

if rank == 0:
Path(settings.cluster_dir).mkdir(parents=True, exist_ok=True)
Expand Down Expand Up @@ -207,7 +209,8 @@ def wrap_prior(cube, theta):
settings.grade_frac,
settings.grade_dims,
settings.nlives,
settings.seed)
settings.seed,
comm)

if settings.cube_samples is not None:
settings.read_resume = read_resume
Expand Down Expand Up @@ -512,10 +515,9 @@ def run(loglikelihood, nDims, **kwargs):

try:
from mpi4py import MPI
comm = MPI.COMM_WORLD
rank = comm.Get_rank()
default_comm = MPI.COMM_WORLD
except ImportError:
rank = 0
default_comm = None

paramnames = kwargs.pop('paramnames', None)

Expand Down Expand Up @@ -552,6 +554,7 @@ def run(loglikelihood, nDims, **kwargs):
'grade_dims': [nDims],
'nlives': {},
'seed': -1,
'comm': default_comm,
}
default_kwargs['grade_frac'] = ([1.0]*len(default_kwargs['grade_dims'])
if 'grade_dims' not in kwargs else
Expand All @@ -563,6 +566,11 @@ def run(loglikelihood, nDims, **kwargs):
default_kwargs.update(kwargs)
kwargs = default_kwargs

if kwargs['comm'] is not None:
rank = kwargs['comm'].rank
else:
rank = 0

if rank == 0:
(Path(kwargs['base_dir']) / kwargs['cluster_dir']).mkdir(
parents=True, exist_ok=True)
Expand Down Expand Up @@ -631,7 +639,7 @@ def wrap_prior(cube, theta):
kwargs['grade_dims'],
kwargs['nlives'],
kwargs['seed'],
)
kwargs['comm'])

if 'cube_samples' in kwargs:
kwargs['read_resume'] = read_resume
Expand Down
23 changes: 20 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,19 @@

import numpy

def check_compiler(default_CC="gcc"):
try:
import mpi4py
except ImportError:
mpi4py_get_include = None
else:
mpi4py_get_include = mpi4py.get_include()


def check_compiler():
"""Checks what compiler is being used (clang, intel, or gcc)."""

CC = default_CC if "CC" not in os.environ else os.environ["CC"]
CC = os.getenv('CC', 'mpicc' if mpi4py_get_include else 'gcc')
os.environ['CC'] = CC
CC_version = subprocess.check_output([CC, "-v"], stderr=subprocess.STDOUT).decode("utf-8").lower()

if "clang" in CC_version:
Expand Down Expand Up @@ -107,14 +116,22 @@ def run(self):
subprocess.run(["make", "veryclean"], check=True, env=os.environ)
return super().run()


include_dirs = ['src/polychord', numpy.get_include()]

if "--no-mpi" in sys.argv:
NAME += '_nompi'
DOCLINES[1] = DOCLINES[1] + ' (cannot be used with MPI)'

elif mpi4py_get_include:
CPPRUNTIMELIB_FLAG += ["-DUSE_MPI"]
print(mpi4py_get_include)
include_dirs += [mpi4py_get_include]

pypolychord_module = Extension(
name='_pypolychord',
library_dirs=['lib'],
include_dirs=['src/polychord', numpy.get_include()],
include_dirs=include_dirs,
libraries=['chord',],
extra_link_args=RPATH_FLAG + CPPRUNTIMELIB_FLAG,
extra_compile_args= ["-std=c++11"] + RPATH_FLAG + CPPRUNTIMELIB_FLAG,
Expand Down
114 changes: 79 additions & 35 deletions src/polychord/mpi_utils.F90
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ module mpi_module
integer, parameter :: tag_run_epoch_babies=10
integer, parameter :: tag_run_stop=11

integer, parameter :: tag_tag_gen=12
integer, parameter :: tag_tag_run=13

type mpi_bundle
integer :: rank
integer :: nprocs
Expand Down Expand Up @@ -476,33 +479,45 @@ function catch_seed(seed_point,cholesky,logL,epoch,mpi_information) result(more_
real(dp),intent(out),dimension(:) :: seed_point !> The seed point to be caught
real(dp),intent(out),dimension(:,:) :: cholesky !> Cholesky matrix to be caught
real(dp),intent(out) :: logL !> loglikelihood contour to be caught
integer :: tag_run
integer, intent(out) :: epoch
type(mpi_bundle), intent(in) :: mpi_information

logical :: more_points_needed ! whether or not we need more points

integer, dimension(MPI_STATUS_SIZE) :: mpistatus ! status identifier


call MPI_RECV( &!
seed_point, &!
size(seed_point), &!
MPI_DOUBLE_PRECISION, &!
tag_run, &!
1, &!
MPI_INTEGER, &!
mpi_information%root, &!
MPI_ANY_TAG, &!
tag_tag_run, &!
mpi_information%communicator,&!
mpistatus, &!
mpierror &!
)
if(mpistatus(MPI_TAG) == tag_run_stop ) then

if(tag_run == tag_run_stop) then
more_points_needed = .false.
return
else if(mpistatus(MPI_TAG) == tag_run_seed) then
else if(tag_run == tag_run_seed) then
more_points_needed = .true.
else
call halt_program('worker error: unrecognised tag')
end if

call MPI_RECV( &!
seed_point, &!
size(seed_point), &!
MPI_DOUBLE_PRECISION, &!
mpi_information%root, &!
tag_run, &!
mpi_information%communicator,&!
mpistatus, &!
mpierror &!
)

call MPI_RECV( &!
cholesky, &!
size(cholesky,1)*size(cholesky,1),&!
Expand Down Expand Up @@ -548,26 +563,34 @@ subroutine throw_seed(seed_point,cholesky,logL,mpi_information,worker_id,epoch,k
type(mpi_bundle),intent(in) :: mpi_information !> mpi handle
integer, intent(in) :: worker_id !> identity of target worker
integer, intent(in) :: epoch !> epoch of seed
logical, intent(in) :: keep_going !> Further signal whether to keep going
logical, intent(in) :: keep_going !> Further signal whether to keep going

integer :: tag ! tag variable to

tag = tag_run_stop ! Default tag is stop tag
if(keep_going) tag = tag_run_seed ! If we want to keep going then change this to the seed tag
integer :: tag_run ! tag variable to

tag_run = tag_run_stop ! Default tag is stop tag
if(keep_going) tag_run = tag_run_seed ! If we want to keep going then change this to the seed tag

call MPI_SEND( &!
seed_point, &!
size(seed_point), &!
MPI_DOUBLE_PRECISION, &!
tag_run, &!
1, &!
MPI_INTEGER, &!
worker_id, &!
tag, &!
tag_tag_run, &!
mpi_information%communicator,&!
mpierror &!
)

if(.not. keep_going) return ! Stop here if we're wrapping up

call MPI_SEND( &!
seed_point, &!
size(seed_point), &!
MPI_DOUBLE_PRECISION, &!
worker_id, &!
tag_run, &!
mpi_information%communicator, &!
mpierror &!
)
call MPI_SEND( &!
cholesky, &!
size(cholesky,1)*size(cholesky,2),&!
Expand Down Expand Up @@ -620,17 +643,14 @@ subroutine no_more_points(mpi_information,worker_id)
type(mpi_bundle), intent(in) :: mpi_information
integer, intent(in) :: worker_id !> Worker to request a new point from


integer :: empty_buffer(0) ! empty buffer to send

call MPI_SEND( &
empty_buffer, &! not sending anything
0, &! size of nothing
MPI_INTEGER, &! sending no integers
worker_id, &! process id to send to
tag_gen_stop, &! continuation tag
mpi_information%communicator,&! mpi handle
mpierror &! error flag
tag_gen_stop, &!
1, &!
MPI_INTEGER, &!
worker_id, &!
tag_tag_gen, &!
mpi_information%communicator,&!
mpierror &!
)

end subroutine no_more_points
Expand All @@ -650,6 +670,16 @@ subroutine request_live_point(live_point,mpi_information,worker_id)
integer, intent(in) :: worker_id !> Worker to request a new point from
real(dp), intent(in), dimension(:) :: live_point !> The live point to be sent

call MPI_SEND( &
tag_gen_request, &!
1, &!
MPI_INTEGER, &!
worker_id, &!
tag_tag_gen, &!
mpi_information%communicator,&!
mpierror &!
)


call MPI_SEND( &!
live_point, &! live point being sent
Expand All @@ -675,6 +705,29 @@ function live_point_needed(live_point,mpi_information)
integer, dimension(MPI_STATUS_SIZE) :: mpistatus ! status identifier

logical :: live_point_needed !> Whether we need more points or not
integer :: tag_gen

call MPI_RECV( &!
tag_gen, &!
1, &!
MPI_INTEGER, &!
mpi_information%root, &!
tag_tag_gen, &!
mpi_information%communicator,&!
mpistatus, &!
mpierror &!
)

! If we've recieved a kill signal, then exit this loop
if(tag_gen == tag_gen_stop) then
live_point_needed = .false.
return
else if(tag_gen == tag_gen_request) then
live_point_needed = .true.
else
call halt_program('generate error: unrecognised tag')
end if


call MPI_RECV( &!
live_point, &! live point recieved
Expand All @@ -687,15 +740,6 @@ function live_point_needed(live_point,mpi_information)
mpierror &! error flag
)

! If we've recieved a kill signal, then exit this loop
if(mpistatus(MPI_TAG) == tag_gen_stop ) then
live_point_needed = .false.
else if(mpistatus(MPI_TAG) == tag_gen_request) then
live_point_needed = .true.
else
call halt_program('generate error: unrecognised tag')
end if

end function live_point_needed


Expand Down
Loading