Skip to content

Commit

Permalink
run twice
Browse files Browse the repository at this point in the history
  • Loading branch information
adrn committed May 17, 2021
1 parent d375309 commit 83a20c7
Showing 1 changed file with 55 additions and 70 deletions.
125 changes: 55 additions & 70 deletions hq/cli/run_mcmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

def run_mcmc(run_path, index, seed=None, overwrite=False):
# Load the analyzed joker samplings file, only keep unimodal:
c = Config(run_path / 'config.yml')
conf = Config(run_path / 'config.yml')

joker_metadata = at.QTable.read(c.metadata_joker_file)
unimodal_tbl = joker_metadata[joker_metadata['unimodal']]
Expand All @@ -26,88 +26,73 @@ def run_mcmc(run_path, index, seed=None, overwrite=False):
raise ValueError("Index is larger than the number of unimodal sources")

metadata_row = unimodal_tbl[index]
source_id = metadata_row[c.source_id_colname]
source_id = metadata_row[conf.source_id_colname]

# Read MAP sample:
MAP_sample = extract_MAP_sample(metadata_row)
logger.log(1, f"{source_id}: MAP sample loaded")

prior, model = c.get_prior('mcmc')
fixed_s_prior, fixed_s_model = c.get_prior('mcmc', fixed_s=MAP_sample['s'])

mcmc_cache_path = c.cache_path / 'mcmc'
# Make sure the root mcmc path exists:
mcmc_cache_path = conf.cache_path / 'mcmc'
mcmc_cache_path.mkdir(exist_ok=True)

# Read the source data and MAP sample:
data = conf.get_source_data(source_id)
joker_MAP_sample = extract_MAP_sample(metadata_row)
logger.log(1, f"{source_id}: MAP sample loaded")

this_cache_path = mcmc_cache_path / source_id
if this_cache_path.exists() and not overwrite:
samples_file = this_cache_path / 'samples.nc'
if samples_file.exists() and not overwrite:
logger.info(f"{source_id} already done!")
return

# Set up The Joker:
joker = tj.TheJoker(prior)
fixed_s_joker = tj.TheJoker(fixed_s_prior)

# Load the data:
logger.debug(f"{source_id}: Loading all data")
data = c.get_source_data(source_id)
# -------- Initial run ---------
# Fix the excess variance parameter to the MAP value from running The Joker
time0 = time.time()

t0 = time.time()
fixed_s_prior, fixed_s_model = conf.get_prior(
'mcmc',
fixed_s=joker_MAP_sample['s'])

# Run MCMC:
with fixed_s_model:
logger.log(1, f"{source_id}: Setting up fixed s MCMC...")
mcmc_init = fixed_s_joker.setup_mcmc(data, MAP_sample)
joker = tj.TheJoker(fixed_s_prior)

mcmc_init = joker.setup_mcmc(data, joker_MAP_sample,
custom_func=conf.get_custom_init_mcmc())

init_trace = pmx.sample(
start=mcmc_init, chains=2, cores=1,
init='adapt_full',
tune=conf.tune, draws=1000, # MAGIC NUMBER
return_inferencedata=True,
discard_tuned_samples=True,
random_seed=seed,
target_accept=conf.target_accept)

if 'logp' not in fixed_s_model.named_vars:
pm.Deterministic('logp', fixed_s_model.logpt)
init_samples = inferencedata_to_samples(joker.prior, init_trace, data)
tmp_MAP_sample = init_samples[init_samples['ln_posterior'].argmax()]
init_trace.to_netcdf(this_cache_path / 'init_samples.nc')

logger.debug(f"{source_id}: Starting MCMC sampling")
trace = pmx.sample(start=mcmc_init, chains=2, cores=1,
tune=c.tune, draws=c.draws,
return_inferencedata=True,
random_seed=seed)
# -------- Main run ---------

init_samples = inferencedata_to_samples(fixed_s_prior, trace, data)
df = trace.posterior.to_dataframe()
mcmc_MAP_sample = init_samples[df.logp.argmax()]
prior, model = conf.get_prior(
'mcmc',
MAP_sample=tmp_MAP_sample)

with model:
logger.log(1, f"{source_id}: Setting up MCMC...")
mcmc_init = joker.setup_mcmc(data, mcmc_MAP_sample)
logger.log(1, f"{source_id}: ...setup complete")

# HACK:
mcmc_init['lnP'] = np.log(mcmc_init.get('P', 1.))

if 'ln_prior' not in model.named_vars:
ln_prior_var = None
for k in joker.prior._nonlinear_equiv_units:
var = model.named_vars[k]
try:
if ln_prior_var is None:
ln_prior_var = var.distribution.logp(var)
else:
ln_prior_var = ln_prior_var + var.distribution.logp(var)
except Exception as e:
logger.warning("Cannot auto-compute log-prior value for "
f"parameter {var}.")
print(e)
continue

pm.Deterministic('ln_prior', ln_prior_var)
logger.log(1, f"{source_id}: setting up ln_prior in pymc3 model")

if 'logp' not in model.named_vars:
pm.Deterministic('logp', model.logpt)
logger.log(1, f"{source_id}: setting up logp in pymc3 model")

logger.debug(f"{source_id}: Starting MCMC sampling")
trace = pmx.sample(start=mcmc_init, chains=4, cores=1,
tune=c.tune, draws=c.draws,
return_inferencedata=True,
discard_tuned_samples=False,
random_seed=seed)

trace.to_netcdf(this_cache_path / 'samples.nc')
joker = tj.TheJoker(prior)
mcmc_init = joker.setup_mcmc(data, tmp_MAP_sample,
custom_func=conf.get_custom_init_mcmc())

logger.debug(f"{source_id}: Starting initial (fixed s) MCMC sampling")
trace = pmx.sample(
start=mcmc_init, chains=4, cores=4,
init='adapt_full',
tune=conf.tune, draws=conf.draws,
return_inferencedata=True,
discard_tuned_samples=False,
random_seed=seed)

samples = inferencedata_to_samples(joker.prior, trace, data)
mcmc_MAP_sample = samples[samples['ln_posterior'].argmax()]

trace.to_netcdf(samples_file)
logger.debug(f"{source_id}: Finished MCMC sampling "
f"({time.time()-t0:.2f} seconds)")
f"({time.time()-time0:.2f} seconds)")

0 comments on commit 83a20c7

Please sign in to comment.