In [8]:
"""
Bayesian (particle / ABC) seed inference + next-30-roll prediction
for Battle Cats–style deterministic RNG (xorshift <<13, >>17, <<15),

Reads banner paste from a local file: paste.txt

Usage:
  1) Save your banner paste into paste.txt (same folder as this script)
  2) Edit observed_rolls below (your 4 rolls)
  3) Run: python bc_bayes_seed.py
"""

from __future__ import annotations

import math
import random
import re
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Sequence, Tuple
from collections import Counter
from pathlib import Path


# -----------------------------
# Banner model + parser
# -----------------------------

@dataclass(frozen=True)
class Banner:
    rate_cumsum: List[int]               # e.g. [6970, 9470, 9970, 10000]
    units_by_rarity: List[List[str]]     # [Rare[], Super[], Uber[], Legendary[]]
    rerollable_rarities: List[int]       # usually [0] for Rare

RARITY_ORDER = ["Rare", "Super", "Uber", "Legendary"]

def _parse_units_with_indices(units_str: str) -> List[str]:
    """
    Parse "0 Wushu Cat, 1 Matador Cat, 2 Rover Cat, ..." into list ordered by index.
    """
    units_str = units_str.strip()
    if not units_str:
        return []

    pattern = re.compile(r"(\d+)\s+([^,]+)(?:,|$)")
    found = pattern.findall(units_str)
    if not found:
        return []

    idx_to_name: Dict[int, str] = {}
    for idx_s, name in found:
        idx_to_name[int(idx_s)] = name.strip()

    return [idx_to_name[i] for i in sorted(idx_to_name.keys())]

def parse_banner_paste_real(paste: str) -> Banner:
    """
    Parses lines like:
      Rare: 69.7% (25 cats) 0 Wushu Cat, 1 Matador Cat, ...
    The paste can include headers/indentation; we scan and extract the relevant lines.
    """
    lines = [ln.strip() for ln in paste.splitlines() if ln.strip()]
    relevant: Dict[str, Dict[str, Any]] = {}

    line_regex = re.compile(
        r"^(?P<rarity>Rare|Super|Uber|Legendary):\s+"
        r"(?P<rate>[\d.]+)%\s+\(\d+\s+cats\)\s*(?P<units>.*)$"
    )

    for ln in lines:
        m = line_regex.match(ln)
        if not m:
            continue
        rarity = m.group("rarity")
        rate_pct = float(m.group("rate"))
        units_str = m.group("units").strip()

        slots = int(round(rate_pct * 100))  # percent -> slots out of 10000
        units = _parse_units_with_indices(units_str)

        relevant[rarity] = {"slots": slots, "units": units}

    slots_list: List[int] = []
    units_by_rarity: List[List[str]] = []
    for r in RARITY_ORDER:
        if r not in relevant:
            slots_list.append(0)
            units_by_rarity.append([])
        else:
            slots_list.append(int(relevant[r]["slots"]))
            units_by_rarity.append(list(relevant[r]["units"]))

    cumsum: List[int] = []
    s = 0
    for v in slots_list:
        s += v
        cumsum.append(s)

    # Force final threshold to 10000; clamp earlier thresholds to stay monotone
    if cumsum:
        cumsum[-1] = 10000
        for i in range(len(cumsum) - 2, -1, -1):
            cumsum[i] = min(cumsum[i], cumsum[i + 1])

    # Matches your JS worker behavior: immediate-dupe reroll only for Rare
    return Banner(rate_cumsum=cumsum, units_by_rarity=units_by_rarity, rerollable_rarities=[0])


# -----------------------------
# RNG + roll simulation (matches your JS worker)
# -----------------------------

def advance_seed(seed: int) -> int:
    seed &= 0xFFFFFFFF
    seed ^= ((seed << 13) & 0xFFFFFFFF)
    seed ^= (seed >> 17)
    seed ^= ((seed << 15) & 0xFFFFFFFF)
    return seed & 0xFFFFFFFF

def get_rarity(seed: int, rate_cumsum: Sequence[int]) -> int:
    x = seed % 10000
    for i, thr in enumerate(rate_cumsum):
        if x < thr:
            return i
    return len(rate_cumsum) - 1

def get_unit(seed: int, units: Sequence[str], removed_index: int = -1) -> Tuple[int, str]:
    n = len(units)
    if n == 0:
        return -1, "<EMPTY_POOL>"
    if removed_index == -1:
        idx = seed % n
        return idx, units[idx]
    if n <= 1:
        return -1, "<EMPTY_POOL>"
    idx = seed % (n - 1)
    if idx >= removed_index:
        idx += 1
    return idx, units[idx]

