Skip to content

Commit

Permalink
set seed of mcmc run
Browse files Browse the repository at this point in the history
  • Loading branch information
adrn committed May 11, 2021
1 parent 3769034 commit 66b1d8a
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 4 deletions.
3 changes: 2 additions & 1 deletion hq/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,8 @@ def run_mcmc(self):
logger.info(f"Theano flags set to: {os.environ['THEANO_FLAGS']}")

from .run_mcmc import run_mcmc # noqa
run_mcmc(args.run_path, index=args.index, overwrite=args.overwrite)
run_mcmc(args.run_path, index=args.index, overwrite=args.overwrite,
seed=args.index)

sys.exit(0)

Expand Down
8 changes: 5 additions & 3 deletions hq/cli/run_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from hq.config import Config


def run_mcmc(run_path, index, overwrite=False):
def run_mcmc(run_path, index, seed=None, overwrite=False):
# Load the analyzed joker samplings file, only keep unimodal:
c = Config(run_path / 'config.yml')

Expand Down Expand Up @@ -64,7 +64,8 @@ def run_mcmc(run_path, index, overwrite=False):
logger.debug(f"{source_id}: Starting MCMC sampling")
trace = pmx.sample(start=mcmc_init, chains=2, cores=1,
tune=c.tune, draws=c.draws,
return_inferencedata=True)
return_inferencedata=True,
random_seed=seed)

init_samples = inferencedata_to_samples(fixed_s_prior, trace, data)
df = trace.posterior.to_dataframe()
Expand Down Expand Up @@ -104,7 +105,8 @@ def run_mcmc(run_path, index, overwrite=False):
trace = pmx.sample(start=mcmc_init, chains=4, cores=1,
tune=c.tune, draws=c.draws,
return_inferencedata=True,
discard_tuned_samples=False)
discard_tuned_samples=False,
random_seed=seed)

trace.to_netcdf(this_cache_path / 'samples.nc')
logger.debug(f"{source_id}: Finished MCMC sampling "
Expand Down

0 comments on commit 66b1d8a

Please sign in to comment.