# 03 · Multi-step Planning, Bottlenecks, and Beam / Pruned Search

This notebook explores the computational bottleneck in multi-step active inference planning and shows how **observation pruning** and **beam search** reduce compute while often preserving action quality.

**Key idea**: Evaluating depth-$H$ policies with expected free energy (EFE) branches over both actions and **counterfactual observations**, scaling roughly like $(|\mathcal A|\,|\mathcal O|)^H$ on discrete toys. We profile wall-time and node expansions, and compare exact recursion vs. pruned recursion vs. beam search.

We use the ring-world generative model from the library:
- $A[o,s] = P(o\mid s)$ (likelihood), $B[a][s',s]=P(s'\mid s,a)$ (transition), $C$ (log-preferences).
- Planning APIs (from `persystems.planning`):
  - `choose_action_planner(...)` — depth-$H$ recursion with **observation-branch pruning**.
  - `choose_action_beam(...)` — depth-$H$ **beam search** over action sequences.

Outputs: nodes expanded, pruned branches, wall-time, action agreement across methods, and example EFE vectors $G(a)$.

In [None]:
# CI-friendly params: shrink work if running in GitHub Actions
import os, time
CI = os.getenv("CI", "").lower() in ("1", "true", "yes")
HORIZONS = [1, 2, 3] if CI else [1, 2, 3, 4]
N_TRIALS_PER_H = 2 if CI else 5
PRUNE_LEVELS = [1e-3, 1e-4] if CI else [1e-2, 1e-3, 1e-4]
BEAM_WIDTHS = [8, 16] if CI else [8, 16, 32]
print({"CI": CI, "HORIZONS": HORIZONS, "N_TRIALS_PER_H": N_TRIALS_PER_H})

## Setup: ring-world GM and helpers

In [None]:
import numpy as np
import matplotlib.pyplot as plt
from persystems.gm import GenerativeModel
from persystems.planning import choose_action_planner, choose_action_beam

np.set_printoptions(precision=4, suppress=True)
plt.rcParams['figure.dpi'] = 120

def make_world(N=5, A_eps=0.15, target=3, seed=0):
    rng = np.random.default_rng(seed)
    gm = GenerativeModel.make_ring_world(N=N, A_eps=A_eps, target_idx=target)
    qs0 = np.ones(N)/N
    return gm, qs0, rng

gm, qs0, rng = make_world()
N = gm.A.shape[0]
print("Actions:", gm.actions, "| N states:", N)

## Benchmark functions
We benchmark wall-time and node counts for:

- **Exact recursion with pruning** (`choose_action_planner`): set `obs_prune_eps` → 0 for exact enumeration.
- **Beam search** (`choose_action_beam`): set `beam_width` and (optionally) a small observation prune threshold.

In [None]:
def bench_planner(gm, qs, H, prune_eps):
    t0 = time.perf_counter()
    a_idx, comp, Gs, diags = choose_action_planner(qs, gm.A, gm.B, gm.C, horizon=H, obs_prune_eps=prune_eps)
    dt = time.perf_counter() - t0
    return {
        "method": f"planner(prune={prune_eps})", "H": H, "a": a_idx, "dt": dt,
        "nodes": diags.nodes_expanded, "pruned": diags.pruned_obs, "Gs": Gs
    }

def bench_beam(gm, qs, H, beam_width, prune_eps):
    t0 = time.perf_counter()
    a_idx, comp, Gs, diags = choose_action_beam(qs, gm.A, gm.B, gm.C, horizon=H, beam_width=beam_width, obs_prune_eps=prune_eps)
    dt = time.perf_counter() - t0
    return {
        "method": f"beam(K={beam_width}, prune={prune_eps})", "H": H, "a": a_idx, "dt": dt,
        "nodes": diags.nodes_expanded, "pruned": diags.pruned_obs, "Gs": Gs, "beam": beam_width
    }

## Sweep horizon $H$ and compare exact vs. beam
We keep the belief uniform to focus on the planning cost itself. For each $H$, we run multiple trials (identical here, but leaves room for stochastic variants) and summarize median wall-time, nodes, and action choice agreement.

In [None]:
records = []
for H in HORIZONS:
    for trial in range(N_TRIALS_PER_H):
        # exact (no pruning)
        rec = bench_planner(gm, qs0, H=H, prune_eps=0.0)
        records.append(rec)
        # pruned variants
        for pe in PRUNE_LEVELS:
            records.append(bench_planner(gm, qs0, H=H, prune_eps=pe))
        # beams
        for bw in BEAM_WIDTHS:
            for pe in PRUNE_LEVELS:
                records.append(bench_beam(gm, qs0, H=H, beam_width=bw, prune_eps=pe))

