In [None]:
import matplotlib.pyplot as plt
import numpy as np
import warnings
from scipy.stats import norm
warnings.filterwarnings("ignore", category=SyntaxWarning)
warnings.filterwarnings("ignore", category=UserWarning)

twopi = 2*np.pi

In [None]:
#===========#
# Functions #
#===========#

# Signal 

def sine_gaussian(time, t0, tau, freq, phase, amp):

    sine     = np.sin((time-t0) *twopi*freq + phase)
    norm     = 1/(np.sqrt(twopi)*tau)
    gaussian = norm * np.exp(-0.5*((time-t0)/tau)**2) * amp

    return sine * gaussian

# Noise 

def generate_white_noise(mu, sig, size):

    noise = np.random.normal(mu, sig, size)

    return noise

def normal_distribution(x, mu, sig):

    norm     = 1/(np.sqrt(twopi)*sig)
    exponent = -0.5 * ((x - mu)/sig)**2

    return norm * np.exp(exponent)

# Inference

def log_normal_distribution(x, mu, sig):

    log_norm     = -0.5 * np.log(twopi) - np.log(sig)
    log_exponent = -0.5 * ((x - mu)/sig)**2

    return log_norm + log_exponent

def log_prior_single(param_name, params, prior_bounds):

    param_value = params[param_name]

    if prior_bounds[param_name][0] <= param_value <= prior_bounds[param_name][1]: return -np.log(prior_bounds[param_name][1]-prior_bounds[param_name][0])
    else                                                                        : return -np.inf
    
def log_likelihood_single(time_axis, data, params, noise_parameters):

    log_likelihood = 0 

    for i in range(len(time_axis)):

        residuals_i = data[i] - sine_gaussian(time_axis[i], params['t0'], params['tau'], params['freq'], params['phase'], params['amp'])
        
        log_likelihood_i = log_normal_distribution(residuals_i, noise_parameters['mu'], noise_parameters['sigma'])
        log_likelihood  += log_likelihood_i
        
    return log_likelihood

def log_posterior_single(time_axis, data, params, param_to_sample, prior_bounds, noise_parameters):

    log_prior      = log_prior_single(param_to_sample, params, prior_bounds)
    log_likelihood = log_likelihood_single(time_axis, data, params, noise_parameters)

    return  log_prior + log_likelihood

def init_plotting():

    """
    Function to set the default plotting parameters.
    
    Parameters:
    None
    
    Returns:
    Nothing, but sets the default plotting parameters.
    """
    plt.rcParams['figure.max_open_warning'] = 0

    plt.rcParams['mathtext.fontset']  = 'stix'
    plt.rcParams['font.family']       = 'STIXGeneral'
    plt.rcParams['text.usetex']            = True

    plt.rcParams['font.size']         = 18
    plt.rcParams['axes.linewidth']    = 1
    plt.rcParams['axes.labelsize']    = plt.rcParams['font.size']
    plt.rcParams['axes.titlesize']    = 1.5*plt.rcParams['font.size']
    plt.rcParams['legend.fontsize']   = plt.rcParams['font.size']*0.9
    plt.rcParams['xtick.labelsize']   = plt.rcParams['font.size']
    plt.rcParams['ytick.labelsize']   = plt.rcParams['font.size']
    plt.rcParams['xtick.major.size']  = 3
    plt.rcParams['xtick.minor.size']  = 3
    plt.rcParams['xtick.major.width'] = 1
    plt.rcParams['xtick.minor.width'] = 1
    plt.rcParams['ytick.major.size']  = 3
    plt.rcParams['ytick.minor.size']  = 3
    plt.rcParams['ytick.major.width'] = 1
    plt.rcParams['ytick.minor.width'] = 1

    plt.rcParams['legend.frameon']             = False
    plt.rcParams['legend.loc']                 = 'center left'
    plt.rcParams['contour.negative_linestyle'] = 'solid'

    plt.gca().spines['right'].set_color('none')
    plt.gca().spines['top'].set_color('none')
    plt.gca().xaxis.set_ticks_position('bottom')
    plt.gca().yaxis.set_ticks_position('left')
    
    return 1

def label_map(param):

    labels = {
        't0'   : r'$t_0$ [s]',
        'tau'  : r'$\tau$ [s]',
        'freq' : r'$f$ [Hz]',
        'phase': r'$\phi$ [rad]',
        'amp'  : r'$A$',
    }
    
    return labels.get(param, param)

