In [None]:
try:
    import cmasher as cmr
except:
    !pip install cmasher
    import cmasher as cmr

try:
    import afterglowpy as grb
except:
    !pip install afterglowpy
    import afterglowpy as grb
    
try:
    import emcee
except:
    !pip install emcee
    import emcee
    
try:
    import corner
except:
    !pip install corner
    import corner

import numpy as np
import pandas as pd
from vasttools.pipeline import Pipeline
import matplotlib.pyplot as plt
from scipy.optimize import minimize, curve_fit
import scipy.stats
from IPython.display import display, Math
from astropy.coordinates import Distance
from astropy import units as u
from astropy.coordinates import Angle
import pickle
from tqdm import tqdm
import requests
from io import BytesIO

In [None]:
def name_to_time(name):
    name = name.lstrip('GW').split('_')
    if len(name) == 1:
        return pd.to_datetime(name[0], format='%y%m%d').to_datetime64()
    return pd.to_datetime(name[0] + name[1], format='%y%m%d%H%M%S').to_datetime64()

In [None]:
pipe = Pipeline()
piperun = pipe.load_run('combined')
# meas = piperun.measurements

In [None]:
def afterglowpy_log_likelihood(theta, Z, Z_static, t, y, yerr):
    model = grb.fluxDensity(t, 888.0e6, **Z_static, **dict(zip(Z.keys(), theta)))
    inv_sigma2 = 1/yerr**2
    res = -0.5*(np.sum((y-model)**2*inv_sigma2 - np.log(inv_sigma2)))
    return res



def afterglowpy_log_prior(theta):
    thetaObs, thetaCore, n0, p, epsilon_e, epsilon_B = theta
    if 0 <= thetaObs <= np.pi/2 \
      and 0 <= thetaCore <= np.pi/2 \
      and 0 < n0 <1 \
      and 2 < p \
      and 0 <= epsilon_e <= 1 \
      and 0 <= epsilon_B <= 1:
        return 0.0
    return -np.inf



def afterglowpy_log_probability(theta, Z, Z_static, t, y, yerr):
    lp = afterglowpy_log_prior(theta)
    if not np.isfinite(lp):
        return -np.inf
    return lp + afterglowpy_log_likelihood(theta, Z, Z_static, t, y, yerr)



def afterglowpy_make_chain_plot(chain, chain_cut, save_plots, fname):
    niters = chain.shape[1]
    ndim = chain.shape[2]

    fig, axes = plt.subplots(ndim,1,sharex=True)
    fig.set_size_inches(7, 20)
    
    param_names = ['$\\theta_{\\rm obs}$', '$\\theta_{\\rm core}$','$n_0$','$p$', '$\\epsilon_{e}$', '$\\epsilon_{B}$', '$d_{L}$']

    for i, (ax,param_name) in enumerate(zip(axes,param_names)):
        ax.plot(chain[:,:,i].T,linestyle='-',color='k',alpha=0.3)
        ax.set_ylabel(param_name)
        ax.set_xlim(0,niters)
        ax.axvline(chain_cut,c='r',linestyle='--')
    plt.show()
    if save_plots:
        plt.savefig(f'{fname}_afterglowpy_chain.png', dpi=200)
    
    
def afterglowpy_make_corner_plot(good_chain, fname, save_plots):
    param_names = ['$\\theta_{\\rm obs}$', '$\\theta_{\\rm core}$','$n_0$','$p$', '$\\epsilon_{e}$', '$\\epsilon_{B}$', 'd_{L}']
    ndim = good_chain.shape[2]
    fig = corner.corner(good_chain.reshape((-1, ndim)), labels=param_names, quantiles=[0.16, 0.5, 0.84], show_titles=True)
    plt.show()
    if save_plots:
        plt.savefig(f'{fname}_afterglowpy_corner.png', dpi=200)

        

def afterglowpy_get_starting_pos(starting_vals, nwalkers):
    thetaObs, thetaCore, n0, p, epsilon_e, epsilon_B = starting_vals  
    pos = [np.asarray([thetaObs, thetaCore, n0, p, epsilon_e, epsilon_B]) + 1e-4*np.random.randn(len(starting_vals)) for i in range(nwalkers)]
    return pos



def afterglowpy_run_mcmc(Z, Z_static, t, y, yerr, starting_vals, niters, nwalkers):
    nu = 0.888
    
    pos = afterglowpy_get_starting_pos(starting_vals, nwalkers)
    
    sampler = emcee.EnsembleSampler(
        nwalkers,
        len(starting_vals),
        afterglowpy_log_probability,
        args=(Z, Z_static, t, y, yerr)
    )
    
    sampler.run_mcmc(pos, niters, progress=True)
    
    return sampler



