# Active Inference POMDP Agent for Circuit Discovery

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/SharathSPhD/ActiveCIrcuitDiscovery/blob/main/notebooks/02_active_inference_demo.ipynb)

This notebook demonstrates the **pymdp-based Active Inference POMDP agent** that guides circuit discovery on Gemma-2-2B.

**Key concepts:**
- Multi-factor POMDP with 3 hidden state factors and 3 observation modalities
- Expected Free Energy (EFE) minimisation for intervention selection
- Dirichlet learning of the observation model from real intervention data
- Comparison against a bandit heuristic, greedy, random, and oracle baselines
- All interventions use the `feature_intervention` API

**Requirements:** Free Colab GPU (T4)

In [None]:
# Install dependencies (pin numpy<2 for transformer-lens compatibility)
!pip install -q "numpy>=1.26.0,<2.0"
!pip install -q torch torchvision --index-url https://download.pytorch.org/whl/cu121
!pip install -q transformer-lens einops jaxtyping typeguard
!pip install -q git+https://github.com/safety-research/circuit-tracer.git
!pip install -q git+https://github.com/infer-actively/pymdp.git
!pip install -q plotly scipy

import sys, os

In [None]:
# Model selection: switch between Gemma and Llama
MODEL_NAME = "google/gemma-2-2b"  # or "meta-llama/Llama-3.2-1B"
TRANSCODER_SET = "gemma"  # or "llama"

## Step 1: Load Model and Generate Attribution Graph

In [None]:
import torch
import numpy as np
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from collections import defaultdict
from typing import Tuple, Dict, List

from circuit_tracer import ReplacementModel, attribute
from circuit_tracer.graph import prune_graph

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"PyTorch: {torch.__version__}, CUDA: {torch.cuda.is_available()}")

model = ReplacementModel.from_pretrained(
    model_name=MODEL_NAME,
    transcoder_set=TRANSCODER_SET,
    backend="transformerlens",
    device=device,
    dtype=torch.float32,
)
print(f"Loaded: {model.cfg.n_layers} layers, d_model={model.cfg.d_model}")

prompt = "When John and Mary went to the store, John gave the bag to"
raw_graph = attribute(prompt=prompt, model=model, max_n_logits=5,
                      desired_logit_prob=0.9, batch_size=256, verbose=True)
pr = prune_graph(raw_graph, node_threshold=0.8, edge_threshold=0.98)

n_sel = len(raw_graph.selected_features)
adj = raw_graph.adjacency_matrix
infl = adj.abs().sum(0)[:n_sel] + adj.abs().sum(1)[:n_sel]
mi = infl.max().item() or 1.0

candidates = []
kept_mask = pr.node_mask[:n_sel]
for i in torch.where(kept_mask)[0].tolist()[:50]:
    ft = raw_graph.selected_features[i]
    layer = int(raw_graph.active_features[ft, 0].item())
    pos = int(raw_graph.active_features[ft, 1].item())
    fidx = int(raw_graph.active_features[ft, 2].item())
    act = float(raw_graph.activation_values[i].item())
    imp = float(infl[i].item()) / mi
    candidates.append(dict(layer=layer, pos=pos, fidx=fidx, act=act, imp=imp,
                           fid=f'L{layer}_P{pos}_F{fidx}'))
candidates.sort(key=lambda x: x['imp'], reverse=True)

clean_logits, _ = model.feature_intervention(prompt, [], return_activations=False)
clean_probs = torch.softmax(clean_logits[0, -1, :], -1)
top_id = int(clean_probs.argmax().item())
print(f"\n{len(candidates)} candidate features extracted")
print(f"Clean prediction: '{model.tokenizer.decode([top_id])}'")

In [None]:
# Export graph for circuit-tracer interactive visualization
from circuit_tracer.utils.create_graph_files import create_graph_files

create_graph_files(raw_graph, 'active_inference', '/tmp/acd_graphs')
print("Graph files saved to /tmp/acd_graphs")

## Step 2: The Active Inference POMDP Agent

