# Best Response Correctness Harness
Compare baseline `best_response_exact` vs a future efficient implementation.


In [None]:
import os, sys
from pathlib import Path

def find_repo_root(start_dir: str) -> str:
    cur = Path(start_dir).resolve()
    for _ in range(6):
        if (cur / "liars_poker").is_dir() or (cur / "pyproject.toml").exists():
            return str(cur)
        if cur.parent == cur:
            break
        cur = cur.parent
    return str(Path(start_dir).resolve())

NB_DIR = Path.cwd()
REPO_ROOT = Path(find_repo_root(NB_DIR))
if str(REPO_ROOT) not in sys.path:
    sys.path.insert(0, str(REPO_ROOT))

from liars_poker import GameSpec, Rules, Policy
from liars_poker.policies.random import RandomPolicy
from liars_poker.policies.tabular import TabularPolicy
from liars_poker.policies.commit_once import CommitOnceMixture
from liars_poker.algo.br_exact import best_response_exact as br_baseline

try:
    from liars_poker.algo.br_exact_efficient import best_response_exact as br_candidate
    HAVE_CANDIDATE = True
except Exception as exc:  # noqa: BLE001
    print("Candidate implementation not available yet:", exc)
    br_candidate = None
    HAVE_CANDIDATE = False

from liars_poker.eval.match import eval_both_seats
from liars_poker.infoset import InfoSet


## Comparison Helper
Run both implementations and assert exploitability, state values, and strategy equality.


In [None]:
import math
from typing import Dict, Tuple


def extract_state_values(tab_policy: TabularPolicy) -> Dict[Tuple[int, Tuple[int, ...], Tuple[int, ...]], float]:
    values = {}
    for iset, val in tab_policy.values().items():
        key = (iset.pid, iset.hand, iset.history)
        values[key] = val
    return values


def extract_probs(tab_policy: TabularPolicy) -> Dict[Tuple[int, Tuple[int, ...], Tuple[int, ...]], Dict[int, float]]:
    out = {}
    for iset, dist in tab_policy.probs.items():
        key = (iset.pid, iset.hand, iset.history)
        out[key] = dict(dist)
    return out


def compare_implementations(spec: GameSpec, opp_policy, *, tol=1e-9):
    base_policy, base_br = br_baseline(spec, opp_policy)
    base_policy.bind_rules(base_br.rules)
    base_vals = base_br.state_card_values
    base_exp = base_br.exploitability()

    if not HAVE_CANDIDATE:
        print("Candidate implementation not available; skipping compare.")
        return

    cand_policy, cand_br = br_candidate(spec, opp_policy)
    cand_policy.bind_rules(cand_br.rules)
    cand_vals = cand_br.state_card_values
    cand_exp = cand_br.exploitability()

    # Exploitability check
    assert all(abs(a - b) <= tol for a, b in zip(base_exp, cand_exp)), f"Exploitability mismatch: {base_exp} vs {cand_exp}"

    # State values check (all histories seen by baseline)
    for history, hand_map in base_vals.items():
        for hand, val in hand_map.items():
            cand_val = cand_vals.get(history, {}).get(hand)
            assert cand_val is not None, f"Missing state {history} for hand {hand}"
            assert math.isclose(val, cand_val, rel_tol=0, abs_tol=tol), f"Value mismatch at {history}, hand {hand}: {val} vs {cand_val}"

    # # Strategy check: compare action dists at all infosets present in baseline policy
    # base_probs = extract_probs(base_policy)
    # cand_probs = extract_probs(cand_policy)
    # for key, dist in base_probs.items():
    #     other = cand_probs.get(key)
    #     assert other is not None, f"Missing infoset {key} in candidate policy"
    #     # Normalize both to avoid tiny numeric differences
    #     def norm(d):
    #         s = sum(d.values())
    #         return {a: (p / s if s else 0.0) for a, p in d.items()}
    #     d1 = norm(dist)
    #     d2 = norm(other)
    #     assert set(d1.keys()) == set(d2.keys()), f"Action set mismatch at {key}: {d1.keys()} vs {d2.keys()}"
    #     for a in d1:
    #         assert math.isclose(d1[a], d2[a], rel_tol=0, abs_tol=tol), f"Prob mismatch at {key}, action {a}: {d1[a]} vs {d2[a]}"

    print("All checks passed for spec:", spec)


## Test Scenarios
A suite of specs and opponent policies to stress branches and card removal.


In [None]:

# Opponent policy helpers
class AlwaysCall(TabularPolicy):
    POLICY_KIND = "AlwaysCall"
    def __init__(self):
        super().__init__()
    def action_probs(self, infoset: InfoSet):
        legal = self._legal_actions(infoset)
        if not legal:
            return {}
        if -1 in legal:
            return {-1: 1.0}
        return {min(legal): 1.0}

def make_deterministic_raise(rules):
    class AlwaysRaise(TabularPolicy):
        POLICY_KIND = "AlwaysRaise"
        def action_probs(self, infoset: InfoSet):
            legal = rules.legal_actions_for(infoset)
            raise_only = [a for a in legal if a != -1]
            if not raise_only:
                return {-1: 1.0}
            return {min(raise_only): 1.0}
    return AlwaysRaise()

specs = [
    GameSpec(ranks=2, suits=1, hand_size=1, claim_kinds=("RankHigh",)),
    GameSpec(ranks=3, suits=1, hand_size=1, claim_kinds=("RankHigh",)),
    GameSpec(ranks=3, suits=2, hand_size=1, claim_kinds=("RankHigh", "Pair")),
    GameSpec(ranks=3, suits=2, hand_size=2, claim_kinds=("RankHigh", "Pair")),
    GameSpec(ranks=4, suits=4, hand_size=1, claim_kinds=("RankHigh", "Pair"), suit_symmetry=True),
    GameSpec(ranks=4, suits=4, hand_size=2, claim_kinds=("RankHigh", "Pair"), suit_symmetry=True),
]

opp_policies = []
for spec in specs:
    r = RandomPolicy(); r.bind_rules(Rules(spec))
    c = AlwaysCall(); c.bind_rules(Rules(spec))
    ar = make_deterministic_raise(Rules(spec)); ar.bind_rules(Rules(spec))
    # Mixed: 70% random, 30% deterministic raise via CommitOnceMixture
    mix = CommitOnceMixture([r, ar], [0.7, 0.3]); mix.bind_rules(Rules(spec))
    opp_policies.append((spec, [r, c, ar, mix]))

for spec, policies in opp_policies:
    print("Running spec", spec)
    for opp in policies:
        opp.begin_episode()
        compare_implementations(spec, opp, tol=1e-9)



In [None]:
spec

In [None]:
opp

In [None]:
# import numpy as np
# from liars_poker.core import GameSpec
# from liars_poker.infoset import InfoSet
# from liars_poker.algo.br_exact_efficient import BestResponseComputerEfficient  # adjust import if needed
# from liars_poker.core import possible_starting_hands

# spec = GameSpec(ranks=2, suits=1, hand_size=1, claim_kinds=('RankHigh',), suit_symmetry=False)
# opp = AlwaysCall()
# opp.bind_rules(Rules(spec))

# # Baseline-ish check: what does the policy say at root for each hand?
# hands = tuple(possible_starting_hands(spec))
# print("hands:", hands)

# # Make the candidate BR object (this constructs the tabular fast-path index)
# br = BestResponseComputerEfficient(spec, opp)

# print("Type(opp):", type(opp))
# print("len(opp.probs) at construction time:", len(getattr(opp, "probs", {})))
# print("Candidate has _opp_tabular_index:", br._opp_tabular_index is not None)
# if br._opp_tabular_index is not None:
#     # How many (pid,history) keys actually exist in the index?
#     print("Opp tabular index keys:", len(br._opp_tabular_index))

# # What actions does candidate think are legal at root?
# actions = br._legal_actions_from_history(tuple())
# print("Candidate root legal actions:", actions)

# # What does the opponent say at root if we query via the proper interface?
# for h in hands:
#     d = opp.prob_dist_at_infoset(InfoSet(pid=0, hand=h, history=tuple()))
#     print(f"opp.prob_dist(pid=0, hand={h}, hist=()):", d)

# # What strategy matrix does the candidate build?
# S = br._get_strategy_matrix(opp, pid=0, history=tuple(), actions=actions)
# print("S shape:", S.shape)
# print("First few rows of S:")
# print(S[:min(5, len(hands))])
# print("Row sums (first few):", S[:min(5, len(hands))].sum(axis=1))

# # Check: does S match prob_dist_at_infoset for each hand?
# max_abs_diff = 0.0
# for j, h in enumerate(hands):
#     d = opp.prob_dist_at_infoset(InfoSet(pid=0, hand=h, history=tuple()))
#     row = np.array([d.get(a, 0.0) for a in actions], dtype=float)
#     if row.sum() > 0: row /= row.sum()
#     max_abs_diff = max(max_abs_diff, float(np.max(np.abs(S[j] - row))))
# print("max abs diff between S rows and prob_dist rows:", max_abs_diff)