def fit_afterglowpy(source, event, mean, save_plots=False, fname=None, epoch_indices='all', niters=1000, nwalkers=50):
    """
    Fit the source measurements to an afterglowpy grb lightcurve via MCMC.

        Parameters:
            source (vasttools.source.Source): the VAST source object
            event_time (datetime64): the GW event time as a datetime64 object
            mean (float): the mean distance luminosity of the GW event in Mpc
            fname (bool): plot prefix name (default None)
            save_plots (bool): whether to save the produced plots (default false)
            epoch_indices ('all' or list): a list of epoch indices to fit the curve to (default 'all')
            niters (int): number of iterations to run MCMC (default 1000)
            nwalkers(int): number of walkers in MCMC (default 50)
    """
    Z = {
        'thetaObs': 0.05,
        'thetaCore': 0.1,
        'n0': 1.0,
        'p': 2.2,
        'epsilon_e': 0.1,
        'epsilon_B': 0.01}
    
    Z_static = {
    'jetType': -1,
    'specType': 0,
    'xi_N': 1.0,
    'L0': 0.0, 
    'q': 0.0,
    'ts': 0.0,
    'E0': 1e+53,
    'd_L': mean*u.Mpc.to('cm')
    }
    
    source_meas = source.measurements
    event_time = name_to_time(event)
    after_meas = source_meas.iloc[np.where(source_meas.dateobs>event_time)[0]]
    
    if epoch_indices == 'all':
        x = np.array((after_meas.dateobs-event_time)/pd.Timedelta(1, unit='d'))
        y = after_meas.flux_peak
        yerr = after_meas.flux_peak_err
    else:
        x = ((after_meas.dateobs-event_time)/pd.Timedelta(1, unit='d')).values[epoch_indices]
        y = after_meas.flux_peak.values[epoch_indices]
        yerr = after_meas.flux_peak_err.values[epoch_indices]

    t=pd.to_timedelta(x, unit='D').total_seconds()
    nu = np.empty(t.shape)
    nu[:] = 888.0e6
    
    starting_vals = np.fromiter(Z.values(), dtype=float)
    pos = afterglowpy_get_starting_pos(starting_vals, nwalkers)

    sampler = afterglowpy_run_mcmc(Z, Z_static, t, y, yerr, starting_vals, niters, nwalkers)
    chain = sampler.chain
    
    
    chain_cut = 200
    afterglowpy_make_chain_plot(chain, chain_cut, save_plots=save_plots, fname=fname)
    
    good_chain = chain[:, chain_cut:, :]
    afterglowpy_make_corner_plot(good_chain, save_plots=save_plots, fname=fname)
    
    
    flat_samples = sampler.get_chain(discard=chain_cut, thin=15, flat=True)
    inds = np.random.randint(len(flat_samples), size=10)
    for ind in inds:
        sample = flat_samples[ind]
        new_y = grb.fluxDensity(t, nu, **Z_static, **dict(zip(Z.keys(), sample)))
        plt.scatter(x, new_y)
    plt.errorbar(x, y, yerr=yerr, fmt=".k", capsize=0)
    plt.xlabel("x")
    plt.ylabel("y")
    plt.show()

In [None]:
def smooth_broken_power_law(t, nu, F_peak, t_peak, delta_1, delta_2, alpha, log_s, nu0=3.0):
    s = 10**log_s
    return (nu/nu0)**alpha * F_peak * ((t/t_peak)**(-s*delta_1) + (t/t_peak)**(-s*delta_2))**(-1.0/s)



def bpl_log_likelihood(theta, x, nu, y, yerr, inv_sigma2):
    F_peak, t_peak, delta_1, delta_2, alpha, log_s = theta
    model = smooth_broken_power_law(x, nu, F_peak, t_peak, delta_1, delta_2, alpha, log_s)
    res = -0.5*(np.sum((y-model)**2*inv_sigma2 - np.log(inv_sigma2)))
    return res



def bpl_log_prior(theta, time_peak_limit):
    F_peak, t_peak, delta_1, delta_2, alpha, log_s = theta
    
    if 0 < t_peak < time_peak_limit \
      and 0 < F_peak \
      and 0 < delta_1 \
      and delta_2 < 0.0 \
      and -10 < alpha \
      and np.isfinite(log_s) \
      and log_s < 10:
        return 0.0
    
    else:
        return -np.inf