def label_map_unitless(param):

    labels = {
        't0'   : r'$t_0$',
        'tau'  : r'$\tau$',
        'freq' : r'$f$',
        'phase': r'$\phi$',
        'amp'  : r'$A$',
    }
    
    return labels.get(param, param)

def compute_point_estimates(samples, param_to_sample):

    median = np.median(samples)
    high   = np.percentile(samples, 95) - median
    low    = median - np.percentile(samples, 5)

    print('\nPoint estimates:\n')
    print(f'{param_to_sample} = {median:2f} + {high:2f} - {low:2f}')

    return

import numpy as np

def simple_sigma_clip_burnin(chain: np.ndarray, k: float = 3.0) -> int:
    """
    Compute the median and std of the whole chain, then
    return the index of the first sample whose value lies
    within ±k * std of the median. All earlier samples are
    treated as burn-in.

    Parameters
    ----------
    chain : np.ndarray
        The full MCMC chain (1D).
    k : float
        Sigma threshold (default 3).

    Returns
    -------
    burn_in : int
        Number of initial samples to discard.
    """
    med   = np.median(chain)
    sigma = np.std(chain, ddof=1)
    mask  = np.abs(chain - med) <= k * sigma # find first index where |x - median| <= k*sigma

    if not np.any(mask): return len(chain) // 2  # fallback if nothing passes
    burn_in = np.argmax(mask) # first True index

    print(f"\n\nComputed burn-in at index: {burn_in} (median={med:.4f}, sigma={sigma:.4f}, k={k})")

    return int(burn_in)

def plot_chain_and_posterior(chain, param_to_sample, injected_parameters, title="Chain pre burn-in", params_analytic=None):

    low, high = np.percentile(chain, [5, 95])
    median    = np.median(chain)

    fig, (ax_post, ax_trace) = plt.subplots(2, 1, figsize=(8.5, 6), sharex=False, gridspec_kw={"height_ratios": [2, 1]})
    fig.suptitle(title, fontsize=20)

    _, bins, _ = ax_post.hist(chain, bins=30, density=True, color='royalblue', edgecolor='black', alpha=0.7)

    # Analytic prediction for linear parameters
    if(params_analytic is not None and param_to_sample=='amp'):
        xplot = np.linspace(min(bins[0], params_analytic['mu_hat'] - 5*params_analytic['sigma_hat']), max(bins[-1], params_analytic['mu_hat'] + 5*params_analytic['sigma_hat']), 400)
        ax_post.plot(xplot, norm.pdf(xplot, loc=params_analytic['mu_hat'], scale=params_analytic['sigma_hat']), lw=2, color='darkorange', linestyle='dashed', label="Analytic")
    ax_post.axvline(injected_parameters[param_to_sample], color='firebrick', label='True value')
    ax_post.axvline(median, color='black', linestyle='dashed', label='Median')
    ax_post.axvline(low, color='black', linestyle='dashed', label='$90 \%$ CI')
    ax_post.axvline(high, color='black', linestyle='dashed'                    )
    ax_post.set_xlabel(f'{label_map(param_to_sample)}')
    ax_post.set_ylabel(f'p({label_map_unitless(param_to_sample)} $|$ d)')
    ax_post.legend(loc='best')

    # increase space between subplots
    fig.subplots_adjust(hspace=0.35)

    ax_trace.plot(chain, c='firebrick')
    ax_trace.set_xlabel('Iteration')
    ax_trace.set_ylabel(f'{label_map_unitless(param_to_sample)}')
    ax_trace.legend(loc='best')

init_plotting()

In [None]:
# User inputs: all the numbers go here. Below has no arbitrary numbers floating around.

# Run parameters
param_to_sample = 'amp' # Parameter to sample with MCMC. Options: ['t0', 'tau', 'freq', 'phase', 'amp'].
zero_noise      = False  # If True, no noise is added to the signal
noise_seed      = 42    # Seed for noise generation (if zero_noise is False)

# Data generation parameters
sampling_rate   = 50 # Hz
duration        = 2  # seconds

# Sampler parameters
N_iterations    = 20000 # Number of MCMC iterations

sigma_proposals = {
    't0'   : 0.2, # seconds
    'amp'  : 0.2, # [AU]
}