The agent maintains a generative model with:
- **3 hidden state factors**: feature importance (4 levels), layer role (3 levels), causal influence (3 levels)
- **3 observation modalities**: KL divergence magnitude, activation magnitude, graph connectivity
- **3 actions**: ablation, activation patching, feature steering

Intervention selection minimises the **Expected Free Energy** (EFE), which decomposes into epistemic (information gain) and pragmatic (preference satisfaction) components. The observation model (A matrix) is learned online via Dirichlet updates from real intervention data.

In [None]:
from src.active_inference.pomdp_agent import ActiveInferencePOMDPAgent

def ablate_feature(model, prompt, feat, clean_probs):
    iv, _ = model.feature_intervention(
        prompt, [(feat['layer'], feat['pos'], feat['fidx'], 0)],
        return_activations=False)
    iv_probs = torch.softmax(iv[0, -1, :], -1)
    kl = max(0, float(torch.nn.functional.kl_div(
        torch.log(iv_probs + 1e-10), clean_probs, reduction='sum').item()))
    return kl

# Add graph connectivity info to candidates
adj = raw_graph.adjacency_matrix
for i, c in enumerate(candidates):
    sel_idx = None
    for j in range(len(raw_graph.selected_features)):
        ft = raw_graph.selected_features[j]
        if (int(raw_graph.active_features[ft, 0].item()) == c['layer'] and
            int(raw_graph.active_features[ft, 1].item()) == c['pos'] and
            int(raw_graph.active_features[ft, 2].item()) == c['fidx']):
            sel_idx = j
            break
    if sel_idx is not None:
        c['in_degree'] = int((adj[:, sel_idx].abs() > 0).sum().item())
        c['out_degree'] = int((adj[sel_idx, :].abs() > 0).sum().item())
    else:
        c['in_degree'] = 0
        c['out_degree'] = 0

BUDGET = min(20, len(candidates))
print(f"Running pymdp Active Inference POMDP agent with budget={BUDGET}")

agent = ActiveInferencePOMDPAgent(n_layers=model.cfg.n_layers)
agent.initialize()

observed_fids = set()
ai_kls = []
for step in range(BUDGET):
    unobserved = [c for c in candidates if c['fid'] not in observed_fids]
    if not unobserved:
        break
    feat, action, efe = agent.select_intervention(unobserved)
    kl = ablate_feature(model, prompt, feat, clean_probs)
    agent.update_beliefs(
        feat, kl_divergence=kl, activation_value=feat['act'],
        graph_connectivity=feat.get('in_degree', 0) + feat.get('out_degree', 0))
    observed_fids.add(feat['fid'])
    ai_kls.append(kl)
    if (step + 1) % 5 == 0:
        print(f"  Step {step+1}: {feat['fid']} -> KL={kl:.6f}, EFE={efe:.4f}")

print(f"\nPOMDP agent: mean KL = {np.mean(ai_kls):.6f}, cumulative = {np.sum(ai_kls):.6f}")
print(f"Converged: {agent.is_converged}")
print(f"Entropy history: {[f'{h:.3f}' for h in agent.get_belief_entropy_history()]}")

In [None]:
## Step 3: Run Baselines (Greedy, Random, Oracle)

In [None]:
# Greedy baseline: features sorted by graph importance (descending)
greedy_order = list(range(len(candidates)))  # already sorted by imp
greedy_kls = []
for i in greedy_order[:BUDGET]:
    kl = ablate_feature(model, prompt, candidates[i], clean_probs)
    greedy_kls.append(kl)

# Random baseline: average over 10 random orderings
np.random.seed(42)
random_trials = []
for _ in range(10):
    perm = np.random.permutation(len(candidates))[:BUDGET]
    trial_kls = [ablate_feature(model, prompt, candidates[int(j)], clean_probs) for j in perm]
    random_trials.append(trial_kls)
random_kls = [float(np.mean([t[i] for t in random_trials])) for i in range(BUDGET)]

# Oracle: ablate ALL candidates, sort by true KL
all_kls = [(i, ablate_feature(model, prompt, candidates[i], clean_probs))
           for i in range(len(candidates))]