import pandas as pd
df = pd.DataFrame(records)
df.head()

### Nodes expanded vs. horizon
Exact recursion should grow rapidly with $H$, while beam/pruned methods grow more slowly.

In [None]:
def summarize_nodes(df):
    grp = df.groupby(["method", "H"]).agg(nodes_median=("nodes", "median"), nodes_mean=("nodes", "mean")).reset_index()
    return grp

g_nodes = summarize_nodes(df)
plt.figure(figsize=(7,4))
for m, sub in g_nodes.groupby("method"):
    plt.plot(sub["H"], sub["nodes_median"], marker='o', label=m)
plt.yscale('log')
plt.xlabel('horizon H')
plt.ylabel('nodes expanded (median, log scale)')
plt.title('Search complexity vs. horizon')
plt.legend(fontsize=7)
plt.tight_layout(); plt.show()
g_nodes.sort_values(["H", "nodes_median"]).head(10)

### Wall-time vs. horizon
Wall-time tracks the node explosion; pruning/beam often reduce it dramatically with little effect on the chosen action in small problems like this ring world.

In [None]:
g_time = df.groupby(["method", "H"]).agg(dt_med=("dt", "median"), dt_mean=("dt", "mean")).reset_index()
plt.figure(figsize=(7,4))
for m, sub in g_time.groupby("method"):
    plt.plot(sub["H"], sub["dt_med"], marker='o', label=m)
plt.yscale('log')
plt.xlabel('horizon H')
plt.ylabel('wall-time (seconds, median, log scale)')
plt.title('Compute cost vs. horizon')
plt.legend(fontsize=7)
plt.tight_layout(); plt.show()
g_time.sort_values(["H", "dt_med"]).head(10)

### Action agreement across methods
We compare each method’s chosen action to the exact depth-$H$ action (no pruning) at the same $H$.

In [None]:
# Build reference exact actions per H
ref = df[df["method"]=="planner(prune=0.0)"][["H","a"]].drop_duplicates()
ref = ref.rename(columns={"a":"a_exact"})
dfj = df.merge(ref, on="H", how="left")
dfj["agree"] = (dfj["a"] == dfj["a_exact"])  # bool

agree_tbl = dfj.groupby(["method","H"]).agg(
    agree_rate=("agree","mean"),
    n=("agree","size")
).reset_index()
agree_tbl.sort_values(["H","agree_rate"], ascending=[True, False]).head(10)

In [None]:
plt.figure(figsize=(7,4))
for m, sub in agree_tbl.groupby("method"):
    plt.plot(sub["H"], 100*sub["agree_rate"], marker='o', label=m)
plt.ylim(0, 105)
plt.xlabel('horizon H')
plt.ylabel('action agreement with exact (%)')
plt.title('Quality vs. compute: agreement with exact planner')
plt.legend(fontsize=7)
plt.tight_layout(); plt.show()

## Example EFE vectors $G(a)$ at depth $H=2$
A small peek at the per-action EFE from different planners (numbers vary slightly with prune/beam heuristics).

In [None]:
def show_example_vectors(H=2):
    rows = []
    exact = bench_planner(gm, qs0, H=H, prune_eps=0.0)
    rows.append(("exact", exact["a"], exact["Gs"]))
    pruned = bench_planner(gm, qs0, H=H, prune_eps=1e-4)
    rows.append(("pruned(1e-4)", pruned["a"], pruned["Gs"]))
    beam16 = bench_beam(gm, qs0, H=H, beam_width=16, prune_eps=1e-4)
    rows.append(("beam(K=16,1e-4)", beam16["a"], beam16["Gs"]))
    for name, a, Gs in rows:
        print(f"{name:>16s} : best a={a}, Gs={np.round(Gs, 5)}")

show_example_vectors(H=2)

## Takeaways
- Exact depth-$H$ planning has *exponential* observation branching; nodes and wall-time explode with $H$.
- **Observation pruning** (drop negligible $Q(o)$) and **beam search** (keep top-$K$ partial plans) cut compute by orders of magnitude.
- On this small ring world, pruned/beam methods often pick the *same* action as exact depth-$H$ while using far fewer nodes.
- For larger problems, these approximations (plus amortized policies) become essential for tractable active inference planning.