injected_parameters = {
    't0'   : 0.0, # seconds
    'tau'  : 0.2, # seconds
    'freq' : 5.0, # Hz
    'phase': 0.0, # radians
    'amp'  : 1.0, # [AU]
}

prior_bounds = { 
    't0'   : [  -1.0,   1.0], # seconds
    'tau'  : [   0.1,   0.3], # seconds
    'freq' : [   1.0,  10.0], # Hz
    'phase': [-np.pi, np.pi], # radians
    'amp'  : [   0.0,  10.0], # [AU]
}

noise_parameters = {
    'mu'   : 0,   # [AU]
    'sigma': 0.2, # [AU]
}

In [None]:
# np.random.seed(noise_seed)

dt        = 1/sampling_rate
time_axis = np.arange(-duration/2, duration/2, dt)
noise     = generate_white_noise(noise_parameters['mu'], noise_parameters['sigma'], len(time_axis))
signal    = sine_gaussian(time_axis, injected_parameters['t0'], injected_parameters['tau'], injected_parameters['freq'], injected_parameters['phase'], injected_parameters['amp'])

data                      = signal.copy()
if not(zero_noise): data += noise

In [None]:
# plot data
fig, ax = plt.subplots(figsize=(10, 5))

ax.plot(   time_axis, signal, c='firebrick', label='signal', linewidth=2.5, zorder = -1)
ax.plot(   time_axis, data,   c='royalblue', label='data'  , linewidth=1.5, zorder = -1)
ax.scatter(time_axis, data,   c='royalblue', label='data'  ,                       s=25)
ax.set_xlabel('Time [s]')
ax.set_ylabel('Data')
ax.legend(loc='upper right')

plt.show()

In [None]:
# Metropolis-Hastings

def mcmc_chain(n_iterations, param_to_sample='amp'):

    theta_prev                       = np.random.uniform(prior_bounds[param_to_sample][0], prior_bounds[param_to_sample][1])
    params_previous                  = injected_parameters.copy()
    params_previous[param_to_sample] = theta_prev
    log_L_prev                       = log_posterior_single(time_axis, data, params_previous, param_to_sample, prior_bounds, noise_parameters)
    # log_L_prev                       = log_likelihood_single(time_axis, data, params_previous, noise_parameters)

    params_prop                      = injected_parameters.copy()

    samples = []
    samples.append(theta_prev)

    n_accepted = 0
    sigma_proposal = sigma_proposals[param_to_sample]

    for i in range(1, n_iterations):

        theta_prop = theta_prev + np.random.normal(0, sigma_proposal)
        u          = np.random.uniform(0, 1)

        params_prop[param_to_sample] = theta_prop
        log_L_prop                   = log_posterior_single(time_axis, data, params_prop, param_to_sample, prior_bounds, noise_parameters)
        # log_L_prop                   = log_likelihood_single(time_axis, data, params_prop, noise_parameters)
        aprob = log_L_prop-log_L_prev
        alpha = min([1, aprob])
        if np.log(u) <= alpha: # compare logs
            theta_prev, log_L_prev = theta_prop, log_L_prop
            n_accepted += 1

        samples.append(theta_prev)

    samples          = np.array(samples)

    acceptance_ratio = n_accepted/(n_iterations)

    print(f'\n\nAcceptance ratio: {acceptance_ratio:.4f}')

    return samples

In [None]:
chain   = mcmc_chain(N_iterations, param_to_sample=param_to_sample)

In [None]:
burn_in = simple_sigma_clip_burnin(chain, k=1.5)
samples = chain[burn_in:]

