In [1]:
import os
import emcee
import warnings
import numpy as np 
from multiprocessing import Pool

from jjmodel.tools import Timer
from jjmodel.input_ import p, a, inp
from jjmodel.mwdisk import disk_builder
from optimizer import posterior, initialize_params, mcmc_runner
from helpers import (ParHandler, IFMRHandler, MSAgeHandler, 
                     IMFHandler, SFRHandler, PopHandler, MCMCLogger, HessConstructor)
from prior import prior


Parameter file(s) : ok.
Number of parameters =  66 , among them technical =  7

 p(run_mode=0, out_dir='new_metgrid', out_mode=1, nprocess=4, Rsun=8.2, zsun=20, zmax=2000, dz=2, sigmad=29.3, sigmat=4.9, sigmag1=2.0, sigmag2=11.0, sigmadh=51.9, sigmash=0.47, td1=0, td2=7.8, dzeta=0.8, eta=5.6, pkey=1, tt1=0.1, tt2=4, gamma=2, beta=3.5, imfkey=0, a0=1.31, a1=1.5, a2=2.88, a3=2.28, m0=0.49, m1=1.43, m2=6.0, dFeHdt=0, n_FeHdt=1, fehkey=0, FeHd0=-0.7, FeHdp=0.29, rd=0.34, q=-0.72, FeHt0=-0.94, FeHtp=0.04, rt=0.77, t0=0.97, FeHsh=-1.5, dFeHsh=0.4, n_FeHsh=5, alpha=0.375, sige=26.0, sigt=45, sigdh=140, sigsh=100, sigmap=array([3.5, 1.3]), tpk=array([10. , 12.5]), dtp=array([0.7 , 0.25]), sigp=array([26.3, 12.6]))


Configuration finished:  0h 0m 0.0s


In [2]:
timer = Timer()
t1 = timer.start()

In [3]:
# Parameters
# ---------------------------------------------------------------------
# General
mode_iso = 'Padova'  # Isochrones for MS and giants; WD always Montreal
mode_pop = 'tot'     # Modeled populations: 'tot' = all, 'wd' = only WDs, 'ms' = only MS + giants
FeH_scatter = 0.07   # Scatter added to AMR of thin and thick disks
Nmet_dt = 7          # Number of metallicities per age bin
radius = 50          # Radius of the modeled sphere, pc
mag_range = [[-0.4,1.65],[-1,18]]   # Hess diagram xy-ranges in (G-G_RP, M_G), mag
mag_step = [0.02,0.2]               # Steps in (G-G_RP, M_G), mag
mag_smooth = [0.06,0.8]             # Smoothing window size in (G-G_RP, M_G), mag
# WD 
f_da_teff = False    # If True fraction of DA/DB WDs is a function of Teff
age_ms_param_file = 'MS_lifetime_padova_new_metgrid/analysis/'+\
                    'fit_v1_Mbr1.18/tau_ms_params_v1_Mbr1.18.txt' # Parameters for MS lifetime fits
# MCMC setup
mode_init = 'blob'          # Blob arournd means or random
blob_f_sig = 0.5            # Defines blob size
n_max = 10                  # Max number of iterations
dir_out = 'output/mcmc'     # Dir for output
save_log = True             # Save all tested parameter combinations
# ---------------------------------------------------------------------

In [15]:
# Choose parameters for MCMC optimization
# ------------------------------------------
par_optim = {
    'ifmr':         ['m_br1', 'm_br2', 'alpha1', 'alpha2', 'alpha3'],
    'dcool':        ['alpha_cool'],
    'f_dadb':       ['f_da'],
    'sfr': {'d':    ['dzeta','eta','td2','sigmap0','tpk0'],
            't':    ['gamma','beta','tt1']},
    'imf':          ['a0', 'a1', 'a2', 'm0', 'm1']  
}

par_handler = ParHandler(par_optim,prior)
labels = par_handler.get_flat_param_list()
params_mean, params_sigma = par_handler.get_prior_for_params()

In [16]:
# Finish MCMC setup based on parameter list
ndim = len(labels)
nwalkers = 4*ndim
n_cores = 8
#n_cores = int(os.environ.get("SLURM_CPUS_ON_NODE", 1))

# Create output directory
os.makedirs(dir_out,exist_ok=True)

# Create logger
logger = MCMCLogger(dir_out=dir_out)
logfile = logger.manage_logfile(save_log)

# Save simulation card 
logger.save_simulation_card(par_optim,mode_iso=mode_iso,mode_pop=mode_pop,
                            radius=radius,FeH_scatter=FeH_scatter,Nmet_dt=Nmet_dt,
                            mag_range=mag_range,mag_step=mag_step,mag_smooth=mag_smooth,
                            age_ms_param_file=age_ms_param_file,f_da_teff=f_da_teff,
                            save_log=save_log,logfile=logfile,
                            mode_init=mode_init,blob_f_sig=blob_f_sig,
                            ndim=ndim,nwalkers=nwalkers,n_cores=n_cores,n_max=n_max
                            )

In [17]:
# Initialize SFR, IMF and population handlers

ifmr_handler = IFMRHandler()
msage_handler = MSAgeHandler(param_file=age_ms_param_file)
imf_handler = IMFHandler(p)
sfr_handler = SFRHandler(p, a, inp)
pop_handler = PopHandler(p, a, inp)
constructor = HessConstructor(radius, p, a)