In [None]:
# import numpy as np
# from liars_poker.env import resolve_call_winner
# from liars_poker.infoset import CALL

# def check_terminal_consistency(spec):
#     opp = AlwaysCall()
#     br = BestResponseComputerEfficient(spec, opp)
#     rules = br.rules
#     hands = br.hands

#     root_actions = br._legal_actions_from_history(tuple())
#     claim_actions = [a for a in root_actions if a != CALL]
#     if not claim_actions:
#         print("No non-call root actions; skipping.")
#         return

#     # Test each claim action as P1 then P2 calls: history = (claim, CALL)
#     for a in claim_actions:
#         history = (a, CALL)

#         # Candidate's terminal uses: truth = (counts_i[r] + counts_j[r] >= need)
#         # We'll compare to resolve_call_winner for all hand pairs.
#         last_claim_idx = InfoSet.last_claim_idx(history[:-1])
#         req = br._claim_reqs[last_claim_idx]
#         r, need = req.rank, req.need
#         c = br.hand_rank_counts[:, r]
#         T = (c[:, None] + c[None, :]) >= need  # candidate "truth" matrix

#         for i, hi in enumerate(hands):
#             for j, hj in enumerate(hands):
#                 # Baseline truth of claim inferred from winner:
#                 # P1 made claim, P2 called. If claim true => P1 wins; else P2 wins.
#                 winner = resolve_call_winner(spec, history, hi, hj)  # hi=P1, hj=P2
#                 baseline_truth = (winner == 'P1')
#                 if bool(T[i, j]) != bool(baseline_truth):
#                     print("Mismatch!")
#                     print("spec:", spec)
#                     print("claim action:", a, "decoded (rank,need)=", (r, need))
#                     print("P1 hand:", hi, "P2 hand:", hj)
#                     print("candidate truth:", bool(T[i, j]))
#                     print("baseline winner:", winner, "=> baseline_truth:", baseline_truth)
#                     return

#     print("Terminal check passed for one-claim call histories for spec:", spec)

# # Example: one of your problematic specs
# spec2 = GameSpec(ranks=3, suits=2, hand_size=1, claim_kinds=('RankHigh','Pair'), suit_symmetry=False)
# check_terminal_consistency(spec2)


In [None]:
# spec = GameSpec(ranks=2, suits=1, hand_size=1, claim_kinds=('RankHigh',), suit_symmetry=False)
# opp = AlwaysCall()
# opp.bind_rules(Rules(spec))
# hands = tuple(possible_starting_hands(spec))

# print("initial len(opp.probs):", len(opp.probs))
# for h in hands:
#     d = opp.prob_dist_at_infoset(InfoSet(pid=0, hand=h, history=tuple()))
# print("after queries len(opp.probs):", len(opp.probs))
# print("sample keys:", list(opp.probs.keys())[:5])


In [None]:
# spec = GameSpec(ranks=3, suits=2, hand_size=1, claim_kinds=('RankHigh','Pair'), suit_symmetry=False)
# rules = Rules(spec)  # or rules_for_spec(spec) depending on your project
# opp = make_deterministic_raise(rules)

# br = BestResponseComputerEfficient(spec, opp)
# actions = br._legal_actions_from_history(tuple())

# print("len(opp.probs) at construction:", len(getattr(opp, "probs", {})))
# print("actions at root:", actions)

# for h in br.hands[:min(5, br.n_hands)]:
#     d = opp.prob_dist_at_infoset(InfoSet(pid=0, hand=h, history=tuple()))
#     print("opp dist example:", d)
#     break

# S = br._get_strategy_matrix(opp, pid=0, history=tuple(), actions=actions)
# print("S row[0]:", S[0])
# print("Is uniform?", np.allclose(S[0], np.ones_like(S[0]) / len(actions)))


In [None]:
# base_policy, base_br = br_baseline(spec, opp)
# base_policy.bind_rules(base_br.rules)
# base_vals = base_br.state_card_values
# base_exp = base_br.exploitability()
# base_br.state_card_values

In [None]:
# cand_policy, cand_br = br_candidate(spec, opp)
# cand_policy.bind_rules(cand_br.rules)
# cand_vals = cand_br.state_card_values
# cand_exp = cand_br.exploitability()
# cand_br.state_card_values

In [None]:
# base_probs = extract_probs(base_policy)
# cand_probs = extract_probs(cand_policy)


