In [19]:
import pandas as pd
import numpy as np
import pymc as pm
import arviz as az
import scipy.stats as sp
import matplotlib.pyplot as plt
import json
import logging
import warnings
import os
import subprocess
import time

warnings.simplefilter("ignore", category=RuntimeWarning)
warnings.simplefilter("ignore", category=UserWarning)

logger = logging.getLogger("pymc")
logger.setLevel(logging.ERROR)


# Function to get the current git tag
def get_git_tag():
        try:
            tag = subprocess.check_output(["git", "describe", "--tags"], stderr=subprocess.DEVNULL).strip().decode()
            return tag
        except subprocess.CalledProcessError:
            return "No tag found"


def generate_iid_samples(means = [1,2], num_samples=2000, weights=[1,1], std_of_modes=1, rng= None):
    """
    Generate IID samples from the bimodal posterior directly.
    """

    rng = rng or np.random.default_rng()

    # Normalize weights
    weights = np.array(weights) / np.sum(weights)

    # Ensure std_of_modes has the same length as means
    if len(std_of_modes) != len(means):
        raise ValueError("Length of std_of_modes must match the number of modes (means).")

    # Choose which mode each sample belongs to based on weights
    chosen_modes = rng.choice(len(means), size=num_samples, p=weights)
    
    # Map the chosen mode indices to corresponding means and standard deviations
    chosen_means = np.array(means)[chosen_modes]
    chosen_stds = np.array(std_of_modes)[chosen_modes]
    
    # Generate samples using the corresponding standard deviations for each sample
    iid_samples = rng.normal(loc=chosen_means, scale=chosen_stds, size=num_samples)

    return iid_samples