def bpl_log_probability(theta, x, nu, y, yerr, inv_sigma2, time_peak_limit):
    lp = bpl_log_prior(theta, time_peak_limit)
    if not np.isfinite(lp):
        return -np.inf
    return lp + bpl_log_likelihood(theta, x, nu, y, yerr, inv_sigma2)



def bpl_make_chain_plot(chain, chain_cut, fname, save_plots, param_names=[]):
    niters = chain.shape[1]
    ndim = chain.shape[2]

    fig, axes = plt.subplots(ndim,1,sharex=True)
    fig.set_size_inches(7, 20)
    
    param_names = ['$F_{{\\rm peak}, 3\.{\\rm GHz}}$', '$t_{{\\rm peak}}$','$\\delta_1$','$\\delta_2$', '$\\alpha$', '$\\log_{10}(s)$']

    for i, (ax,param_name) in enumerate(zip(axes,param_names)):
        ax.plot(chain[:,:,i].T,linestyle='-',color='k',alpha=0.3)
        ax.set_ylabel(param_name)
        ax.set_xlim(0,niters)
        ax.axvline(chain_cut,c='r',linestyle='--')
    plt.show()
    if save_plots:
        plt.savefig(f'{fname}_power_law_chain.png', dpi=200)

    
    
def bpl_make_corner_plot(good_chain, save_plots, fname):
    param_names = ['$F_{{\\rm peak}, 3\.{\\rm GHz}}$', '$t_{{\\rm peak}}$','$\\delta_1$','$\\delta_2$', '$\\alpha$', '$\\log_{10}(s)$']
    ndim = good_chain.shape[2]
    fig = corner.corner(good_chain.reshape((-1, ndim)), labels=param_names, quantiles=[0.16, 0.5, 0.84], show_titles=True)
    plt.show()
    if save_plots:
        plt.savefig(f'{fname}_power_law_corner.png', dpi=200)
    
    
    
def bpl_get_starting_pos(starting_vals, nwalkers, ndim=6):
    F_peak, t_peak, delta_1, delta_2, alpha, log_s = starting_vals  
    pos = [np.asarray([F_peak, t_peak, delta_1, delta_2, alpha, log_s]) + 1e-4*np.random.randn(ndim) for i in range(nwalkers)]
    return pos



def bpl_run_mcmc(x, y, yerr, starting_vals, niters, nwalkers, time_peak_limit):
    nu = 0.888
    inv_sigma2 = 1.0/yerr**2
    
    pos = bpl_get_starting_pos(starting_vals, nwalkers, ndim)
    
    sampler = emcee.EnsembleSampler(nwalkers, len(starting_vals), bpl_log_probability, args=(x, nu, y, yerr, inv_sigma2, time_peak_limit))
    
    sampler.run_mcmc(pos, niters, progress=True)
    
    return sampler



def bpl_get_best_params(chain):
    ndim = chain.shape[2]
    
    chain = chain.reshape((-1, ndim))
    vals = map(lambda v: (v[1], v[2]-v[1], v[1]-v[0]), zip(*np.percentile(chain, [16, 50, 84],axis=0)))
    
    param_names = ['F_peak', 't_peak', 'delta_1', 'delta_2', 'alpha', 'log_s']
    
    param_dict = dict(zip(param_names,vals))
    
    return param_dict
    


def calc_chi2(x, y, yerr, best_params, param_names, model, nu0=3.0):
    """Calculates chi-squared between a set of measurements and a model"""
    args = []
    for param in param_names:
        val = best_params[param][0]
        args.append(val)

    best_fit = model(x, 0.888, *args)
    
    chi2 = np.sum((best_fit-y)**2/yerr**2)
    return chi2



def bpl_make_plot(x, y, yerr, save_plots, fname, model=None, params=None, tvals=np.arange(10,400), plot_models=False):
    """Make broken power law plot"""
    fig = plt.figure(figsize=(10,6))
    ax = fig.add_subplot(111)
    
    ax.errorbar(x, y, yerr, linestyle='')
    
    ax.set_xscale('log')
    ax.set_yscale('log')
    
    ax.set_xlabel('Time (days)')

    ax.set_ylabel('Flux Density ($\mu$Jy)')
        
    if model:
        plot_model(model, params, tvals, ax)
    
    ax.set_xlim(min(x)/1.2,max(x)*1.2)
    plt.show()
    if save_plots:
        plt.savefig(f'{fname}_fitted_power_law.png', dpi=200)

    

def plot_model(model, params, tvals, ax):
    """Add a model lightcurve to a plot"""
    best_fit = model(tvals, 0.888, *params)
    
    ax.plot(tvals,best_fit,marker='',linestyle='-',c='k',linewidth=1.5,zorder=0)
    ax.set_ylim(bottom=0.2)
    ax.set_xscale('log')
    return



