In [1]:
# init the workdirectly to be able to import local python modules
# just like the case in the actual python command line
import sys
import os
notebook_dir = os.getcwd()
parent_dir = os.path.dirname(notebook_dir)
if parent_dir not in sys.path:
    sys.path.insert(0, parent_dir)

In [2]:
import arviz as az
import matplotlib.pyplot as plt
import numpy as np
from pathlib import Path
import pymc as pm
import pytensor as pt
from estimation.mixnorm.ksigma.model import tempered_normal_mixture
from estimation.mixnorm.ksigma.pymc_tools import progress_callback

In [3]:
%config InlineBackend.figure_format = 'retina'
# Initialize random number generator
RANDOM_SEED = 8927
rng = np.random.default_rng(RANDOM_SEED)
az.style.use("arviz-darkgrid")
print(f"Running on PyMC v{pm.__version__}")

Running on PyMC v5.23.0


In [4]:
# True parameter values
mu=0
sigma=1
# sample sizes for the different simulations
n_factors = np.array([10, 50, 100, 200, 250, 500, 1000])

In [5]:
pathstr = "../estimation/mixnorm/ksigma/data/observations/n{n}.csv"
x_data = np.loadtxt(Path(pathstr.format(n=1000)), delimiter=',')

In [6]:
model = tempered_normal_mixture(beta=1/np.log(len(x_data)), 
                                data=x_data, 
                                n_components=3,
                                weights_prior_alpha=np.full(3, 0.1),
                                mean_prior_cov=pt.tensor.eye(3)*2
                               )
with model:
    idata = pm.sample(draws=8000, tune=2000, chains=2,
                      callback=progress_callback,
                      progressbar=False,  # Disable the default progress bar.
                      cores=10, max_treedepth=50, target_accept=.995)

Initializing NUTS using jitter+adapt_diag...
Multiprocess sampling (2 chains in 2 jobs)
NUTS: [weights, mus, like]


Chain 1, completed 500 draws. Last draw in 0.009s with acceptance rate:0.996.
Chain 0, completed 500 draws. Last draw in 0.009s with acceptance rate:0.985.
Chain 1, completed 1000 draws. Last draw in 0.009s with acceptance rate:1.000.
Chain 0, completed 1000 draws. Last draw in 0.019s with acceptance rate:0.997.
Chain 1, completed 1500 draws. Last draw in 0.012s with acceptance rate:0.986.
Chain 0, completed 1500 draws. Last draw in 0.020s with acceptance rate:0.997.
Chain 1, completed 2000 draws. Last draw in 0.027s with acceptance rate:0.991.
Chain 0, completed 2000 draws. Last draw in 0.089s with acceptance rate:0.990.
Chain 1, completed 2500 draws. Last draw in 0.010s with acceptance rate:0.996.
Chain 0, completed 2500 draws. Last draw in 0.045s with acceptance rate:0.996.
Chain 1, completed 3000 draws. Last draw in 0.025s with acceptance rate:0.999.
Chain 0, completed 3000 draws. Last draw in 0.038s with acceptance rate:0.998.
Chain 1, completed 3500 draws. Last draw in 0.040s wit

Sampling 2 chains for 2_000 tune and 8_000 draw iterations (4_000 + 16_000 draws total) took 263 seconds.
We recommend running at least 4 chains for robust computation of convergence diagnostics


In [None]:
az.plot_trace(idata, combined=True);

In [None]:
az.plot_pair(idata, var_names=["weights", "mus"], kind="scatter")

In [None]:
az.plot_posterior(idata, var_names=["weights", "mus"])

In [None]:
az.summary(idata, var_names=["weights", "mus"], round_to=2)

In [None]:
az.plot_forest(idata, var_names=["weights", "mus"], combined=True, hdi_prob=0.95, r_hat=True);

In [None]:
az.plot_energy(idata);

In [None]:
fig = plt.figure(figsize=(10, 8))
ax = fig.add_subplot(111, projection='3d')
colors = ['red', 'blue', 'green', 'yellow']

# --- 4. Loop through chains and plot ---
for i in range(2):
    # chain_data = idata.posterior["mu"][i, :, :] # Get data for the current chain (100x3)

    x_coords = idata.posterior["mus"][i, :, :][:, 0]
    y_coords = idata.posterior["weights"][i, :, :][:, 0]
    z_coords = idata.posterior["weights"][i, :, :][:, 1]

    ax.scatter(
        x_coords,
        y_coords,
        z_coords,
        c=colors[i % len(colors)], # Use modulo to cycle colors if more chains than colors
        s=5,                     # Size of markers
        alpha=0.7,                # Transparency
        # edgecolor='k',            # Black edges for markers
        label=f'Chain {i}'        # Label for the legend
    )

# Adjust viewing angle (optional)
ax.view_init(elev=45, azim=45)

# Tight layout ensures labels and titles fit
plt.tight_layout()