# Multinomial Regression in PyMC: From Basics to Random Effects

This article covers multinomial regression modeling in PyMC, from simple classification problems to hierarchical models with correlated random effects. We'll explore both the high-level Bambi interface and the low-level PyMC approach, discussing when each is appropriate.

**Prerequisites:** Familiarity with Bayesian basics (priors, posteriors, MCMC) and comfort with Python and basic PyMC usage.

## What is Multinomial Regression?

Multinomial regression extends binary logistic regression to handle outcomes with more than two categories. While logistic regression models the log-odds of a single event, multinomial regression models the log-odds of each category relative to a reference category:

$$\log\left(\frac{P(Y=k)}{P(Y=K)}\right) = X\beta_k$$

where $K$ is the reference category and $k$ ranges over the other categories.

**Common applications include:**
- Classification with multiple classes (e.g., species identification)
- Choice modeling (e.g., consumer product selection)
- Compositional data analysis (e.g., vote shares, market shares)

## The Identifiability Problem

A key challenge in multinomial regression is identifiability. The model is overparameterized: adding a constant to all category logits doesn't change the probabilities. Two common solutions exist:

1. **Reference category (pivot) approach:** Fix one category's coefficients to zero
2. **Sum-to-zero constraint:** Constrain parameters to sum to zero across categories

We'll demonstrate both approaches, with emphasis on PyMC's `ZeroSumNormal` distribution which implements the second approach elegantly.

In [None]:
import numpy as np
import pandas as pd
import pymc as pm
import arviz as az
import bambi as bmb
import pytensor.tensor as pt
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.stats import multivariate_normal
from scipy.special import softmax
import sklearn.datasets
import warnings

warnings.filterwarnings('ignore')
np.random.seed(42)

print(f"PyMC version: {pm.__version__}")
print(f"ArviZ version: {az.__version__}")

## 1. Basic Multinomial Regression with the Iris Dataset

We'll start with the classic Iris dataset, predicting species from sepal and petal measurements. This is an ideal example because:
- It has exactly 3 categories (setosa, versicolor, virginica)
- The features have clear predictive power
- The data is well-behaved for demonstration purposes

In [None]:
# Load and prepare data
iris = sns.load_dataset("iris")

# Center the predictors (good practice for interpretation)
for col in ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']:
    iris[col] = iris[col] - iris[col].mean()

print(f"Dataset shape: {iris.shape}")
print(f"Species: {iris['species'].unique()}")
iris.head()

In [None]:
# Visualize the data
sns.pairplot(iris, hue="species", diag_kind="kde")
plt.suptitle("Iris Dataset: Feature Relationships by Species", y=1.02)
plt.show()

### 1.1 The Bambi Approach (High-Level)

Bambi provides a formula-based interface similar to R's brms. For multinomial regression, we use the `categorical` family. Bambi automatically handles the reference category approach, treating the first category (alphabetically) as the reference.

In [None]:
# Bambi model - one line!
bambi_model = bmb.Model(
    "species ~ sepal_length + sepal_width + petal_length + petal_width", 
    iris, 
    family="categorical",
)

# Fit the model
bambi_idata = bambi_model.fit(random_seed=42)

In [None]:
# View the summary
az.summary(bambi_idata, round_to=2)

**Interpreting the coefficients:**

Bambi uses 'setosa' as the reference category. Each coefficient represents the log-odds of that category relative to setosa, per unit increase in the predictor. For example:
- `petal_length[virginica]` ~ 3.7 means that a 1cm increase in petal length increases the log-odds of virginica vs. setosa by about 3.7
- Negative coefficients for `sepal_width` indicate that wider sepals favor setosa over the other species

### 1.2 The PyMC Approach with ZeroSumNormal

While Bambi's reference category approach works well, PyMC offers `ZeroSumNormal` - a more elegant solution that treats all categories symmetrically. Instead of fixing one category to zero, we constrain all parameters to sum to zero.