def simulate_rolls(seed0: int, banner: Banner, n_rolls: int, *, apply_dupe_reroll: bool = True) -> Tuple[List[str], int]:
    seed = seed0 & 0xFFFFFFFF
    out: List[str] = []
    last_roll: Optional[str] = None

    for _ in range(n_rolls):
        seed = advance_seed(seed)
        rarity = get_rarity(seed, banner.rate_cumsum)

        seed = advance_seed(seed)
        units = banner.units_by_rarity[rarity]
        unit_idx, unit_id = get_unit(seed, units)

        if (
            apply_dupe_reroll
            and last_roll is not None
            and unit_id == last_roll
            and rarity in banner.rerollable_rarities
        ):
            seed = advance_seed(seed)
            _, rerolled_id = get_unit(seed, units, removed_index=unit_idx)
            out.append(rerolled_id)
            last_roll = rerolled_id
        else:
            out.append(unit_id)
            last_roll = unit_id

    return out, seed


# -----------------------------
# Bayesian posterior over seeds (particles / ABC)
# -----------------------------

def log_likelihood(predicted: Sequence[str], observed: Sequence[str], *, epsilon: float, banner: Banner) -> float:
    if len(predicted) != len(observed):
        raise ValueError("predicted and observed length mismatch")
    if not (0.0 <= epsilon < 1.0):
        raise ValueError("epsilon must be in [0,1)")

    all_units = [u for pool in banner.units_by_rarity for u in pool]
    q = 1.0 / max(1, len(all_units))

    ll = 0.0
    for p, y in zip(predicted, observed):
        if p == y:
            ll += math.log(max(1e-300, 1.0 - epsilon))
        else:
            ll += math.log(max(1e-300, epsilon * q))
    return ll

def sample_posterior_seeds(
    banner: Banner,
    observed_rolls: Sequence[str],
    *,
    num_particles: int = 120_000,
    epsilon: float = 0.03,
    rng_seed: int = 12345,
    apply_dupe_reroll: bool = True,
) -> Tuple[List[int], List[float]]:
    random.seed(rng_seed)

    m = len(observed_rolls)
    seeds: List[int] = []
    logw: List[float] = []

    for _ in range(num_particles):
        s0 = random.getrandbits(32) or 1

        pred, _ = simulate_rolls(s0, banner, m, apply_dupe_reroll=apply_dupe_reroll)

        if epsilon == 0.0:
            if pred == list(observed_rolls):
                seeds.append(s0)
                logw.append(0.0)
        else:
            seeds.append(s0)
            logw.append(log_likelihood(pred, observed_rolls, epsilon=epsilon, banner=banner))

    if not seeds:
        return [], []

    mx = max(logw)
    w = [math.exp(lw - mx) for lw in logw]
    s = sum(w)
    w = [wi / s for wi in w] if s else [1.0 / len(w)] * len(w)
    return seeds, w

def resample_multinomial(seeds: Sequence[int], weights: Sequence[float], n: int, *, rng_seed: int) -> List[int]:
    random.seed(rng_seed)
    if not seeds:
        return []

    cdf: List[float] = []
    s = 0.0
    for w in weights:
        s += w
        cdf.append(s)

    out: List[int] = []
    for _ in range(n):
        r = random.random()
        lo, hi = 0, len(cdf) - 1
        while lo < hi:
            mid = (lo + hi) // 2
            if r <= cdf[mid]:
                hi = mid
            else:
                lo = mid + 1
        out.append(seeds[lo])
    return out


# -----------------------------
# Predictive distribution for next K rolls
# -----------------------------

@dataclass
class PredictiveSummary:
    per_position_unit_probs: List[Dict[str, float]]
    expected_rarity_counts: Dict[str, float]

