diff --git a/bagpipes/catalogue/fit_catalogue.py b/bagpipes/catalogue/fit_catalogue.py index 759fe37..c1cfa96 100644 --- a/bagpipes/catalogue/fit_catalogue.py +++ b/bagpipes/catalogue/fit_catalogue.py @@ -129,7 +129,7 @@ def __init__(self, IDs, fit_instructions, load_data, spectrum_exists=True, utils.make_dirs(run=run) def fit(self, verbose=False, n_live=400, mpi_serial=False, - track_backlog=False): + track_backlog=False, sampler="multinest"): """ Run through the catalogue fitting each object. Parameters @@ -181,7 +181,8 @@ def fit(self, verbose=False, n_live=400, mpi_serial=False, continue # If not fit the object and update the output catalogue - self._fit_object(self.IDs[i], verbose=verbose, n_live=n_live) + self._fit_object(self.IDs[i], verbose=verbose, n_live=n_live, + sampler=sampler) self.done[i] = True @@ -195,7 +196,7 @@ def fit(self, verbose=False, n_live=400, mpi_serial=False, self.done.shape[0], "objects completed.") def _fit_mpi_serial(self, verbose=False, n_live=400, - track_backlog=False): + track_backlog=False, sampler="multinest"): """ Run through the catalogue fitting multiple objects at once on different cores. """ @@ -230,7 +231,7 @@ def _fit_mpi_serial(self, verbose=False, n_live=400, # Load posterior for finished object to update catalogue self._fit_object(oldID, use_MPI=False, verbose=False, - n_live=n_live) + n_live=n_live, sampler=sampler) save_cat = Table.from_pandas(self.cat) save_cat.write("pipes/cats/" + self.run + ".fits", @@ -260,7 +261,7 @@ def _fit_mpi_serial(self, verbose=False, n_live=400, self.n_posterior = 5 # hacky, these don't get used self._fit_object(ID, use_MPI=False, verbose=False, - n_live=n_live) + n_live=n_live, sampler=sampler) comm.send([ID, rank], dest=0) # Tell 0 object is done @@ -282,7 +283,8 @@ def _set_redshift(self, ID): else: self.fit_instructions["redshift"] = self.redshifts[ind] - def _fit_object(self, ID, verbose=False, n_live=400, use_MPI=True): + def _fit_object(self, ID, verbose=False, n_live=400, use_MPI=True, + sampler="multinest"): """ Fit the specified object and update the catalogue. """ # Set the correct redshift for this object @@ -305,7 +307,8 @@ def _fit_object(self, ID, verbose=False, n_live=400, use_MPI=True): time_calls=self.time_calls, n_posterior=self.n_posterior) - self.obj_fit.fit(verbose=verbose, n_live=n_live, use_MPI=use_MPI) + self.obj_fit.fit(verbose=verbose, n_live=n_live, use_MPI=use_MPI, + sampler=sampler) if rank == 0: if self.vars is None: diff --git a/bagpipes/fitting/fit.py b/bagpipes/fitting/fit.py index 7ebd8b4..1a561f0 100644 --- a/bagpipes/fitting/fit.py +++ b/bagpipes/fitting/fit.py @@ -14,13 +14,28 @@ except (ImportError, RuntimeError, SystemExit) as e: print("Bagpipes: PyMultiNest import failed, fitting will be unavailable.") +try: + from nautilus import Sampler + +except (ImportError, RuntimeError, SystemExit) as e: + pass + # detect if run through mpiexec/mpirun try: from mpi4py import MPI rank = MPI.COMM_WORLD.Get_rank() + size = MPI.COMM_WORLD.Get_size() + from mpi4py.futures import MPIPoolExecutor + + if size == 1: + pool = None + + else: + pool = MPIPoolExecutor(size) except ImportError: rank = 0 + pool = None from .. import utils from .. import plotting @@ -97,7 +112,7 @@ def __init__(self, galaxy, fit_instructions, run=".", time_calls=False, self.fitted_model = fitted_model(galaxy, self.fit_instructions, time_calls=time_calls) - def fit(self, verbose=False, n_live=400, use_MPI=True): + def fit(self, verbose=False, n_live=400, use_MPI=True, sampler="multinest"): """ Fit the specified model to the input galaxy data. Parameters @@ -126,12 +141,23 @@ def fit(self, verbose=False, n_live=400, use_MPI=True): with warnings.catch_warnings(): warnings.simplefilter("ignore") - pmn.run(self.fitted_model.lnlike, - self.fitted_model.prior.transform, - self.fitted_model.ndim, n_live_points=n_live, - importance_nested_sampling=False, verbose=verbose, - sampling_efficiency="model", - outputfiles_basename=self.fname, use_MPI=use_MPI) + + if sampler == "multinest": + pmn.run(self.fitted_model.lnlike, + self.fitted_model.prior.transform, + self.fitted_model.ndim, n_live_points=n_live, + importance_nested_sampling=False, verbose=verbose, + sampling_efficiency="model", + outputfiles_basename=self.fname, use_MPI=use_MPI) + + elif sampler == "nautilus": + n_sampler = Sampler(self.fitted_model.prior.transform, + self.fitted_model.lnlike, n_live=n_live, + n_networks=1, pool=pool, + n_dim=self.fitted_model.ndim, + filepath=self.fname + "temp.h5") + + n_sampler.run(verbose=verbose) if rank == 0 or not use_MPI: runtime = time.time() - start_time @@ -139,20 +165,29 @@ def fit(self, verbose=False, n_live=400, use_MPI=True): print("\nCompleted in " + str("%.1f" % runtime) + " seconds.\n") # Load MultiNest outputs and save basic quantities to file. - samples2d = np.loadtxt(self.fname + "post_equal_weights.dat") - lnz_line = open(self.fname + "stats.dat").readline().split() + if sampler == "multinest": + samples2d = np.loadtxt(self.fname + "post_equal_weights.dat") + lnz_line = open(self.fname + "stats.dat").readline().split() + self.results["samples2d"] = samples2d[:, :-1] + self.results["lnlike"] = samples2d[:, -1] + self.results["lnz"] = float(lnz_line[-3]) + self.results["lnz_err"] = float(lnz_line[-1]) + + elif sampler == "nautilus": + samples2d, log_w, log_l = n_sampler.posterior(equal_weight=True) + self.results["samples2d"] = samples2d + self.results["lnlike"] = log_l + self.results["lnz"] = n_sampler.evidence() + self.results["lnz_err"] = -99 + + self.results["median"] = np.median(samples2d, axis=0) + self.results["conf_int"] = np.percentile(self.results["samples2d"], + (16, 84), axis=0) file = h5py.File(self.fname[:-1] + ".h5", "w") file.attrs["fit_instructions"] = str(self.fit_instructions) - self.results["samples2d"] = samples2d[:, :-1] - self.results["lnlike"] = samples2d[:, -1] - self.results["lnz"] = float(lnz_line[-3]) - self.results["lnz_err"] = float(lnz_line[-1]) - self.results["median"] = np.median(samples2d, axis=0) - self.results["conf_int"] = np.percentile(self.results["samples2d"], - (16, 84), axis=0) for k in self.results.keys(): file.create_dataset(k, data=self.results[k])