In [None]:
import numpy as np
from scipy import signal
import matplotlib.pyplot as plt
import corner
import emcee
import json

In [None]:
pfile = "../tests/files/real_mini_mcmc.json"
with open(pfile) as f:
    mcmc_dict = json.load(f)

# Select from dictionary the necessary parameters to be changed
# labels = mcmc_dict['mcmc']['parameter_order']

# labels = [r"$n_{\mathrm{sfr}}$", r"$\alpha$", r"log$_{10}(\mu_{\rm{host}})$", r"log$_{10}(\sigma_{\rm{host}})$", r"log$_{10}(E_{\mathrm{max}})$", r"$\gamma$", r"$H_0$"]
# filenames = ["../mcmc/Hoffmann2023_CRAFT_no_191228", "../mcmc/Hoffmann2023_CRAFT_no_191228_2"]
# filenames = ["../mcmc/Hoffmann2023_CRAFT2", "../mcmc/Hoffmann2023_CRAFT3"]
# filenames = ["../mcmc/Hoffmann2023_exact_no_191228", "../mcmc/Hoffmann2023_exact_no_191228_2", 
#              "../mcmc/Hoffmann2023_exact_no_191228_3", "../mcmc/Hoffmann2023_exact_no_191228_4", 
#              "../mcmc/Hoffmann2023_exact_no_191228_5", "../mcmc/Hoffmann2023_exact_no_191228_6",
#              "../mcmc/Hoffmann2023_exact_no_191228_7", "../mcmc/Hoffmann2023_exact_no_191228_8"]
# filenames = ["../mcmc/Hoffmann2023_exact", "../mcmc/Hoffmann2023_exact2", "../mcmc/Hoffmann2023_exact3", "../mcmc/Hoffmann2023_exact5", 
#             "../mcmc/Hoffmann2023_exact6", "../mcmc/Hoffmann2023_exact7", "../mcmc/Hoffmann2023_exact8"]

# labels = [r"$n_{\mathrm{sfr}}$", r"$\alpha$", r"log$_{10}(\mu_{\rm{host}})$", r"log$_{10}(\sigma_{\rm{host}})$", r"log$_{10}(E_{\mathrm{max}})$", r"$\gamma$", r"$H_0$", r"DM$_{\rm{halo}}$"]
# labels = ["sfr_n", "alpha", "lmean", "lsigma", "lEmax", "gamma", "H0", "DMhalo"]
# filenames = ['../mcmc/DSA5', '../mcmc/DSA7']
# filenames = ['../mcmc/DSA2', '../mcmc/DSA3 ']

# labels = [r"log$_{10}(F)$", r"$n_{\mathrm{sfr}}$", r"$\alpha$", r"log$_{10}(\mu_{\rm{host}})$", r"log$_{10}(\sigma_{\rm{host}})$", 
#           r"log$_{10}(E_{\mathrm{max}})$", r"$\gamma$", r"$H_0$", r"log$_{10}(E_{\rm{min}})$", r"DM$_{\rm{halo}}$"]
# labels = ["logF", "sfr_n", "alpha", "lmean", "lsigma", "lEmax", "gamma", "H0", "lEmin", "DMhalo"]
# filenames = ['../mcmc/DSA_FAST_CRAFT', '../mcmc/DSA_FAST_CRAFT_2']

labels = [r"$n_{\mathrm{sfr}}$", r"$\alpha$", r"log$_{10}(\mu_{\rm{host}})$", r"log$_{10}(\sigma_{\rm{host}})$", 
          r"log$_{10}(E_{\mathrm{max}})$", r"$\gamma$", r"$H_0$", r"log$_{10}(E_{\rm{min}})$", r"DM$_{\rm{halo}}$"]
labels = ["sfr_n", "alpha", "lmean", "lsigma", "lEmax", "gamma", "H0", "lEmin", "DMhalo"]
filenames = ['../mcmc/DSA_FAST_CRAFT_3']

# labels = [r"log$_{10}(F)$", r"$n_{\mathrm{sfr}}$", r"$\alpha$", r"log$_{10}(\mu_{\rm{host}})$", r"log$_{10}(\sigma_{\rm{host}})$", r"log$_{10}(E_{\mathrm{max}})$", r"$\gamma$", r"$H_0$", r"DM$_{\rm{halo}}$"]
# filenames = ['../mcmc/DSA6']