def predictive_next_k(
    banner: Banner,
    observed_rolls: Sequence[str],
    *,
    k: int = 30,
    num_particles: int = 120_000,
    epsilon: float = 0.03,
    posterior_paths: int = 10_000,
    rng_seed: int = 12345,
    apply_dupe_reroll: bool = True,
) -> PredictiveSummary:
    seeds, weights = sample_posterior_seeds(
        banner,
        observed_rolls,
        num_particles=num_particles,
        epsilon=epsilon,
        rng_seed=rng_seed,
        apply_dupe_reroll=apply_dupe_reroll,
    )
    if not seeds:
        raise RuntimeError("No posterior mass. Increase epsilon or num_particles.")

    posterior_seeds = resample_multinomial(seeds, weights, posterior_paths, rng_seed=rng_seed + 1)

    unit_to_rarity: Dict[str, int] = {}
    for ri, pool in enumerate(banner.units_by_rarity):
        for u in pool:
            unit_to_rarity.setdefault(u, ri)

    m = len(observed_rolls)
    per_pos_counts: List[Counter] = [Counter() for _ in range(k)]
    rarity_sum = Counter()

    for s0 in posterior_seeds:
        full, _ = simulate_rolls(s0, banner, m + k, apply_dupe_reroll=apply_dupe_reroll)
        future = full[m:]

        for i, u in enumerate(future):
            per_pos_counts[i][u] += 1
            ri = unit_to_rarity.get(u, None)
            if ri is not None:
                rarity_sum[ri] += 1

    n_paths = len(posterior_seeds)

    per_position_probs: List[Dict[str, float]] = []
    for c in per_pos_counts:
        total = sum(c.values())
        per_position_probs.append({u: cnt / total for u, cnt in c.items()} if total else {})

    expected_rarity_counts: Dict[str, float] = {}
    for ri, tot in rarity_sum.items():
        expected_rarity_counts[RARITY_ORDER[ri]] = tot / n_paths

    return PredictiveSummary(per_position_unit_probs=per_position_probs, expected_rarity_counts=expected_rarity_counts)


# -----------------------------
# Pretty printing
# -----------------------------

def topk(probs: Dict[str, float], k: int = 5) -> List[Tuple[str, float]]:
    return sorted(probs.items(), key=lambda x: x[1], reverse=True)[:k]

def print_summary(summary: PredictiveSummary, *, top_k: int = 5, positions: int = 10) -> None:
    print("expected rarity counts over next horizon:")
    for r in RARITY_ORDER:
        if r in summary.expected_rarity_counts:
            print(f"  {r:10s}: {summary.expected_rarity_counts[r]:.2f}")
    print()
    print("per position most likely units:")
    for i, probs in enumerate(summary.per_position_unit_probs[:positions], start=1):
        best = topk(probs, top_k)
        s = ", ".join([f"{u} ({p:.2%})" for u, p in best])
        print(f"  Roll+{i:02d}: {s}")


# -----------------------------
# Main: load paste from paste.txt
# -----------------------------

def load_paste_file(path: str = "paste.txt") -> str:
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(
            f"Could not find {path!r}. Put your banner paste in a file named paste.txt "
            f"in the same folder as this script (or pass a full path)."
        )
    return p.read_text(encoding="utf-8", errors="replace")

if __name__ == "__main__":
    banner_paste = load_paste_file("paste.txt")
    banner = parse_banner_paste_real(banner_paste)

    # Edit these with your observed rolls (must match unit names in the paste exactly)
    observed_rolls = ["Viking Cat", "Tin Cat", "Fortune Teller Cat", "Archer Cat", "Bishop Cat", "Gold Cat"]

    summary = predictive_next_k(
        banner,
        observed_rolls,
        k=30,
        num_particles=300_000,
        epsilon=0.03,
        posterior_paths=10_000,
        rng_seed=12345,
        apply_dupe_reroll=True,
    )

    print_summary(summary, top_k=6, positions=10)


expected rarity counts over next horizon:
  Rare      : 20.71
  Super     : 7.66
  Uber      : 1.49
  Legendary : 0.14

per position most likely units:
  Roll+01: Gardener Cat (10.63%), Pogo Cat (5.45%), Wheel Cat (5.31%), Cat Gunslinger (5.22%), Welterweight Cat (4.77%), Fortune Teller Cat (4.03%)
  Roll+02: Viking Cat (8.68%), Cat Gunslinger (7.28%), Tin Cat (6.66%), Salon Cat (5.38%), Matador Cat (5.22%), Fortune Teller Cat (5.10%)
  Roll+03: Bishop Cat (7.81%), Pirate Cat (6.40%), Thief Cat (5.07%), Salon Cat (5.06%), Swordsman Cat (5.04%), Wheel Cat (4.19%)
  Roll+04: Wheel Cat (6.65%), Stilts Cat (6.58%), Wushu Cat (6.07%), Mer-Cat (5.39%), Welterweight Cat (5.28%), Bishop Cat (5.06%)
  Roll+05: Tin Cat (7.51%), Cat Gunslinger (6.36%), Wheel Cat (5.20%), Gardener Cat (4.81%), Rocker Cat (4.05%), Nerd Cat (3.95%)
  Roll+06: Onmyoji Cat (10.19%), Welterweight Cat (5.22%), Fencer Cat (5.10%), Archer Cat (4.28%), Mer-Cat (4.26%), Swordsman Cat (3.99%)
  Roll+07: Shaman Cat (6.42%), S

