# Serialization Tests
Verify the new serializer for simple and nested policies.


In [2]:
import os, sys, shutil, random
from pathlib import Path

def find_repo_root(start_dir: str) -> str:
    cur = Path(start_dir).resolve()
    for _ in range(6):
        if (cur / "liars_poker").is_dir() or (cur / "pyproject.toml").exists():
            return str(cur)
        if cur.parent == cur:
            break
        cur = cur.parent
    return str(Path(start_dir).resolve())

NB_DIR = Path.cwd()
REPO_ROOT = Path(find_repo_root(NB_DIR))
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

ARTIFACTS_ROOT = REPO_ROOT / "artifacts"
ARTIFACTS_ROOT.mkdir(parents=True, exist_ok=True)


from liars_poker import GameSpec, Rules
from liars_poker.policies.tabular import TabularPolicy
from liars_poker.policies.random import RandomPolicy
from liars_poker.policies.commit_once import CommitOnceMixture
from liars_poker.serialization import save_policy, load_policy
from liars_poker.infoset import InfoSet

# --- 3. Set up Test Directory relative to REPO ROOT ---
# This ensures it goes to /YourRepo/artifacts/..., not /YourRepo/notebooks/tests/artifacts/...
TEST_ROOT = REPO_ROOT / "artifacts" / "test" / "serialization"

# Clean and recreate
if TEST_ROOT.exists():
    shutil.rmtree(TEST_ROOT)
TEST_ROOT.mkdir(parents=True, exist_ok=True)

print(f"Repo Root found at: {REPO_ROOT}")
print(f"Test Artifacts prepared at: {TEST_ROOT}")


Repo Root found at: C:\Users\adidh\Documents\liars_poker
Test Artifacts prepared at: C:\Users\adidh\Documents\liars_poker\artifacts\test\serialization


## Test 1: Simple Tabular Policy
Create a tabular policy with three infosets, save, load, and assert equality.


In [3]:
# Build spec and rules
spec = GameSpec(ranks=3, suits=1, hand_size=1, claim_kinds=("RankHigh",), suit_symmetry=False)
rules = Rules(spec)

# Three distinct infosets
i1 = InfoSet(pid=0, hand=(1,), history=())
i2 = InfoSet(pid=1, hand=(2,), history=(0,))
i3 = InfoSet(pid=0, hand=(3,), history=(0, 1))

# Helper to pick legal then assign custom probs
pol = TabularPolicy(); pol.bind_rules(rules)

# i1: opening, claims only
legal1 = rules.legal_actions_for(i1)
pol.set(i1, {legal1[0]: 0.7, legal1[1]: 0.2, legal1[2]: 0.1})
pol.set_annotations(values={i1: 0.5})

# i2: after one claim, include CALL if legal
legal2 = rules.legal_actions_for(i2)
pol.set(i2, {legal2[0]: 0.1, legal2[-1]: 0.9})
pol._state_value[i2] = 0.2

# i3: deeper history
legal3 = rules.legal_actions_for(i3)
pol.set(i3, {legal3[0]: 1.0})
pol._state_value[i3] = -0.1

save_dir = TEST_ROOT / 'simple'
save_dir.mkdir(parents=True, exist_ok=True)
save_policy(pol, save_dir)

loaded_pol, loaded_spec = load_policy(save_dir)
loaded_rules = Rules(loaded_spec)
loaded_pol.bind_rules(loaded_rules)

assert loaded_pol.probs == pol.probs, 'Probs mismatch'
assert loaded_pol.values() == pol.values(), 'Values mismatch'
assert loaded_pol.visits() == pol.visits(), 'Visits mismatch'
print('Test 1 passed: TabularPolicy round-trip exact match.')


Test 1 passed: TabularPolicy round-trip exact match.


## Test 2: Complex Nested Mixture
Create three distinct children (A, B, Random), save mixture, inspect files, load, and verify ordering/data.


In [4]:
# Spec and rules for mixture
mix_spec = GameSpec(ranks=3, suits=2, hand_size=1, claim_kinds=("RankHigh", "Pair"), suit_symmetry=True)
mix_rules = Rules(mix_spec)

# Child A (Tabular) with Strategy A
a = TabularPolicy(); a.bind_rules(mix_rules)
hand_a = (1,)
iset_a = InfoSet(pid=0, hand=hand_a, history=())
leg_a = mix_rules.legal_actions_for(iset_a)
a.set(iset_a, {leg_a[0]: 0.8, leg_a[1]: 0.2})

# Child B (Tabular) with Strategy B
actions_b_hand = (2,)
istet_b = InfoSet(pid=0, hand=actions_b_hand, history=())
b = TabularPolicy(); b.bind_rules(mix_rules)
leg_b = mix_rules.legal_actions_for(istet_b)
b.set(istet_b, {leg_b[0]: 0.1, leg_b[1]: 0.9})

# Child C (Random)
c = RandomPolicy(); c.bind_rules(mix_rules)

weights = [0.5, 0.3, 0.2]
mix = CommitOnceMixture([a, b, c], weights)
mix.bind_rules(mix_rules)

mix_dir = TEST_ROOT / 'mixture'
mix_dir.mkdir(parents=True, exist_ok=True)
save_policy(mix, mix_dir)

# Inspect file tree (flat expected)
for root, dirs, files in os.walk(mix_dir):
    rel_root = Path(root).relative_to(mix_dir)
    print(rel_root if str(rel_root) != '.' else '.')
    for f in sorted(files):
        print('  file:', f)
    for d in sorted(dirs):
        print('  dir :', d)

loaded_mix, loaded_spec = load_policy(mix_dir)
loaded_mix.bind_rules(Rules(loaded_spec))

# Verify weights
assert [round(w, 3) for w in loaded_mix.weights] == [0.5, 0.3, 0.2], 'Weights mismatch'

# Verify children order and data
loaded_children = list(loaded_mix.policies)
assert len(loaded_children) == 3, 'Child count mismatch'

loaded_a = loaded_children[0]
loaded_b = loaded_children[1]

assert isinstance(loaded_a, TabularPolicy)
assert isinstance(loaded_b, TabularPolicy)

# Compare strategy distributions
assert loaded_a.probs[iset_a] == a.probs[iset_a], 'Child A distribution mismatch'
assert loaded_b.probs[istet_b] == b.probs[istet_b], 'Child B distribution mismatch'

print('Test 2 passed: Mixture round-trip preserved weights and child data.')


.
  file: blobs.npz
  file: metadata.json
Test 2 passed: Mixture round-trip preserved weights and child data.