**Advantages of ZeroSumNormal:**
- Symmetric treatment of all categories
- Better interpretability (each coefficient represents deviation from the mean)
- Improved sampling efficiency in some cases
- More natural for compositional data

In [None]:
# Prepare data for PyMC
X = iris[['sepal_length', 'sepal_width', 'petal_length', 'petal_width']].values
y = pd.Categorical(iris['species']).codes  # Convert to 0, 1, 2

n_obs = X.shape[0]
n_features = X.shape[1]
n_categories = 3

print(f"Observations: {n_obs}, Features: {n_features}, Categories: {n_categories}")

In [None]:
# PyMC model with ZeroSumNormal
with pm.Model() as pymc_zsn_model:
    # Data containers
    X_data = pm.Data('X_data', X)
    y_obs = pm.Data('y_obs', y)
    
    # Intercepts with sum-to-zero constraint
    alpha = pm.ZeroSumNormal('alpha', sigma=5, shape=n_categories)
    
    # Coefficients with sum-to-zero constraint across categories
    beta = pm.ZeroSumNormal('beta', sigma=5, shape=(n_features, n_categories))
    
    # Linear predictor
    mu = pm.math.dot(X_data, beta) + alpha
    
    # Likelihood (logit_p handles the softmax internally)
    obs = pm.Categorical('obs', logit_p=mu, observed=y_obs)
    
    # Sample
    pymc_zsn_idata = pm.sample(random_seed=42, idata_kwargs={"log_likelihood": True})

In [None]:
az.summary(pymc_zsn_idata, round_to=2)

**Note on ZeroSumNormal interpretation:** With the sum-to-zero constraint, each coefficient represents the category's deviation from the mean effect. A positive `beta[2, 2]` (petal_length for virginica) means virginica has an above-average response to petal length.

### 1.3 Comparison: Pivot vs ZeroSumNormal

For completeness, let's also build the traditional pivot model to compare sampling efficiency.

In [None]:
# Traditional pivot model (category 0 as reference)
with pm.Model() as pymc_pivot_model:
    X_data = pm.Data('X_data', X)
    y_obs = pm.Data('y_obs', y)
    
    # Parameters only for non-reference categories
    alpha = pm.Normal('alpha', mu=0, sigma=5, shape=n_categories - 1)
    beta = pm.Normal('beta', mu=0, sigma=5, shape=(n_features, n_categories - 1))
    
    # Compute logits for non-pivot categories
    mu_nonpivot = pm.math.dot(X_data, beta) + alpha
    
    # Add zeros for the pivot category
    zeros = pm.math.zeros((n_obs, 1))
    mu = pm.math.concatenate([zeros, mu_nonpivot], axis=1)
    
    obs = pm.Categorical('obs', logit_p=mu, observed=y_obs)
    
    pymc_pivot_idata = pm.sample(random_seed=42, idata_kwargs={"log_likelihood": True})

In [None]:
# Compare ESS between models
zsn_summary = az.summary(pymc_zsn_idata)
pivot_summary = az.summary(pymc_pivot_idata)

print("ZeroSumNormal model - Mean ESS bulk:", zsn_summary['ess_bulk'].mean().round(0))
print("Pivot model - Mean ESS bulk:", pivot_summary['ess_bulk'].mean().round(0))

### 1.4 Model Diagnostics

Before trusting our results, we should check sampling diagnostics.

In [None]:
# Trace plots for the ZeroSumNormal model
az.plot_trace(pymc_zsn_idata, var_names=['alpha', 'beta'])
plt.tight_layout()
plt.show()

In [None]:
# Check R-hat and divergences
summary = az.summary(pymc_zsn_idata)
print(f"Max R-hat: {summary['r_hat'].max():.3f}")
print(f"Min ESS bulk: {summary['ess_bulk'].min():.0f}")
print(f"Divergences: {pymc_zsn_idata.sample_stats.diverging.sum().values}")

