# Fit a Tempered Mixture Model to Normal Samples

This notebook loads generated normal samples and fits a 2-component normal mixture using PyMC, with support for tempered posteriors.

## Import Required Libraries

In [None]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import pymc as pm
import arviz as az
from pymc_extensions.pmx import column_stack_vars

## Load the Normal Samples

In [None]:
samples_df = pd.read_csv("../../data/normal_samples.csv")
data = samples_df["samples"].values
print(f"Loaded {len(data)} samples.")

## Fit a Tempered 2-Component Normal Mixture Model
We use PyMC to fit a mixture model. The temperature parameter allows for tempered posteriors.

In [None]:
# Set temperature for tempered posterior (1.0 = standard posterior)
temperature = 1.0

with pm.Model() as model:
    w = pm.Dirichlet('w', a=np.array([1, 1]))
    mu = pm.Normal('mu', mu=0, sigma=5, shape=2)
    sigma = pm.HalfNormal('sigma', sigma=1, shape=2)
    comp = pm.Normal.dist(mu=mu, sigma=sigma, shape=2)
    # Tempered likelihood
    like = pm.Mixture('y', w=w, comp_dists=comp, observed=data)
    if temperature != 1.0:
        model.logp = lambda point=None: temperature * model.basic_RVs[0].logp(point)
    trace = pm.sample(1000, tune=1000, return_inferencedata=True, target_accept=0.9)


## Visualize Posterior and Mixture Fit

In [None]:
az.plot_trace(trace, var_names=['mu', 'sigma', 'w'])
plt.show()

az.plot_posterior(trace, var_names=['mu', 'sigma', 'w'])
plt.show()

# Plot mixture fit
x = np.linspace(data.min() - 1, data.max() + 1, 500)
post_mu = trace.posterior['mu'].mean(dim=["chain", "draw"]).values
post_sigma = trace.posterior['sigma'].mean(dim=["chain", "draw"]).values
post_w = trace.posterior['w'].mean(dim=["chain", "draw"]).values
mix_pdf = sum(w * (1/(s * np.sqrt(2 * np.pi)) * np.exp(-(x-m)**2/(2*s**2))) for w, m, s in zip(post_w, post_mu, post_sigma))
plt.figure(figsize=(10, 6))
plt.hist(data, bins=30, density=True, alpha=0.5, label="Data")
plt.plot(x, mix_pdf, label="Fitted Mixture PDF", color="red")
plt.legend()
plt.title("Mixture Model Fit to Data")
plt.show()