In [None]:
"""
-------------------------------------------------------------------------
Derivation of the analytic posterior for a parameter mu with a flat prior
-------------------------------------------------------------------------

Model:
    y = mu * phi + eps,     eps ~ N(0, sigma_n^2 I)
where phi is the known signal template (e.g. sine-Gaussian shape).

We assume a flat (improper) prior p(mu) ∝ 1.

Likelihood (up to constants in mu):
    log p(y | mu) = - (1 / (2 sigma_n^2)) * || y - mu * phi ||^2

Expand the quadratic term:
    || y - mu * phi ||^2 = y^T y - 2 mu (phi^T y) + mu^2 (phi^T phi)

Then:
    log p(y | mu) = - (1 / (2 sigma_n^2)) [ mu^2 (phi^T phi) - 2 mu (phi^T y) ] + const

Define:
    A = (phi^T phi) / sigma_n^2
    B = (phi^T y) / sigma_n^2

=> log posterior (up to const) = -0.5 * [ A mu^2 - 2 B mu ]

Complete the square in mu:
    A mu^2 - 2 B mu = A (mu - B/A)^2 - B^2 / A

Hence:
    p(mu | y) ∝ exp[ -0.5 A (mu - B/A)^2 ]
              = N( mu ; mu_hat, sigma_mu^2 )

where
    mu_hat = B / A = (phi^T y) / (phi^T phi)
    sigma_mu^2 = 1 / A = sigma_n^2 / (phi^T phi)

Thus:
    mu | y  ~  N( (phi^T y)/(phi^T phi),  sigma_n^2 / (phi^T phi) )

This is the posterior under a flat prior — equivalent to the
ordinary least-squares estimator with variance determined by
the noise and template norm.

----------------------------------------------------------------------
"""

# Analytic posterior for amp with flat prior

phi       = sine_gaussian(time_axis, injected_parameters['t0'], injected_parameters['tau'], injected_parameters['freq'], injected_parameters['phase'], 1.0)
phi_T_phi = np.sum(phi**2)
phi_T_y   = np.sum(phi * data)  
mu_hat    = phi_T_y / phi_T_phi
sigma_hat = noise_parameters['sigma'] / np.sqrt(phi_T_phi)

plot_chain_and_posterior(  chain, param_to_sample, injected_parameters, title="Chain pre burn-in"                                                               )
plot_chain_and_posterior(samples, param_to_sample, injected_parameters, title="Samples post burn-in", params_analytic={'mu_hat': mu_hat, 'sigma_hat': sigma_hat})
compute_point_estimates( samples, param_to_sample                                                                                                               )

In [None]:
print('do this better as in pyRing')

# Waveform reconstruction, including error bands

reconstructed_signal_median = sine_gaussian(time_axis, injected_parameters['t0'], injected_parameters['tau'], injected_parameters['freq'], injected_parameters['phase'], np.median(samples))
reconstructed_signal_low    = sine_gaussian(time_axis, injected_parameters['t0'], injected_parameters['tau'], injected_parameters['freq'], injected_parameters['phase'], np.percentile(samples, 5  ))
reconstructed_signal_high   = sine_gaussian(time_axis, injected_parameters['t0'], injected_parameters['tau'], injected_parameters['freq'], injected_parameters['phase'], np.percentile(samples, 95 ))

# Signal to noise ratio

snr = np.sqrt( np.sum( reconstructed_signal_median**2 ) ) / noise_parameters['sigma']
print(f'\nSignal to noise ratio of the reconstructed signal: SNR = {snr:.2f}\n')

fig, ax = plt.subplots(figsize=(10, 5))
ax.plot(   time_axis, signal, c='firebrick', label='Injected signal', linewidth=2.5, zorder = -1)
ax.scatter(time_axis, data,   c='lightgray', label='Data'           , s=15      )
ax.plot(   time_axis, reconstructed_signal_median, c='royalblue', label='Reconstructed signal (median)', linewidth=2.5)
ax.fill_between(time_axis, reconstructed_signal_low, reconstructed_signal_high, color='royalblue', alpha=0.3, label='$90\%$ credible interval')
ax.set_xlabel('Time [s]')
ax.set_ylabel('Signal')
ax.legend(loc='upper right', fontsize=12)
plt.show()

# Compute residuals

residuals = data - reconstructed_signal_median  
fig, ax = plt.subplots(figsize=(10, 5))
ax.scatter(time_axis, residuals, c='royalblue', label='Residuals', s=15)
# show 2-sigma noise lines
ax.axhline(   noise_parameters['sigma'], color='black', linestyle='dashed', lw=0.8, label=r'$\sigma_{n}$')
ax.axhline(  -noise_parameters['sigma'], color='black', linestyle='dashed', lw=0.8)
ax.axhline( 2*noise_parameters['sigma'], color='black', linestyle='dashed', lw=0.8)
ax.axhline(-2*noise_parameters['sigma'], color='black', linestyle='dashed', lw=0.8)
ax.axhline(0, color='black', linestyle='dashed')
ax.set_xlabel('Time [s]')
ax.set_ylabel('Residuals')
ax.legend(loc='upper right')
plt.show()