def fit_bpl(source, event, mean, fname=None, save_plots=False, epoch_indices='all', niters=1000, nwalkers=50):
    """
    Fit the source measurements to a broken power law lightcurve via MCMC.
    
        Parameters:
            source (vasttools.source.Source): the VAST source object
            event_time (datetime64): the GW event time as a datetime64 object
            mean (float): the mean distance luminosity of the GW event in Mpc
            fname (bool): plot prefix name (default None)
            save_plots (bool): whether to save the produced plots (default false)
            epoch_indices ('all' or list): a list of epoch indices to fit the curve to (default 'all')
            niters (int): number of iterations to run MCMC (default 1000)
            nwalkers(int): number of walkers in MCMC (default 50)
    """
    event_time = name_to_time(event)
    source_meas = source.measurements
    after_meas = source_meas.iloc[np.where(source_meas.dateobs>event_time)[0]]
    
    if epoch_indices == 'all':
        x = np.array((after_meas.dateobs-event_time)/pd.Timedelta(1, unit='d'))
        y = after_meas.flux_peak
        yerr = after_meas.flux_peak_err
    else:
        x = ((after_meas.dateobs-event_time)/pd.Timedelta(1, unit='d')).values[epoch_indices]
        y = after_meas.flux_peak.values[epoch_indices]
        yerr = after_meas.flux_peak_err.values[epoch_indices]
        
    
    time_peak_limit = x[np.argmax(y)+1]
    
    starting_vals = {
        'F_peak': max(y),
        't_peak': x[np.argmax(y)],
        'delta_1': 1,
        'delta_2': -1,
        'alpha': 1,
        'log_s': 1}
    
    
    sampler = bpl_run_mcmc(x, y, yerr, starting_vals.values(), niters, nwalkers,time_peak_limit)
    chain = sampler.chain
    
    chain_cut = 200
    bpl_make_chain_plot(chain, chain_cut, save_plots=save_plots, fname=fname)
    
    good_chain = chain[:, chain_cut:, :]
    bpl_make_corner_plot(good_chain, save_plots=save_plots, fname=fname)
    
    best_params = bpl_get_best_params(good_chain)
    
    param_names = ['F_peak', 't_peak', 'delta_1', 'delta_2', 'alpha', 'log_s']

    chi2_best = calc_chi2(x, y, yerr, best_params, param_names, smooth_broken_power_law)
    print(chi2_best)
    
    args = []
    for param in param_names:
        val = best_params[param][0]
        args.append(val)

    bpl_make_plot(x, y, yerr, tvals = np.arange(min(x)-10, max(x)+100), model=smooth_broken_power_law, params=args, save_plots=save_plots, fname=fname)
    return chi2_best

In [None]:
def make_plot(E_iso_vals, dist_vals, theta_vals, n_vals, fname, save_plots, alpha=0.3):
    colours = cmr.take_cmap_colors('cmr.rainforest', 3, cmap_range=(0.3,0.7), return_fmt='hex')
    
    fig = plt.figure(figsize=(10, 10))
    ax = fig.add_subplot(111)
    
    ax.set_xlim(0,np.pi/2)
#     ax.set_ylim(1e-6, 1e1)
    ax.set_ylim(1e-3, 1e1)
    ax.set_yscale('log')
    
    ax.set_xlabel(r'$\theta_{\rm obs}$ (rad)')
    ax.set_ylabel(r'$n$ (cm$^{-3}$)')
    
#     ax.axvline(np.rad2deg(theta_median), c='k')
#     ax.axvline(np.rad2deg(theta_median-theta_lower), c='k', ls='--')
#     ax.axvline(np.rad2deg(theta_median+theta_upper), c='k', ls='--')


    for E_iso, col in zip(E_iso_vals, colours):
        full_lims = []
        for dist in dist_vals:
            picklefile = 'pickles/{}_Eiso{}_d{}.pickle'.format(fname, E_iso, dist.value)
            lims = get_lims(picklefile)
            if lims is not None:
                full_lims.append(lims)
        if len(full_lims) != 2:
            continue
        label='$E_{{\\rm iso}}=${} erg'.format(E_iso)

        ax.fill_between(theta_vals, n_vals[full_lims[0]], n_vals[full_lims[1]], color=col, alpha=alpha, label=label)

    plt.legend(loc='upper left')
    plt.tight_layout()
    if save_plots:
        plt.savefig(f'{fname}_constraints.png', dpi=200)
    plt.show()
    
    
    