In [None]:
# Posterior predictive check
with pymc_zsn_model:
    ppc = pm.sample_posterior_predictive(pymc_zsn_idata, random_seed=42)

# Compare predicted vs actual class frequencies
pred_classes = ppc.posterior_predictive['obs'].values.flatten()
actual_counts = np.bincount(y, minlength=3) / len(y)
pred_counts = np.bincount(pred_classes, minlength=3) / len(pred_classes)

fig, ax = plt.subplots(figsize=(8, 5))
x = np.arange(3)
width = 0.35
ax.bar(x - width/2, actual_counts, width, label='Observed', alpha=0.8)
ax.bar(x + width/2, pred_counts, width, label='Predicted', alpha=0.8)
ax.set_xticks(x)
ax.set_xticklabels(['setosa', 'versicolor', 'virginica'])
ax.set_ylabel('Proportion')
ax.set_title('Posterior Predictive Check: Class Frequencies')
ax.legend()
plt.show()

---

## 2. Random Effects Multinomial Regression

Real-world data often has hierarchical structure: students nested in schools, patients in hospitals, or repeated measurements on individuals. Random effects models account for this grouping, allowing us to:

1. **Capture group-level variation** in category preferences
2. **Model overdispersion** in count data
3. **Estimate correlations** between category preferences within groups

### 2.1 Why Random Effects?

Consider a study of consumer preferences across multiple stores. Without random effects, we assume all stores have identical category preferences. With random effects, each store can have its own preference pattern, and we can learn about the correlation structure of these preferences.

### 2.2 Simulated Data Example

We'll generate data with known correlation structure to validate our model's ability to recover true parameters.

In [None]:
# Simulation settings
np.random.seed(33)
num_groups = 100
num_categories = 3

# True covariance structure for random effects
# Using pivot parameterization: only num_categories-1 random effects
true_corr = np.array([
    [1.0, 0.7],
    [0.7, 1.0]
])
true_sds = np.array([0.5, 1.0])
true_cov = np.diag(true_sds) @ true_corr @ np.diag(true_sds)

print("True covariance matrix:")
print(true_cov)

In [None]:
# Generate group-level random effects
mean_effect = np.zeros(num_categories - 1)
group_effects = multivariate_normal.rvs(mean=mean_effect, cov=true_cov, size=num_groups)

# Add pivot logits (zeros for category 0)
group_effects_full = np.hstack([np.zeros((num_groups, 1)), group_effects])

# Verify empirical covariance
emp_cov = np.cov(group_effects.T)
print("Empirical covariance (should be close to true):")
print(emp_cov)

In [None]:
# Generate observations
data = []
for idx in range(num_groups):
    n_obs = np.random.poisson(30)  # ~30 observations per group
    probs = softmax(group_effects_full[idx])
    
    for _ in range(n_obs):
        observed_cat = np.random.choice(num_categories, p=probs)
        data.append({
            'group_id': idx,
            'cat_id': observed_cat,
        })

simulated_df = pd.DataFrame(data)
print(f"Total observations: {len(simulated_df)}")
print(f"Observations per group: {simulated_df.groupby('group_id').size().describe()}")

In [None]:
# Create aggregated version for multinomial likelihood
agg_df = (
    simulated_df
    .groupby(['group_id', 'cat_id'])
    .size()
    .unstack(fill_value=0)
    .reset_index()
)
agg_df.columns = ['group_id', 'cat_0', 'cat_1', 'cat_2']
agg_df['total'] = agg_df[['cat_0', 'cat_1', 'cat_2']].sum(axis=1)

print(agg_df.head(10))

### 2.3 Bambi's Limitations

Bambi can fit random effects in categorical models, but with an important limitation: it assumes **uncorrelated random effects** across categories. This means it estimates a single variance parameter shared across all category-specific random effects.

In [None]:
# Add group as categorical for Bambi
simulated_df['group'] = simulated_df['group_id'].astype(str)
simulated_df['cat'] = simulated_df['cat_id'].apply(lambda x: f'cat_{x}')

