# Chain Reasoning, Persistence & Naive Baseline

End-to-end demonstration of multi-hop chain priors, weakest link diagnostics, and state persistence vs a naive cosine top-k baseline.

Outline:
1. Define small factual corpus forming a geography chain with distractors.
2. Embed texts (real model if installed; fallback hash embedding otherwise).
3. Naive cosine top-k retrieval.
4. Lattice with chain prior: receipt, chain verdict, weakest link.
5. Bundle vs naive list.
6. Persistence export/import; verify cached U* reuse.
7. Structural mutation invalidates signature and forces recompute.

Each section is self-contained so you can re-run selectively.

In [None]:
# 1. Imports & utility helpers
from typing import List, Tuple

import numpy as np

from oscillink import OscillinkLattice
from oscillink.adapters.text import embed_texts

np.random.seed(42)

def cosine_topk(mat: np.ndarray, q: np.ndarray, k: int = 5) -> List[Tuple[int, float]]:
    qn = q / (np.linalg.norm(q) + 1e-12)
    sims = (mat @ qn)
    idx = np.argsort(-sims)[:k]
    return [(int(i), float(sims[i])) for i in idx]

def describe(title: str):
    print(f"\n=== {title} ===")

In [None]:
# 2. Corpus definition (multi-hop chain + distractors)
texts = [
    "Paris is in France.",               # 0
    "France is part of Europe.",         # 1
    "Europe is a continent.",            # 2
    "A continent contains many countries.", # 3 (supports chain context)
    "Tokyo is in Japan.",                # 4 distractor cluster
    "Japan is an island nation.",        # 5 distractor cluster
    "Basketball is a team sport.",       # 6 unrelated
    "The Eiffel Tower is in Paris.",     # 7 supportive to Paris
    "France borders Spain.",             # 8 supportive to France
    "Europe has multiple climates."      # 9 supportive to Europe
]

chain_indices = [0, 1, 2]  # Paris -> France -> Europe

print(f"Total texts: {len(texts)} | Chain nodes: {chain_indices}")
for i, t in enumerate(texts):
    print(f"{i:2d}: {t}")

In [None]:
# 3. Generate embeddings (real model if available, else deterministic hash fallback)
emb = embed_texts(texts)
print("Embeddings shape:", emb.shape)

# Choose query embedding: average of chain endpoints + slight emphasis on middle hop
psi = (emb[chain_indices].mean(axis=0)).astype(np.float32)
psi /= (np.linalg.norm(psi) + 1e-12)
print("Query vector norm:", np.linalg.norm(psi))

# 4. Naive cosine top-k baseline
k = 6
baseline = cosine_topk(emb, psi, k=k)
describe("Naive cosine top-k")
for i, score in baseline:
    print(f"idx={i:2d} score={score:.4f} text={texts[i]}")

In [None]:
# 5. Build lattice and add chain prior
lat = OscillinkLattice(emb, kneighbors=6, lamG=1.0, lamC=0.6, lamQ=4.0, deterministic_k=True)
lat.set_query(psi)
lat.add_chain(chain_indices, lamP=0.25)
settle_info = lat.settle(max_iters=12, tol=1e-3)
rec = lat.receipt()
chain_rec = lat.chain_receipt(chain_indices)
bundle = lat.bundle(k=6)

describe("Lattice receipt summary")
print("ΔH_total:", rec["deltaH_total"], "null_points:", len(rec.get("null_points", [])))
print("ustar iters/res:", rec["meta"].get("ustar_iters"), rec["meta"].get("ustar_res"))

describe("Chain receipt")
print("verdict:", chain_rec["verdict"], "weakest_link:", chain_rec.get("weakest_link"))

# Map bundle indices to texts
bundle_view = [
    {"rank": r+1, "id": b["id"], "score": round(b["score"],4), "align": round(b["align"],4), "text": texts[b["id"]][:50]} 
    for r, b in enumerate(bundle)
]
describe("Bundle top-6")
for row in bundle_view:
    print(row)

In [None]:
# 6. Compare baseline vs bundle coverage for chain
baseline_ids = {i for i,_ in baseline}
bundle_ids = {b['id'] for b in bundle}
chain_set = set(chain_indices)

def coverage(ids):
    return len(ids & chain_set) / len(chain_set)

print("Baseline chain coverage:", coverage(baseline_ids))
print("Bundle chain coverage  :", coverage(bundle_ids))
print("Missed chain nodes baseline:", list(chain_set - baseline_ids))
print("Missed chain nodes bundle  :", list(chain_set - bundle_ids))

In [None]:
# 7. Persistence: export / import and verify cached U* reuse
state = lat.export_state()
from oscillink import OscillinkLattice as OL

lat2 = OL.from_state(state)
# Re-run receipt & bundle; expect 0 new solves if cache metadata preserved.
rec2 = lat2.receipt()
bundle2 = lat2.bundle(k=6)
print("ustar_solves original -> copy:", lat.stats.get("ustar_solves"), "->", lat2.stats.get("ustar_solves"))
print("cache hits copy:", lat2.stats.get("ustar_cache_hits"))
print("state provenance equal:", state["provenance"] == lat2.export_state()["provenance"])

In [None]:
# 8. Structural mutation: change a supporting text and observe signature invalidation
mutated_texts = texts.copy()
mutated_texts[1] = "France is a sovereign state in East Asia."  # incorrect shift
emb_mut = embed_texts(mutated_texts)
lat_mut = OscillinkLattice(emb_mut, kneighbors=6, lamG=1.0, lamC=0.6, lamQ=4.0, deterministic_k=True)
lat_mut.set_query(psi)
lat_mut.add_chain(chain_indices, lamP=0.25)
lat_mut.settle(max_iters=12, tol=1e-3)
rec_mut = lat_mut.receipt()
print("Original deltaH:", rec["deltaH_total"], "Mutated deltaH:", rec_mut["deltaH_total"])
print("Signature changed:", rec["meta"]["signature"]["payload"]["state_sig"] != rec_mut["meta"]["signature"]["payload"]["state_sig"])