# 04 · Model Evidence & Bayesian Model Selection (FFX & RFX)

This notebook demonstrates how to compute **marginal log-evidence** for discrete generative models and use it for **Bayesian model selection**:

- **Fixed-effects (FFX):** assumes a single model generated all sequences; compares summed log-evidence.
- **Random-effects (RFX):** assumes sequences may come from *different* models; infers model frequencies with a Dirichlet posterior and computes **exceedance probabilities**.

We construct two similar models differing only in their observation noise (likelihood $A$) and generate synthetic data from the lower-noise model. We then check that BMS prefers the true model.


In [None]:
# CI-friendly params
import os
CI = os.getenv("CI", "").lower() in ("1", "true", "yes")
N_SEQ      = 4   if CI else 16   # number of sequences (subjects/episodes)
T_PER_SEQ  = 30  if CI else 60   # length per sequence
NSAMPLES_RFX = 2000 if CI else 20000  # Dirichlet sampling for exceedance
print({"CI": CI, "N_SEQ": N_SEQ, "T_PER_SEQ": T_PER_SEQ, "NSAMPLES_RFX": NSAMPLES_RFX})

## Imports & setup

In [None]:
import numpy as np, pandas as pd, matplotlib.pyplot as plt
from persystems.gm import GenerativeModel
from persystems.bms import log_evidence_discrete, compare_models_ffx, compare_models_rfx
np.set_printoptions(precision=4, suppress=True)
plt.rcParams['figure.dpi'] = 120

## Build two close models and generate data from model 0
We create two ring-world models that differ in the likelihood noise parameter $A_\varepsilon$; lower noise (0.10) should better explain data generated under it than higher noise (0.30). Actions are chosen randomly to avoid policy confounds; you could also plug in an active-inference policy here.

In [None]:
N = 5
gm0 = GenerativeModel.make_ring_world(N=N, A_eps=0.10, target_idx=3)  # true generator
gm1 = GenerativeModel.make_ring_world(N=N, A_eps=0.30, target_idx=3)
models = [
    {"A": gm0.A, "B": gm0.B, "name": "Aeps=0.10"},
    {"A": gm1.A, "B": gm1.B, "name": "Aeps=0.30"},
]

def gen_sequence(gm, T, rng):
    true_s = rng.integers(0, N)
    acts, obs = [], []
    for t in range(T):
        a_idx = rng.integers(0, len(gm.B))                 # random actions for data gen
        true_s = (true_s + gm.actions[a_idx]) % N
        o = int(rng.choice(np.arange(N), p=gm.A[:, true_s]))
        acts.append(a_idx); obs.append(o)
    return acts, obs

rng = np.random.default_rng(0)
sequences = [gen_sequence(gm0, T_PER_SEQ, rng) for _ in range(N_SEQ)]
len(sequences), len(sequences[0][0]), len(sequences[0][1])

## Compute marginal log-evidence per model (FFX view)
We compute exact marginal log-evidence by forward filtering (summing predictive likelihoods $Q(o_t)$ along each sequence). Under FFX, we sum log-evidence over all sequences and compare posteriors across models via a softmax with (optional) log-priors.

In [None]:
# Per-sequence, per-model log evidence
LE = np.zeros((N_SEQ, len(models)))
for s_idx, (acts, obs) in enumerate(sequences):
    for m_idx, m in enumerate(models):
        LE[s_idx, m_idx] = log_evidence_discrete(m["A"], m["B"], acts, obs)

LE_df = pd.DataFrame(LE, columns=[m["name"] for m in models])
LE_df.head()

### FFX: summed evidence and posterior over models
We expect the lower-noise model (Aeps=0.10) to dominate when it generated the data.

In [None]:
ffx = compare_models_ffx(models, sequences)
totals = ffx["log_evidence"]
post   = ffx["post"]
ffx_table = pd.DataFrame({"model": [m["name"] for m in models],
                          "sum_log_evidence": totals,
                          "posterior": post})
ffx_table.sort_values("sum_log_evidence", ascending=False)

In [None]:
plt.figure(figsize=(6,4))
x = np.arange(len(models))
plt.bar(x-0.2, totals - totals.max(), width=0.4, label='Δ log evidence (vs max)')
plt.bar(x+0.2, post, width=0.4, label='posterior (FFX)')
plt.xticks(x, [m['name'] for m in models])
plt.title('Fixed-effects BMS: evidence & posterior')
plt.legend(fontsize=8)
plt.tight_layout(); plt.show()

## RFX: model frequencies and exceedance probabilities
Under random-effects BMS, each sequence may come from a different model; we infer a Dirichlet posterior over model frequencies and estimate **exceedance probabilities** via Monte Carlo sampling from the Dirichlet posterior.

We again expect the true model (Aeps=0.10) to have the largest exceedance probability if it explains most sequences better than the alternative.

In [None]:
rfx = compare_models_rfx(models, sequences, alpha0=1.0, nsamples=NSAMPLES_RFX)
alpha = rfx["alpha"]
ex    = rfx["exceedance"]
rmean = rfx["r_mean"]
rfx_table = pd.DataFrame({"model": [m["name"] for m in models],
                          "alpha": alpha,
                          "mean_freq": rmean,
                          "exceedance_prob": ex})
rfx_table.sort_values("exceedance_prob", ascending=False)

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(9,4))
x = np.arange(len(models))
ax[0].bar(x, rmean)
ax[0].set_xticks(x); ax[0].set_xticklabels([m['name'] for m in models])
ax[0].set_title('RFX: posterior mean frequencies')
ax[0].set_ylim(0,1)
ax[1].bar(x, ex)
ax[1].set_xticks(x); ax[1].set_xticklabels([m['name'] for m in models])
ax[1].set_title('RFX: exceedance probabilities')
ax[1].set_ylim(0,1)
plt.tight_layout(); plt.show()

## Sanity: per-sequence responsibilities (which model explains which sequence?)
RFX internally computes soft responsibilities per sequence (posterior over models for that sequence). This is useful to diagnose **mixtures** (e.g., when some sequences truly came from different models).

In [None]:
resp = rfx["responsibilities"]  # shape (S, M)
resp_df = pd.DataFrame(resp, columns=[m["name"] for m in models])
resp_df.index.name = 'sequence'
resp_df.head()

## Takeaways
- **Log-evidence** is the principled score for comparing generative models; FEP/ELBO connects variational free energy to model evidence.
- **FFX** pools evidence across sequences; **RFX** infers model *frequencies* and is robust to heterogeneity.
- On synthetic data generated by the low-noise model, both FFX and RFX (typically) favor **Aeps=0.10** with higher evidence/posterior/exceedance.