In [17]:
def simulate_n(seed0: int, banner, n: int):
    seed = seed0 & 0xFFFFFFFF
    out = []
    last = None

    for _ in range(n):
        seed = advance_seed(seed)
        rarity = get_rarity(seed, banner.rate_cumsum)

        seed = advance_seed(seed)
        units = banner.units_by_rarity[rarity]
        unit_idx, unit_id = get_unit(seed, units)

        if last is not None and unit_id == last and rarity in banner.rerollable_rarities:
            seed = advance_seed(seed)
            _, rerolled = get_unit(seed, units, removed_index=unit_idx)
            out.append(rerolled)
            last = rerolled
        else:
            out.append(unit_id)
            last = unit_id

    return out, seed

# --- your case ---
seed_before = 659650523
observed = ["Viking Cat", "Tin Cat", "Fortune Teller Cat", "Archer Cat", "Bishop Cat", "Gold Cat", "Pirate Cat"]

pred6, seed_after6 = simulate_n(seed_before, banner, 7)
print("matches observed?", pred6 == observed)
print("Seed after 7:", seed_after6)  # should be 3183349844

pred7, seed_after7 = simulate_n(seed_before, banner, 8)
print("Next cat:", pred7[-1])        # should be Pirate Cat
print("Seed after 8:", seed_after7)


matches observed? True
Seed after 7: 2099626319
Next cat: Kasa Jizo
Seed after 8: 844176769


In [18]:
from __future__ import annotations

import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple
from collections import Counter

# Import the cython extension (build first!)
import seed_seeker

RARITY_ORDER = ["Rare", "Super", "Uber", "Legendary"]

@dataclass(frozen=True)
class Banner:
    rate_cumsum: List[int]
    units_by_rarity: List[List[str]]
    rerollable_rarities: List[int]

def _parse_units_with_indices(units_str: str) -> List[str]:
    units_str = units_str.strip()
    if not units_str:
        return []
    pat = re.compile(r"(\d+)\s+([^,]+)(?:,|$)")
    found = pat.findall(units_str)
    idx_to_name: Dict[int, str] = {int(i): name.strip() for i, name in found}
    return [idx_to_name[i] for i in sorted(idx_to_name)]

def parse_banner_paste(paste: str) -> Banner:
    lines = [ln.strip() for ln in paste.splitlines() if ln.strip()]
    relevant: Dict[str, Dict[str, Any]] = {}

    line_regex = re.compile(
        r"^(?P<rarity>Rare|Super|Uber|Legendary):\s+"
        r"(?P<rate>[\d.]+)%\s+\(\d+\s+cats\)\s*(?P<units>.*)$"
    )

    for ln in lines:
        m = line_regex.match(ln)
        if not m:
            continue
        rarity = m.group("rarity")
        rate_pct = float(m.group("rate"))
        slots = int(round(rate_pct * 100))
        units = _parse_units_with_indices(m.group("units"))
        relevant[rarity] = {"slots": slots, "units": units}

    slots_list, pools = [], []
    for r in RARITY_ORDER:
        if r not in relevant:
            slots_list.append(0)
            pools.append([])
        else:
            slots_list.append(relevant[r]["slots"])
            pools.append(relevant[r]["units"])

    cumsum = []
    s = 0
    for v in slots_list:
        s += v
        cumsum.append(s)

    # Fix rounding drift
    cumsum[-1] = 10000
    for i in range(len(cumsum)-2, -1, -1):
        cumsum[i] = min(cumsum[i], cumsum[i+1])

    return Banner(rate_cumsum=cumsum, units_by_rarity=pools, rerollable_rarities=[0])

def load_paste(path: str = "paste.txt") -> str:
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"Missing {path}")
    return p.read_text(encoding="utf-8", errors="replace")

# ---- deterministic simulator in Python for prediction (fast enough once seeds known) ----

def advance_seed(seed: int) -> int:
    seed &= 0xFFFFFFFF
    seed ^= ((seed << 13) & 0xFFFFFFFF)
    seed ^= (seed >> 17)
    seed ^= ((seed << 15) & 0xFFFFFFFF)
    return seed & 0xFFFFFFFF

def get_rarity(seed: int, rate_cumsum: Sequence[int]) -> int:
    x = seed % 10000
    for i, thr in enumerate(rate_cumsum):
        if x < thr:
            return i
    return len(rate_cumsum) - 1

