In [112]:
import pandas as pd
import numpy as np
import pymc as pm
import arviz as az
import scipy.stats as sp
import matplotlib.pyplot as plt
import json
import logging
import warnings
import os
import subprocess
import time

warnings.simplefilter("ignore", category=RuntimeWarning)
warnings.simplefilter("ignore", category=UserWarning)

logger = logging.getLogger("pymc")
logger.setLevel(logging.ERROR)


# Function to get the current git tag
def get_git_tag():
        try:
            tag = subprocess.check_output(["git", "describe", "--tags"], stderr=subprocess.DEVNULL).strip().decode()
            return tag
        except subprocess.CalledProcessError:
            return "No tag found"


def generate_iid_samples(mode_means = [1,2], num_samples=2000, weights=[1,1], std_of_modes=1, rng= None):
    """
    Generate IID samples from the bimodal posterior directly.
    """

    rng = rng or np.random.default_rng()

    # Normalize weights
    weights = np.array(weights) / np.sum(weights)

    # Ensure std_of_modes has the same length as means
    if len(std_of_modes) != len(mode_means):
        raise ValueError("Length of std_of_modes must match the number of modes (means).")

    # Choose which mode each sample belongs to based on weights
    chosen_modes = rng.choice(len(mode_means), size=num_samples, p=weights)
    
    # Map the chosen mode indices to corresponding means and standard deviations
    chosen_means = np.array(mode_means)[chosen_modes]
    chosen_stds = np.array(std_of_modes)[chosen_modes]
    
    # Generate samples using the corresponding standard deviations for each sample
    iid_samples = rng.normal(loc=chosen_means, scale=chosen_stds, size=num_samples)

    return iid_samples