## Empirical Seat-by-Seat Comparison
Compare sampled win rates of baseline vs candidate best responses across seats for each spec/opponent, and check if differences are statistically insignificant.


In [None]:

import math
from collections import namedtuple

WinStats = namedtuple('WinStats', 'p1 p2')

def seat_eval(spec, br_policy_fn, opp_policy, episodes=500, seed=1234):
    policy, _ = br_policy_fn(spec, opp_policy)
    policy.bind_rules(Rules(spec))
    a = eval_both_seats(spec, policy, opp_policy, episodes=episodes, seed=seed)
    b = eval_both_seats(spec, opp_policy, policy, episodes=episodes, seed=seed + 1)
    return WinStats(p1=a['P1'], p2=b['P2'])

def chi2_two_props(x1, n1, x2, n2):
    # Chi-square test for difference in proportions (df=1); returns p-value
    if n1 <= 0 or n2 <= 0:
        return float('nan')
    p_pool = (x1 + x2) / (n1 + n2)
    if p_pool in (0, 1):
        return 1.0
    exp1_s, exp1_f = n1 * p_pool, n1 * (1 - p_pool)
    exp2_s, exp2_f = n2 * p_pool, n2 * (1 - p_pool)
    chi2 = 0.0
    chi2 += (x1 - exp1_s) ** 2 / exp1_s
    chi2 += (n1 - x1 - exp1_f) ** 2 / exp1_f
    chi2 += (x2 - exp2_s) ** 2 / exp2_s
    chi2 += (n2 - x2 - exp2_f) ** 2 / exp2_f
    try:
        import scipy.stats as stats
        return 1 - stats.chi2.cdf(chi2, 1)
    except Exception:
        return float('nan')


In [None]:

# Run empirical comparison for seat-by-seat win rates
EPISODES = 800
results = []
for spec, policies in opp_policies:
    for opp in policies:
        opp.begin_episode()
        base_stats = seat_eval(spec, br_baseline, opp, episodes=EPISODES, seed=111)
        if HAVE_CANDIDATE:
            cand_stats = seat_eval(spec, br_candidate, opp, episodes=EPISODES, seed=222)
            p1_p = chi2_two_props(base_stats.p1 * EPISODES, EPISODES, cand_stats.p1 * EPISODES, EPISODES)
            p2_p = chi2_two_props(base_stats.p2 * EPISODES, EPISODES, cand_stats.p2 * EPISODES, EPISODES)
            results.append((spec, opp.__class__.__name__, base_stats, cand_stats, p1_p, p2_p))
        else:
            results.append((spec, opp.__class__.__name__, base_stats, None, None, None))

def fmt_p(p):
    if p is None or math.isnan(p):
        return "n/a"
    flag = " **LOW**" if p < 0.05 else ""
    return f"{p:.4f}{flag}"

for row in results:
    spec, opp_name, base_stats, cand_stats, p1_p, p2_p = row
    print(f"Spec={spec}, Opp={opp_name}")
    print(f"  Base: P1 win={base_stats.p1:.3f}, P2 win={base_stats.p2:.3f}")
    if cand_stats:
        print(f"  Cand: P1 win={cand_stats.p1:.3f}, P2 win={cand_stats.p2:.3f}")
        print(f"  chi2 p-values: P1 seat={fmt_p(p1_p)}, P2 seat={fmt_p(p2_p)}")
    else:
        print("  Candidate not available; skipped")
    print()


In [None]:
pol_, _ = br_candidate(spec, c)
_.exploitability()

## Performance Benchmark
Compare baseline vs candidate on a non-trivial spec.


In [None]:
import time

spec_bench = GameSpec(ranks=7, suits=4, hand_size=2, claim_kinds=("RankHigh", "Pair"), suit_symmetry=True)
rules_bench = Rules(spec_bench)
opp_bench = RandomPolicy(); opp_bench.bind_rules(rules_bench)

if HAVE_CANDIDATE:

    # Benchmark
    def run_baseline():
        br_baseline(spec_bench, opp_bench)
    def run_candidate():
        br_candidate(spec_bench, opp_bench)

    import timeit
    base_time = timeit.timeit(run_baseline, number=1)
    cand_time = timeit.timeit(run_candidate, number=1)
    print(f"Baseline time: {base_time:.4f}s, Candidate time: {cand_time:.4f}s, speedup: {base_time / cand_time if cand_time else float('inf'):.2f}x")
else:
    print("Candidate implementation not available; benchmark skipped.")