def get_unit_index(seed: int, n: int, removed_index: int = -1) -> int:
    if n <= 0:
        return -1
    if removed_index < 0:
        return seed % n
    if n <= 1:
        return -1
    idx = seed % (n - 1)
    if idx >= removed_index:
        idx += 1
    return idx

def simulate_next_k(seed0: int, banner_units_int: List[List[int]], rate_cumsum: List[int], rerollable: List[int],
                    observed_len: int, k: int) -> List[int]:
    seed = seed0 & 0xFFFFFFFF
    last = None
    out: List[int] = []

    reroll_set = set(rerollable)

    for _ in range(observed_len + k):
        seed = advance_seed(seed)
        rarity = get_rarity(seed, rate_cumsum)

        seed = advance_seed(seed)
        pool = banner_units_int[rarity]
        n = len(pool)
        idx = get_unit_index(seed, n, -1)
        unit_id = pool[idx] if idx >= 0 else -1

        if last is not None and rarity in reroll_set and unit_id == last:
            seed = advance_seed(seed)
            idx2 = get_unit_index(seed, n, idx)
            unit_id = pool[idx2] if idx2 >= 0 else -1

        out.append(unit_id)
        last = unit_id

    return out[observed_len:]

# ---- Bayes over exact seeds + prediction mixture ----

def posterior_over_exact_seeds(exact_seeds: List[int], prior: Dict[int, float] | None = None) -> Dict[int, float]:
    if not exact_seeds:
        return {}
    if prior is None:
        w = 1.0 / len(exact_seeds)
        return {s: w for s in exact_seeds}
    # restrict + renormalize
    total = sum(prior.get(s, 0.0) for s in exact_seeds)
    if total <= 0:
        w = 1.0 / len(exact_seeds)
        return {s: w for s in exact_seeds}
    return {s: prior.get(s, 0.0) / total for s in exact_seeds}

def predictive_mixture(
    seeds_posterior: Dict[int, float],
    id_to_name: Dict[int, str],
    banner_units_int: List[List[int]],
    rate_cumsum: List[int],
    rerollable: List[int],
    observed_len: int,
    k: int,
) -> List[Dict[str, float]]:
    """
    Returns per-position probability dicts for next k rolls.
    """
    per_pos = [Counter() for _ in range(k)]

    for s, w in seeds_posterior.items():
        future = simulate_next_k(s, banner_units_int, rate_cumsum, rerollable, observed_len, k)
        for i, uid in enumerate(future):
            per_pos[i][uid] += w

    # normalize each position (should already sum ~1)
    out: List[Dict[str, float]] = []
    for c in per_pos:
        total = sum(c.values())
        probs = {}
        for uid, mass in c.items():
            name = id_to_name.get(uid, f"<id:{uid}>")
            probs[name] = float(mass / total) if total > 0 else 0.0
        out.append(probs)
    return out

def topk(probs: Dict[str, float], k: int = 6) -> List[Tuple[str, float]]:
    return sorted(probs.items(), key=lambda x: x[1], reverse=True)[:k]

if __name__ == "__main__":
    banner = parse_banner_paste(load_paste("paste.txt"))

    # Example observed rolls (edit these)
    observed_rolls = ["Viking Cat", "Tin Cat", "Fortune Teller Cat", "Archer Cat", "Bishop Cat", "Gold Cat"]

    # Build int IDs
    name_to_id: Dict[str, int] = {}
    id_to_name: Dict[int, str] = {}

    def get_id(name: str) -> int:
        if name not in name_to_id:
            new_id = len(name_to_id) + 1
            name_to_id[name] = new_id
            id_to_name[new_id] = name
        return name_to_id[name]

    banner_units_int: List[List[int]] = []
    for pool in banner.units_by_rarity:
        banner_units_int.append([get_id(u) for u in pool])

    observed_int = [get_id(u) for u in observed_rolls]

    # --- EXACT seek ---
    # IMPORTANT: Python+Cython brute force over all 2^32 seeds is still huge.
    # If you already have a seed from elsewhere, you don't need seeking.
    #
    # Here we show how to search a specific range. Tune to your needs.
    start_seed = 0
    end_seed = 1_000_000  # demo range; expand as you like

    exact = seed_seeker.seek_seeds(
        start_seed,
        end_seed,
        banner.rate_cumsum,
        banner_units_int,
        banner.rerollable_rarities,
        observed_int,
        max_found=100,
    )
    print(f"Exact seeds found in range [{start_seed}, {end_seed}): {exact}")

    # Bayes over exact set (uniform posterior by default)
    post = posterior_over_exact_seeds(exact)

    # Predict next 30
    pred = predictive_mixture(
        post,
        id_to_name,
        banner_units_int,
        banner.rate_cumsum,
        banner.rerollable_rarities,
        observed_len=len(observed_int),
        k=30,
    )

    for i in range(10):
        best = topk(pred[i], 6)
        s = ", ".join([f"{u} ({p:.2%})" for u, p in best])
        print(f"Roll+{i+1:02d}: {s}")


