# Mess3 generator smoke test

This notebook builds a simple `mess3` hidden Markov model and prints a few generated symbols.


In [1]:
# pyright: reportMissingImports=false
import jax

from simplexity.generative_processes.builder import build_hidden_markov_model

# Instantiate mess3 with defaults from config (x=0.15, a=0.6)
model = build_hidden_markov_model("mess3", x=0.15, a=0.6)

# Initial belief state: use the model's normalizing_eigenvector as a valid distribution
state = model.normalizing_eigenvector[None, :]

# Generate a short sequence
key = jax.random.PRNGKey(0)[None, :]
sequence_len = 10
final_state, observations = model.generate(state, key, sequence_len, False)

print("Generated observations:", observations.squeeze().tolist())


Generated observations: [0, 2, 0, 2, 2, 0, 0, 1, 2, 1]


In [2]:
# pyright: reportMissingImports=false
# Generate again but keep all intermediate belief states (priors)
states_per_step, obs = model.generate(state, key, sequence_len, True)

# Normalize belief states to probabilities per step
beliefs = jax.vmap(model.normalize_belief_state)(states_per_step.squeeze(0))

for t, (o, b) in enumerate(zip(obs.squeeze(0).tolist(), beliefs.tolist())):
	print(f"t={t:02d} obs={o} belief={b}")


t=00 obs=0 belief=[0.3333333432674408, 0.3333333432674408, 0.3333333432674408]
t=01 obs=2 belief=[0.6000000238418579, 0.20000000298023224, 0.20000000298023224]
t=02 obs=0 belief=[0.31578949093818665, 0.17105263471603394, 0.5131579041481018]
t=03 obs=2 belief=[0.5894569158554077, 0.14816294610500336, 0.2623802125453949]
t=04 obs=2 belief=[0.2984992265701294, 0.145717591047287, 0.555783212184906]
t=05 obs=0 belief=[0.16437214612960815, 0.12040877342224121, 0.7152190804481506]
t=06 obs=0 belief=[0.4870404005050659, 0.14601799845695496, 0.36694154143333435]
t=07 obs=1 belief=[0.6828927993774414, 0.12545859813690186, 0.19164860248565674]
t=08 obs=2 belief=[0.365500271320343, 0.456887811422348, 0.17761191725730896]
t=09 obs=1 belief=[0.23474085330963135, 0.26835331320762634, 0.4969058632850647]


In [None]:
# pyright: reportMissingImports=false
import itertools

from simplexity.generative_processes.builder import build_product_hidden_markov_model

# Build a product of two mess3 generators
prod_model = build_product_hidden_markov_model(
	process_names=["mess3", "mess3"],
	process_kwargs=[{"x": 0.15, "a": 0.6}, {"x": 0.2, "a": 0.5}],
)

# Initial state for the product model
prod_state = prod_model.normalizing_eigenvector[None, :]

# Generate a few steps and print observations as tuples and beliefs
prod_final_state, prod_obs = prod_model.generate(prod_state, key, sequence_len, False)

# Decode product observations back to tuples for readability
vocab_sizes = [3, 3]

def unravel_index(idx: int, dims: list[int]) -> tuple[int, ...]:
	coords = []
	for d in reversed(dims):
		coords.append(idx % d)
		idx //= d
	return tuple(reversed(coords))

obs_tuples = [unravel_index(int(o), vocab_sizes) for o in prod_obs.squeeze(0).tolist()]
print("Product observations (tuples):", obs_tuples)

# Now with all states and beliefs
prod_states_per_step, prod_obs_all = prod_model.generate(prod_state, key, sequence_len, True)
prod_beliefs = jax.vmap(prod_model.normalize_belief_state)(prod_states_per_step.squeeze(0))

for t, (o, b) in enumerate(zip(prod_obs_all.squeeze(0).tolist(), prod_beliefs.tolist())):
	print(f"t={t:02d} obs={unravel_index(int(o), vocab_sizes)} belief_dim={len(b)}")