# labels = ["sfr_n", "alpha", "lmean", "lsigma", "lEmax", "gamma", "H0", "lEmin"]
# filenames = ['../mcmc/FASTnpCRAFT6']
# filenames = ['../mcmc/FASTnpCRAFT2', '../mcmc/FASTnpCRAFT3', '../mcmc/FASTnpCRAFT4']

# labels = ["sfr_n", "alpha", "lmean", "lsigma", "lEmax", "gamma", "H0", "lEmin", "DMhalo"]
# filenames = ['../mcmc/DSA4']

samples = []

for i, filename in enumerate(filenames):
    reader = emcee.backends.HDFBackend(filename + '.h5')
    samples.append(reader.get_chain())

In [None]:
# Make alpha negative
a=-1
for i, x in enumerate(labels):
    if x == r"$\alpha$":
        a = i

if a != -1:
    for sample in samples:
        sample[:,:,a] = -sample[:,:,a]    

In [None]:
# Make F linear
# a=-1
# for i, x in enumerate(labels):
#     if x == r"log$_{10}(F)$":
#         a = i

# if a != -1:
#     for sample in samples:
#         sample[:,:,a] = 10**sample[:,:,a]    

In [None]:
for j,sample in enumerate(samples):
    fig, axes = plt.subplots(sample.shape[2], 1, figsize=(20,30), sharex=True)
    plt.title("Sample: " + filenames[j])
    for i,ax in enumerate(axes):
        for k in range(sample.shape[1]):
            ax.plot(sample[:,k,i], '.-', label=str(k))

        ax.set_ylabel(labels[i])
    
    axes[-1].set_xlabel("Step number")
    axes[-1].legend()


In [None]:
# https://emcee.readthedocs.io/en/stable/tutorials/autocorr/#a-more-realistic-example
def next_pow_two(n):
    i = 1
    while i < n:
        i = i << 1
    return i

def autocorr_func_1d(x, norm=True):
    x = np.atleast_1d(x)
    if len(x.shape) != 1:
        raise ValueError("invalid dimensions for 1D autocorrelation function")
    n = next_pow_two(len(x))

    # Compute the FFT and then (from that) the auto-correlation function
    f = np.fft.fft(x - np.mean(x), n=2 * n)
    acf = np.fft.ifft(f * np.conjugate(f))[: len(x)].real
    acf /= 4 * n

    # Optionally normalize
    if norm and acf[0] != 0:
        acf /= acf[0]

    return acf

# Automated windowing procedure following Sokal (1989)
def auto_window(taus, c):
    m = np.arange(len(taus)) < c * taus
    if np.any(m):
        return np.argmin(m)
    return len(taus) - 1

def autocorr(y, c=5.0):
    f = np.zeros(y.shape[1])
    for yy in y:
        f += autocorr_func_1d(yy)
    f /= len(y)
    taus = 2.0 * np.cumsum(f) - 1.0
    window = auto_window(taus, c)
    return taus[window]

# def autocorr_ml(y, thin=1, c=5.0):
#     # Compute the initial estimate of tau using the standard method
#     init = autocorr(y, c=c)
#     z = y[:, ::thin]
#     N = z.shape[1]

#     # Build the GP model
#     tau = max(1.0, init / thin)
#     kernel = terms.RealTerm(
#         np.log(0.9 * np.var(z)),
#         -np.log(tau),
#         bounds=[(-5.0, 5.0), (-np.log(N), 0.0)],
#     )
#     kernel += terms.RealTerm(
#         np.log(0.1 * np.var(z)),
#         -np.log(0.5 * tau),
#         bounds=[(-5.0, 5.0), (-np.log(N), 0.0)],
#     )
#     gp = celerite.GP(kernel, mean=np.mean(z))
#     gp.compute(np.arange(z.shape[1]))

    # # Define the objective
    # def nll(p):
    #     # Update the GP model
    #     gp.set_parameter_vector(p)

    #     # Loop over the chains and compute likelihoods
    #     v, g = zip(*(gp.grad_log_likelihood(z0, quiet=True) for z0 in z))

    #     # Combine the datasets
    #     return -np.sum(v), -np.sum(g, axis=0)

    # # Optimize the model
    # p0 = gp.get_parameter_vector()
    # bounds = gp.get_parameter_bounds()
    # soln = minimize(nll, p0, jac=True, bounds=bounds)
    # gp.set_parameter_vector(soln.x)

    # # Compute the maximum likelihood tau
    # a, c = kernel.coefficients[:2]
    # tau = thin * 2 * np.sum(a / c) / np.sum(a)
    # return tau