ModuleNotFoundError: No module named 'seed_seeker'

In [19]:
from __future__ import annotations

import argparse
import pickle
import re
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Dict, List, Sequence, Tuple
from collections import Counter

import seed_seeker  # compiled cython module

RARITY_ORDER = ["Rare", "Super", "Uber", "Legendary"]

# -----------------------------
# Banner parsing (paste.txt)
# -----------------------------

@dataclass(frozen=True)
class Banner:
    rate_cumsum: List[int]
    units_by_rarity: List[List[str]]
    rerollable_rarities: List[int]

def _parse_units_with_indices(units_str: str) -> List[str]:
    units_str = units_str.strip()
    if not units_str:
        return []
    pat = re.compile(r"(\d+)\s+([^,]+)(?:,|$)")
    found = pat.findall(units_str)
    idx_to_name: Dict[int, str] = {int(i): name.strip() for i, name in found}
    return [idx_to_name[i] for i in sorted(idx_to_name)]

def parse_banner_paste(paste: str) -> Banner:
    lines = [ln.strip() for ln in paste.splitlines() if ln.strip()]
    relevant: Dict[str, Dict[str, Any]] = {}

    line_regex = re.compile(
        r"^(?P<rarity>Rare|Super|Uber|Legendary):\s+"
        r"(?P<rate>[\d.]+)%\s+\(\d+\s+cats\)\s*(?P<units>.*)$"
    )

    for ln in lines:
        m = line_regex.match(ln)
        if not m:
            continue
        rarity = m.group("rarity")
        rate_pct = float(m.group("rate"))
        slots = int(round(rate_pct * 100))
        units = _parse_units_with_indices(m.group("units"))
        relevant[rarity] = {"slots": slots, "units": units}

    slots_list, pools = [], []
    for r in RARITY_ORDER:
        if r not in relevant:
            slots_list.append(0)
            pools.append([])
        else:
            slots_list.append(relevant[r]["slots"])
            pools.append(relevant[r]["units"])

    cumsum = []
    s = 0
    for v in slots_list:
        s += v
        cumsum.append(s)
    cumsum[-1] = 10000
    for i in range(len(cumsum) - 2, -1, -1):
        cumsum[i] = min(cumsum[i], cumsum[i + 1])

    return Banner(rate_cumsum=cumsum, units_by_rarity=pools, rerollable_rarities=[0])

def load_paste(path: str = "paste.txt") -> str:
    p = Path(path)
    if not p.exists():
        raise FileNotFoundError(f"Missing {path}")
    return p.read_text(encoding="utf-8", errors="replace")


# -----------------------------
# Deterministic simulator (Python) for prediction mixture once posterior small
# -----------------------------

def advance_seed(seed: int) -> int:
    seed &= 0xFFFFFFFF
    seed ^= ((seed << 13) & 0xFFFFFFFF)
    seed ^= (seed >> 17)
    seed ^= ((seed << 15) & 0xFFFFFFFF)
    return seed & 0xFFFFFFFF

def get_rarity(seed: int, rate_cumsum: Sequence[int]) -> int:
    x = seed % 10000
    for i, thr in enumerate(rate_cumsum):
        if x < thr:
            return i
    return len(rate_cumsum) - 1

def get_unit_index(seed: int, n: int, removed_index: int = -1) -> int:
    if n <= 0:
        return -1
    if removed_index < 0:
        return seed % n
    if n <= 1:
        return -1
    idx = seed % (n - 1)
    if idx >= removed_index:
        idx += 1
    return idx

