In [3]:
from fit2d import Galaxy, RingModel
from fit2d.mcmc import LinearPrior
from fit2d.mcmc import emcee_lnlike, piecewise_start_points
from fit2d.models import PiecewiseModel

from astropy.io import fits
from datetime import datetime
import joblib
import emcee

In [4]:
name = "UGC3974"
distance = 8000. # [kpc]
observed_2d_vel_field_fits_file = "/home/anna/Desktop/fit2d/data/UGC3974_1mom.fits"
deg_per_pixel=4.17e-4
v_systemic=270. 

ring_param_file = "/home/anna/Desktop/fit2d/data/UGC3974_ring_parameters.txt"
# x and y dims are switched in ds9 fits display versus np array shape
fits_ydim, fits_xdim = fits.open(observed_2d_vel_field_fits_file)[0].data.shape
num_bins = 10

mask_sigma=1.
mcmc_nwalkers = 20
mcmc_niter = 10
mcmc_ndim = num_bins
mcmc_nthreads = 4

random_seed = 1234

In [5]:
galaxy = Galaxy(
    name=name,
    distance=distance,
    observed_2d_vel_field_fits_file=observed_2d_vel_field_fits_file,
    deg_per_pixel=deg_per_pixel,
    v_systemic=v_systemic, 
)

ring_model = RingModel(
    ring_param_file=ring_param_file,
    fits_xdim=fits_xdim,
    fits_ydim=fits_ydim,
    distance=distance
)

In [6]:
piecewise_model = PiecewiseModel(num_bins=num_bins)
piecewise_model.set_bounds(0, 200)
piecewise_model.set_bin_edges(rmin=ring_model.radii_kpc[0], rmax=ring_model.radii_kpc[-1])

prior = LinearPrior(bounds=piecewise_model.bounds)
prior_transform = prior.transform_from_unit_cube

In [7]:
start_positions = piecewise_start_points(mcmc_nwalkers, piecewise_model.bounds, random_seed=random_seed)

In [8]:
fit_inputs = {
    "piecewise_model": piecewise_model,
    "galaxy": galaxy,
    "ring_model": ring_model,
    "prior_transform": prior_transform
}
with open("fit_inputs.pkl", "wb") as f:
    joblib.dump(fit_inputs, f)

In [13]:
rotation_curve_func_kwargs = {"radii_to_interpolate": ring_model.radii_kpc}

lnlike_args = [
    piecewise_model,
    rotation_curve_func_kwargs, 
    galaxy, 
    ring_model, 
    mask_sigma
]

lnlike_args = {
    "model": piecewise_model,
    "rotation_curve_func_kwargs": rotation_curve_func_kwargs, 
    "galaxy": galaxy, 
    "ring_model": ring_model, 
    "mask_sigma": mask_sigma
    }
    
sampler = emcee.EnsembleSampler(
    mcmc_nwalkers,
    mcmc_ndim, 
    emcee_lnlike, 
    args=[lnlike_args], 
    threads=mcmc_nthreads)



In [None]:
# this will break up the fitting procedure into smaller chunks of size batch_size and save progress
dateTimeObj = datetime.now()
timestampStr = dateTimeObj.strftime("%d-%b-%Y")

batch_size = 4
mcmc_output = []
for batch in range(mcmc_niter // batch_size):
    if batch == 0:
        batch_start = start_positions
    else:
        batch_start = None
        sampler.pool = temp_pool
    mcmc_output += sampler.run_mcmc(batch_start, batch_size)
    temp_pool = sampler.pool
    del sampler.pool
    with open(f'sampler_{timestampStr}.pkl', 'wb') as f:
        joblib.dump(sampler, f)
    with open(f'mcmc_output_{timestampStr}.pkl', 'wb') as f:
        joblib.dump(mcmc_output, f)
    print(f"Done with steps {batch*batch_size} - {(batch+1)*batch_size} out of {mcmc_niter}")
