Skip to content

Commit

Permalink
first run with s fixed, then with s
Browse files Browse the repository at this point in the history
  • Loading branch information
adrn committed May 11, 2021
1 parent 8e76767 commit 3769034
Showing 1 changed file with 25 additions and 6 deletions.
31 changes: 25 additions & 6 deletions hq/cli/run_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import pymc3 as pm
import thejoker as tj
import pymc3_ext as pmx
from thejoker.samples_helpers import inferencedata_to_samples

# Project
from hq.log import logger
Expand All @@ -25,35 +26,53 @@ def run_mcmc(run_path, index, overwrite=False):
raise ValueError("Index is larger than the number of unimodal sources")

metadata_row = unimodal_tbl[index]
source_id = metadata_row[c.source_id_colname]

# Read MAP sample:
MAP_sample = extract_MAP_sample(metadata_row)
logger.log(1, f"{source_id}: MAP sample loaded")

prior, model = c.get_prior('mcmc')
fixed_s_prior, fixed_s_model = c.get_prior('mcmc', fixed_s=MAP_sample['s'])

mcmc_cache_path = c.cache_path / 'mcmc'
mcmc_cache_path.mkdir(exist_ok=True)

source_id = metadata_row[c.source_id_colname]
this_cache_path = mcmc_cache_path / source_id
if this_cache_path.exists() and not overwrite:
logger.info(f"{source_id} already done!")
return

# Set up The Joker:
joker = tj.TheJoker(prior)
fixed_s_joker = tj.TheJoker(fixed_s_prior)

# Load the data:
logger.debug(f"{source_id}: Loading all data")
data = c.get_source_data(source_id)

t0 = time.time()

# Read MAP sample:
MAP_sample = extract_MAP_sample(metadata_row)
logger.log(1, f"{source_id}: MAP sample loaded")

# Run MCMC:
with fixed_s_model:
logger.log(1, f"{source_id}: Setting up fixed s MCMC...")
mcmc_init = fixed_s_joker.setup_mcmc(data, MAP_sample)

if 'logp' not in fixed_s_model.named_vars:
pm.Deterministic('logp', fixed_s_model.logpt)

logger.debug(f"{source_id}: Starting MCMC sampling")
trace = pmx.sample(start=mcmc_init, chains=2, cores=1,
tune=c.tune, draws=c.draws,
return_inferencedata=True)

init_samples = inferencedata_to_samples(fixed_s_prior, trace, data)
df = trace.posterior.to_dataframe()
mcmc_MAP_sample = init_samples[df.logp.argmax()]

with model:
logger.log(1, f"{source_id}: Setting up MCMC...")
mcmc_init = joker.setup_mcmc(data, MAP_sample)
mcmc_init = joker.setup_mcmc(data, mcmc_MAP_sample)
logger.log(1, f"{source_id}: ...setup complete")

# HACK:
Expand Down

0 comments on commit 3769034

Please sign in to comment.