Skip to content

Commit

Permalink
added nautilus sampler (no mpi yet)
Browse files Browse the repository at this point in the history
  • Loading branch information
Adam Carnall authored and Adam Carnall committed Jul 10, 2023
1 parent bc44327 commit 41eeef8
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 23 deletions.
17 changes: 10 additions & 7 deletions bagpipes/catalogue/fit_catalogue.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def __init__(self, IDs, fit_instructions, load_data, spectrum_exists=True,
utils.make_dirs(run=run)

def fit(self, verbose=False, n_live=400, mpi_serial=False,
track_backlog=False):
track_backlog=False, sampler="multinest"):
""" Run through the catalogue fitting each object.
Parameters
Expand Down Expand Up @@ -181,7 +181,8 @@ def fit(self, verbose=False, n_live=400, mpi_serial=False,
continue

# If not fit the object and update the output catalogue
self._fit_object(self.IDs[i], verbose=verbose, n_live=n_live)
self._fit_object(self.IDs[i], verbose=verbose, n_live=n_live,
sampler=sampler)

self.done[i] = True

Expand All @@ -195,7 +196,7 @@ def fit(self, verbose=False, n_live=400, mpi_serial=False,
self.done.shape[0], "objects completed.")

def _fit_mpi_serial(self, verbose=False, n_live=400,
track_backlog=False):
track_backlog=False, sampler="multinest"):
""" Run through the catalogue fitting multiple objects at once
on different cores. """

Expand Down Expand Up @@ -230,7 +231,7 @@ def _fit_mpi_serial(self, verbose=False, n_live=400,

# Load posterior for finished object to update catalogue
self._fit_object(oldID, use_MPI=False, verbose=False,
n_live=n_live)
n_live=n_live, sampler=sampler)

save_cat = Table.from_pandas(self.cat)
save_cat.write("pipes/cats/" + self.run + ".fits",
Expand Down Expand Up @@ -260,7 +261,7 @@ def _fit_mpi_serial(self, verbose=False, n_live=400,

self.n_posterior = 5 # hacky, these don't get used
self._fit_object(ID, use_MPI=False, verbose=False,
n_live=n_live)
n_live=n_live, sampler=sampler)

comm.send([ID, rank], dest=0) # Tell 0 object is done

Expand All @@ -282,7 +283,8 @@ def _set_redshift(self, ID):
else:
self.fit_instructions["redshift"] = self.redshifts[ind]

def _fit_object(self, ID, verbose=False, n_live=400, use_MPI=True):
def _fit_object(self, ID, verbose=False, n_live=400, use_MPI=True,
sampler="multinest"):
""" Fit the specified object and update the catalogue. """

# Set the correct redshift for this object
Expand All @@ -305,7 +307,8 @@ def _fit_object(self, ID, verbose=False, n_live=400, use_MPI=True):
time_calls=self.time_calls,
n_posterior=self.n_posterior)

self.obj_fit.fit(verbose=verbose, n_live=n_live, use_MPI=use_MPI)
self.obj_fit.fit(verbose=verbose, n_live=n_live, use_MPI=use_MPI,
sampler=sampler)

if rank == 0:
if self.vars is None:
Expand Down
67 changes: 51 additions & 16 deletions bagpipes/fitting/fit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,28 @@
except (ImportError, RuntimeError, SystemExit) as e:
print("Bagpipes: PyMultiNest import failed, fitting will be unavailable.")

try:
from nautilus import Sampler

except (ImportError, RuntimeError, SystemExit) as e:
pass

# detect if run through mpiexec/mpirun
try:
from mpi4py import MPI
rank = MPI.COMM_WORLD.Get_rank()
size = MPI.COMM_WORLD.Get_size()
from mpi4py.futures import MPIPoolExecutor

if size == 1:
pool = None

else:
pool = MPIPoolExecutor(size)

except ImportError:
rank = 0
pool = None

from .. import utils
from .. import plotting
Expand Down Expand Up @@ -97,7 +112,7 @@ def __init__(self, galaxy, fit_instructions, run=".", time_calls=False,
self.fitted_model = fitted_model(galaxy, self.fit_instructions,
time_calls=time_calls)

def fit(self, verbose=False, n_live=400, use_MPI=True):
def fit(self, verbose=False, n_live=400, use_MPI=True, sampler="multinest"):
""" Fit the specified model to the input galaxy data.
Parameters
Expand Down Expand Up @@ -126,33 +141,53 @@ def fit(self, verbose=False, n_live=400, use_MPI=True):

with warnings.catch_warnings():
warnings.simplefilter("ignore")
pmn.run(self.fitted_model.lnlike,
self.fitted_model.prior.transform,
self.fitted_model.ndim, n_live_points=n_live,
importance_nested_sampling=False, verbose=verbose,
sampling_efficiency="model",
outputfiles_basename=self.fname, use_MPI=use_MPI)

if sampler == "multinest":
pmn.run(self.fitted_model.lnlike,
self.fitted_model.prior.transform,
self.fitted_model.ndim, n_live_points=n_live,
importance_nested_sampling=False, verbose=verbose,
sampling_efficiency="model",
outputfiles_basename=self.fname, use_MPI=use_MPI)

elif sampler == "nautilus":
n_sampler = Sampler(self.fitted_model.prior.transform,
self.fitted_model.lnlike, n_live=n_live,
n_networks=1, pool=pool,
n_dim=self.fitted_model.ndim,
filepath=self.fname + "temp.h5")

n_sampler.run(verbose=verbose)

if rank == 0 or not use_MPI:
runtime = time.time() - start_time

print("\nCompleted in " + str("%.1f" % runtime) + " seconds.\n")

# Load MultiNest outputs and save basic quantities to file.
samples2d = np.loadtxt(self.fname + "post_equal_weights.dat")
lnz_line = open(self.fname + "stats.dat").readline().split()
if sampler == "multinest":
samples2d = np.loadtxt(self.fname + "post_equal_weights.dat")
lnz_line = open(self.fname + "stats.dat").readline().split()
self.results["samples2d"] = samples2d[:, :-1]
self.results["lnlike"] = samples2d[:, -1]
self.results["lnz"] = float(lnz_line[-3])
self.results["lnz_err"] = float(lnz_line[-1])

elif sampler == "nautilus":
samples2d, log_w, log_l = n_sampler.posterior(equal_weight=True)
self.results["samples2d"] = samples2d
self.results["lnlike"] = log_l
self.results["lnz"] = n_sampler.evidence()
self.results["lnz_err"] = -99

self.results["median"] = np.median(samples2d, axis=0)
self.results["conf_int"] = np.percentile(self.results["samples2d"],
(16, 84), axis=0)

file = h5py.File(self.fname[:-1] + ".h5", "w")

file.attrs["fit_instructions"] = str(self.fit_instructions)

self.results["samples2d"] = samples2d[:, :-1]
self.results["lnlike"] = samples2d[:, -1]
self.results["lnz"] = float(lnz_line[-3])
self.results["lnz_err"] = float(lnz_line[-1])
self.results["median"] = np.median(samples2d, axis=0)
self.results["conf_int"] = np.percentile(self.results["samples2d"],
(16, 84), axis=0)
for k in self.results.keys():
file.create_dataset(k, data=self.results[k])

Expand Down

0 comments on commit 41eeef8

Please sign in to comment.