all_kls.sort(key=lambda x: x[1], reverse=True)
oracle_kls = [kl for _, kl in all_kls[:BUDGET]]

print(f"Greedy:  mean KL = {np.mean(greedy_kls):.6f}, cumulative = {np.sum(greedy_kls):.6f}")
print(f"Random:  mean KL = {np.mean(random_kls):.6f}, cumulative = {np.sum(random_kls):.6f}")
print(f"Oracle:  mean KL = {np.mean(oracle_kls):.6f}, cumulative = {np.sum(oracle_kls):.6f}")
print(f"AI:      mean KL = {np.mean(ai_kls):.6f}, cumulative = {np.sum(ai_kls):.6f}")
ai_eff = np.sum(ai_kls) / max(np.sum(oracle_kls), 1e-10) * 100
greedy_eff = np.sum(greedy_kls) / max(np.sum(oracle_kls), 1e-10) * 100
random_eff = np.sum(random_kls) / max(np.sum(oracle_kls), 1e-10) * 100
print(f"\nOracle efficiency: AI={ai_eff:.1f}%, Greedy={greedy_eff:.1f}%, Random={random_eff:.1f}%")

In [None]:
## Step 4: Visualize Comparison

In [None]:
# Cumulative KL curves
ai_cum = np.cumsum(ai_kls)
greedy_cum = np.cumsum(greedy_kls)
random_cum = np.cumsum(random_kls)
oracle_cum = np.cumsum(oracle_kls)
steps = list(range(1, BUDGET + 1))

fig = make_subplots(rows=1, cols=2,
    subplot_titles=['Cumulative KL Divergence', 'Per-Step KL Divergence'])

for name, cum, color in [('POMDP Agent', ai_cum, '#E91E63'),
                          ('Greedy', greedy_cum, '#2196F3'),
                          ('Random', random_cum, '#9E9E9E'),
                          ('Oracle', oracle_cum, '#4CAF50')]:
    fig.add_trace(go.Scatter(x=steps, y=cum.tolist(), mode='lines+markers',
                             name=name, line=dict(color=color, width=2)), row=1, col=1)

for name, kls, color in [('POMDP', ai_kls, '#E91E63'),
                          ('Greedy', greedy_kls, '#2196F3'),
                          ('Random', random_kls, '#9E9E9E')]:
    fig.add_trace(go.Scatter(x=steps, y=kls, mode='lines+markers',
                             name=name, line=dict(color=color, width=1.5),
                             showlegend=False), row=1, col=2)

fig.update_layout(height=450, template='plotly_white',
                  title_text='POMDP Active Inference Agent vs Baselines')
fig.update_xaxes(title_text='Intervention Step', row=1, col=1)
fig.update_xaxes(title_text='Intervention Step', row=1, col=2)
fig.update_yaxes(title_text='Cumulative KL', row=1, col=1)
fig.update_yaxes(title_text='KL Divergence', row=1, col=2)
fig.show()

# Belief entropy over time
entropy_hist = agent.get_belief_entropy_history()
fig2 = go.Figure()
fig2.add_trace(go.Scatter(
    y=entropy_hist, mode='lines+markers',
    line=dict(color='#FF5722', width=2), name='Belief Entropy'))
fig2.update_layout(title='POMDP Agent Belief Entropy Over Time',
                   xaxis_title='Step', yaxis_title='Total Entropy (nats)',
                   template='plotly_white', height=350)
fig2.show()

# EFE history
efe_hist = agent.get_efe_history()
fig3 = go.Figure()
fig3.add_trace(go.Scatter(
    y=efe_hist, mode='lines+markers',
    line=dict(color='#673AB7', width=2), name='EFE'))
fig3.update_layout(title='Expected Free Energy Over Time',
                   xaxis_title='Step', yaxis_title='EFE (lower = more informative)',
                   template='plotly_white', height=350)
fig3.show()

# Feature importance ranking from posterior beliefs
rankings = agent.get_feature_importance_ranking()
print("\nTop 10 features by inferred importance:")
for fid, score in rankings[:10]:
    print(f"  {fid}: importance = {score:.4f}")