# Bambi model with random intercepts
bambi_re_model = bmb.Model('cat ~ (1|group)', data=simulated_df, family='categorical')
bambi_re_idata = bambi_re_model.fit(random_seed=42)

In [None]:
# Note: only one sigma parameter for all random effects
az.summary(bambi_re_idata, var_names=['Intercept', '1|group_sigma'])

To properly model **correlated** random effects, we need to use PyMC directly with `LKJCholeskyCov`.

### 2.4 Full PyMC Implementation with Correlated Random Effects

We'll use the `LKJCholeskyCov` prior to estimate the full covariance matrix of random effects. This prior places a uniform distribution on correlation matrices (when `eta=1`) while allowing us to specify priors on the standard deviations.

In [None]:
# Aggregated multinomial model with correlated random effects
with pm.Model() as re_model:
    # Data
    group_idx = pm.Data("group_idx", agg_df["group_id"].values)
    count_obs = pm.Data("count_obs", agg_df[["cat_0", "cat_1", "cat_2"]].values)
    total_obs = pm.Data("total_obs", agg_df["total"].values)
    
    # Prior on standard deviations
    sd_dist = pm.HalfStudentT.dist(nu=3, sigma=2)
    
    # LKJ prior on correlation + standard deviations
    chol_cov, corr, stds = pm.LKJCholeskyCov(
        "chol_cov",
        n=num_categories - 1,  # Pivot parameterization
        eta=1,  # Uniform on correlations
        sd_dist=sd_dist,
        compute_corr=True
    )
    
    # Non-centered parameterization for random effects
    z = pm.Normal("z", 0, 1, shape=(num_groups, num_categories - 1))
    group_effects_m = pm.Deterministic(
        "group_effects",
        pt.dot(z, chol_cov.T)
    )
    
    # Overall mean (could be zero, but we estimate it)
    mean_eff = pm.Normal("mean_effect", 0, 1, shape=(num_categories - 1))
    
    # Logits with pivot at zero
    logits = pt.concatenate(
        [pt.zeros((num_groups, 1)), group_effects_m + mean_eff],
        axis=1
    )
    
    # Multinomial likelihood
    p = pm.math.softmax(logits, axis=1)
    pm.Multinomial("obs", p=p, n=total_obs, observed=count_obs)
    
    # Sample with increased target_accept for better exploration
    re_trace = pm.sample(
        random_seed=42, 
        target_accept=0.95,
        idata_kwargs={"log_likelihood": True}
    )

In [None]:
# Check diagnostics
az.summary(re_trace, var_names=["mean_effect", "chol_cov_corr", "chol_cov_stds"], round_to=3)

In [None]:
# Recover the covariance matrix
dat = az.extract(re_trace, var_names=["chol_cov_corr", "chol_cov_stds"])
corr_samples = dat["chol_cov_corr"].values
stds_samples = dat["chol_cov_stds"].values[:, None, :]

# Compute covariance for each sample
cov_samples = corr_samples * stds_samples * stds_samples.transpose(1, 0, 2)
estimated_cov = np.mean(cov_samples, axis=2)

print("True covariance matrix:")
print(true_cov)
print("\nEstimated covariance matrix (posterior mean):")
print(estimated_cov)

In [None]:
# Compare true vs estimated random effects
true_effects = group_effects  # Shape: (num_groups, 2)
posterior_means = re_trace.posterior["group_effects"].mean(dim=["chain", "draw"]).values

fig, axes = plt.subplots(1, 2, figsize=(10, 4))

for i in range(2):
    axes[i].scatter(true_effects[:, i], posterior_means[:, i], alpha=0.6)
    lims = [min(true_effects[:, i].min(), posterior_means[:, i].min()) - 0.2,
            max(true_effects[:, i].max(), posterior_means[:, i].max()) + 0.2]
    axes[i].plot(lims, lims, 'r--', alpha=0.5, label='y=x')
    axes[i].set_xlabel("True Logits")
    axes[i].set_ylabel("Posterior Means")
    axes[i].set_title(f"Category {i+1} Random Effects")
    axes[i].legend()