SFR_ref = sfr_handler.create_reference_sfr()
imf_ref, (mass_binsc, IMF_ref) = imf_handler.create_reference_imf()

# Load observed Hess diagram 
hess_ref = np.loadtxt('./data/hess/hess_' + mode_pop + '.txt')

In [7]:
# Calculate vertical disk structure
disk_builder(p,a,inp,status_progress=True)

# Prepare parameters for generating populations
pop_kwargs = par_handler.prepare_population_kwargs(FeH_scatter=FeH_scatter,
                                                   Nmet_dt=Nmet_dt,
                                                   mode_pop=mode_pop
                                                   )

# Create population tables
pop_tabs_ref = pop_handler.create_reference_pop_tabs(imf_ref, mode_iso, **pop_kwargs) 

# Create reference copies for important columns
pop_tabs_ref = pop_handler.create_reference_columns(pop_tabs_ref,['N', 'Mini', 'age', 'age_WD'])

# Prepare idex columns for reference ages and initial masses
indt, indm = pop_handler.get_age_mass_idx(pop_tabs_ref,mass_binsc)

# Define DA/DB WD indices
ind_wd = pop_handler.make_wd_idx_dict(pop_tabs_ref)
pop_handler.display_wd_stats(mode_pop,ind_wd)

Results of this run will be saved to already existing folder output/Rsun8.2_new_metgrid

Output directory tree created.

---Local run---

  Process for R = 8.2 kpc: start       
    Process for R = 8.2 kpc: fimax optimized
      Process for R = 8.2 kpc: PE solved   
        Process 8.2   : exit, time: 0h 0m 0.17s


Input data saved.

Output data saved.

---Local run ended sucessfully---


Stellar population synthesis for R = 8.2 kpc:
	thin disk	thick disk	halo


WD indices d:	 309014 235812
WD indices t:	 81552 24550
WD indices sh:	 2003 122


In [18]:
# Prepare parameters for posterior calculation
kwargs_post = par_handler. prepare_posterior_kwargs(SFR_ref,IMF_ref,indt,indm,
                                                    save_log=save_log,logfile=logfile,
                                                    mode_pop=mode_pop,ind_wd=ind_wd,
                                                    f_da_teff=f_da_teff,
                                                    ifmr_handler=ifmr_handler,
                                                    msage_handler=msage_handler
                                                    )

# Define posterior function for MCMC
def probability_for_mcmc(theta):
    post, blob = posterior(theta,params_mean,params_sigma,p,a,inp,
                           pop_tabs_ref,hess_ref,
                           mag_range, mag_step, mag_smooth,
                           sfr_handler, imf_handler, pop_handler, par_handler, constructor,
                           **kwargs_post
                           )
    if not np.isfinite(post):
        return -np.inf, blob
    return post, blob


In [19]:
# Initialize MCMC parameters
kwargs_init = par_handler.prepare_initialization_kwargs(mode_init,
                                                        params_mean,
                                                        params_sigma,
                                                        labels,
                                                        blob_f_sig=blob_f_sig
                                                        )
pos = initialize_params(mode_init, nwalkers, ndim, **kwargs_init)


In [23]:
p_test = params_mean + 0.01*params_sigma
p1, p2 = probability_for_mcmc(p_test)  # Test run
print(p1, p2) # likelihood and prior

-0.8258631137186968 (np.float64(-4.999999999999762e-05),)


In [None]:
# Run MCMC!
pool = Pool(processes=n_cores) 
sampler = emcee.EnsembleSampler(nwalkers, ndim, probability_for_mcmc, pool=pool)

sampler, autocorr = mcmc_runner(sampler,pos,ndim,n_max)


In [11]:
# Check MCMC performance
# -------------------------

# Get MCMC chains, prior, and posterior 
chains_flat = logger.get_chains(sampler)

# Integrated autocorrelation time and acceptance fraction
logger.get_run_stats(sampler,autocorr)

# Best parameters
logger.get_best_params(chains_flat,labels,params_mean,params_sigma)

Acceptance_fraction =  0.3
Integrated autocorrelation time could not be estimated

               value     +error     -error max_error, %
     m_br1      2.808      0.159       0.14        5.7
     m_br2      3.658      0.193      0.141        5.3
    alpha1      0.087      0.019      0.014       22.1
    alpha2      0.184      0.035      0.031       19.3
    alpha3      0.084      0.016      0.011       18.5
alpha_cool      0.014      0.011      0.035      259.2
      f_da      0.806      0.067      0.044        8.4
     dzeta      0.833       0.05      0.038        6.0
       eta      5.597      0.129      0.153        2.7
       td2      7.928      0.878      0.871       11.1
   sigmap0        3.5      1.097      1.002       31.3
      tpk0      9.941       0.74      0.947        9.5
     gamma      2.041      0.605      0.513       29.6
      beta      3.505      0.763      0.779       22.2
       tt1        0.1      0.028      0.029       29.4
        a0      1.276       0.17    

In [12]:
print('Output saved')
print(f'Execution time: {timer.stop(t1)}')

Output saved
Execution time: 0h 3m 22.04s