def get_initvals(init_scheme, means, num_chains, rng=None):
    """Generates initialization values based on the chosen scheme."""

    middle_point = sum(means) / 2

    if init_scheme == "half_per_mode":
        # Half the chains start near the first mode, half near the second mode
        initvals = [
            {"mixed_normal": means[0]} for _ in range(num_chains // 2)  #"w": np.array([0.9, 0.1]), 
        ] + [
            {"mixed_normal": means[1]} for _ in range(num_chains // 2) #"w": np.array([0.1, 0.9])
        ]

    elif init_scheme == "all_in_middle":
        # All chains start in the middle between the two modes
        initvals = [{"w": np.array([0.5, 0.5]), "mixed_normal": middle_point} for _ in range(num_chains)]

    elif init_scheme == "random":
        # Chains are initialized randomly between the modes
        initvals = [
            {"w": rng.dirichlet(np.ones(2)), "mixed_normal": rng.uniform(means[0], means[1])}
            for _ in range(num_chains)
        ]

    elif init_scheme == "all_near_first_mode":
        # All chains start near the first mode
        initvals = [{"w": np.array([0.9, 0.1]), "mixed_normal": means[0]} for _ in range(num_chains)]

    elif init_scheme == "all_near_second_mode":
        # All chains start near the second mode
        initvals = [{"w": np.array([0.1, 0.9]), "mixed_normal": means[1]} for _ in range(num_chains)]

    else:
        raise ValueError(f"Unknown initialization scheme: {init_scheme}")

    return initvals



class BimodalPosteriorExample:
    
    def __init__(self, means, std_of_modes=1, weights=[1,1]):
        if len(means) != len(std_of_modes):
            raise ValueError("Each mode must have a corresponding standard deviation.")
        self.model = self._bimodal_posterior(weights, means, std_of_modes)
    
    def _bimodal_posterior(self, weights, means, std_of_modes):        
        with pm.Model() as model:
            
            # Mixture weights
            #weights = pm.Dirichlet("w", a=np.array(weights))
            
            # Normalize weights
            weights = np.array(weights) / np.sum(weights)

            weights = np.array(weights)

            # Create component distributions with different std per mode
            mu = [pm.Normal.dist(mu, sigma) for mu, sigma in zip(means, std_of_modes)]
            
            # Mixture distribution
            mixed_normal = pm.Mixture("mixed_normal", w=weights, comp_dists=mu)

        return model

    def run_sampling(self, 
                     sampler_name, 
                     num_samples=2000, 
                     tune=1000, 
                     num_chains=2, 
                     initvals=None,
                     use_init_scheme=False,
                     run_random_seed= None
                     ):
    

        with self.model:

            # Define which sampler to use
            if sampler_name == "Metro":
                sampler = pm.Metropolis()
            elif sampler_name == "HMC":
                sampler = pm.NUTS()
            elif sampler_name == "DEMetro":
                sampler = pm.DEMetropolis()
            else:
                raise ValueError(f"Unknown sampler: {sampler_name}")

            if use_init_scheme:
                trace = pm.sample(num_samples, tune=tune, step=sampler, initvals= initvals,chains=num_chains, return_inferencedata=True, progressbar=False, random_seed= run_random_seed)
            else:
                trace = pm.sample(num_samples, tune=tune, step=sampler, chains=num_chains, return_inferencedata=True, progressbar=False, random_seed= run_random_seed)
        
        return trace
    

def run_experiment(
    config_descr,
    runs,
    varying_attribute, 
    varying_values,      
    mode_means,
    std_of_modes,
    weights,
    num_samples,
    num_chains,
    use_init_scheme=False,
    init_scheme= None ,
    base_random_seed=None,
    
):

    print(f"\n===== Config {config_descr} started! =====\n")

    # Initialize random number generator
    rng = np.random.default_rng(base_random_seed) 

    # === Precompute IID samples ===
    if varying_attribute == "init_scheme" or varying_attribute == "num_chains":
        # Posterior is fixed; generate samples once
        iid_samples = generate_iid_samples(
            num_samples=num_samples, means=mode_means, weights=weights, std_of_modes=std_of_modes, rng=rng
        )
    else:
        # Generate IID samples for each varying value (posterior changes)
        iid_samples_dict = {}
        for value in varying_values:
            if varying_attribute == "mode_means":
                means = value
                iid_samples_dict[value] = generate_iid_samples(num_samples=num_samples, means=means, weights=weights, std_of_modes=std_of_modes, rng=rng)
            elif varying_attribute == "std_of_modes":
                iid_samples_dict[value] = generate_iid_samples(num_samples=num_samples, means=mode_means, weights=weights, std_of_modes=value, rng=rng)
            elif varying_attribute == "weights":
                iid_samples_dict[value] = generate_iid_samples(num_samples=num_samples, means=mode_means, weights=value, std_of_modes=std_of_modes, rng=rng)
            elif varying_attribute == "num_samples":
                iid_num_samples = value 
                iid_samples_dict[value] = generate_iid_samples(num_samples= iid_num_samples, means=mode_means, weights=weights, std_of_modes=std_of_modes, rng=rng)
            else:
                raise ValueError(f"Unknown varying attribute: {varying_attribute}")



        # === Experiment Setup ===
    if varying_attribute == "num_samples":
        samples_per_chain = "varies"  # Indicate that it changes per run
    elif varying_attribute == "num_chains":
        samples_per_chain = "varies"  # Indicate that it changes per run
    else:
        samples_per_chain = num_samples // num_chains  # Fixed value

    # Create configuration folder inside the experiment root
    config_folder = os.path.join(experiment_root_folder, f"{config_descr}_with_{runs}_runs")
    os.makedirs(config_folder)

    experiment_metadata = {
        "config_descr": config_descr,
        "runs": runs,
        "varying_attribute": varying_attribute,
        "varying_values": varying_values,
        "mode_means": mode_means,
        "std_of_modes": std_of_modes,
        "num_samples": num_samples,
        "num_chains": num_chains,
        "samples_per_chain": samples_per_chain,
        "weights": weights,
        "init_scheme": init_scheme if use_init_scheme else "None",
        "chain_initialization": init_scheme,
        "base_random_seed": base_random_seed,
        "git_tag": get_git_tag(),
    }

     # Save metadata
    metadata_filename = os.path.join(config_folder, f"metadata_config_{config_descr}.json")
    with open(metadata_filename, "w") as f:
        json.dump(experiment_metadata, f, indent=4)


    # Define fixed colors for each sampler
    sampler_colors = {
        "Metro": "blue",
        "HMC": "red",
        "DEMetro": "green"
    }

    # === Run the Experiment ===
    for run_id in range(1, runs + 1):
        
        if varying_attribute == "num_chains":
            longest_chain = varying_values[-1]
            chain_seeds = rng.integers(0, 1_000_000, size=longest_chain)

        else:
            chain_seeds = rng.integers(0, 1_000_000, size=num_chains)

        print(f"\n===== Running {config_descr} - Run {run_id} =====\n")

        run_folder = os.path.join(config_folder, f"run_{run_id}")
        results_folder = os.path.join(run_folder, "results")
        traces_folder = os.path.join(run_folder, "traces")
        plots_folder = os.path.join(run_folder, "plots")

        os.makedirs(run_folder)
        os.makedirs(results_folder)
        os.makedirs(traces_folder)
        os.makedirs(plots_folder)

        results = []

        for value in varying_values:

            # adjust number of seeds if chain length varies
            if varying_attribute == "num_chains":
                num_chains = value
                samples_per_chain = num_samples // num_chains
                chain_seeds_used = chain_seeds[:num_chains]  # Slice fixed seeds per run
            else:
                chain_seeds_used = chain_seeds  # Use full fixed seeds when not varying num_chains

            # Create the model
            if varying_attribute == "mode_means":
                means = value
                model = BimodalPosteriorExample(means=means, std_of_modes=std_of_modes, weights=weights)
            elif varying_attribute == "std_of_modes":
                means = mode_means
                model = BimodalPosteriorExample(means=means, std_of_modes=value, weights=weights)
            elif varying_attribute == "weights":
                means = mode_means
                model = BimodalPosteriorExample(means=means, std_of_modes=std_of_modes, weights=value)
            elif varying_attribute == "init_scheme":
                means = mode_means
                init_scheme = value
                model = BimodalPosteriorExample(means=means, std_of_modes=std_of_modes, weights=weights)
            elif varying_attribute == "num_samples":
                means = mode_means
                num_samples = value
                samples_per_chain = num_samples // num_chains
                model = BimodalPosteriorExample(means=means, std_of_modes=std_of_modes, weights=weights)
            elif varying_attribute == "num_chains":
                means = mode_means
                model = BimodalPosteriorExample(means=means, std_of_modes=std_of_modes, weights=weights) 
            else:
                raise ValueError(f"Unknown varying attribute: {varying_attribute}")

            # Generate initialization values
            if use_init_scheme:
                initvals = get_initvals(init_scheme, means, num_chains, rng)
            else:
                initvals = None

            # Get IID samples for the current varying value
            if varying_attribute != "init_scheme" and varying_attribute != "num_chains":
                iid_samples = iid_samples_dict[value] 

            
            # Run sampling for all samplers
            for sampler_name in ["Metro", "HMC", "DEMetro"]:
                
                print(f"Running for sampler {sampler_name} with {varying_attribute} = {value}")

                trace = model.run_sampling(sampler_name, num_samples=samples_per_chain, num_chains=num_chains, initvals=initvals, use_init_scheme=use_init_scheme, run_random_seed=chain_seeds_used.tolist())
                
                # Save trace to NetCDF file
                trace_filename = os.path.join(traces_folder, f"{sampler_name}_trace.nc")
                az.to_netcdf(trace, trace_filename)

                # Compute Wasserstein Distance
                ws_distance = sp.wasserstein_distance(trace.posterior["mixed_normal"].values.flatten(), iid_samples)

                # Compute R-hat
                r_hat = az.rhat(trace)["mixed_normal"].item()

                # Computes ESS
                ess = az.ess(trace)["mixed_normal"].item()

                results.append({
                    varying_attribute: value,
                    "sampler": sampler_name,
                    "wasserstein_distance": ws_distance,
                    "r_hat": r_hat,
                    "ess": ess,
                    "chain_seeds": chain_seeds_used.tolist(),
                })


        # Convert results to DataFrame
        df_results = pd.DataFrame(results)

        # Convert tuples to strings for plotting if varying_attribute is 'std_of_modes' or 'mode_means'
        if varying_attribute in ["std_of_modes", "weights"]:
            df_results[varying_attribute] = df_results[varying_attribute].apply(lambda x: str(x))

        # Compute mode distance and sort by it if mode_means is varying
        if varying_attribute == "mode_means":
            df_results["mode_distance"] = df_results[varying_attribute].apply(lambda x: abs(x[1] - x[0]))
            df_results = df_results.sort_values("mode_distance", ascending=True)
            varying_attribute_for_plot = "mode_distance"  # Use this for plotting
        else:
            df_results = df_results.sort_values(varying_attribute, ascending=True)
            varying_attribute_for_plot = varying_attribute


        # initialize plots for all samplers
        fig_ws, ax_ws = plt.subplots(figsize=(10, 6))
        fig_rhat, ax_rhat = plt.subplots(figsize=(10, 6))
        fig_ess, ax_ess = plt.subplots(figsize=(10, 6))

        # Iterate over samplers and plot the metrics
        for sampler in df_results["sampler"].unique():
            
            df_sampler = df_results[df_results["sampler"] == sampler]

            # Save CSV files
            csv_filename = os.path.join(results_folder, f"{sampler}_results.csv")
            df_sampler.to_csv(csv_filename, index=False)


            # Plot Wasserstein Distance
            ax_ws.plot(df_sampler[varying_attribute_for_plot], df_sampler["wasserstein_distance"], 
                marker="o", linestyle="-", label=sampler, 
                color=sampler_colors.get(sampler, "black"))

            # Plot R-hat values
            ax_rhat.plot(df_sampler[varying_attribute_for_plot], df_sampler["r_hat"], 
                    marker="o", linestyle="-", label=sampler, 
                    color=sampler_colors.get(sampler, "black"))
            
            # Plot ESS values
            ax_ess.plot(df_sampler[varying_attribute_for_plot], df_sampler["ess"], 
                    marker="o", linestyle="-", label=sampler, 
                    color=sampler_colors.get(sampler, "black"))
            
        
        # Set dynamic axis labels and titles
        attribute_label = "Mode Distance" if varying_attribute == "mode_means" else varying_attribute.replace("_", " ").title()
        
        # ===== Finalize and Save Wasserstein Plot =====
        ax_ws.set_xlabel(attribute_label)
        ax_ws.set_ylabel("Wasserstein Distance")
        ax_ws.set_title(f"Wasserstein Distance for Samplers (config =_{config_descr})")
        ax_ws.legend(title="Sampler")
        ax_ws.grid(True)
        plot_filename = os.path.join(plots_folder, f"Wasserstein_run_{run_id}.pdf")
        fig_ws.savefig(plot_filename, bbox_inches="tight")
        plt.close(fig_ws)

        # ===== Finalize and Save R-hat Plot =====
        ax_rhat.set_xlabel(attribute_label)
        ax_rhat.set_ylabel("R-hat")
        ax_rhat.set_title(f"R-hat for Samplers (config =_{config_descr})")
        ax_rhat.legend(title="Sampler")
        ax_rhat.grid(True)
        rhat_plot_filename = os.path.join(plots_folder, f"R-hat_run_{run_id}.pdf")
        fig_rhat.savefig(rhat_plot_filename, bbox_inches="tight")
        plt.close(fig_rhat)

        # ===== Finalize and Save ESS Plot =====
        ax_ess.set_xlabel(attribute_label)
        ax_ess.set_ylabel("ESS")
        ax_ess.set_title(f"ESS for Samplers (config =_{config_descr})")
        ax_ess.legend(title="Sampler")
        ax_ess.grid(True)
        ess_plot_filename = os.path.join(plots_folder, f"ESS_run_{run_id}.pdf")
        fig_ess.savefig(ess_plot_filename, bbox_inches="tight")
        plt.close(fig_ess)


    print("\n===== All Runs Completed Successfully! =====\n")

    # ===== GLOBAL RESULTS FOLDER =====
    global_folder = os.path.join(config_folder, "global_results")
    global_results_folder = os.path.join(global_folder, "results")
    global_plots_folder = os.path.join(global_folder, "plots")

    os.makedirs(global_folder)
    os.makedirs(global_results_folder)
    os.makedirs(global_plots_folder)

    # Collect all results from all runs
    df_all_runs = []

    for run_id in range(1, runs + 1):
        run_folder = os.path.join(config_folder, f"run_{run_id}")
        results_folder = os.path.join(run_folder, "results")

        for sampler in ["Metro", "HMC", "DEMetro"]:
            csv_filename = os.path.join(results_folder, f"{sampler}_results.csv")
            df_run = pd.read_csv(csv_filename)
            df_run["run_id"] = run_id 
            df_run["sampler"] = sampler  
            df_all_runs.append(df_run)


    # Combine all results into a single data frame 
    df_all_runs = pd.concat(df_all_runs, ignore_index=True)

    if varying_attribute == "mode_means":
        df_all_runs["mode_distance"] = df_all_runs[varying_attribute].apply(lambda x: abs(eval(x)[1] - eval(x)[0]))
        df_all_runs = df_all_runs.sort_values("mode_distance", ascending=True)
        varying_attribute_for_global_plot = "mode_distance"
    else:
        df_all_runs = df_all_runs.sort_values(varying_attribute, ascending=True)
        varying_attribute_for_global_plot = varying_attribute


    # Initialize global plots
    fig_ws, ax_ws = plt.subplots(figsize=(10, 6))
    fig_rhat, ax_rhat = plt.subplots(figsize=(10, 6))
    fig_ess, ax_ess = plt.subplots(figsize=(10, 6))

    for sampler in ["Metro", "HMC", "DEMetro"]:
        df_sampler = df_all_runs[df_all_runs["sampler"] == sampler]

        # Pivot tables: rows = varying attribute, columns = run_id, values = metrics
        df_ws = df_sampler.pivot_table(index=varying_attribute_for_global_plot, columns="run_id", values="wasserstein_distance")
        df_rhat = df_sampler.pivot_table(index=varying_attribute_for_global_plot, columns="run_id", values="r_hat")
        df_ess = df_sampler.pivot_table(index=varying_attribute_for_global_plot, columns="run_id", values="ess")

        # Compute mean and standard deviation for error bars
        ws_mean, ws_std = df_ws.mean(axis=1), df_ws.std(axis=1)
        rhat_mean, rhat_std = df_rhat.mean(axis=1), df_rhat.std(axis=1)
        ess_mean, ess_std = df_ess.mean(axis=1), df_ess.std(axis=1)

        color = sampler_colors.get(sampler, "black")

        # Plot with error bars
        ax_ws.errorbar(ws_mean.index, ws_mean, yerr=ws_std, fmt="o-", label=sampler, color=color, capsize=5)
        ax_rhat.errorbar(rhat_mean.index, rhat_mean, yerr=rhat_std, fmt="o-", label=sampler, color=color, capsize=5)
        ax_ess.errorbar(ess_mean.index, ess_mean, yerr=ess_std, fmt="o-", label=sampler, color=color, capsize=5)

        # Save global averages 
        df_global_avg = pd.DataFrame({
            varying_attribute: ws_mean.index,
            "global_avg_ws": ws_mean.values,
            "global_avg_ws_std": ws_std.values,
            "global_avg_rhat": rhat_mean.values,
            "global_avg_rhat_std": rhat_std.values,
            "global_avg_ess": ess_mean.values,
            "global_avg_ess_std": ess_std.values,
        })

        sampler_csv_filename = os.path.join(global_results_folder, f"Global_results_{sampler}.csv")
        df_global_avg.to_csv(sampler_csv_filename, index=False)

    # ===== Save Global Wasserstein Plot =====
    ax_ws.set_xlabel(attribute_label)
    ax_ws.set_ylabel("Average Wasserstein Distance")
    ax_ws.set_title(f"Averaged Wasserstein Distance ({runs} Runs, config = {config_descr})")
    ax_ws.legend(title="Sampler")
    ax_ws.grid(True)
    fig_ws.savefig(os.path.join(global_plots_folder, "Wasserstein_global_plot.pdf"), bbox_inches="tight")
    plt.close(fig_ws)

    # ===== Save Global R-hat Plot =====
    ax_rhat.set_xlabel(attribute_label)
    ax_rhat.set_ylabel("Average R-hat")
    ax_rhat.set_title(f"Averaged R-hat Values ({runs} Runs, config = {config_descr})")
    ax_rhat.legend(title="Sampler")
    ax_rhat.grid(True)
    fig_rhat.savefig(os.path.join(global_plots_folder, "Rhat_global_plot.pdf"), bbox_inches="tight")
    plt.close(fig_rhat)

    # ===== Save Global ESS Plot =====
    ax_ess.set_xlabel(attribute_label)
    ax_ess.set_ylabel("Average ESS")
    ax_ess.set_title(f"Averaged ESS ({runs} Runs,  config = {config_descr})")
    ax_ess.legend(title="Sampler")
    ax_ess.grid(True)
    fig_ess.savefig(os.path.join(global_plots_folder, "ESS_global_plot.pdf"), bbox_inches="tight")
    plt.close(fig_ess)

    print(f"\n===== Config {config_descr} Completed Successfully! =====\n")

In [20]:

# List of all possible attributes in a config
REQUIRED_ATTRIBUTES = {
    "mode_means",
    "std_of_modes",
    "weights",
    "num_samples",
    "num_chains",
    "use_init_scheme",
}

def validate_config(config):
    """Checks if the config correctly defines one varying attribute and all other attributes are fixed."""
    
    # Ensure "varying_attribute" is specified
    if "varying_attribute" not in config:
        raise ValueError(f"Config '{config.get('config_descr', 'Unnamed')}' is missing 'varying_attribute'.")

    varying_attr = config["varying_attribute"]

    # Ensure "varying_values" is present
    if "varying_values" not in config:
        raise ValueError(f"Config '{config['config_descr']}' is missing 'varying_values' for '{varying_attr}'.")

    # Ensure varying_attribute is a recognized attribute
    if varying_attr not in REQUIRED_ATTRIBUTES and varying_attr!= "init_scheme":
        raise ValueError(f"Config '{config['config_descr']}' has an invalid 'varying_attribute': '{varying_attr}'.")

    # Validate that `init_scheme` is present only if `use_init_scheme` is True
    if "use_init_scheme" in config:
        if config["use_init_scheme"] and "init_scheme" not in config and varying_attr != "init_scheme":
            raise ValueError(f"Config '{config['config_descr']}' is missing 'init_scheme' but 'use_init_scheme' is set to True.")
        elif not config["use_init_scheme"] and "init_scheme" in config:
            raise ValueError(f"Config '{config['config_descr']}' defines 'init_scheme' but 'use_init_scheme' is False.")


    # Check that all other required attributes are fixed (not missing)
    for attr in REQUIRED_ATTRIBUTES:
        if attr == varying_attr:
            # The varying attribute should not have a fixed value
            if attr in config:
                raise ValueError(f"Config '{config['config_descr']}' incorrectly defines '{attr}' as both fixed and varying.")
        else:
            # All other attributes must be explicitly defined
            if attr not in config:
                raise ValueError(f"Config '{config['config_descr']}' is missing required fixed attribute '{attr}'.")




Rand_test = [
    {
        "config_descr": "Random_test",
        "runs": 2,
        "varying_attribute": "mode_means",
        "varying_values": [(3, 8), (4, 20)],
        "std_of_modes": (1,1),
        "weights": (1,1),
        "num_samples": 100,
        "num_chains": 4,
        "use_init_scheme": False,
        "base_random_seed": 42
    }   
]

Testcases_all_attr = [
    {
        "config_descr": "Weights_test",
        "runs": 2,
        "varying_attribute": "weights",    
        "varying_values": [(1,1), (1,2), (5,1)],
        "mode_means": (3,14),
        "num_samples": 100,
        "num_chains": 4,
        "std_of_modes": (1,1),
        "use_init_scheme": False,
        "base_random_seed": 42
    },
    {
        "config_descr": "Mode_Means_test",
        "runs": 2,
        "varying_attribute": "mode_means",    
        "varying_values": [(0,3), (0,10), (0, 20)],
        "std_of_modes": (1,1),
        "weights": (1,1),
        "num_samples": 100,
        "num_chains": 4,
        "use_init_scheme": False,
        "base_random_seed": 42
    },

    {
        "config_descr": "Std_of_Modes_test",
        "runs": 2,
        "varying_attribute": "std_of_modes",    
        "varying_values": [(1,1),(1,4),(2,1)],
        "mode_means": (3,-3),
        "weights": (1,1),
        "num_samples": 100,
        "num_chains": 4,
        "use_init_scheme": False,
        "base_random_seed": 42
    },

    {
        "config_descr": "Samples_test",
        "runs": 2,
        "varying_attribute": "num_samples",    
        "varying_values": [100, 200, 300],
        "mode_means": (3,-3),
        "std_of_modes": (1,1),
        "weights": (1,1),
        "num_chains": 4,
        "use_init_scheme": False,
        "base_random_seed": 42
    },
    
    {
        "config_descr": "Chains_test",
        "runs": 2,
        "varying_attribute": "num_chains",    
        "varying_values": [4,6,8],
        "mode_means": (3,-3),
        "std_of_modes": (1,1),
        "weights": (1,1),
        "num_samples": 100,
        "use_init_scheme": False,
        "base_random_seed": 42
    },

    {
        "config_descr": "Init_Scheme_test",
        "runs": 2,
        "varying_attribute": "init_scheme",    
        "varying_values": ["half_per_mode", "all_in_middle"],
        "mode_means": (3,-3),
        "std_of_modes": (1,1),
        "weights": (1,1),
        "num_samples": 100,
        "num_chains": 4,
        "use_init_scheme": True,
        "base_random_seed": 42
    }
]


Varying_std_of_modes = [

    {
        "config_descr": "Varying_std_of_modes",
        "runs": 1,
        "varying_attribute": "std_of_modes",
        "varying_values": [(1,1)],
        "mode_means": (3,-3),
        "weights": (1,1),
        "num_samples": 100,
        "num_chains": 4,
        "use_init_scheme": False,
        "base_random_seed": 42
    }
]

DEMetro_vs_others = [
    {
        "config_descr": "Optimal_DEMetro",
        "runs":10,
        "varying_attribute": "mode_means",    
        "varying_values": [(0,8),(0,10),(0,15),(0,20)],
        "num_samples": 40000,
        "num_chains": 4,
        "weights": (1,8),
        "std_of_modes": (1,1),
        "use_init_scheme": True,
        "init_scheme": "half_per_mode",
        "base_random_seed": 42
    },

    {
        "config_descr": "Fair_case",
        "runs": 10,
        "varying_attribute": "mode_means",    
        "varying_values": [(0,8),(0,10),(0,15),(0,20)],
        "num_samples": 40000,
        "num_chains": 4,
        "weights": (1,1),
        "std_of_modes": (1,1),
        "use_init_scheme": False,
        "base_random_seed": 42
    },

    {   
        "config_descr": "Only_unequal_weights",
        "runs": 10,
        "varying_attribute": "mode_means",    
        "varying_values": [(0,8),(0,10),(0,15),(0,20)],
        "num_samples": 40000,
        "num_chains": 4,
        "weights": (5,1),
        "std_of_modes": (1,1),
        "use_init_scheme": False,
        "base_random_seed": 42
    },

    {
        "config_descr": "Only_init_scheme",
        "runs": 10,
        "varying_attribute": "mode_means",    
        "varying_values": [(0,8),(0,10),(0,15),(0,20)],
        "num_samples": 40000,
        "num_chains": 4,
        "weights": (1,1),
        "std_of_modes": (1,1),
        "use_init_scheme": True,
        "init_scheme": "half_per_mode",
        "base_random_seed": 42
    },

    {
        "config_descr": "Init_all_in_middle",
        "runs": 10,
        "varying_attribute": "mode_means",    
        "varying_values": [(0,8),(0,10),(0,15),(0,20)],
        "num_samples": 40000,
        "num_chains": 4,
        "weights": (1,1),
        "std_of_modes": (1,1),
        "use_init_scheme": True,
        "init_scheme": "all_in_middle",
        "base_random_seed": 42
    }

]

DEMetro_vs_others_short = [
    {
        "config_descr": "Optimal_DEMetro",
        "runs": 5,
        "varying_attribute": "mode_means",    
        "varying_values": [(0,8),(0,10),(0,15),(0,20)],
        "num_samples": 4000,
        "num_chains": 4,
        "weights": (1,8),
        "std_of_modes": (1,1),
        "use_init_scheme": True,
        "init_scheme": "half_per_mode",
        "base_random_seed": 42
    },

    {
        "config_descr": "Fair_case",
        "runs": 5,
        "varying_attribute": "mode_means",    
        "varying_values": [(0,8),(0,10),(0,15),(0,20)],
        "num_samples": 4000,
        "num_chains": 4,
        "weights": (1,1),
        "std_of_modes": (1,1),
        "use_init_scheme": False,
        "base_random_seed": 42
    },
]

Test_fixed_weights = [
    {
        "config_descr": "Fixed_weights",
        "runs": 2,
        "varying_attribute": "weights",
        "varying_values": [(1,1), (3,1)],
        "mode_means": (3,-3),
        "std_of_modes": (1,1),
        "num_samples": 1000,
        "num_chains": 4,
        "use_init_scheme": True,
        "init_scheme": "half_per_mode",
        "base_random_seed": 42
    }
]


In [22]:
# Choose the experiment to run
experiment = Rand_test
experiment_name = "random_test_2"

# Define the root directory for all experiments
experiment_root_folder = f"experiment_{experiment_name}"
os.makedirs(experiment_root_folder)

for config in experiment:
    validate_config(config)

print("All configurations are valid. Starting experiments...")

for config in experiment:
    run_experiment(
    config_descr=config["config_descr"],
    runs=config["runs"],
    varying_attribute=config["varying_attribute"],
    varying_values=config["varying_values"],
    mode_means="varies" if config["varying_attribute"] == "mode_means" else config["mode_means"],
    std_of_modes="varies" if config["varying_attribute"] == "std_of_modes" else config["std_of_modes"],
    weights="varies" if config["varying_attribute"] == "weights" else config["weights"],
    use_init_scheme="varies" if config["varying_attribute"] == "use_init_scheme" else config["use_init_scheme"],
    init_scheme="varies" if config["varying_attribute"] == "init_scheme" else config.get("init_scheme", None),
    num_samples="varies" if config["varying_attribute"] == "num_samples" else config["num_samples"],
    num_chains="varies" if config["varying_attribute"] == "num_chains" else config["num_chains"],
    base_random_seed=config["base_random_seed"]
)

print("All experiments completed successfully!")


All configurations are valid. Starting experiments...

===== Config Random_test started! =====


===== Running Random_test - Run 1 =====

Running for sampler Metro with mode_means = (3, 8)
Running for sampler HMC with mode_means = (3, 8)
Running for sampler DEMetro with mode_means = (3, 8)
Running for sampler Metro with mode_means = (4, 20)
Running for sampler HMC with mode_means = (4, 20)
Running for sampler DEMetro with mode_means = (4, 20)

===== Running Random_test - Run 2 =====

Running for sampler Metro with mode_means = (3, 8)
Running for sampler HMC with mode_means = (3, 8)
Running for sampler DEMetro with mode_means = (3, 8)
Running for sampler Metro with mode_means = (4, 20)
Running for sampler HMC with mode_means = (4, 20)
Running for sampler DEMetro with mode_means = (4, 20)

===== All Runs Completed Successfully! =====


===== Config Random_test Completed Successfully! =====

All experiments completed successfully!