In [None]:
# Reject walkers with bad autocorrelation values
def auto_corr_rej(samples, burnin=0):
    good_samples = []

    # Loop through each sample and generate a list of good walkers and bad walkers
    for j,sample in enumerate(samples): 
        # burnin=200
        good_walkers = []
        bad_walkers = []


        # for i in range(sample.shape[1]):
        #     # if np.all(sample[burnin:burnin+30,i,0] == sample[burnin,i,0]):
        #     if ( np.std(sample[burnin:burnin+30,i,0] ) )
        #         bad_walkers.append(i)
        #     else:
        #         good_walkers.append(i)

        # Loop through each walker in the current sample
        for i in range(sample.shape[1]):
            bad = False

            # Loop through each parameter for the walker
            for k in range(sample.shape[2]):

                # If any of the parameters have a bad autocorrelation function then set as a bad walker
                acf = autocorr_func_1d(sample[burnin:,i,k], norm=False)
                if np.max(acf) < 1e-10:
                    bad = True
                    break

            if bad:
                bad_walkers.append(i)
            else:
                good_walkers.append(i)
            
            # fig = plt.figure()
            # plt.title(str(j) + ", " + str(i))
            # ax = fig.add_subplot(1,1,1)
            # ax.plot(acf)

        print("Discarded walkers for sample " + str(j) + ": " + str(bad_walkers))

        # Add the new sample with the bad walkers discarded to the good_samples list
        good_samples.append(sample[burnin:,good_walkers,:])

    return good_samples

# Reject walkers with small standard deviations
def std_rej(samples, burnin=0):
    good_samples = []

    if not type(burnin) == list:
        burnin = [burnin for i in range(len(samples))]

    # Loop through each sample
    for i, sample in enumerate(samples):
        bad_walkers = []
        good_walkers = []

        # For each parameter
        for k in range(sample.shape[2]):
            sd = []

            # Loop through every walker and get a list of the standard deviations
            for j in range(sample.shape[1]):
                sd.append(np.std(sample[burnin[i]:burnin[i]+100,j,k]))
            
            # Normalise standard deviation
            sd = sd / np.max(sd)

            # Flag any walkers with standard deviations less than 1e-2
            # bad_walkers = np.flatnonzero(sd < 1e-2)
            # temp = []
            for m in range(len(sd)):
                if sd[m] < 1e-2:
                    bad_walkers.append(m)
        
        bad_walkers = np.unique(np.array(bad_walkers))

        print("Discarded walkers for sample " + str(i) + ": " + str(bad_walkers))
        for l in range(sample.shape[1]):
            if l not in bad_walkers:
                good_walkers.append(l)

        # Add the new sample with the bad walkers discarded to the good_samples list
        good_samples.append(sample[burnin[i]:,good_walkers,:])
    
    return good_samples

In [None]:
good_samples = std_rej(samples, burnin=0)
# good_samples = samples
# _ = auto_corr_rej(samples, burnin=0)

In [None]:
burnin = []
for sample in good_samples:
    # Compute the estimators for a few different chain lengths
    N = np.exp(np.linspace(np.log(10), np.log(sample.shape[0]), 10)).astype(int)
    new = np.empty(len(N))
    for i, n in enumerate(N):
        new[i] = autocorr(sample[:, :n, 0].T)

    # Plot the comparisons
    fig, ax = plt.subplots(1,1,figsize=(8,6))
    ax.loglog(N, new, "o-", label="new")
    ylim = ax.get_ylim()
    ax.plot(N, N / 50.0, "--k", label=r"$\tau = N/50$")
    ax.set_ylim(ylim)
    ax.set_xlabel("number of samples, $N$")
    ax.set_ylabel(r"$\tau$ estimates")
    # ax.legend(fontsize=14);

    burnin.append(int(1.5*new[-1]))