plt.tight_layout()
plt.show()

### 2.5 Posterior Predictive Checks

Let's verify that our model captures the data-generating process correctly.

In [None]:
# Generate posterior predictions
with re_model:
    re_ppc = pm.sample_posterior_predictive(re_trace, random_seed=42)

In [None]:
# Compare predicted vs observed for selected groups
selected_groups = [0, 1, 3, 4, 5, 6]
num_selected = len(selected_groups)

all_pp_counts = az.extract(
    re_ppc, group="posterior_predictive", var_names=["obs"]
).values  # (num_groups, num_categories, num_samples)

pp_counts = all_pp_counts[selected_groups, :, :]
obs_counts = agg_df.loc[selected_groups, ["cat_0", "cat_1", "cat_2"]].values

fig, axes = plt.subplots(2, 3, figsize=(12, 8))
axes = axes.flatten()
categories = ["cat_0", "cat_1", "cat_2"]

for idx, group in enumerate(selected_groups):
    ax = axes[idx]
    data = [pp_counts[idx, cat_idx, :] for cat_idx in range(len(categories))]
    sns.boxplot(data=data, ax=ax, color='lightblue')
    ax.scatter(range(len(categories)), obs_counts[idx], color='red', s=100, zorder=5, label='Observed')
    ax.set_xticks(range(len(categories)))
    ax.set_xticklabels(categories)
    ax.set_title(f'Group {group}')
    ax.set_ylabel('Counts')
    if idx == 0:
        ax.legend()

plt.suptitle('Posterior Predictive Checks: Predicted (boxplots) vs Observed (red dots)', y=1.02)
plt.tight_layout()
plt.show()

---

## 3. Cross-Validation for Multinomial Models

Model comparison is crucial for choosing between alternative specifications. However, standard LOO-CV often fails for multinomial models.

### 3.1 Why LOO-PSIS Often Fails

Leave-One-Out cross-validation uses Pareto Smoothed Importance Sampling (PSIS) to estimate predictive accuracy. The Pareto k diagnostic indicates the reliability of this approximation:

- k < 0.5: Good
- 0.5 ≤ k < 0.7: Acceptable
- **k ≥ 0.7: Unreliable** - importance sampling fails

Multinomial models often have many observations with k > 0.7 because:
1. Observations can be highly influential (especially in small groups)
2. The categorical likelihood creates discrete "jumps" in the posterior
3. Group random effects make single observations disproportionately important

In [None]:
# Attempt LOO-CV on our random effects model
try:
    loo_result = az.loo(re_trace)
    print(loo_result)
    print(f"\nPareto k > 0.7: {np.sum(loo_result.pareto_k > 0.7)} observations")
except Exception as e:
    print(f"LOO failed: {e}")

### 3.2 K-Fold Cross-Validation as Alternative

When LOO fails, K-fold CV provides a robust alternative. Instead of approximating leave-one-out, we:
1. Split data into K folds
2. Refit the model K times, each time holding out one fold
3. Evaluate predictive accuracy on held-out folds

This is computationally expensive but reliable.

### 3.3 Choosing Folds for Grouped Data (Critical!)

**This is a key point that is often overlooked.**

For hierarchical models, you must split by **groups**, not individual observations. Why?

- If observations from the same group appear in both train and test sets, the random effect for that group is estimated from training data and "leaks" into predictions
- This dramatically overstates predictive accuracy
- The model appears to generalize well, but it's actually memorizing group-specific patterns

**Correct approach:** Create folds where each group is entirely in train OR test, never both.

In [None]:
from sklearn.model_selection import KFold

