In [1]:
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:95% !important; }</style>"))


from astropy.io import fits
import dill as pickle
from emcee import EnsembleSampler
import numpy as np
import os
import sys

pwd = os.getcwd()
little_things_root_dir = os.path.dirname(pwd)
sys.path.append(little_things_root_dir)

from little_things_lib.galaxy_piecewise import Galaxy
from little_things_lib.piecewise_mcmc_fitter import (
    EmceeParameters,
    generate_nwalkers_start_points, 
    lnprob)
from little_things_lib.plotting import plot_posterior_distributions, plot_walker_paths
from datetime import datetime


RAD_PER_ARCSEC = np.pi / (60*60*180)

In [2]:
%matplotlib inline
import matplotlib.pyplot as plt

## Enter parameters for galaxy in cell below

In [13]:
galaxy_name = 'NGC2366'

velocity_error_2d = 20  # [km/s]

DEG_PER_PIXEL = 4.17e-4
DISTANCE = 3400 # kpc

# can leave these as any number for now, not used
LUMINOSITY = 1e8  # solar luminositiy
HI_MASS = 1e8  # solar masses

In [14]:
mcmc_params = EmceeParameters(
    ndim=15,
    nwalkers=30, 
    nburn=100,
    niter=300,
    nthin=3,
    nthreads=4
)


In [15]:
"""
Expect the data to be provided in following naming convention in 'data' directory:

2D observed velocity field FITS file: <galaxy_name>_1mom.fits
Bbarolo fit parameters text file: <galaxy_name>_ring_parameters.txt
Stellar velocity curve: <galaxy_name>_stellar_velocities.txt
Gas velocity_curve: <galaxy_name>_gas_velocities.txt

"""


data_dir = os.path.join(little_things_root_dir, 'data')

observed_2d_vel_field_fits_file = os.path.join(data_dir, f'{galaxy_name}_1mom.fits')
ring_parameters_file = os.path.join(data_dir, f'{galaxy_name}_ring_parameters.txt')

stellar_velocities_file = os.path.join(data_dir, f'{galaxy_name}_stellar_velocities.txt')
gas_velocities_file = os.path.join(data_dir, f'{galaxy_name}_gas_velocities.txt')


In [16]:
radii_arcsec, test_rotation_curve, inclinations, position_angles, x_centers, y_centers , v_systemics = \
    np.loadtxt(ring_parameters_file, usecols=(1,2,4,5,-4,-3,-2)).T

radii_kpc = radii_arcsec * RAD_PER_ARCSEC * DISTANCE
avg_inclination = np.mean(inclinations)
avg_position_angle = np.mean(position_angles)
avg_x_center = np.mean(x_centers)
avg_y_center = np.mean(y_centers)
v_systemic = np.mean(v_systemics)

observed_2d_vel_field = fits.open(observed_2d_vel_field_fits_file)[0].data

stellar_radii, stellar_vel = np.loadtxt(stellar_velocities_file, unpack=True)
gas_radii, gas_vel = np.loadtxt(gas_velocities_file, unpack=True)

# comment out below and uncomment above lines when actually doing a fit
# the lines below are for testing only!
#stellar_radii = np.linspace(0, 7, 41)
#gas_radii = np.linspace(0, 7, 41)
#stellar_vel = np.linspace(0, 1, 41)
#gas_vel = np.linspace(0, 1, 41)

In [17]:
nparams=15
bin_edges=[]
indices=np.linspace(0,len(radii_kpc)-1,nparams+1)
#np.linspace/np.logpsace
for i in indices:
    bin_edges.append(radii_kpc[int(i)])
bin_edges[0]=bin_edges[0]-.001
bin_edges[-1]=bin_edges[-1]+.001

In [18]:
galaxy = Galaxy(
    distance_to_galaxy=DISTANCE,  # [kpc] Look this up for the galaxy 
    deg_per_pixel=DEG_PER_PIXEL ,  # 'CRDELT1' and 'CRDELT2' in the FITS file header (use absolute value)
    bin_edges=np.array(bin_edges),
    galaxy_name=np.array(galaxy_name),
    vlos_2d_data=observed_2d_vel_field,
    v_error_2d=velocity_error_2d,
    output_dir='output',
    luminosity=LUMINOSITY,
    HI_mass=HI_MASS)

tilted_ring_params = {
    'v_systemic': v_systemic,
    'radii': radii_kpc,   #sets galaxy.radii
    'inclination': inclinations,
    'position_angle': position_angles,
    'x_pix_center': x_centers,
    'y_pix_center': y_centers
}

galaxy.set_tilted_ring_parameters(**tilted_ring_params)

galaxy.interpolate_baryonic_rotation_curve(
    baryon_type='stellar',
    rotation_curve_radii=stellar_radii,
    rotation_curve_velocities=stellar_vel)

galaxy.interpolate_baryonic_rotation_curve(
    baryon_type='gas',
    rotation_curve_radii=gas_radii,
    rotation_curve_velocities=gas_vel)



### The cell below should be modified for the piecewise model

In [19]:

# initialize MCMC start position and bounds
galaxy.set_piecewise_prior_bounds(
    
    bin_edges,
    test_rotation_curve
  
)

start_pos = generate_nwalkers_start_points(

    mcmc_params.nwalkers,
    galaxy
)