In [None]:
# good_samples = std_rej(samples, burnin=200)
# _ = auto_corr_rej(samples, burnin=burnin)
# burnin = (np.ones(len(good_samples)) * 100).astype(int)
# print(burnin)

In [None]:
for j,sample in enumerate(good_samples):
    fig, axes = plt.subplots(sample.shape[2], 1, figsize=(20,30), sharex=True)
    plt.title("Sample: " + filenames[j])
    for i,ax in enumerate(axes):
        for k in range(sample.shape[1]):
            ax.plot(sample[burnin[j]:,k,i], '.-', label=str(k))

        ax.set_ylabel(labels[i])
    
    axes[-1].set_xlabel("Step number")
    axes[-1].legend()


In [None]:
# Enforce more restrictive priors on a parameter
def change_priors(sample, param_num, max=np.inf, min=-np.inf):

    condition = np.logical_and(sample[:,param_num] > min, sample[:,param_num] < max)
    good_idxs = np.flatnonzero(condition)

    return sample[good_idxs, :]

In [None]:
# Get the final sample without burnin and without bad walkers
final_sample = [[] for i in range(samples[0].shape[2])]

print(burnin)
# burnin = (np.ones(len(good_samples)) * 60).astype(int)
# print(burnin)

for j,sample in enumerate(good_samples):
    for i in range(sample.shape[2]):
        final_sample[i].append(sample[burnin[j]:,:,i].flatten())
final_sample = np.array([np.hstack(final_sample[i]) for i in range(len(final_sample))]).T

# final_sample = change_priors(final_sample, 8, min=20.0)
# final_sample = change_priors(final_sample, 7, max=110.0)
# final_sample = change_priors(final_sample, 9, max=80.0)
# final_sample = change_priors(final_sample, 1, max=1.0, min=-3.5)

print(final_sample.shape)


In [None]:
fig = plt.figure(figsize=(12,12))
titles = ['' for i in range(final_sample.shape[1])]
corner.corner(final_sample,labels=labels, show_titles=True, titles=titles, fig=fig,title_kwargs={"fontsize": 15},label_kwargs={"fontsize": 15}, quantiles=[0.16,0.5,0.84]);

In [None]:
nBins = 30
win_len = int(nBins/10)
CL = 0.68

best_fit = {}

for i in range(len(labels)):
    fig = plt.figure(figsize=(6,4))
    ax = fig.add_subplot(1,1,1)
    hist, bin_edges, _ = ax.hist(final_sample[:,i], bins=nBins, density=True)
    bin_width = bin_edges[1] - bin_edges[0]
    bins = -np.diff(bin_edges)/2.0 + bin_edges[1:]

    ax.set_xlabel(labels[i])
    ax.set_ylabel("P("+labels[i]+")")

    # Use mode ordered
    # ordered_idxs = np.argsort(hist)

    # sum = hist[ordered_idxs[0]] * bin_width
    # j = 1
    # while(sum < 1-CL):
    #     sum += hist[ordered_idxs[j]] * bin_width
    #     j = j+1

    # best = bins[ordered_idxs[-1]]
    # lower = bins[np.min(ordered_idxs[j:])]
    # upper = bins[np.max(ordered_idxs[j:])]

    # Use median
    best = np.quantile(final_sample[:,i], 0.5)
    # best = bins[np.argmax(hist)]
    lower = np.quantile(final_sample[:,i], 0.16)
    upper = np.quantile(final_sample[:,i], 0.84)

    best_fit[labels[i]] = best
    u_lower = best - lower
    u_upper = upper - best
    ax.axvline(lower, color='r')
    ax.axvline(best, color='r')
    ax.axvline(upper, color='r')
    # print(labels[i] + ": " + str(best) + " (-" + str(u_lower) + "/+" + str(u_upper) + ")")
    print(rf'{labels[i]}: {best} (-{u_lower}/+{u_upper})')