def kfold_cv_multinomial(model_func, data, n_folds=5, random_seed=42):
    """
    Perform K-fold CV with proper group-level folds.
    
    Parameters:
    -----------
    model_func : callable
        Function that takes training data and returns (model, trace)
    data : DataFrame
        Aggregated data with group_id column
    n_folds : int
        Number of folds
    random_seed : int
        Random seed for reproducibility
    
    Returns:
    --------
    dict with fold-wise and overall metrics
    """
    # Get unique groups
    groups = data['group_id'].unique()
    
    # Create group-level folds
    kf = KFold(n_splits=n_folds, shuffle=True, random_state=random_seed)
    
    fold_results = []
    
    for fold_idx, (train_groups_idx, test_groups_idx) in enumerate(kf.split(groups)):
        train_groups = groups[train_groups_idx]
        test_groups = groups[test_groups_idx]
        
        # Split data by groups
        train_data = data[data['group_id'].isin(train_groups)].reset_index(drop=True)
        test_data = data[data['group_id'].isin(test_groups)].reset_index(drop=True)
        
        print(f"\nFold {fold_idx + 1}: {len(train_groups)} train groups, {len(test_groups)} test groups")
        
        # Reindex groups for training
        group_map = {g: i for i, g in enumerate(train_groups)}
        train_data = train_data.copy()
        train_data['group_idx'] = train_data['group_id'].map(group_map)
        
        # Fit model on training data
        model, trace = model_func(train_data, len(train_groups))
        
        # For test data, we need to predict for "new groups"
        # This means sampling new random effects from the prior
        with model:
            # Get posterior samples of hyperparameters
            mean_eff_samples = trace.posterior['mean_effect'].values
            chol_samples = trace.posterior['chol_cov'].values
            
        # Compute log-likelihood for test groups
        # For new groups, we marginalize over the random effects distribution
        test_ll = compute_test_loglik(
            test_data, mean_eff_samples, chol_samples, n_samples=100
        )
        
        fold_results.append({
            'fold': fold_idx,
            'n_train_groups': len(train_groups),
            'n_test_groups': len(test_groups),
            'test_elpd': test_ll
        })
    
    return fold_results


def compute_test_loglik(test_data, mean_eff_samples, chol_samples, n_samples=100):
    """
    Compute expected log-likelihood for test groups by marginalizing
    over the random effects distribution.
    """
    from scipy.stats import multinomial
    
    n_chains, n_draws, n_cat_minus_1 = mean_eff_samples.shape
    total_samples = n_chains * n_draws
    
    # Flatten chain and draw dimensions
    mean_flat = mean_eff_samples.reshape(total_samples, n_cat_minus_1)
    chol_flat = chol_samples.reshape(total_samples, -1)
    
    # Sample indices
    sample_idx = np.random.choice(total_samples, size=n_samples, replace=False)
    
    log_liks = []
    
    for group_id in test_data['group_id'].unique():
        group_data = test_data[test_data['group_id'] == group_id]
        counts = group_data[['cat_0', 'cat_1', 'cat_2']].values[0]
        total = counts.sum()
        
        group_ll = []
        
        for idx in sample_idx:
            mean = mean_flat[idx]
            
            # Reconstruct Cholesky
            L = np.zeros((n_cat_minus_1, n_cat_minus_1))
            L[np.tril_indices(n_cat_minus_1)] = chol_flat[idx]
            
            # Sample random effect for new group
            z = np.random.randn(n_cat_minus_1)
            re = L @ z
            
            # Compute probabilities
            logits = np.concatenate([[0], mean + re])
            probs = softmax(logits)
            
            # Log-likelihood
            ll = multinomial.logpmf(counts, n=total, p=probs)
            group_ll.append(ll)
        
        # Log of mean (not mean of log) for proper averaging
        log_liks.append(np.log(np.mean(np.exp(group_ll))))
    
    return np.sum(log_liks)