def get_initvals(init_scheme, means, num_chains, rng=None):
    """Generates initialization values based on the chosen scheme."""

    middle_point = sum(means) / 2

    if init_scheme == "half_per_mode":
        # Half the chains start near the first mode, half near the second mode
        initvals = [
            {"mixed_normal": means[0]} for _ in range(num_chains // 2)  #"w": np.array([0.9, 0.1]), 
        ] + [
            {"mixed_normal": means[1]} for _ in range(num_chains // 2) #"w": np.array([0.1, 0.9])
        ]

    elif init_scheme == "all_in_middle":
        # All chains start in the middle between the two modes
        initvals = [{"w": np.array([0.5, 0.5]), "mixed_normal": middle_point} for _ in range(num_chains)]

    elif init_scheme == "random":
        # Chains are initialized randomly between the modes
        initvals = [
            {"w": rng.dirichlet(np.ones(2)), "mixed_normal": rng.uniform(means[0], means[1])}
            for _ in range(num_chains)
        ]

    elif init_scheme == "all_near_first_mode":
        # All chains start near the first mode
        initvals = [{"w": np.array([0.9, 0.1]), "mixed_normal": means[0]} for _ in range(num_chains)]

    elif init_scheme == "all_near_second_mode":
        # All chains start near the second mode
        initvals = [{"w": np.array([0.1, 0.9]), "mixed_normal": means[1]} for _ in range(num_chains)]

    else:
        raise ValueError(f"Unknown initialization scheme: {init_scheme}")

    return initvals


class PosteriorExample:
    """Base class for different posterior types."""
    
    def __init__(self):
        self.model = None  # Placeholder for the PyMC model
    
    def _define_posterior(self):
        """Subclasses should implement this method to define the posterior."""
        raise NotImplementedError("Subclasses must implement _define_posterior()")

    def run_sampling(self, sampler_name, num_samples=2000, tune=1000, num_chains=2, initvals=None, init_scheme=None, run_random_seed=None):
        """Runs MCMC sampling using the chosen sampler."""
        
        with self.model:

            # Define which sampler to use
            if sampler_name == "Metro":
                sampler = pm.Metropolis()
            elif sampler_name == "HMC":
                sampler = pm.NUTS()
            elif sampler_name == "DEMetro":
                sampler = pm.DEMetropolis()
            else:
                raise ValueError(f"Unknown sampler: {sampler_name}")

            if init_scheme != None:
                trace = pm.sample(num_samples, tune=tune, step=sampler, initvals=initvals, chains=num_chains, return_inferencedata=True, progressbar=False, random_seed=run_random_seed)
            else:
                trace = pm.sample(num_samples, tune=tune, step=sampler, chains=num_chains, return_inferencedata=True, progressbar=False, random_seed=run_random_seed)
        
        return trace


class BimodalPosterior(PosteriorExample):
    
    def __init__(self, mode_means=[-2, 2], std_of_modes=[1, 1], weights=[1, 1]):
        if len(mode_means) != len(std_of_modes):
            raise ValueError("Each mode must have a corresponding standard deviation.")
        super().__init__()
        self.model = self._define_posterior(mode_means, std_of_modes, weights)
    
    def _define_posterior(self, mode_means, std_of_modes, weights):        
        with pm.Model() as model:
            
            weights = np.array(weights) / np.sum(weights)  # Normalize weights

            # Create component distributions
            components = [pm.Normal.dist(mu, sigma) for mu, sigma in zip(mode_means, std_of_modes)]
            
            # Mixture distribution
            mixed_normal = pm.Mixture("mixed_normal", w=pm.math.constant(weights), comp_dists=components)

        return model


class CauchyPosterior(PosteriorExample):
    def __init__(self, loc=0, scale=1):
        super().__init__()
        self.model = self._define_posterior(loc, scale)

    def _define_posterior(self, loc, scale):
        with pm.Model() as model:
            
            cauchy = pm.Cauchy("cauchy", alpha=loc, beta=scale)
        
        return model


class BetaPosterior(PosteriorExample):
    def __init__(self, a=2, b=2):
        super().__init__()
        self.model = self._define_posterior(a, b)

    def _define_posterior(self, a, b):
        with pm.Model() as model:
            
            beta = pm.Beta("beta", alpha=a, beta=b)

        return model
    


# new generalized version

def run_experiment(
    posterior_type,
    config_descr,
    runs,
    varying_attribute, 
    varying_values,      
    num_samples,
    num_chains,
    init_scheme=None,
    base_random_seed=None,
    **posterior_kwargs
):
    print(f"\n===== Config {config_descr} started! =====\n")

    # Initialize random number generator
    rng = np.random.default_rng(base_random_seed)

    # === Select Posterior Type and Precompute IID Samples ===
    if posterior_type == "bimodal":
        required_keys = ["mode_means", "std_of_modes", "weights"]
        # Remove varying attribute from required keys
        required_keys = [k for k in required_keys if k != varying_attribute]  

        if not all(k in posterior_kwargs for k in required_keys):
            raise ValueError(f"Bimodal posterior requires {required_keys}")
        
        posterior_cls = BimodalPosterior
        iid_kwargs = {
            "mode_means": posterior_kwargs.get("mode_means", "varies"),
            "std_of_modes": posterior_kwargs.get("std_of_modes", "varies"),
            "weights": posterior_kwargs.get("weights", "varies")
        }

    elif posterior_type == "cauchy":
        required_keys = ["loc", "scale"]
        required_keys = [k for k in required_keys if k != varying_attribute]

        if not all(k in posterior_kwargs for k in required_keys):
            raise ValueError(f"Cauchy posterior requires {required_keys}")
        
        posterior_cls = CauchyPosterior
        iid_kwargs = {"loc": posterior_kwargs.get("loc", "varies"), "scale": posterior_kwargs.get("scale", "varies")}

    elif posterior_type == "beta":
        required_keys = ["a", "b"]
        required_keys = [k for k in required_keys if k != varying_attribute]

        if not all(k in posterior_kwargs for k in required_keys):
            raise ValueError(f"Beta posterior requires {required_keys}")
        
        posterior_cls = BetaPosterior
        iid_kwargs = {"a": posterior_kwargs.get("a", "varies"), "b": posterior_kwargs.get("b", "varies")}

    else:
        raise ValueError(f"Unknown posterior type: {posterior_type}")


    # Create configuration folder inside the experiment root
    config_folder = os.path.join(experiment_root_folder, f"{config_descr}_with_{runs}_runs")
    os.makedirs(config_folder)


    # Define folder for saving histograms
    iid_histogram_folder = os.path.join(config_folder, "iid_histograms")
    os.makedirs(iid_histogram_folder)

    # === Handle Precomputed IID Samples for Varying Attributes ===
    iid_samples_dict = {}

    if varying_attribute in iid_kwargs or varying_attribute == "num_samples":
        # If num_samples or a posterior parameter varies, generate IID samples for each value
        for value in varying_values:
            if varying_attribute == "num_samples":
                current_num_samples = value  # Update num_samples dynamically
            else:
                iid_kwargs[varying_attribute] = value  # Adjust the varying posterior parameter
                current_num_samples = num_samples  # Use fixed num_samples if not varying
                
            # Generate IID samples for each varying value
            if posterior_type == "bimodal":
                iid_samples_dict[value] = generate_iid_samples(num_samples=current_num_samples, rng=rng, **iid_kwargs)
            elif posterior_type == "cauchy":
                iid_samples_dict[value] = sp.cauchy.rvs(**iid_kwargs, size=current_num_samples, random_state=rng)
            elif posterior_type == "beta":
                iid_samples_dict[value] = sp.beta.rvs(**iid_kwargs, size=current_num_samples, random_state=rng)

            # Define extreme percentile cutoffs
            lower_percentile = 1
            upper_percentile = 99

            if posterior_type == "cauchy":
                # Trim extreme percentiles for better visualization
                lower_bound, upper_bound = np.percentile(iid_samples_dict[value], [lower_percentile, upper_percentile])
                filtered_samples = iid_samples_dict[value][(iid_samples_dict[value] >= lower_bound) & (iid_samples_dict[value] <= upper_bound)]

            else:
                filtered_samples = iid_samples_dict[value]


            plt.figure(figsize=(8, 6))
            plt.hist(filtered_samples, bins=50, alpha=0.75, density=True, color='blue', edgecolor='black')
            plt.title(f"IID Samples Histogram ({varying_attribute}={value})")
            plt.xlabel("Sample Value")
            plt.ylabel("Density")
            plt.grid(True)
            histogram_filename = os.path.join(iid_histogram_folder, f"iid_hist_{varying_attribute}_{value}.pdf")
            plt.savefig(histogram_filename, bbox_inches="tight")
            plt.close()

                
    else:

        # If the posterior is fixed (var_attr == chain or var== init_scheme), generate IID samples once
        if posterior_type == "bimodal":
            iid_samples = generate_iid_samples(num_samples=num_samples, rng=rng, **iid_kwargs)
        elif posterior_type == "cauchy":
            iid_samples = sp.cauchy.rvs(**iid_kwargs, size=num_samples, random_state=rng)
        elif posterior_type == "beta":
            iid_samples = sp.beta.rvs(**iid_kwargs, size=num_samples, random_state=rng)

        
        # Apply trimming only for Cauchy distribution
        if posterior_type == "cauchy":
            lower_bound, upper_bound = np.percentile(iid_samples, [lower_percentile, upper_percentile])
            filtered_samples = iid_samples[(iid_samples >= lower_bound) & (iid_samples <= upper_bound)]
        else:
            filtered_samples = iid_samples

        
        # Save a single histogram
        plt.figure(figsize=(8, 6))
        plt.hist(filtered_samples, bins=50, alpha=0.75, density=True, color='blue', edgecolor='black')
        plt.title(f"IID Samples Histogram (Fixed Posterior)")
        plt.xlabel("Sample Value")
        plt.ylabel("Density")
        plt.grid(True)
        histogram_filename = os.path.join(iid_histogram_folder, "iid_hist_fixed.pdf")
        plt.savefig(histogram_filename, bbox_inches="tight")
        plt.close()
        

    # === Experiment Setup ===
    samples_per_chain = "varies" if varying_attribute in ["num_samples", "num_chains"] else num_samples // num_chains

    experiment_metadata = {
        "config_descr": config_descr,
        "runs": runs,
        "posterior_type": posterior_type,
        "varying_attribute": varying_attribute,
        "varying_values": varying_values,
        "num_samples": num_samples,
        "num_chains": num_chains,
        "samples_per_chain": samples_per_chain,
        "init_scheme": init_scheme,
        "base_random_seed": base_random_seed,
        "git_tag": get_git_tag(),
    }
    experiment_metadata.update(posterior_kwargs)  # Add posterior-specific parameters

    # Save metadata
    metadata_filename = os.path.join(config_folder, f"metadata_config_{config_descr}.json")
    with open(metadata_filename, "w") as f:
        json.dump(experiment_metadata, f, indent=4)

    # Define fixed colors for each sampler
    sampler_colors = {
        "Metro": "blue",
        "HMC": "red",
        "DEMetro": "green"
    }

    # === Run the Experiment ===
    for run_id in range(1, runs + 1):
        print(f"\n===== Running {config_descr} - Run {run_id} =====\n")

        run_random_seed = int(rng.integers(1_000_000))


        run_folder = os.path.join(config_folder, f"run_{run_id}")
        results_folder = os.path.join(run_folder, "results")
        traces_folder = os.path.join(run_folder, "traces_and_trace_plots")
        plots_folder = os.path.join(run_folder, "plots_of_run")
        

        os.makedirs(run_folder)
        os.makedirs(results_folder)
        os.makedirs(traces_folder)
        os.makedirs(plots_folder)

        results = []


        for value in varying_values:

            var_attr_folder = os.path.join(traces_folder, f"{varying_attribute}_{value}")
            os.makedirs(var_attr_folder)

            # Ensure `posterior_kwargs` is not modified globally
            current_posterior_kwargs = posterior_kwargs.copy()  #  Create a copy for this iteration
            if varying_attribute in current_posterior_kwargs:
                current_posterior_kwargs[varying_attribute] = value  #  Modify only the copy

            
            if varying_attribute == "num_samples":
                num_samples = value
                samples_per_chain = num_samples // num_chains
            elif varying_attribute == "num_chains":
                num_chains = value
                samples_per_chain = num_samples // num_chains
            elif varying_attribute == "init_scheme":
                init_scheme = value
    

            model = posterior_cls(**current_posterior_kwargs)  #  Use the loop-specific copy


            # Generate initialization values
            #if init_scheme != None:
            #    initvals = get_initvals(init_scheme, current_posterior_kwargs.get("mode_means", [0]), num_chains, rng)
            #initvals = get_initvals(init_scheme, posterior_kwargs.get("means", [0]), num_chains, rng) if use_init_scheme else None


            # Get IID samples for the current varying value
            if varying_attribute != "init_scheme" and varying_attribute != "num_chains":
                iid_samples = iid_samples_dict[value] 


            # Run sampling for all samplers
            for sampler_name in ["Metro", "HMC", "DEMetro"]:
                print(f"Running {sampler_name} with {varying_attribute} = {value}")


                # **Measure Computation Time**
                start_time = time.time()
                trace = model.run_sampling(
                    sampler_name, num_samples=samples_per_chain, num_chains=num_chains,
                    init_scheme = init_scheme, run_random_seed=run_random_seed
                )
                end_time = time.time()
                runtime = end_time - start_time

                # Save trace to NetCDF file
                trace_filename = os.path.join(var_attr_folder, f"{sampler_name}_trace.nc")
                az.to_netcdf(trace, trace_filename)

                # Save trace plot
                trace_plot_filename = os.path.join(var_attr_folder, f"{sampler_name}_trace_plot.pdf")
                az.plot_trace(trace, compact=True)
                plt.savefig(trace_plot_filename, bbox_inches="tight")
                plt.close()

                

                # Select correct posterior variable name
                if posterior_type == "bimodal":
                    posterior_var_name = "mixed_normal"
                elif posterior_type == "cauchy":
                    posterior_var_name = "cauchy"
                elif posterior_type == "beta":
                    posterior_var_name = "beta"
                else:
                    raise ValueError("Unknown posterior type for diagnostics.")

                # Compute Wasserstein distance
                ws_distance = sp.wasserstein_distance(trace.posterior[posterior_var_name].values.flatten(), iid_samples)

                # Compute R-hat and ESS
                r_hat = az.rhat(trace)[posterior_var_name].item()
                ess = az.ess(trace)[posterior_var_name].item()


                results.append({
                    varying_attribute: value,
                    "sampler": sampler_name,
                    "wasserstein_distance": ws_distance,
                    "r_hat": r_hat,
                    "ess": ess,
                    "runtime": runtime
                })

        # Convert results to DataFrame and save
        df_results = pd.DataFrame(results)

        # Handle tuple-based attributes consistently
        if isinstance(df_results[varying_attribute].iloc[0], tuple):
            if varying_attribute == "mode_means":
                df_results["mode_distance"] = df_results[varying_attribute].apply(lambda x: abs(x[1] - x[0]))
                varying_attribute_for_plot = "mode_distance"
            else:
                df_results[varying_attribute] = df_results[varying_attribute].apply(str)
                varying_attribute_for_plot = varying_attribute
        else:
            varying_attribute_for_plot = varying_attribute

        # Sort the DataFrame by the final chosen attribute
        df_results = df_results.sort_values(varying_attribute_for_plot, ascending=True)



        # initialize plots for all samplers
        fig_ws, ax_ws = plt.subplots(figsize=(10, 6))
        fig_rhat, ax_rhat = plt.subplots(figsize=(10, 6))
        fig_ess, ax_ess = plt.subplots(figsize=(10, 6))
        fig_time, ax_time = plt.subplots(figsize=(10, 6))


        for sampler in df_results["sampler"].unique():
            df_sampler = df_results[df_results["sampler"] == sampler]
            csv_filename = os.path.join(results_folder, f"{sampler}_results.csv")
            df_sampler.to_csv(csv_filename, index=False)

  
            # Plot Wasserstein Distance
            ax_ws.plot(df_sampler[varying_attribute_for_plot], df_sampler["wasserstein_distance"], 
                marker="o", linestyle="-", label=sampler, 
                color=sampler_colors.get(sampler, "black"))

            # Plot R-hat values
            ax_rhat.plot(df_sampler[varying_attribute_for_plot], df_sampler["r_hat"], 
                    marker="o", linestyle="-", label=sampler, 
                    color=sampler_colors.get(sampler, "black"))
            
            # Plot ESS values
            ax_ess.plot(df_sampler[varying_attribute_for_plot], df_sampler["ess"], 
                    marker="o", linestyle="-", label=sampler, 
                    color=sampler_colors.get(sampler, "black"))

            # **Plot Computation Time**
            ax_time.plot(df_sampler[varying_attribute_for_plot], df_sampler["runtime"], 
                        marker="o", linestyle="-", label=sampler, 
                        color=sampler_colors.get(sampler, "black"))
            
        
        # Set dynamic axis labels and titles
        attribute_label = "Mode Distance" if varying_attribute == "mode_means" else varying_attribute.replace("_", " ").title()
        
        # ===== Finalize and Save Wasserstein Plot =====
        ax_ws.set_xlabel(attribute_label)
        ax_ws.set_ylabel("Wasserstein Distance")
        ax_ws.set_title(f"Wasserstein Distance for Samplers (config =_{config_descr})")
        ax_ws.legend(title="Sampler")
        ax_ws.grid(True)
        plot_filename = os.path.join(plots_folder, f"Wasserstein_run_{run_id}.pdf")
        fig_ws.savefig(plot_filename, bbox_inches="tight")
        plt.close(fig_ws)

        # ===== Finalize and Save R-hat Plot =====
        ax_rhat.set_xlabel(attribute_label)
        ax_rhat.set_ylabel("R-hat")
        ax_rhat.set_title(f"R-hat for Samplers (config =_{config_descr})")
        ax_rhat.legend(title="Sampler")
        ax_rhat.grid(True)
        rhat_plot_filename = os.path.join(plots_folder, f"R-hat_run_{run_id}.pdf")
        fig_rhat.savefig(rhat_plot_filename, bbox_inches="tight")
        plt.close(fig_rhat)

        # ===== Finalize and Save ESS Plot =====
        ax_ess.set_xlabel(attribute_label)
        ax_ess.set_ylabel("ESS")
        ax_ess.set_title(f"ESS for Samplers (config =_{config_descr})")
        ax_ess.legend(title="Sampler")
        ax_ess.grid(True)
        ess_plot_filename = os.path.join(plots_folder, f"ESS_run_{run_id}.pdf")
        fig_ess.savefig(ess_plot_filename, bbox_inches="tight")
        plt.close(fig_ess)

        # ===== Finalize and Save Time Plot =====
        ax_time.set_xlabel(attribute_label)
        ax_time.set_ylabel("Computation Time (seconds)")
        ax_time.set_title(f"Computation Time for Samplers (config =_{config_descr})")
        ax_time.legend(title="Sampler")
        ax_time.grid(True)
        time_plot_filename = os.path.join(plots_folder, f"ComputationTime_run_{run_id}.pdf")
        fig_time.savefig(time_plot_filename, bbox_inches="tight")
        plt.close(fig_time)


    print("\n===== All Runs Completed Successfully! =====\n")

    # ===== GLOBAL RESULTS FOLDER =====
    global_folder = os.path.join(config_folder, "global_results")
    global_results_folder = os.path.join(global_folder, "results")
    global_plots_folder = os.path.join(global_folder, "plots")

    os.makedirs(global_folder)
    os.makedirs(global_results_folder)
    os.makedirs(global_plots_folder)

    # Collect all results from all runs
    df_all_runs = []

    for run_id in range(1, runs + 1):
        run_folder = os.path.join(config_folder, f"run_{run_id}")
        results_folder = os.path.join(run_folder, "results")

        for sampler in ["Metro", "HMC", "DEMetro"]:
            csv_filename = os.path.join(results_folder, f"{sampler}_results.csv")
            df_run = pd.read_csv(csv_filename)
            df_run["run_id"] = run_id 
            df_run["sampler"] = sampler  
            df_all_runs.append(df_run)


    # Combine all results into a single data frame 
    df_all_runs = pd.concat(df_all_runs, ignore_index=True)

    if varying_attribute == "mode_means":
        df_all_runs["mode_distance"] = df_all_runs[varying_attribute].apply(lambda x: abs(eval(x)[1] - eval(x)[0]))
        df_all_runs = df_all_runs.sort_values("mode_distance", ascending=True)
        varying_attribute_for_global_plot = "mode_distance"
    else:
        df_all_runs = df_all_runs.sort_values(varying_attribute, ascending=True)
        varying_attribute_for_global_plot = varying_attribute


    # Initialize global plots
    fig_ws, ax_ws = plt.subplots(figsize=(10, 6))
    fig_rhat, ax_rhat = plt.subplots(figsize=(10, 6))
    fig_ess, ax_ess = plt.subplots(figsize=(10, 6))
    fig_time, ax_time = plt.subplots(figsize=(10, 6)) 

    for sampler in ["Metro", "HMC", "DEMetro"]:
        df_sampler = df_all_runs[df_all_runs["sampler"] == sampler]

        # Pivot tables: rows = varying attribute, columns = run_id, values = metrics
        df_ws = df_sampler.pivot_table(index=varying_attribute_for_global_plot, columns="run_id", values="wasserstein_distance")
        df_rhat = df_sampler.pivot_table(index=varying_attribute_for_global_plot, columns="run_id", values="r_hat")
        df_ess = df_sampler.pivot_table(index=varying_attribute_for_global_plot, columns="run_id", values="ess")
        df_time = df_sampler.pivot_table(index=varying_attribute, columns="run_id", values="runtime")
    
        # Compute mean and standard deviation for error bars
        ws_mean, ws_std = df_ws.mean(axis=1), df_ws.std(axis=1)
        rhat_mean, rhat_std = df_rhat.mean(axis=1), df_rhat.std(axis=1)
        ess_mean, ess_std = df_ess.mean(axis=1), df_ess.std(axis=1)
        time_mean, time_std = df_time.mean(axis=1), df_time.std(axis=1)

        color = sampler_colors.get(sampler, "black")

        # Plot with error bars
        ax_ws.errorbar(ws_mean.index, ws_mean, yerr=ws_std, fmt="o-", label=sampler, color=color, capsize=5)
        ax_rhat.errorbar(rhat_mean.index, rhat_mean, yerr=rhat_std, fmt="o-", label=sampler, color=color, capsize=5)
        ax_ess.errorbar(ess_mean.index, ess_mean, yerr=ess_std, fmt="o-", label=sampler, color=color, capsize=5)
        ax_time.errorbar(time_mean.index, time_mean, yerr=time_std, fmt="o-", label=sampler, color=color, capsize=5)

        # Save global averages 
        df_global_avg = pd.DataFrame({
            varying_attribute: ws_mean.index,
            "global_avg_ws": ws_mean.values,
            "global_avg_ws_std": ws_std.values,
            "global_avg_rhat": rhat_mean.values,
            "global_avg_rhat_std": rhat_std.values,
            "global_avg_ess": ess_mean.values,
            "global_avg_ess_std": ess_std.values,
            "global_avg_time": time_mean.values,
            "global_avg_time_std": time_std.values
        })

        sampler_csv_filename = os.path.join(global_results_folder, f"Global_results_{sampler}.csv")
        df_global_avg.to_csv(sampler_csv_filename, index=False)

    # ===== Save Global Wasserstein Plot =====
    ax_ws.set_xlabel(attribute_label)
    ax_ws.set_ylabel("Average Wasserstein Distance")
    ax_ws.set_title(f"Averaged Wasserstein Distance ({runs} Runs, config = {config_descr})")
    ax_ws.legend(title="Sampler")
    ax_ws.grid(True)
    fig_ws.savefig(os.path.join(global_plots_folder, "Wasserstein_global_plot.pdf"), bbox_inches="tight")
    plt.close(fig_ws)

    # ===== Save Global R-hat Plot =====
    ax_rhat.set_xlabel(attribute_label)
    ax_rhat.set_ylabel("Average R-hat")
    ax_rhat.set_title(f"Averaged R-hat Values ({runs} Runs, config = {config_descr})")
    ax_rhat.legend(title="Sampler")
    ax_rhat.grid(True)
    fig_rhat.savefig(os.path.join(global_plots_folder, "Rhat_global_plot.pdf"), bbox_inches="tight")
    plt.close(fig_rhat)

    # ===== Save Global ESS Plot =====
    ax_ess.set_xlabel(attribute_label)
    ax_ess.set_ylabel("Average ESS")
    ax_ess.set_title(f"Averaged ESS ({runs} Runs,  config = {config_descr})")
    ax_ess.legend(title="Sampler")
    ax_ess.grid(True)
    fig_ess.savefig(os.path.join(global_plots_folder, "ESS_global_plot.pdf"), bbox_inches="tight")
    plt.close(fig_ess)

    # ===== Save Global Time Plot =====
    ax_time.set_xlabel(attribute_label)
    ax_time.set_ylabel("Average Computation Time (seconds)")
    ax_time.set_title(f"Averaged Computation Time ({runs} Runs, config = {config_descr})")
    ax_time.legend(title="Sampler")
    ax_time.grid(True)
    fig_time.savefig(os.path.join(global_plots_folder, "Time_global_plot.pdf"), bbox_inches="tight")
    plt.close(fig_time)

    print(f"\n===== Config {config_descr} Completed Successfully! =====\n")



In [114]:

# List of all possible attributes in a config
REQUIRED_ATTRIBUTES = {
    "mode_means",
    "std_of_modes",
    "weights",
    "num_samples",
    "num_chains",
    "use_init_scheme",
}


# needs to be adapetd for new posterior types
def validate_config(config):
    """Checks if the config correctly defines one varying attribute and all other attributes are fixed."""
    
    # Ensure "varying_attribute" is specified
    if "varying_attribute" not in config:
        raise ValueError(f"Config '{config.get('config_descr', 'Unnamed')}' is missing 'varying_attribute'.")

    varying_attr = config["varying_attribute"]

    # Ensure "varying_values" is present
    if "varying_values" not in config:
        raise ValueError(f"Config '{config['config_descr']}' is missing 'varying_values' for '{varying_attr}'.")

    # Ensure varying_attribute is a recognized attribute
    if varying_attr not in REQUIRED_ATTRIBUTES and varying_attr!= "init_scheme":
        raise ValueError(f"Config '{config['config_descr']}' has an invalid 'varying_attribute': '{varying_attr}'.")

    # Validate that `init_scheme` is present only if `use_init_scheme` is True
    if "use_init_scheme" in config:
        if config["use_init_scheme"] and "init_scheme" not in config and varying_attr != "init_scheme":
            raise ValueError(f"Config '{config['config_descr']}' is missing 'init_scheme' but 'use_init_scheme' is set to True.")
        elif not config["use_init_scheme"] and "init_scheme" in config:
            raise ValueError(f"Config '{config['config_descr']}' defines 'init_scheme' but 'use_init_scheme' is False.")


    # Check that all other required attributes are fixed (not missing)
    for attr in REQUIRED_ATTRIBUTES:
        if attr == varying_attr:
            # The varying attribute should not have a fixed value
            if attr in config:
                raise ValueError(f"Config '{config['config_descr']}' incorrectly defines '{attr}' as both fixed and varying.")
        else:
            # All other attributes must be explicitly defined
            if attr not in config:
                raise ValueError(f"Config '{config['config_descr']}' is missing required fixed attribute '{attr}'.")



# posterior_type = "bimodal", "cauchy", "beta"
# varying_attribute = "num_samples", "num_chains", "init_scheme" or posterior specific attribute
# bimmodal specific attributes = "mode_means", "std_of_modes", "weights"
# cauchy specific attributes = "loc", "scale"
# beta specific attributes = "a", "b"
# all but the varying attribute must be fixed and present in the config



Testcases_all_attr = [
    {
        "config_descr": "Weights_test",
        "posterior_type": "bimodal",
        "runs": 2,
        "varying_attribute": "weights",    
        "varying_values": [(1,1), (1,2), (5,1)],
        "mode_means": (3,14),
        "num_samples": 100,
        "num_chains": 4,
        "std_of_modes": (1,1),
        "use_init_scheme": False,
        "base_random_seed": 42
    },
    {
        "config_descr": "Mode_Means_test",
        "posterior_type": "bimodal",
        "runs": 2,
        "varying_attribute": "mode_means",    
        "varying_values": [(0,3), (0,10), (0, 20)],
        "std_of_modes": (1,1),
        "weights": (1,1),
        "num_samples": 100,
        "num_chains": 4,
        "use_init_scheme": False,
        "base_random_seed": 42
    },

    {
        "config_descr": "Std_of_Modes_test",
        "posterior_type": "bimodal",
        "runs": 2,
        "varying_attribute": "std_of_modes",    
        "varying_values": [(1,1),(1,4),(2,1)],
        "mode_means": (3,-3),
        "weights": (1,1),
        "num_samples": 100,
        "num_chains": 4,
        "use_init_scheme": False,
        "base_random_seed": 42
    },

    {
        "config_descr": "Samples_test",
        "posterior_type": "bimodal",
        "runs": 2,
        "varying_attribute": "num_samples",    
        "varying_values": [100, 200, 300],
        "mode_means": (3,-3),
        "std_of_modes": (1,1),
        "weights": (1,1),
        "num_chains": 4,
        "use_init_scheme": False,
        "base_random_seed": 42
    },
    
    {
        "config_descr": "Chains_test",
        "posterior_type": "bimodal",
        "runs": 2,
        "varying_attribute": "num_chains",    
        "varying_values": [4,6,8],
        "mode_means": (3,-3),
        "std_of_modes": (1,1),
        "weights": (1,1),
        "num_samples": 100,
        "use_init_scheme": False,
        "base_random_seed": 42
    }
]



Test_bimodal_posterior = [
    {
        "config_descr": "Bimodal_test",
        "posterior_type": "bimodal",
        "runs": 2,
        "varying_attribute": "mode_means",
        "varying_values": [(3,8), (4,20)],
        "std_of_modes": (1,1),
        "weights": (1,1),
        "num_samples": 100,
        "num_chains": 4,
        "use_init_scheme": False,
        "base_random_seed": 42
    }
]

Test_new_posterior_types = [
    {
        "config_descr": "Cauchy_test",
        "posterior_type": "cauchy",
        "runs": 2,
        "varying_attribute": "loc",
        "varying_values": [0, -2, 5],
        "num_chains": 4,
        "scale": 1,
        "num_samples": 10000,
        #"num_chains": 4,
        "base_random_seed": 42
    },
    {
        "config_descr": "Beta_test",
        "posterior_type": "beta",
        "runs": 2,
        "varying_attribute": "a",
        "varying_values": [0.5,5,1],
        "num_chains": 4,
        "b": 3,
        "num_samples": 10000,
        #"num_chains": 4,
        "base_random_seed": 42
    }
]


In [110]:
# Choose the experiment to run
experiment = Testcases_all_attr
experiment_name = "all_attr_but_init_scheme"

# Define the root directory for all experiments
experiment_root_folder = f"experiment_{experiment_name}"
os.makedirs(experiment_root_folder)

#for config in experiment:
#    validate_config(config)

#print("All configurations are valid. Starting experiments...")

for config in experiment:
    run_experiment(
    posterior_type=config["posterior_type"],
    config_descr=config["config_descr"],
    runs=config["runs"],
    varying_attribute=config["varying_attribute"],
    varying_values=config["varying_values"],
    init_scheme="varies" if config["varying_attribute"] == "init_scheme" else config.get("init_scheme", None),
    num_samples="varies" if config["varying_attribute"] == "num_samples" else config["num_samples"],
    num_chains="varies" if config["varying_attribute"] == "num_chains" else config["num_chains"],
    base_random_seed=config["base_random_seed"],
    **{k: v for k, v in config.items() if k not in [
        "config_descr", "runs", "varying_attribute", "varying_values", 
        "num_samples", "num_chains", "init_scheme", 
        "base_random_seed", "posterior_type"
    ]}  # Pass remaining keys as posterior_kwargs
)

print("All experiments completed successfully!")



===== Config Weights_test started! =====


===== Running Weights_test - Run 1 =====

Running Metro with weights = (1, 1)
Running HMC with weights = (1, 1)
Running DEMetro with weights = (1, 1)
Running Metro with weights = (1, 2)
Running HMC with weights = (1, 2)
Running DEMetro with weights = (1, 2)
Running Metro with weights = (5, 1)
Running HMC with weights = (5, 1)
Running DEMetro with weights = (5, 1)

===== Running Weights_test - Run 2 =====

Running Metro with weights = (1, 1)
Running HMC with weights = (1, 1)
Running DEMetro with weights = (1, 1)
Running Metro with weights = (1, 2)
Running HMC with weights = (1, 2)
Running DEMetro with weights = (1, 2)
Running Metro with weights = (5, 1)
Running HMC with weights = (5, 1)
Running DEMetro with weights = (5, 1)

===== All Runs Completed Successfully! =====


===== Config Weights_test Completed Successfully! =====


===== Config Mode_Means_test started! =====


===== Running Mode_Means_test - Run 1 =====

Running Metro with mode_me

There were 1 divergences after tuning. Increase `target_accept` or reparameterize.


Running DEMetro with num_samples = 100
Running Metro with num_samples = 200
Running HMC with num_samples = 200


There were 3 divergences after tuning. Increase `target_accept` or reparameterize.


Running DEMetro with num_samples = 200
Running Metro with num_samples = 300
Running HMC with num_samples = 300


There were 3 divergences after tuning. Increase `target_accept` or reparameterize.


Running DEMetro with num_samples = 300

===== Running Samples_test - Run 2 =====

Running Metro with num_samples = 100
Running HMC with num_samples = 100
Running DEMetro with num_samples = 100
Running Metro with num_samples = 200
Running HMC with num_samples = 200
Running DEMetro with num_samples = 200
Running Metro with num_samples = 300
Running HMC with num_samples = 300
Running DEMetro with num_samples = 300

===== All Runs Completed Successfully! =====


===== Config Samples_test Completed Successfully! =====


===== Config Chains_test started! =====


===== Running Chains_test - Run 1 =====

Running Metro with num_chains = 4
Running HMC with num_chains = 4
Running DEMetro with num_chains = 4
Running Metro with num_chains = 6
Running HMC with num_chains = 6
Running DEMetro with num_chains = 6
Running Metro with num_chains = 8
Running HMC with num_chains = 8
Running DEMetro with num_chains = 8

===== Running Chains_test - Run 2 =====

Running Metro with num_chains = 4
Running HMC wi

There were 4 divergences after tuning. Increase `target_accept` or reparameterize.


Running DEMetro with num_chains = 4
Running Metro with num_chains = 6
Running HMC with num_chains = 6


There were 5 divergences after tuning. Increase `target_accept` or reparameterize.


Running DEMetro with num_chains = 6
Running Metro with num_chains = 8
Running HMC with num_chains = 8


There were 4 divergences after tuning. Increase `target_accept` or reparameterize.


Running DEMetro with num_chains = 8

===== All Runs Completed Successfully! =====


===== Config Chains_test Completed Successfully! =====

All experiments completed successfully!
