In [2]:
import numpy as np
import os

In [13]:
import numpy as np
TRUE_C = 0.5
TRUE_S = 5e-5

# Format for distance = [alpha, beta, wasserstein, energy, mmd, cvmd, kullback-leibler]

def abc_posterior_data(nparams: int, distances: np.ndarray, distance_quantile: float, distance_metric: str) -> np.ndarray:
    if distance_metric == "Wasserstein Distance":
        index = 3
    elif distance_metric == "Cramer-von Mises Distance":
        index = 4
    elif distance_metric == "Frechet Distance":
        index = 5
    elif distance_metric == "Hausdorff Distance":
        index = 6

    # Calculate quantile from given quantile
    threshold = np.nanquantile(distances[:,index], distance_quantile)
    ## Identify Alpha and Beta after filtering
    posterior_params = distances[distances[:,index] <= threshold][:,0:nparams]

    return posterior_params

def abc_posterior(nparams: int, distances: np.ndarray, distance_quantile: float, distance_metric: str) -> np.ndarray:

    # Calculate posterior mean, median, lower bound and upper bound for each metric
    posterior_mean = np.zeros(nparams)
    posterior_median = np.zeros(nparams)
    posterior_lower_bound = np.zeros(nparams)
    posterior_upper_bound = np.zeros(nparams)
    posterior_std = np.zeros(nparams)
    posterior_sqerr = np.zeros(nparams)

    posterior_params = abc_posterior_data(nparams, distances, distance_quantile, distance_metric)
    
    for i in range(nparams):
        posterior_mean[i] = np.mean(posterior_params[:,i])
        posterior_median[i] = np.nanquantile(posterior_params[:,i], 0.5)
        posterior_lower_bound[i] = np.nanquantile(posterior_params[:,i], 0.025)
        posterior_upper_bound[i] = np.nanquantile(posterior_params[:,i], 0.975)
        posterior_std[i] = np.std(posterior_params[:,i])
        if i == 2: # Because the structure goes [cx, cy, s, ...]
            posterior_sqerr[i] = (TRUE_S - posterior_median[i])**2
        else:
            posterior_sqerr[i] = (TRUE_C - posterior_median[i])**2
    
    posterior = np.array([posterior_mean, posterior_median, posterior_std, posterior_lower_bound, posterior_upper_bound, posterior_sqerr])
    posterior = posterior.T

    # Format is [[alpha posterior], [beta posterior]]
    return posterior

In [17]:
RUN_DIR = "runs"
SAVE_DIR = "results"
NUM_RUNS = 2
NPARAMS = 3 # cx, cy, s
DISTANCE_METRIC = ["Wasserstein Distance", "Cramer-von Mises Distance", "Frechet Distance", "Hausdorff Distance"]
QUANTILES = [0.05, 0.01, 0.001] # 5%, 1%, 0.1%

models = os.listdir(RUN_DIR)

# For each model
for model in models:
    
    # We generate result for each distance metric
    for metric in DISTANCE_METRIC:
        metric_path = os.path.join(SAVE_DIR, metric) # e.g. results/"Wasserstein Distance"
        
        # If the path doesn't exist 
        if not os.path.isdir(metric_path):
            os.mkdir(metric_path)
        
        # Separate results are needed for each quantile for threshold analysis.
        for quantile in QUANTILES: 
            posterior = np.zeros((NUM_RUNS, NPARAMS, 6)) # 6 - Median, Mean, Lower & Upper Bound, StDev, RMSE
            
            # Analyse results from each run
            for i in range(NUM_RUNS):
                run_path = os.path.join(RUN_DIR, model, f"run{i+1}.npy")
                run_data = np.load(run_path)
                posterior[i] = abc_posterior(NPARAMS, run_data, quantile, metric)
                
            posterior_path = os.path.join(metric_path, f"{quantile}posterior.npy")
            np.save(posterior_path, posterior)

In [20]:
np.load("results/Frechet Distance/0.05posterior.npy")

array([[[4.97319459e-01, 4.98929079e-01, 6.86247084e-02, 3.74124847e-01,
         6.39818086e-01, 1.14687159e-06],
        [5.02849360e-01, 4.93020679e-01, 6.07210709e-02, 3.98082143e-01,
         6.19896817e-01, 4.87109221e-05],
        [4.62062442e-04, 4.23059674e-04, 2.58715112e-04, 2.63754158e-05,
         9.37241015e-04, 1.39173520e-07]],

       [[4.90513752e-01, 4.87802399e-01, 6.18235671e-02, 3.86961759e-01,
         5.86677531e-01, 1.48781467e-04],
        [4.66328091e-01, 4.62379362e-01, 5.28740220e-02, 3.85116133e-01,
         5.64957578e-01, 1.41531240e-03],
        [5.34078153e-04, 5.72427732e-04, 2.72185561e-04, 5.69897229e-05,
         9.69241892e-04, 2.72930735e-07]]])