In [None]:
def build_re_model(train_data, n_groups):
    """
    Build and fit the random effects multinomial model.
    """
    n_cat = 3
    
    with pm.Model() as model:
        group_idx = pm.Data("group_idx", train_data["group_idx"].values)
        count_obs = pm.Data("count_obs", train_data[["cat_0", "cat_1", "cat_2"]].values)
        total_obs = pm.Data("total_obs", train_data["total"].values)
        
        sd_dist = pm.HalfStudentT.dist(nu=3, sigma=2)
        chol_cov, _, _ = pm.LKJCholeskyCov(
            "chol_cov", n=n_cat - 1, eta=1, sd_dist=sd_dist
        )
        
        z = pm.Normal("z", 0, 1, shape=(n_groups, n_cat - 1))
        group_effects_m = pm.Deterministic("group_effects", pt.dot(z, chol_cov.T))
        
        mean_effect = pm.Normal("mean_effect", 0, 1, shape=(n_cat - 1))
        
        logits = pt.concatenate(
            [pt.zeros((n_groups, 1)), group_effects_m + mean_effect],
            axis=1
        )
        
        p = pm.math.softmax(logits, axis=1)
        pm.Multinomial("obs", p=p[group_idx], n=total_obs, observed=count_obs)
        
        trace = pm.sample(
            draws=500, tune=500,  # Reduced for speed in CV
            random_seed=42, 
            target_accept=0.95,
            progressbar=False
        )
    
    return model, trace

Due to the computational cost of K-fold CV (refitting K times), we'll demonstrate with 3 folds. In practice, 5-10 folds is common.

In [None]:
# Run 3-fold CV (reduced for demonstration)
print("Running K-fold CV with group-level folds...")
cv_results = kfold_cv_multinomial(build_re_model, agg_df, n_folds=3, random_seed=42)

print("\n" + "="*50)
print("K-Fold CV Results:")
total_elpd = sum(r['test_elpd'] for r in cv_results)
print(f"Total ELPD (sum across folds): {total_elpd:.2f}")
for r in cv_results:
    print(f"  Fold {r['fold']+1}: ELPD = {r['test_elpd']:.2f}")

---

## 4. Model Comparison: Logistic-Normal vs Dirichlet-Multinomial

Two common approaches for modeling overdispersed multinomial data:

