In [None]:
# Install required Python libraries first
!pip install jax jaxlib equinox

# Remove any existing clones to avoid [already exists] errors
!rm -rf /content/pymdp
!rm -rf /content/thrml

# Clone both GitHub repositories
!git clone https://github.com/apashea/pymdp.git /content/pymdp
!git clone https://github.com/extropic-ai/thrml.git /content/thrml

# Install both as editable pip packages to generate proper metadata (required for importlib.metadata.version)
!pip install -e /content/pymdp
!pip install -e /content/thrml

# Now import both libraries
import pymdp
import thrml


Cloning into '/content/pymdp'...
remote: Enumerating objects: 54, done.[K
remote: Counting objects: 100% (54/54), done.[K
remote: Compressing objects: 100% (52/52), done.[K
remote: Total 54 (delta 3), reused 42 (delta 1), pack-reused 0 (from 0)[K
Receiving objects: 100% (54/54), 146.76 KiB | 1.29 MiB/s, done.
Resolving deltas: 100% (3/3), done.
Cloning into '/content/thrml'...
remote: Enumerating objects: 129, done.[K
remote: Counting objects: 100% (52/52), done.[K
remote: Compressing objects: 100% (40/40), done.[K
remote: Total 129 (delta 30), reused 13 (delta 12), pack-reused 77 (from 1)[K
Receiving objects: 100% (129/129), 33.55 MiB | 17.47 MiB/s, done.
Resolving deltas: 100% (39/39), done.
Obtaining file:///content/pymdp
[31mERROR: file:///content/pymdp does not appear to be a Python project: neither 'setup.py' nor 'pyproject.toml' found.[0m[31m
[0mObtaining file:///content/thrml
  Installing build dependencies ... [?25l[?25hdone
  Checking if build backend supports b

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

# THRML imports - corrected paths
from thrml.models import CategoricalEBMFactor, CategoricalGibbsConditional, FactorizedEBM
from thrml.pgm import CategoricalNode
from thrml.block_management import Block
from thrml.block_sampling import BlockGibbsSpec, SamplingSchedule, sample_states
from thrml.factor import FactorSamplingProgram

import pymdp
from pymdp import utils
from pymdp.agent import Agent

# ============================================================================
# SETUP: Define the minimal POMDP
# ============================================================================

# Problem specification
n_states = 3  # Number of discrete hidden states
n_obs = 2     # Number of discrete observations
T = 5         # Time horizon (short for fast testing)

# Random seed for reproducibility
key = jax.random.key(42)

# ============================================================================
# STEP 1: Define generative model parameters (shared by both THRML and pymdp)
# ============================================================================

# A matrix: Observation model P(o|s)
# Shape: (n_obs, n_states)
A_matrix = jnp.array([
    [0.8, 0.1, 0.1],   # P(o=0 | s=0,1,2)
    [0.2, 0.9, 0.9]    # P(o=1 | s=0,1,2)
])

# B matrix: Transition model P(s_t+1 | s_t)
# Shape: (n_states, n_states)
# B[i, j] = P(s_t+1=i | s_t=j)
B_matrix = jnp.array([
    [0.7, 0.2, 0.1],   # P(s'=0 | s=0,1,2)
    [0.2, 0.6, 0.2],   # P(s'=1 | s=0,1,2)
    [0.1, 0.2, 0.7]    # P(s'=2 | s=0,1,2)
])

# D vector: Initial state prior P(s_0)
# Shape: (n_states,)
D_vector = jnp.array([0.33, 0.33, 0.34])  # Approximately uniform

# Generate synthetic observations for testing
key, subkey = jax.random.split(key)
observed_data = jax.random.randint(subkey, (T,), minval=0, maxval=n_obs, dtype=jnp.uint8)
print(f"Observed data: {observed_data}")

# ============================================================================
# VALIDATION TEST 1: Matrix Convention Verification
# ============================================================================

print("\n" + "="*60)
print("VALIDATION TEST 1: Matrix Convention Verification")
print("="*60)

# Manually verify that THRML's transposed A matrix represents the same model
# Check: P(o=1|s=0) should be 0.2 in both representations
print("\nVerifying A matrix conventions:")
print(f"  Original A_matrix[1, 0] (P(o=1|s=0)): {A_matrix[1, 0]}")
print(f"  Transposed log_A.T[0, 1] (exp of log P(o=1|s=0)): {jnp.exp(jnp.log(A_matrix[1, 0] + 1e-16))}")
print(f"  ✓ Match: {jnp.allclose(A_matrix[1, 0], jnp.exp(jnp.log(A_matrix[1, 0] + 1e-16)))}")

# Verify B matrix
print("\nVerifying B matrix conventions:")
print(f"  B_matrix[0, 1] (P(s'=0|s=1)): {B_matrix[0, 1]}")
print(f"  ✓ THRML uses same B matrix (no transpose needed)")

# Verify D vector
print("\nVerifying D vector:")
print(f"  D_vector: {D_vector}")
print(f"  Sum (should be ~1.0): {jnp.sum(D_vector):.6f}")

# ============================================================================
# STEP 2: Build THRML model (smoothing inference)
# ============================================================================

# Create temporal chain of nodes
state_nodes = [CategoricalNode() for _ in range(T)]
obs_nodes = [CategoricalNode() for _ in range(T)]

# Convert probability matrices to log-space for THRML
log_A = jnp.log(A_matrix + 1e-16)  # Shape: (n_obs, n_states) = (2, 3)
log_B = jnp.log(B_matrix + 1e-16)
log_D = jnp.log(D_vector + 1e-16)

# Observation factor: P(o_t | s_t) for all t
# IMPORTANT: Transpose A matrix to match block order [states, obs]
log_A_transposed = log_A.T  # Shape: (n_states, n_obs) = (3, 2)
log_A_temporal = jnp.tile(log_A_transposed[None, :, :], (T, 1, 1))  # Shape: (T, n_states, n_obs) = (5, 3, 2)

obs_factor = CategoricalEBMFactor(
    [Block(state_nodes), Block(obs_nodes)],
    log_A_temporal  # Shape: (T, n_states, n_obs)
)

# Transition factor: P(s_t+1 | s_t) for t=0..T-2
# Weight tensor shape: (T-1, n_states, n_states)
# Note: B_matrix is (n_states, n_states), so we tile it T-1 times
log_B_temporal = jnp.tile(log_B[None, :, :], (T-1, 1, 1))

transition_factor = CategoricalEBMFactor(
    [Block(state_nodes[:-1]), Block(state_nodes[1:])],
    log_B_temporal
)

# Prior factor: P(s_0) - only applies to first state
# Weight tensor shape: (1, n_states)
prior_factor = CategoricalEBMFactor(
    [Block([state_nodes[0]])],
    log_D[None, :]
)

# Combine into EBM
ebm = FactorizedEBM([obs_factor, transition_factor, prior_factor])

# ============================================================================
# VALIDATION TEST 2: Energy Function Verification
# ============================================================================

print("\n" + "="*60)
print("VALIDATION TEST 2: Energy Function Verification")
print("="*60)

# Test energy computation for a known configuration
test_states = jnp.array([0, 0, 0, 0, 0], dtype=jnp.uint8)  # All states = 0
test_energy = ebm.energy([test_states, observed_data], [Block(state_nodes), Block(obs_nodes)])

# Manually compute expected energy
# Energy = -log P(states, obs) = -[log P(s_0) + sum_t log P(o_t|s_t) + sum_t log P(s_t+1|s_t)]
manual_energy = -log_D[0]  # Prior for s_0=0
for t in range(T):
    manual_energy -= log_A_transposed[0, observed_data[t]]  # Observation likelihood
for t in range(T-1):
    manual_energy -= log_B[0, 0]  # Transition s_t=0 -> s_t+1=0

print(f"\nEnergy for all-zero state sequence:")
print(f"  THRML computed energy: {test_energy:.6f}")
print(f"  Manually computed energy: {manual_energy:.6f}")
print(f"  ✓ Match: {jnp.allclose(test_energy, manual_energy, rtol=1e-5)}")

# ============================================================================
# STEP 3: Set up THRML sampling program
# ============================================================================

# Gibbs spec: states are free (to be inferred), observations are clamped
gibbs_spec = BlockGibbsSpec(
    free_super_blocks=[Block(state_nodes)],  # Changed from free_blocks
    clamped_blocks=[Block(obs_nodes)]
)

# Sampler for categorical states
sampler = CategoricalGibbsConditional(n_states)

# Create sampling program
prog = FactorSamplingProgram(
    gibbs_spec,
    [sampler],
    [obs_factor, transition_factor, prior_factor],
    []
)

# Initialize states randomly
key, subkey = jax.random.split(key)
init_states = jax.random.randint(subkey, (T,), 0, n_states, dtype=jnp.uint8)

# Sampling schedule
schedule = SamplingSchedule(
    n_warmup=100,      # Burn-in iterations
    n_samples=1000,    # Collect 1000 samples
    steps_per_sample=2 # 2 Gibbs sweeps between samples
)

# ============================================================================
# STEP 4: Run THRML inference (smoothing)
# ============================================================================

print("\n" + "="*60)
print("Running THRML smoothing inference...")
print("="*60)

key, subkey = jax.random.split(key)
samples = sample_states(
    subkey,
    prog,
    schedule,
    [init_states],
    [observed_data],
    [Block(state_nodes)]
)

# Compute empirical marginals from samples
# samples[0] has shape: (n_samples, T)
state_samples = samples[0]
marginals_thrml = jnp.mean(
    jax.nn.one_hot(state_samples, n_states),
    axis=0
)  # Shape: (T, n_states)

print("\nTHRML smoothing posteriors P(s_t | o_0:T-1):")
for t in range(T):
    print(f"  t={t}: {marginals_thrml[t]}")

# ============================================================================
# STEP 5: Build pymdp model and run sequential filtering
# ============================================================================

print("\n" + "="*60)
print("Running pymdp sequential filtering...")
print("="*60)

# Convert THRML model to pymdp format
# Our A_matrix is (n_obs, n_states) = (2, 3)
A_pymdp = utils.obj_array(1)
A_pymdp[0] = np.array(A_matrix)  # Transpose: (n_states, n_obs) = (3, 2)

A_pymdp[0] = utils.norm_dist(A_pymdp[0])
print(f"is A_pymdp[0] normalized? : {utils.is_normalized(A_pymdp[0])}")

# pymdp B matrix: B[s', s] (same as our B_matrix)
B_pymdp = utils.obj_array(1)

# Add action dimension: shape (n_states, n_states, 1)
# B_pymdp[0][s', s, 0] = P(s'|s) for the single null action
B_matrix_with_action = np.expand_dims(np.array(B_matrix), axis=2)

B_pymdp[0] = B_matrix_with_action  # Shape: (3, 3, 1)

# Verify shape
print(f"B_pymdp[0] shape: {B_pymdp[0].shape}")  # Should be (3, 3, 1)

# pymdp D vector: initial state prior
D_pymdp = utils.obj_array(1)
D_pymdp[0] = np.array(D_vector)

# Create pymdp agent with explicit inference algorithm
agent = Agent(
    A=A_pymdp,
    B=B_pymdp,
    D=D_pymdp,
    inference_algo="VANILLA"  # Explicit FPI for filtering
)

# Process observations ONE AT A TIME (sequential filtering)
qs_pymdp_history = []
for t, obs in enumerate(observed_data):
    observation = [int(obs)]  # Single observation for single modality
    qs = agent.infer_states(observation)
    q_pi, neg_efe = agent.infer_policies()
    action_idx = agent.sample_action()
    qs_pymdp_history.append(qs[0])  # Extract state factor 0
    print(f"  t={t}: pymdp filtering P(s_{t} | o_0:{t}) = {qs[0]}")

# ============================================================================
# STEP 6: Compare results
# ============================================================================

print("\n" + "="*60)
print("COMPARISON")
print("="*60)

print("\nNote: THRML computes smoothing P(s_t | o_0:T-1) while")
print("      pymdp computes filtering P(s_t | o_0:t)")
print("      These are different inference problems!\n")

# Compare final timestep (where both have seen all observations)
final_qs_pymdp = qs_pymdp_history[-1]
final_marginal_thrml = marginals_thrml[-1]

print(f"Final timestep (t={T-1}) comparison:")
print(f"  pymdp filtering:  {final_qs_pymdp}")
print(f"  THRML smoothing:  {final_marginal_thrml}")

error = jnp.max(jnp.abs(final_marginal_thrml - final_qs_pymdp))
print(f"\nMax absolute error at final timestep: {error:.4f}")

# Success criterion
if error < 0.05:
    print("\n✓ SUCCESS: Error < 5%, THRML approximates pymdp well!")
else:
    print(f"\n✗ WARNING: Error = {error:.4f} >= 5%")
    print("  Consider increasing n_samples or n_warmup in THRML")

# ============================================================================
# STEP 7: Full comparison across all timesteps (for reference)
# ============================================================================

print("\n" + "="*60)
print("FULL TIMESTEP COMPARISON (for reference)")
print("="*60)

print("\nTimestep-by-timestep comparison:")
for t in range(T):
    error_t = jnp.max(jnp.abs(marginals_thrml[t] - qs_pymdp_history[t]))
    print(f"  t={t}: max_error = {error_t:.4f}")
    print(f"    pymdp filtering:  {qs_pymdp_history[t]}")
    print(f"    THRML smoothing:  {marginals_thrml[t]}")

print("\nNote: Errors at early timesteps are expected to be larger")
print("      because THRML uses future observations (smoothing)")
print("      while pymdp only uses past observations (filtering).")

# ============================================================================
# VALIDATION TEST 1: Matrix Convention Verification
# ============================================================================

print("\n" + "="*60)
print("VALIDATION TEST 1: Matrix Convention Verification")
print("="*60)

# Manually verify that A matrices represent the same model
print("\nVerifying A matrix conventions:")
print(f"Original A_matrix[o=0, s=0] = {A_matrix[0, 0]:.3f}")
print(f"  This means P(o=0 | s=0) = {A_matrix[0, 0]:.3f}")

print(f"\nTHRML log_A_transposed[s=0, o=0] = {jnp.exp(log_A_transposed[0, 0]):.3f}")
print(f"  This means P(o=0 | s=0) = {jnp.exp(log_A_transposed[0, 0]):.3f}")

print(f"\npymdp A_pymdp[0][s=0, o=0] = {A_pymdp[0][0, 0]:.3f}")
print(f"  This means P(o=0 | s=0) = {A_pymdp[0][0, 0]:.3f}")

# Check if they match
thrml_prob = float(jnp.exp(log_A_transposed[0, 0]))
pymdp_prob = float(A_pymdp[0][0, 0])
original_prob = float(A_matrix[0, 0])

if jnp.allclose(thrml_prob, original_prob, atol=1e-6) and jnp.allclose(pymdp_prob, original_prob, atol=1e-6):
    print("\n✓ Matrix conventions are correctly aligned!")
else:
    print("\n✗ WARNING: Matrix conventions may not be aligned!")
    print(f"  THRML: {thrml_prob:.6f}, pymdp: {pymdp_prob:.6f}, original: {original_prob:.6f}")

# ============================================================================
# VALIDATION TEST 2: Energy Function Verification
# ============================================================================

print("\n" + "="*60)
print("VALIDATION TEST 2: Energy Function Verification")
print("="*60)

# Test energy computation for a known configuration
test_states = jnp.array([0, 0, 0, 0, 0], dtype=jnp.uint8)  # All states = 0
test_obs = observed_data  # Use actual observations

print(f"\nTesting energy for state configuration: {test_states}")
print(f"With observations: {test_obs}")

# Compute energy using THRML
energy_thrml = ebm.energy(
    [test_states, test_obs],
    [Block(state_nodes), Block(obs_nodes)]
)

print(f"\nTHRML energy: {energy_thrml:.4f}")

# Manually compute expected energy
# Energy = -sum(log P(o_t | s_t)) - sum(log P(s_t+1 | s_t)) - log P(s_0)
manual_energy = 0.0

# Initial state prior contribution
manual_energy -= float(log_D[test_states[0]])

# Observation contributions
for t in range(T):
    manual_energy -= float(log_A_transposed[test_states[t], test_obs[t]])

# Transition contributions
for t in range(T-1):
    manual_energy -= float(log_B[test_states[t+1], test_states[t]])

print(f"Manually computed energy: {manual_energy:.4f}")

if jnp.allclose(energy_thrml, manual_energy, atol=1e-4):
    print("\n✓ Energy function is correctly implemented!")
else:
    print(f"\n✗ WARNING: Energy mismatch! Difference: {abs(energy_thrml - manual_energy):.6f}")

# ============================================================================
# VALIDATION TEST 3: Multiple Trial Runs (Sampling Convergence)
# ============================================================================

print("\n" + "="*60)
print("VALIDATION TEST 3: Multiple Trial Runs (Sampling Convergence)")
print("="*60)

print("\nRunning 5 trials with different random seeds...")

n_trials = 5
trial_errors = []

for trial_idx in range(n_trials):
    # Use different random seed for each trial
    trial_key = jax.random.key(42 + trial_idx)

    # Split key for initialization and sampling
    trial_key, init_key, samp_key = jax.random.split(trial_key, 3)

    # Initialize state
    trial_init_states = [jax.random.randint(init_key, (T,), 0, n_states, dtype=jnp.uint8)]

    # Run sampling
    trial_samples = sample_states(
        samp_key,
        prog,
        schedule,
        trial_init_states,
        [observed_data],
        [Block(state_nodes)]
    )

    # Extract the state samples (first element of list, shape: (n_samples, T))
    state_samples = trial_samples[0]  # Shape: (n_samples, T)

    # Compute empirical marginals for each timestep
    trial_marginals = []
    for t in range(T):
        # Get samples at timestep t: shape (n_samples,)
        samples_at_t = state_samples[:, t]
        # Count occurrences of each state
        counts = jnp.bincount(samples_at_t, length=n_states)
        # Normalize to get marginal distribution
        marginal = counts / schedule.n_samples
        trial_marginals.append(marginal)

    # Compare with pymdp at final timestep
    trial_error = jnp.max(jnp.abs(trial_marginals[-1] - qs_pymdp_history[-1]))
    trial_errors.append(float(trial_error))

    print(f"  Trial {trial_idx + 1}: final timestep error = {trial_error:.4f}")

# Compute statistics
mean_error = np.mean(trial_errors)
std_error = np.std(trial_errors)

print(f"\nMultiple trial statistics:")
print(f"  Mean error: {mean_error:.4f}")
print(f"  Std error: {std_error:.4f}")

if mean_error < 0.05 and std_error < 0.02:
    print("\n✓ Sampling is stable across trials!")
elif mean_error < 0.10:
    print("\n⚠ Acceptable stability, but consider more samples")
else:
    print("\n✗ High variance across trials - increase n_samples")

# ============================================================================
# VALIDATION TEST 4: Increased Sample Size Test
# ============================================================================

print("\n" + "="*60)
print("VALIDATION TEST 4: Increased Sample Size Test")
print("="*60)

print("\nTesting with 10x more samples (10,000 samples)...")

# Create schedule with 10x more samples
large_schedule = SamplingSchedule(
    n_warmup=100,
    n_samples=10000,  # 10x increase
    steps_per_sample=2
)

# Create FRESH initial states for this test (don't reuse init_states)
key, large_init_key, large_samp_key = jax.random.split(key, 3)
large_init_states = [jax.random.randint(large_init_key, (T,), 0, n_states, dtype=jnp.uint8)]

# Run sampling with larger sample size
large_samples = sample_states(
    large_samp_key,
    prog,
    large_schedule,
    large_init_states,  # ✓ Use fresh initial states
    [observed_data],
    [Block(state_nodes)]
)

# Extract the state samples (first element of list, shape: (10000, T))
large_state_samples = large_samples[0]  # Shape: (10000, T)

# Compute empirical marginals for each timestep
large_marginals = []
for t in range(T):
    # Get samples at timestep t: shape (10000,)
    samples_at_t = large_state_samples[:, t]
    # Count occurrences of each state
    counts = jnp.bincount(samples_at_t, length=n_states)
    # Normalize to get marginal distribution
    marginal = counts / large_schedule.n_samples
    large_marginals.append(marginal)

# Compare with pymdp at final timestep
large_error = jnp.max(jnp.abs(large_marginals[-1] - qs_pymdp_history[-1]))

print(f"\nFinal timestep error with 10,000 samples: {large_error:.4f}")
print(f"Original error with 1,000 samples: {error:.4f}")

if large_error < error:
    print("\n✓ Error decreased with more samples (expected behavior)")
else:
    print("\n⚠ Error did not decrease (may indicate sampling has converged)")


# ============================================================================
# SUMMARY
# ============================================================================

print("\n" + "="*60)
print("VALIDATION SUMMARY")
print("="*60)

print("\n1. Final timestep comparison:")
print(f"   Error: {error:.4f} {'✓ PASS' if error < 0.05 else '✗ FAIL'}")

print("\n2. Matrix convention verification:")
print(f"   {'✓ PASS' if jnp.allclose(thrml_prob, original_prob, atol=1e-6) else '✗ FAIL'}")

print("\n3. Energy function verification:")
print(f"   {'✓ PASS' if jnp.allclose(energy_thrml, manual_energy, atol=1e-4) else '✗ FAIL'}")

print("\n4. Multiple trial stability:")
print(f"   Mean error: {mean_error:.4f}, Std: {std_error:.4f}")
print(f"   {'✓ PASS' if mean_error < 0.05 and std_error < 0.02 else '⚠ ACCEPTABLE' if mean_error < 0.10 else '✗ FAIL'}")

print("\n5. Increased sample size test:")
print(f"   Error with 10x samples: {large_error:.4f}")
print(f"   {'✓ PASS' if large_error < error else '⚠ NOTE'}")

print("\n" + "="*60)
print("All validation tests complete!")
print("="*60)

Observed data: [1 0 0 0 1]

VALIDATION TEST 1: Matrix Convention Verification

Verifying A matrix conventions:
  Original A_matrix[1, 0] (P(o=1|s=0)): 0.20000000298023224
  Transposed log_A.T[0, 1] (exp of log P(o=1|s=0)): 0.19999998807907104
  ✓ Match: True

Verifying B matrix conventions:
  B_matrix[0, 1] (P(s'=0|s=1)): 0.20000000298023224
  ✓ THRML uses same B matrix (no transpose needed)

Verifying D vector:
  D_vector: [0.33 0.33 0.34]
  Sum (should be ~1.0): 1.000000

VALIDATION TEST 2: Energy Function Verification

Energy for all-zero state sequence:
  THRML computed energy: 6.423669
  Manually computed energy: 6.423670
  ✓ Match: True

Running THRML smoothing inference...

THRML smoothing posteriors P(s_t | o_0:T-1):
  t=0: [0.30900002 0.45400003 0.23700002]
  t=1: [0.8880001 0.066     0.046    ]
  t=2: [0.95400006 0.032      0.014     ]
  t=3: [0.88000005 0.07300001 0.047     ]
  t=4: [0.29200003 0.462      0.246     ]

Running pymdp sequential filtering...
is A_pymdp[0] norma