In [None]:
from burstfit.data import BurstData
from burstfit.model import Model, SgramModel
from burstfit.utils.plotter import plot_me, plot_2d_fit
from burstfit.utils.functions import pulse_fn_vec, sgram_fn_vec, gauss_norm, gauss, sgram_fn
from burstfit.io import BurstIO
from burstfit.fit import BurstFit
import logging
import pylab as plt
import numpy as np

logging_format = "%(asctime)s - %(funcName)s -%(name)s - %(levelname)s - %(message)s"
logging.basicConfig(
    level=logging.DEBUG,
    format=logging_format,
)
logger = logging.getLogger('matplotlib')
logger.setLevel(logging.INFO)

%matplotlib inline

Load the candidate pickle file

In [None]:
pkl = 'pkl_file_with_native_resolution_canddata.pkl'
cc = list(candidates.iter_cands(pkl, 
                                select='canddata'))

Indexes to use the relevant canddata

In [None]:
cd = cc[ind][ind2]
print(cd.state.inttime)

In [None]:
mask_chans = None

In [None]:
def sgram_fn_2(
    metadata,
    pulse_function,
    spectra_function,
    spectra_params,
    pulse_params,
    other_params,
):
    """
    Vectorized implementation of spectrogram function. Assumes the following input names for pulse_function:
    S, mu, sigma, tau
    Args:
        metadata: Some useful metadata (nt, nf, dispersed_at_dm, tsamp, fstart, foff)
        pulse_function: Function to model pulse
        spectra_function: Function to model spectra
        spectra_params: Dictionary with spectra parameters
        pulse_params: Dictionary with pulse parameters
        other_params: list of other params needed for this function (eg: [dm])
    Returns:
    """

    nt, nf, dispersed_at_dm, tsamp, fstart, foff = metadata
    [dm] = other_params
    nt = int(nt)
    nf = int(nf)
    freqs = fstart + foff * np.linspace(0, nf - 1, nf)
    chans = np.arange(nf)
    times = np.arange(nt)
    spectra_from_fit = spectra_function(chans, **spectra_params)  # nu_0, nu_sig)

    model_dm = dispersed_at_dm - dm

    assert "mu" in pulse_params.keys()
    assert "S" in pulse_params.keys()
    assert "sigma" in pulse_params.keys()

    mu_t = pulse_params["mu"]
    mus = (
        mu_t
        + 4148808.0 * model_dm * (1 / (freqs[0]) ** 2 - 1 / (freqs) ** 2) / 1000 / tsamp
    )
    
    l = np.zeros(shape=(nf, nt))
    for i, freq in enumerate(freqs):
        p = pulse_function(times, pulse_params["S"], mus[i], pulse_params["sigma"])
        l[i, :] += p

    model = l * spectra_from_fit[:, None]

    return model