1. **Logistic-Normal Multinomial** (what we've been using): Random effects on the logit scale
   - Can model any correlation structure (positive or negative)
   - More flexible but more complex

2. **Dirichlet-Multinomial**: Random effects on the probability simplex
   - Only supports negative correlations between categories
   - Simpler and often samples more efficiently
   - Appropriate when categories compete (increase in one decreases others)

Let's compare both models on our simulated data (which has positive correlation).

In [None]:
# Dirichlet-Multinomial model
with pm.Model() as dm_model:
    group_idx = pm.Data("group_idx", agg_df["group_id"].values)
    count_obs = pm.Data("count_obs", agg_df[["cat_0", "cat_1", "cat_2"]].values)
    total_obs = pm.Data("total_obs", agg_df["total"].values)
    
    # Concentration parameter for Dirichlet
    concentration = pm.HalfNormal('concentration', sigma=5, shape=3)
    
    # Group-specific probabilities from Dirichlet
    group_alpha = pm.Dirichlet(
        'group_alpha', 
        a=concentration, 
        shape=(num_groups, 3)
    )
    
    pm.DirichletMultinomial(
        'obs', 
        a=group_alpha[group_idx], 
        n=total_obs, 
        observed=count_obs
    )
    
    dm_trace = pm.sample(
        random_seed=42, 
        target_accept=0.95,
        idata_kwargs={"log_likelihood": True}
    )

In [None]:
# Generate posterior predictive samples
with dm_model:
    dm_ppc = pm.sample_posterior_predictive(dm_trace, random_seed=42)

In [None]:
# Compare correlation structures
obs_counts = agg_df[["cat_0", "cat_1", "cat_2"]].values
obs_corr = np.corrcoef(obs_counts.T)

# Dirichlet-Multinomial predictions
dm_pred = az.extract(dm_ppc, group="posterior_predictive", var_names=["obs"]).values
dm_means = dm_pred.mean(axis=2)
dm_corr = np.corrcoef(dm_means.T)

# Logistic-Normal predictions
ln_pred = az.extract(re_ppc, group="posterior_predictive", var_names=["obs"]).values
ln_means = ln_pred.mean(axis=2)
ln_corr = np.corrcoef(ln_means.T)

# Plot comparison
fig, axes = plt.subplots(1, 3, figsize=(15, 4))
categories = ["cat_0", "cat_1", "cat_2"]

for ax, corr_mat, title in zip(
    axes, 
    [obs_corr, dm_corr, ln_corr],
    ['Observed', 'Dirichlet-Multinomial', 'Logistic-Normal']
):
    sns.heatmap(
        corr_mat, annot=True, fmt='.2f',
        xticklabels=categories, yticklabels=categories,
        ax=ax, cmap='RdBu_r', vmin=-1, vmax=1, center=0
    )
    ax.set_title(title)

plt.suptitle('Correlation Structure Comparison', y=1.02)
plt.tight_layout()
plt.show()

**Key observation:** The Dirichlet-Multinomial model can only produce negative correlations between categories (due to the constraint that probabilities sum to 1). When the true data has positive correlations (as in our simulation), the Logistic-Normal model is more appropriate.

In [None]:
# Attempt LOO comparison (expect warnings)
print("Attempting LOO comparison (expect Pareto k warnings)...\n")

loo_ln = az.loo(re_trace)
loo_dm = az.loo(dm_trace)

print(f"\nLogistic-Normal - Bad Pareto k: {np.sum(loo_ln.pareto_k > 0.7)}")
print(f"Dirichlet-Multinomial - Bad Pareto k: {np.sum(loo_dm.pareto_k > 0.7)}")

In [None]:
# Despite warnings, let's see the comparison
comparison = az.compare({
    'Logistic-Normal': re_trace, 
    'Dirichlet-Multinomial': dm_trace
})
print(comparison[['rank', 'elpd_loo', 'p_loo', 'weight', 'warning']])

**Note:** Given the Pareto k warnings, these LOO estimates should be interpreted with caution. For reliable model comparison, use K-fold CV as demonstrated in Section 3.

---

## 5. Conclusion

### Summary

We've covered the progression from basic to advanced multinomial regression in PyMC:

1. **Basic multinomial regression** can be done elegantly with Bambi for simple cases

2. **ZeroSumNormal** provides a modern approach to the identifiability problem, treating categories symmetrically

3. **Random effects** require PyMC for full flexibility, especially for correlated random effects using `LKJCholeskyCov`

4. **Cross-validation** for hierarchical multinomial models requires special care:
   - LOO-PSIS often fails (check Pareto k diagnostics!)
   - K-fold CV is more robust
   - **Critical:** Use group-level folds, not observation-level

5. **Model selection** between Logistic-Normal and Dirichlet-Multinomial depends on the correlation structure in your data

### Guidelines: Bambi vs PyMC

| Use Bambi when... | Use PyMC when... |
|-------------------|------------------|
| Simple fixed effects | Correlated random effects |
| Quick exploration | Custom likelihood functions |
| Standard priors suffice | Need full control over priors |
| Interpretable formula syntax | Complex hierarchical structures |

### Key Takeaways

1. Always check Pareto k diagnostics when using LOO
2. For hierarchical models, split by groups in cross-validation
3. Use the non-centered parameterization for random effects
4. Consider the true correlation structure when choosing between Logistic-Normal and Dirichlet-Multinomial
5. ZeroSumNormal is often preferable to the pivot approach for interpretability

### Further Reading

- [PyMC Documentation](https://www.pymc.io/)
- [Bambi Documentation](https://bambinos.github.io/bambi/)
- [ArviZ Model Comparison](https://python.arviz.org/en/latest/user_guide/model_comparison.html)
- Gelman et al., "Bayesian Data Analysis" (Chapter 14 on hierarchical models)
- McElreath, "Statistical Rethinking" (Chapter 12 on multilevel models)