def get_lims(picklefile):
    try:
        ruled_out = pickle.load(open(picklefile, 'rb'))
    except:
        return None
    lim = ruled_out.argmax(axis=1)
    return lim



def run_calc(t, nu, upper_lims, E_iso_vals, dist_vals, theta_vals, n_vals, Z_static, fname):
    for E_iso in E_iso_vals:
        for dist in dist_vals:
            print("Calculating lightcurves for E_iso={} ergs, d_L={}".format(E_iso, dist))
            picklefile = 'pickles/{}_Eiso{}_d{}.pickle'.format(fname, E_iso, dist.value)
            ruled_out = np.empty(shape=(len(theta_vals), len(n_vals)), dtype=bool)
            for i, theta_obs in enumerate(tqdm(theta_vals)):
                for j, n in enumerate(n_vals):
                    Fnu = grb.fluxDensity(
                        t,
                        nu,
                        E0=E_iso,
                        d_L=dist.to(u.cm).value,
                        thetaObs=theta_obs,
                        n0=n,
                        **Z_static
                        )
                    
                    ruled_out[i,j] = np.count_nonzero(Fnu>=upper_lims) >= 2
            pickle.dump(ruled_out, open(picklefile, 'wb'))
            
            
            
def get_constraints(
    source,
    event,
    mean,
    upper,
    lower,
    fname,
    save_plots=False,
    E_iso_vals=np.array([5e54, 2e55, 5e55]),
    theta_vals=np.deg2rad(np.linspace(0,90,50)),
    n=50
):
    Z_static = {
        'jetType': -1,
        'specType': 0,
        'xi_N': 1.0,
        'L0': 0.0, 
        'q': 0.0,
        'ts': 0.0,
        'p': 2.2,
        'epsilon_e': 0.1,
        'epsilon_B': 0.01,
        'thetaCore': 0.1
    }
    
    event_time = name_to_time(event)
    source_meas = source.measurements
    after_meas = source_meas.iloc[np.where(source_meas.dateobs>event_time)[0]]
    
    dist_vals = np.array([mean-lower, mean+upper])*u.Mpc
    n_vals = np.logspace(-3,1,n)
    
    x = np.array((after_meas.dateobs-event_time)/pd.Timedelta(1, unit='d'))

    t=pd.to_timedelta(x, unit='D').total_seconds()
    
    nu = 888e6
    
    meas = source.measurements
    sigma = np.median(meas['flux_peak_err'])
    upper_lims = sigma/50
    run_calc(t, nu, upper_lims, E_iso_vals, dist_vals, theta_vals, n_vals, Z_static, fname)
    make_plot(E_iso_vals, dist_vals, theta_vals, n_vals, fname, save_plots)

In [None]:
r = requests.get('https://docs.google.com/spreadsheets/d/e/2PACX-1vTPTtxWq4mVNiM5eKL_98a53O6-gQteS7Ab7kdIUqtwxsThLIR7yh60kPTTiwbw0pE45mXoZUYeBCWA/pub?output=csv')
df = pd.read_csv(BytesIO(r.content), index_col=0)
interesting = df[(df['Dougal classification'] == 'yes')]
interesting

In [None]:
def plot_lightcurve(source, event, save_fig=False):
    event_time = name_to_time(event)
    a = source.plot_lightcurve(start_date=pd.Timestamp(event_time), save=save_fig)
    ax = a.gca()
    ax.set_xscale('log')
    ax.set_yscale('log')
    ax.set_ylim(bottom=min(x for x in source.measurements.flux_peak if x>0)/1.01)
    plt.show()
    plt.close(fig='all')

In [None]:
%matplotlib inline
for sid in interesting.index.values:
    print(sid)
    source = piperun.get_source(sid)
    meas = source.measurements
    
    source_df = interesting[interesting.index == sid]
    event = source_df.event.values[0]
    mean, upper, lower = source_df[['Event Distance (Mpc)', 'Event Distance Upper (Mpc)', 'Event Distance Lower (Mpc)']].values[0]

    
    epoch_indices='all'
    plot_lightcurve(source, event)
    if sid == 5535012:
        epoch_indices = [1,2,3,4,5,6,8,9]
    elif sid == 4895153:
        epoch_indices = [range(1,10)]
    
#     fit_afterglowpy(source, event, mean, fname=f'{sid}_{event}', save_plots=False, epoch_indices=epoch_indices)
#     get_constraints(source, event, mean, upper, lower, fname=f'{sid}_{event}', save_plots=False, n=100)
#     fit_bpl(source, event, mean, fname=f'{sid}_{event}', save_plots=False, epoch_indices=epoch_indices)
    