def simulate_next_k_from_seed_before(
    seed_before: int,
    banner_units_int: List[List[int]],
    rate_cumsum: List[int],
    rerollable: List[int],
    observed_len: int,
    k: int,
) -> List[int]:
    seed = seed_before & 0xFFFFFFFF
    last = None
    out: List[int] = []
    reroll_set = set(rerollable)

    for _ in range(observed_len + k):
        seed = advance_seed(seed)
        rarity = get_rarity(seed, rate_cumsum)

        seed = advance_seed(seed)
        pool = banner_units_int[rarity]
        n = len(pool)
        idx = get_unit_index(seed, n, -1)
        unit_id = pool[idx] if idx >= 0 else -1

        if last is not None and rarity in reroll_set and unit_id == last:
            seed = advance_seed(seed)
            idx2 = get_unit_index(seed, n, idx)
            unit_id = pool[idx2] if idx2 >= 0 else -1

        out.append(unit_id)
        last = unit_id

    return out[observed_len:]


# -----------------------------
# Posterior management: over seed_after values (chaining)
# -----------------------------

Posterior = Dict[int, float]  # maps seed_before_next_roll -> probability mass

def posterior_uniform(states: List[int]) -> Posterior:
    if not states:
        return {}
    w = 1.0 / len(states)
    return {s: w for s in states}

def posterior_from_matches_seed_after(matches: List[Tuple[int, int]], prior_over_seed_before: Posterior | None) -> Posterior:
    """
    matches: list of (seed_before, seed_after)
    prior_over_seed_before: posterior from previous iteration over seed_before values.
      If None: treat seed_before prior as uniform over the found matches.

    Returns posterior over seed_after values.
    """
    if not matches:
        return {}

    # If no prior, uniform over matches
    if prior_over_seed_before is None:
        counts = Counter(seed_after for _, seed_after in matches)
        total = sum(counts.values())
        return {s_after: c / total for s_after, c in counts.items()}

    # With prior: mass on a match = prior(seed_before); then aggregate by seed_after
    acc = Counter()
    total_mass = 0.0
    for seed_before, seed_after in matches:
        w = prior_over_seed_before.get(seed_before, 0.0)
        if w > 0:
            acc[seed_after] += w
            total_mass += w

    if total_mass <= 0:
        # fallback: uniform over seed_after in matches
        return posterior_uniform(list({s_after for _, s_after in matches}))

    return {s_after: float(mass / total_mass) for s_after, mass in acc.items()}

def save_posterior(path: str, posterior: Posterior) -> None:
    with open(path, "wb") as f:
        pickle.dump(posterior, f)

def load_posterior(path: str) -> Posterior:
    with open(path, "rb") as f:
        obj = pickle.load(f)
    if not isinstance(obj, dict):
        raise ValueError("Pickle did not contain a dict posterior")
    # ensure int->float
    out: Posterior = {}
    for k, v in obj.items():
        out[int(k)] = float(v)
    return out


# -----------------------------
# Prediction mixture from posterior
# -----------------------------

def predictive_mixture(
    posterior_seed_before: Posterior,
    id_to_name: Dict[int, str],
    banner_units_int: List[List[int]],
    rate_cumsum: List[int],
    rerollable: List[int],
    observed_len: int,
    k: int,
) -> List[Dict[str, float]]:
    per_pos = [Counter() for _ in range(k)]
    for seed_before, w in posterior_seed_before.items():
        future = simulate_next_k_from_seed_before(seed_before, banner_units_int, rate_cumsum, rerollable, observed_len, k)
        for i, uid in enumerate(future):
            per_pos[i][uid] += w

    out: List[Dict[str, float]] = []
    for c in per_pos:
        total = sum(c.values())
        probs: Dict[str, float] = {}
        for uid, mass in c.items():
            probs[id_to_name.get(uid, f"<id:{uid}>")] = float(mass / total) if total > 0 else 0.0
        out.append(probs)
    return out

def topk(probs: Dict[str, float], k: int = 6) -> List[Tuple[str, float]]:
    return sorted(probs.items(), key=lambda x: x[1], reverse=True)[:k]


# -----------------------------
# Main
# -----------------------------