In [None]:
def gauss_decimate_64(x, S, mu, sigma):
    decimate_factor = 64
    
    if decimate_factor*sigma < 0.5 and mu > 0 and mu < len(x):
        pulse = np.zeros(len(x))
        pulse[int(np.around(mu))] = S
    else:
        S*=decimate_factor
        mu*=decimate_factor
        sigma*=decimate_factor

        x = np.arange(decimate_factor*len(x))
        high_res_pulse =  (S / (np.sqrt(2 * np.pi) * sigma)) * np.exp(
            -(1 / 2) * ((x - mu) / sigma) ** 2)
        pulse = high_res_pulse.reshape(len(x)//decimate_factor, decimate_factor).mean(-1)
    return pulse

def gauss_decimate_256(x, S, mu, sigma):
    decimate_factor = 256
    
    if decimate_factor*sigma < 0.5 and mu > 0 and mu < len(x):
        pulse = np.zeros(len(x))
        pulse[int(np.around(mu))] = S
    else:
        S*=decimate_factor
        mu*=decimate_factor
        sigma*=decimate_factor

        x = np.arange(decimate_factor*len(x))
        high_res_pulse =  (S / (np.sqrt(2 * np.pi) * sigma)) * np.exp(
            -(1 / 2) * ((x - mu) / sigma) ** 2)
        pulse = high_res_pulse.reshape(len(x)//decimate_factor, decimate_factor).mean(-1)
    return pulse

Fitting function

In [None]:
def fit(cd, casa_flux=1, flux_scale=1, mcmc=True, mask_chans=None):    
    name = cd.candid
    segment, candint, dmind, dtind, beamnum = cd.loc
    width_m = cd.state.dtarr[dtind]
    dm = cd.state.dmarr[dmind]
    dt = cd.state.inttime*width_m
    dispersed_at_dm = dm

    sgram = np.flip((cd.data.real.sum(axis=2).T), axis=0)
    
    if np.any(mask_chans):
        sgram.mask[mask_chans[0]:mask_chans[1], :] = True
    nf, nt = sgram.shape
    i0 = np.argmax(sgram.sum(0))
    spectra = sgram[:, i0]
    freqs = cd.state.freq
    inttime = cd.state.inttime

    print(cd.snr1)
    print(np.mean(sgram), np.std(sgram))

    pnames = ['S', 'mu_t', 'sigma_t']
    pulseModel = Model(gauss_decimate_64, param_names=pnames)

    snames = ['mu_f', 'sigma_f']
    spectraModel = Model(gauss_norm, param_names=snames)

    sgramModel = SgramModel(pulseModel, spectraModel, sgram_fn_2, clip_fac=0)

    roll = -1*(i0 - nt//2)
    sgram = np.roll(sgram, roll, 1)
    sgram = sgram * flux_scale
    
    radiometer_std = np.std(sgram[:, :20].mean(0))

    off_pulse_data = sgram[:, :20]
    on_pulse_data = sgram    

    off_pulse_mean = np.mean(off_pulse_data)
    off_pulse_std = np.std(off_pulse_data)
    # # # logger.info(f"Off pulse mean and std are: {off_pulse_mean, off_pulse_std}")
    on_pulse_data = on_pulse_data - off_pulse_mean
    on_pulse_data = on_pulse_data / off_pulse_std
    
    mask = np.around(sgram.mask.mean(1)).astype('bool')    
    
    bf = BurstFit(
        sgram_model=sgramModel,
        sgram=on_pulse_data,
        width=width_m,
        dm=dm,
        foff=(freqs[-1] - freqs[0])*1000/len(freqs),
        fch1=freqs[0]*1000,
        tsamp=inttime,
        clip_fac=0,
        mask = mask)
    bf.fitall(profile_bounds=([0, nt//2-10, 0], [1000, nt//2+10, 50])) 
    
    peak_flux_sgram = np.max(sgram.mean(0))
    casa_flux = burst['imfit_res']['results']['component0']['peak']['value']
    print(f'Peak flux (CASA): {casa_flux}')
    print(f'Peak flux (canddata): {peak_flux_sgram}')    
        
    bf.radiometer_std = radiometer_std
    bf.off_pulse_ts_std = np.std(bf.ts[:20])
    
    fluence = 10*bf.sgram_params['all'][1]['popt'][2]*radiometer_std/bf.off_pulse_ts_std
    sigma =  bf.sgram_params['all'][1]['popt'][4]*10
    width = 2.355 * sigma
    peak_fit_flux = fluence/(sigma*np.sqrt(2*np.pi))

    print(f'Peak flux fitted: {peak_fit_flux}')
    print(f'Fluence (Jy ms) is : {fluence}')
    print(f'Fit width (ms) is: {width}')
    
    if mcmc:
        mcmc_kwargs = {}
        mcmc_kwargs = {'nwalkers':40, 'nsteps':20000,
                       'skip':10000, 'ncores':20, 
                       'start_pos_dev':0.01,
                       'prior_range':0.5, 
                       'save_results':True,
                       'outname': 'rf_burstfit_mcmc'}

        bf.run_mcmc(plot=True, **mcmc_kwargs)

Reading MCMC outputs

In [None]:
def get_param(h5):
    reader = emcee.backends.HDFBackend(h5)
    try:
        tau = reader.get_autocorr_time()
        burnin = int(2 * np.max(tau))
        samples = reader.get_chain(discard=burnin, flat=True)
    except (AutocorrError, ValueError):
        print('Got error!')
        samples = reader.get_chain(discard=0, flat=True)
        burnin = int(samples.shape[0] * 0.75)
        samples = samples[burnin:,:]

    print("burn-in: {0}".format(burnin))
    print("flat chain shape: {0}".format(samples.shape))
    
    mu_fs = samples[:, 0] 
    sigma_fs = samples[:, 1]
    Ss = samples[:, 2]
    mu_ts = samples[:, 3]
    sigma_ts = samples[:, 4] 
    DMs = samples[:, 5]
    
    return mu_fs, sigma_fs, Ss, mu_ts, sigma_ts, DMs

def get_bary_time(ra, dec, mjd, dm, max_freq):
    """
    ra, dec as hms and dms strings
    
    """
    t_ms_vla = 4.148808*10**6*dm*(1/(max_freq)**2)
    _mjd = mjd - t_ms_vla/(1000*60*60*24)
    
    ip_frb = coord.SkyCoord(ra, dec, frame='icrs')
    vla = coord.EarthLocation.of_site('vla')
    times = time.Time(_mjd, format='mjd', scale='utc', location=vla)  
    ltt_bary = times.light_travel_time(ip_frb, )  
    time_barycentre = times.tdb + ltt_bary
    return time_barycentre.value

def get_param_and_errors(samples):
    qs = np.quantile(samples, [0.16, 0.5, 0.84], axis=0)
    e1 = qs[1] - qs[0]
    e2 = qs[2] - qs[1]
    p = qs[1]
    return p, e1, e2

# Fit! 

In [None]:
fit(cd, casa_flux=1, flux_scale=1, mcmc=True, mask_chans=None)

## Read MCMC outputs and convert to physical parameters

In [None]:
from emcee.autocorr import AutocorrError
import emcee
mu_fs, sigma_fs, Ss, mu_ts, sigma_ts, DMs = get_param('rf_burstfit_mcmc.h5')

In [None]:
foff = bf.foff
fch1 = bf.fch1
tsamp = bf.tsamp
nstart = cd.time_top - i0*bf.tsamp/(24*60*60)
tcand = cd.time_top

In [None]:
fluences = (tsamp*1000)*Ss*bf.radiometer_std/bf.off_pulse_ts_std
widths = 2.355*sigma_ts*tsamp*1000
nu_0s = (fch1 + mu_fs*foff)/1000
nu_sigs = sigma_fs*foff

if foff > 0:
    fmax = fch1 + foff * bf.nf
else:
    fmax = fch1
_mjds = nstart + mu_ts*tsamp/(24*60*60)

fluence = get_param_and_errors(fluences)
width = get_param_and_errors(widths)
nu_0 = get_param_and_errors(nu_0s)
nu_sig = get_param_and_errors(nu_sigs)
DM = get_param_and_errors(DMs)

_mjd = get_param_and_errors(_mjds)
mjds = [get_bary_time(_mjd[0], DM[0], fmax),
       get_bary_time(_mjd[0]-_mjd[1], DM[0], fmax),
       get_bary_time(_mjd[0]+_mjd[2], DM[0], fmax)]

mjd = [mjds[0], mjds[0] - mjds[1], mjds[2] - mjds[0]]

In [None]:
print(fluence, width, nu_0, nu_sig, DM, mjd)