In [None]:
# set to False to greatly reduce size of saved output
save_blob = False

# initialize sampler
sampler = EnsembleSampler(
    mcmc_params.nwalkers,
    mcmc_params.ndim, 
    lnprob, 
    args=[galaxy, save_blob], 
    threads=mcmc_params.nthreads)

# burn in 
sampler.run_mcmc(start_pos, N=mcmc_params.nburn)
start_pos_after_burn_in = sampler._last_run_mcmc_result[0]
sampler.reset()

In [None]:
# this will break up the fitting procedure into smaller chunks of size batch_size and save progress
dateTimeObj = datetime.now()
timestampStr = dateTimeObj.strftime("%d-%b-%Y")

batch_size = 4
mcmc_output = []
for batch in range(mcmc_params.niter // batch_size):
    if batch == 0:
        batch_start = start_pos_after_burn_in
    else:
        batch_start = None
        sampler.pool = temp_pool
    mcmc_output += sampler.run_mcmc(batch_start, batch_size, thin=mcmc_params.nthin)
    temp_pool = sampler.pool
    del sampler.pool
    with open(f'sampler_{timestampStr}.pkl', 'wb') as f:
        pickle.dump(sampler, f)
    with open(f'mcmc_output_{timestampStr}.pkl', 'wb') as f:
        pickle.dump(mcmc_output, f)
    print(f"Done with steps {batch*batch_size} - {(batch+1)*batch_size} out of {mcmc_params.niter}")












#### The two cells below demonstrate how to load your saved results. This is useful in the follow cases:

1) You finished a long MCMC fitting run and want to plot the results without having to redo the whole thing.

2) Your computer crashed before it was done running the previous MCMC fit and you want to restart it from the last saved iteration. In this case you can run the second cell below to continue the MCMC fitting. The results of the restarted run will be saved separately from the results of the previous run. The combined results will also be saved.

In [None]:
# example of how to load the pickled objects
# change the name of the files to whatever yours were saved as

with open('sampler_09-Apr-2020.pkl', 'rb') as f:
    saved_sampler = pickle.load(f)
    restart_pos = saved_sampler._last_run_mcmc_result[0]
with open('mcmc_output_09-Apr-2020.pkl', 'rb') as f:
    saved_mcmc_output = pickle.load(f)

In [None]:
# example of how to restart the MCMC fit from the last save point 
# assumes you loaded the sampler and mcmc_output from the saved pickles in the example above

restart_sampler = EnsembleSampler(
    mcmc_params.nwalkers,
    mcmc_params.ndim, 
    lnprob, 
    args=[galaxy, save_blob], 
    threads=mcmc_params.nthreads)
restart_mcmc_output = []

dateTimeObj = datetime.now()
timestampStr = dateTimeObj.strftime("%d-%b-%Y")

batch_size = 10
for batch in range(mcmc_params.niter // batch_size):
    if batch==0:
        batch_start = restart_pos
    else:
        batch_start = None
        restart_sampler.pool = temp_pool
    restart_mcmc_output += restart_sampler.run_mcmc(batch_start, batch_size, thin=mcmc_params.nthin)
    temp_pool = restart_sampler.pool
    del sampler.pool
    with open(f'sampler_{timestampStr}.pkl', 'wb') as f:
        pickle.dump(restart_sampler, f)
    with open(f'mcmc_output_{timestampStr}.pkl', 'wb') as f:
        pickle.dump(restart_mcmc_output, f)
    print(f"Done with steps {batch*batch_size} - {(batch+1)*batch_size} out of {mcmc_params.niter}")

    
# this step adds the MCMC results from the restarted run to the ones that were saved from the previous run.
total_mcmc_output = saved_mcmc_output + restart_mcmc_output
    with open(f'total_mcmc_output_{timestampStr}.pkl', 'wb') as f:
        pickle.dump(total_mcmc_output, f)

In [None]:
#### Plot the posterior distributions and walker paths

# Uncomment below line if using the the sampler loaded from the save file
#sampler=saved_sampler

parameter_labels = sorted(list(galaxy.bounds.keys()))

plot_posterior_distributions(sampler, labels=parameter_labels)

plot_walker_paths(
    sampler,
    mcmc_params,
    labels=parameter_labels)

def get_fit_stats(sampler, labels=parameter_labels):
    for i, label in enumerate(labels):
        chain = sampler.chain[:, :, i].flatten()
        print(f"{label}: {np.mean(chain)} +/- {np.std(chain)}")
        
get_fit_stats(sampler)

In [None]:
#blobs = np.reshape(sampler.blobs, (int(mcmc_params.niter/mcmc_params.nthin) * mcmc_params.nwalkers, 10))
blobs = [] 
for subarr in sampler.blobs:
    blobs += subarr
blobs = np.array(blobs)

rotcurve_index = len(galaxy.bounds) + 1
v_tot = np.mean( blobs[:, rotcurve_index])

plt.plot(radii_kpc, test_rotation_curve, linewidth=2., label="data", color="black")

plt.step(galaxy.bin_edges[:-1], v_tot, where="post", color=black)
plt.plot(galaxy.bin_edges[-2:], [v_tot[-1], v_tot[-1], color=black], label="total, model")

plt.xlabel("r [kpc]")
plt.ylabel("v [km/s]")
plt.legend()