In [None]:
import dill
import arviz as az
import pymc as pm
import matplotlib.pyplot as plt
import numpy as np

In [None]:
with open("HDP_models_traces.pkl", "rb") as f:
    models, traces, datasets, a_all, b_all = dill.load(f)

RANDOM_SEED = 42

In [None]:
def summary_HDP(trace, model, N_sources=1, k=1):
    """
    Summarize and visualize the HDP model results.

    :param trace: Trace of the model
    :param model: The model being summarized
    :param N_sources: Number of sources
    :param k: Number of components for the model
    :return: None
    """
    # Variable names to summarize
    var_names = ["α", "β", "σ", "gamma", "alpha0", "beta"]
    # Add π_norm for each source
    for s in range(N_sources):
        var_names.append(f"pi_{s}")
        var_names.append(f"pi_norm_{s}")
    
    # Print summary of the trace
    print(az.summary(trace, var_names=var_names))
    print("Trace summary completed.")

    # Plot trace and posterior distributions
    az.plot_trace(trace, var_names=["α", "β", "σ", "gamma", "alpha0", "beta"])
    plt.show()

    with model:
        posterior_predictive = pm.sample_posterior_predictive(trace, random_seed=RANDOM_SEED)
    
    az.plot_posterior(trace, var_names=["α", "β", "σ", "gamma", "alpha0", "beta"])
    plt.show()

    az.plot_ppc(posterior_predictive, num_pp_samples=1000)
    plt.show()

    # Bar plot of posterior expected mixture weights
    fig, ax = plt.subplots(figsize=(8, 6))
    plot_w = np.arange(k) + 1  # Component indices
    for s in range(N_sources):
        weights = trace.posterior[f"pi_norm_{s}"].mean(("chain", "draw"))
        ax.bar(
            plot_w - 0.5 + s * 0.1,  # Offset for sources
            weights,
            width=0.1, label=f"Source {s + 1}",
        )
    
    ax.set_xlim(0.5, k + 0.5)
    ax.set_xlabel("Component")
    ax.set_ylabel("Posterior expected mixture weight")
    ax.legend(title="Sources")
    plt.show()

In [None]:
for (ns, nc), (X, Y, proportions) in datasets.items():
    print(f"Running model for setting with {ns} sources and {nc} components.")
    print(f"True proportions for sources: {proportions}")
    print(f"True regression parameters (a, b): {a_all[nc - 2]}, {b_all[nc - 2]}")
    summary_HDP(traces[(ns, nc)], models[(ns, nc)], N_sources=ns, k=nc)