def main():
    ap = argparse.ArgumentParser()
    ap.add_argument("--paste", default="paste.txt")
    ap.add_argument("--posterior-pkl", default="posterior.pkl")
    ap.add_argument("--use-pickle", action="store_true", help="Load prior posterior from pickle (if present).")
    ap.add_argument("--save-pickle", action="store_true", help="Save updated posterior to pickle after this update.")
    ap.add_argument("--start", type=int, default=0, help="Start seed (inclusive) for seeking (only used if no prior pickle).")
    ap.add_argument("--end", type=int, default=1_000_000, help="End seed (exclusive) for seeking (only used if no prior pickle).")
    ap.add_argument("--max-found", type=int, default=0, help="Stop after this many matches (0 = all in range).")
    ap.add_argument("--predict-k", type=int, default=30)
    ap.add_argument("--print-pos", type=int, default=10)
    args = ap.parse_args()

    banner = parse_banner_paste(load_paste(args.paste))

    # EDIT HERE or load from elsewhere (keep it simple: hardcode for now)
    observed_rolls = ["Viking Cat", "Tin Cat", "Fortune Teller Cat", "Archer Cat", "Bishop Cat", "Gold Cat"]

    # Map names to ints
    name_to_id: Dict[str, int] = {}
    id_to_name: Dict[int, str] = {}

    def get_id(name: str) -> int:
        if name not in name_to_id:
            new_id = len(name_to_id) + 1
            name_to_id[name] = new_id
            id_to_name[new_id] = name
        return name_to_id[name]

    banner_units_int: List[List[int]] = []
    for pool in banner.units_by_rarity:
        banner_units_int.append([get_id(u) for u in pool])

    observed_int = [get_id(u) for u in observed_rolls]

    # Load prior posterior if requested and exists
    prior: Posterior | None = None
    if args.use_pickle and Path(args.posterior_pkl).exists():
        prior = load_posterior(args.posterior_pkl)
        print(f"Loaded prior posterior from {args.posterior_pkl} with {len(prior)} states.")
    else:
        print("No prior posterior loaded (starting fresh).")

    # Exact match seeking:
    # - If prior exists, we seek from each prior seed_before state directly by running seeker over tiny ranges?
    #   Not necessary: you already have specific seed_before candidates, so you can just simulate forward.
    #   BUT: since you requested Cython seeker usage, we’ll use it in two modes:
    #
    # Mode A: no prior -> brute seek in [start,end)
    # Mode B: prior exists -> for each prior seed_before, check exact match by seeking range [seed_before, seed_before+1)
    #
    matches: List[Tuple[int, int]] = []

    if prior is None:
        # Full range search
        matches = seed_seeker.seek_seeds_before_after(
            args.start, args.end,
            banner.rate_cumsum,
            banner_units_int,
            banner.rerollable_rarities,
            observed_int,
            args.max_found
        )
        print(f"Matches found in range [{args.start},{args.end}): {len(matches)}")
    else:
        # Check each prior state as the candidate "seed_before" for this iteration
        # by searching the 1-seed range [s, s+1)
        for s_before, w in sorted(prior.items(), key=lambda kv: -kv[1]):
            m = seed_seeker.seek_seeds_before_after(
                s_before, s_before + 1,
                banner.rate_cumsum,
                banner_units_int,
                banner.rerollable_rarities,
                observed_int,
                1
            )
            if m:
                matches.extend(m)
        print(f"Matches found from prior support: {len(matches)} / {len(prior)}")

    if not matches:
        raise SystemExit("No exact matching seeds found (check banner paste, roll names, version/reroll rules, or search range).")

    # Build posterior over seed_after (this is the new "seed_before_next_roll" support)
    posterior_after = posterior_from_matches_seed_after(matches, prior_over_seed_before=prior)

    # Save if requested
    if args.save_pickle:
        save_posterior(args.posterior_pkl, posterior_after)
        print(f"Saved posterior with {len(posterior_after)} states to {args.posterior_pkl}")

    # Predict next K from the posterior seed_after states
    # IMPORTANT: We are now predicting starting from seed_after values (seed BEFORE next roll),
    # so observed_len=0 for this prediction horizon.
    pred = predictive_mixture(
        posterior_after,
        id_to_name,
        banner_units_int,
        banner.rate_cumsum,
        banner.rerollable_rarities,
        observed_len=0,
        k=args.predict_k,
    )

    for i in range(min(args.print_pos, len(pred))):
        best = topk(pred[i], 6)
        s = ", ".join([f"{u} ({p:.2%})" for u, p in best])
        print(f"Roll+{i+1:02d}: {s}")

if __name__ == "__main__":
    main()


usage: ipykernel_launcher.py [-h] [--paste PASTE]
                             [--posterior-pkl POSTERIOR_PKL] [--use-pickle]
                             [--save-pickle] [--start START] [--end END]
                             [--max-found MAX_FOUND] [--predict-k PREDICT_K]
                             [--print-pos PRINT_POS]
ipykernel_launcher.py: error: unrecognized arguments: -f /home/chloe/.local/share/jupyter/runtime/kernel-75bfe12f-979b-422b-97b3-cf378b65122a.json


SystemExit: 2

  warn("To exit: use 'exit', 'quit', or Ctrl-D.", stacklevel=1)
