In [2]:
import numpy as np 
import scipy.stats

In [5]:
# Define a reasonable mean vector and covariance matrix
mean = [0, 0, 0, 0]

covariance = np.array([
    [1, 0.5, 0.2, 0.1],
    [0.5, 1, 0.3, 0.2],
    [0.2, 0.3, 1, 0.4],
    [0.1, 0.2, 0.4, 1]
])  # Example covariance matrix

In [6]:
samples_main = scipy.stats.multivariate_normal(
    mean, covariance, allow_singular=True
).rvs(size=10000)

In [8]:
num_amplitude_params = 4
samples_re = samples_main[:, :num_amplitude_params:2]  
samples_im = samples_main[:, 1:num_amplitude_params:2]  
sample_abs_amplitudes = np.sqrt(samples_re**2 + samples_im**2) 

In [None]:
log_samples = np.log(samples_main)
samples_weights = np.exp(-np.sum(log_samples, axis=1))

In [9]:
def weighted_quantile(values, quantiles, weights=None):
    values = np.array(values)
    quantiles = np.array(quantiles)
    if weights is None:
        weights = np.ones(values.shape[0])
    weights = np.array(weights)

    # Sort values and weights along the first axis
    sorter = np.argsort(values, axis=0)
    sorted_values = np.take_along_axis(values, sorter, axis=0)
    sorted_weights = np.take_along_axis(weights[:, None], sorter, axis=0)

    # Compute cumulative weights
    cumulative_weights = np.cumsum(sorted_weights, axis=0) - 0.5 * sorted_weights
    cumulative_weights /= cumulative_weights[-1, :]
    
    # Interpolate quantiles
    quantile_values = np.empty((len(quantiles), values.shape[1]))
    for i in range(values.shape[1]):
        quantile_values[:, i] = np.interp(quantiles, cumulative_weights[:, i], sorted_values[:, i])

    return quantile_values

In [None]:
quantiles = [0.1, 0.25, 0.5, 0.75, 0.9]
quantile_vals = weighted_quantile(sample_abs_amplitudes, quantiles, weights=samples_weights)
quantile_vals_test = np.quantile(sample_abs_amplitudes, quantiles, weights=samples_weights, axis=0)

In [11]:
print(quantile_vals)

[[0.42919941 0.43656083]
 [0.70855969 0.734279  ]
 [1.12931666 1.13741718]
 [1.64713044 1.64775776]
 [2.18303938 2.14233664]]


In [12]:
print(quantile_vals_test)

[[0.42931837 0.43665993]
 [0.7085715  0.73434275]
 [1.12936163 1.13742662]
 [1.64716357 1.6477604 ]
 [2.18308499 2.14234549]]