In [None]:
import scipy.stats as st

In [None]:
nsamps = np.linspace(3, np.log10(final_sample.shape[0]/10), 30)
nsamps = [int(10**x) for x in nsamps]
print("Number of samps: " + str(nsamps))

for i in range(len(labels)):
    # nsamps = []
    std = []
    for j in range(len(nsamps)):
        best = []
        nruns = int(final_sample.shape[0] / nsamps[j])
        for k in range(nruns):
            # best.append(np.quantile(final_sample[nsamps[j]*k:nsamps[j]*(k+1),i], 0.5))
            step = int(final_sample.shape[0]/nsamps[j])
            best.append(np.quantile(final_sample[k::step,i], 0.5))
        std.append(np.std(best))

    # print(labels[i] + ": " + str(std))

    line = st.linregress(np.log10(nsamps),np.log10(std))
    x = np.linspace(nsamps[0], nsamps[-1], 50)
    y = 1/np.sqrt(x)
    y = y / y[0] * std[0]
    y = 10**(line.slope*np.log10(x) + line.intercept)
    # print(line.slope)
    print(labels[i] + ": " + str(10**(line.slope*np.log10(final_sample.shape[0]) + line.intercept)))
    print(str(line.slope))
    fig = plt.figure(figsize=(6,4))
    ax = fig.add_subplot(1,1,1)

    ax.plot(nsamps, std)
    ax.loglog(x,y)
    ax.set_xlabel("Number of samples")
    ax.set_ylabel("Standard deviation")
    ax.set_title(labels[i])

In [None]:
from zdm import survey
from zdm import cosmology as cos
from zdm.craco import loading
from zdm.misc_functions import *
import zdm.iteration as it
from zdm.MCMC import calc_log_posterior

In [None]:
prefix='DSA_FAST_CRAFT'

cos.init_dist_measures()
state = loading.set_state()

# get the grid of p(DM|z)
zDMgrid, zvals,dmvals=get_zdm_grid(state,new=True,plot=False,method='analytic',save=True,datdir='MCMCData')

# Load surveys
with open('../Pickle/'+prefix+'surveys.pkl', 'rb') as infile:
    surveys=pickle.load(infile)
    names=pickle.load(infile)

# Load grids
with open('../Pickle/'+prefix+'grids.pkl', 'rb') as infile:
    grids=pickle.load(infile)

# Make params with the correct parameters but no priors
params = {}
for key in labels:
    params[key] = {
        "min": -np.inf,
        "max": np.inf
    }

In [None]:
# best_fit["lEmin"] = 30
fit = {}
# best_fit["gamma"] = -1.0
newC, llc = it.minimise_const_only(best_fit, grids, surveys)

for s,g in zip(surveys, grids):
    g.state.FRBdemo.lC = newC

    fig = plt.figure(figsize=(6,4))
    ax = fig.add_subplot(1,1,1)

    ax.set_title(s.name)
    ax.set_xlabel("DM")
    ax.set_ylabel("P(DM)")
    ax.set_xlim(xmax=7000)

    rates=g.rates
    dmvals=g.dmvals
    pdm=np.sum(rates,axis=0)

    ax.plot(dmvals, pdm)

    expected=it.CalculateIntegral(g,s)
    expected *= 10**g.state.FRBdemo.lC
    observed=s.NORM_FRB

    print(s.name + " - expected, observed: " + str(expected) + ", " + str(observed))

In [None]:
uDMGs = 0.5

fig = plt.figure(figsize=(10,40))

for j,(s,g) in enumerate(zip(surveys, grids)):
    ax = fig.add_subplot(len(surveys),1,j+1)
    plt.title(s.name)
    ax.set_xlabel('DM')
    ax.set_ylabel('Weight')

    rates=g.rates
    dmvals=g.dmvals
    zvals=g.zvals
    DMobs=s.DMEGs

    dm_weights, iweights = it.calc_DMG_weights(DMobs, s.DMGs, uDMGs, dmvals)
    for i in range(len(DMobs)):
        ax.plot(dmvals[iweights[i]], dm_weights[i], '.-', label=s.frbs["TNS"][i] + " " + str(s.DMGs[i]))

    ax.legend()
    