In [1]:


from __future__ import annotations
from dataclasses import dataclass
from typing import List, Tuple, Optional, Iterable
import itertools
import math
import random

# ---------------------------------------------------------
# 2×2 board + production rules (classical WM evaluation)
# ---------------------------------------------------------
# Board indices (row‑major):
#   0 1
#   2 3
# Value 0 encodes BLANK.

RULE_NAMES = {0b00: "UP", 0b01: "DOWN", 0b10: "LEFT", 0b11: "RIGHT"}
RULE_BITS = {r: f"{r:02b}" for r in range(4)}


def swap(a: List[int], i: int, j: int) -> List[int]:
    b = a.copy()
    b[i], b[j] = b[j], b[i]
    return b


def apply_rule_once(board: List[int], rule: int) -> Optional[List[int]]:
    """Apply one legal move. Return new board or None if move is illegal.

    Rules (blank index i → j):
      UP:    if i in {2,3} then j = i-2
      DOWN:  if i in {0,1} then j = i+2
      LEFT:  if i in {1,3} then j = i-1
      RIGHT: if i in {0,2} then j = i+1
    """
    i = board.index(0)
    if rule == 0b00:  # UP
        if i in (2, 3):
            return swap(board, i, i - 2)
    elif rule == 0b01:  # DOWN
        if i in (0, 1):
            return swap(board, i, i + 2)
    elif rule == 0b10:  # LEFT
        if i in (1, 3):
            return swap(board, i, i - 1)
    elif rule == 0b11:  # RIGHT
        if i in (0, 2):
            return swap(board, i, i + 1)
    return None


def apply_rule_sequence(board: List[int], seq: Iterable[int]) -> Optional[List[int]]:
    """Apply a sequence of rules; return final board or None if any move is illegal."""
    b = board
    for r in seq:
        b = apply_rule_once(b, r)
        if b is None:
            return None
    return b


# ---------------------------------------
# Grover search on a finite set (no deps)
# ---------------------------------------

def grover_amplify(N: int, marked: List[int], iterations: Optional[int] = None) -> Tuple[List[float], List[float]]:
    """Generic Grover on N items with given marked indices.
    Returns (amplitudes, probabilities).
    """
    if N <= 0:
        raise ValueError("N must be positive")

    # Start in uniform superposition
    a = [1.0 / math.sqrt(N)] * N

    M = len(marked)
    if M == 0:
        # No solutions → do nothing (uniform remains)
        return a, [x * x for x in a]

    if iterations is None:
        iterations = max(1, int(math.floor((math.pi / 4.0) * math.sqrt(N / M))))

    marked_set = set(marked)

    for _ in range(iterations):
        # Oracle: phase flip on the marked items
        for i in range(N):
            if i in marked_set:
                a[i] = -a[i]
        # Diffusion: reflect about the mean
        mean = sum(a) / N
        a = [2 * mean - x for x in a]

    p = [x * x for x in a]
    return a, p


def simulate_measurements(prob: List[float], shots: int = 2048) -> dict:
    """Sample measurement outcomes according to prob distribution."""
    outcomes = list(range(len(prob)))
    counts = {i: 0 for i in outcomes}
    for _ in range(shots):
        r = random.random()
        s = 0.0
        for i, pr in enumerate(prob):
            s += pr
            if r <= s:
                counts[i] += 1
                break
    return counts


# ---------------------------------------
# One‑move search (original use‑case)
# ---------------------------------------

def grover_over_rules_one_move(start: List[int], goal: List[int]) -> Tuple[List[float], List[float], List[int]]:
    """Grover search over the 4 rules. Returns (amp, prob, marked_indices)."""
    marked = [r for r in range(4) if apply_rule_once(start, r) == goal]
    a, p = grover_amplify(4, marked)
    return a, p, marked


def predict_winner(prob: List[float]) -> int:
    """Index of the most probable item (ties: lowest index)."""
    best = max(prob)
    for i, v in enumerate(prob):
        if v == best:
            return i
    return 0


def fmt_rule(r: int) -> str:
    return f"{RULE_BITS[r]} ({RULE_NAMES[r]})"


# ---------------------------------------
# Depth‑d sequences (optional extension)
# ---------------------------------------

def grover_over_sequences(start: List[int], goal: List[int], depth: int) -> Tuple[List[float], List[float], List[Tuple[int, ...]]]:
    """Grover over all 4^depth sequences. Returns (amp, prob, solutions).
    The search space ordering is lexicographic over rule indices (0..3).
    """
    if depth <= 0:
        raise ValueError("depth must be >= 1")

    space = list(itertools.product(range(4), repeat=depth))
    marked_idx = []
    for idx, seq in enumerate(space):
        b2 = apply_rule_sequence(start, seq)
        if b2 == goal:
            marked_idx.append(idx)

    a, p = grover_amplify(len(space), marked_idx)
    # Decode solutions for convenience
    solutions = [space[i] for i in marked_idx]
    return a, p, solutions


# ---------------------------------------
# Tests (keep the original demo; add more)
# ---------------------------------------
@dataclass
class OneMoveTest:
    name: str
    start: List[int]
    goal: List[int]
    expected_bits: Optional[str]  # e.g. "10"; None if no single solution expected


def run_one_move_test(t: OneMoveTest) -> None:
    amp, prob, marked = grover_over_rules_one_move(t.start, t.goal)
    winner = predict_winner(prob)

    print(f"\n[ONE‑MOVE] {t.name}")
    print(f" start={t.start}  goal={t.goal}")
    for r in range(4):
        print(f"  P(rule={fmt_rule(r)}) = {prob[r]:.6f}")
    print(f" marked rules: {[fmt_rule(r) for r in marked] or 'None'}")
    print(f" predicted winner: {fmt_rule(winner)}")

    if t.expected_bits is not None:
        assert RULE_BITS[winner] == t.expected_bits, (
            f"Expected {t.expected_bits} but got {RULE_BITS[winner]} ({RULE_NAMES[winner]})"
        )
        print(" ✅ PASS")
    else:
        print(" (No single expected rule specified)")


@dataclass
class DepthDTest:
    name: str
    start: List[int]
    goal: List[int]
    depth: int


def run_depth_d_test(t: DepthDTest) -> None:
    amp, prob, solutions = grover_over_sequences(t.start, t.goal, t.depth)
    N = 4 ** t.depth
    winner = predict_winner(prob)

    print(f"\n[DEPTH‑{t.depth}] {t.name}")
    print(f" start={t.start}  goal={t.goal}  search_space=4^{t.depth}={N}")
    print(f" #solutions found by oracle: {len(solutions)}")
    if solutions:
        def seq_bits(seq):
            return "-".join(RULE_BITS[r] for r in seq)
        show = ", ".join(seq_bits(s) for s in solutions[:8])
        more = "" if len(solutions) <= 8 else f" …(+{len(solutions)-8} more)"
        print(f" example solution sequences: {show}{more}")
    print(f" predicted top index: {winner}")


if __name__ == "__main__":
    # ---------------- Original demo case (UNCHANGED) ----------------
    # start=[1,0,2,3] → goal=[0,1,2,3]; expected LEFT = 10
    original_demo = OneMoveTest(
        name="Original one‑move demo (LEFT)",
        start=[1, 0, 2, 3],
        goal=[0, 1, 2, 3],
        expected_bits="10",
    )

    # ---------------- Additional one‑move tests ----------------
    tests = [
        original_demo,
        OneMoveTest(
            name="DOWN from top‑left",
            start=[0, 1, 2, 3],  # blank at 0
            goal=[2, 1, 0, 3],  # DOWN -> swap 0 and 2
            expected_bits="01",
        ),
        OneMoveTest(
            name="UP from bottom‑left",
            start=[2, 1, 0, 3],  # blank at 2
            goal=[0, 1, 2, 3],  # UP -> swap 2 and 0
            expected_bits="00",
        ),
        OneMoveTest(
            name="RIGHT from top‑left",
            start=[0, 1, 2, 3],  # blank at 0
            goal=[1, 0, 2, 3],  # RIGHT -> swap 0 and 1
            expected_bits="11",
        ),
        OneMoveTest(
            name="No single‑move solution (identity)",
            start=[0, 1, 2, 3],
            goal=[0, 1, 2, 3],  # requires 0 moves; our domain has no NO‑OP rule
            expected_bits=None,
        ),
    ]

    for t in tests:
        run_one_move_test(t)

    # ---------------- Optional depth‑2 demo (no assertions) ----------------
    # Example: two moves to go from [1,0,2,3] → [2,1,3,0]
    # (This is just a demonstrator; please specify exact depth‑2 targets you care about.)
    depth2_examples = [
        DepthDTest(
            name="Depth‑2 example",
            start=[1, 0, 2, 3],
            goal=[2, 1, 3, 0],
            depth=2,
        ),
        DepthDTest(
            name="Depth‑2 back‑and‑forth (many solutions likely none)",
            start=[0, 1, 2, 3],
            goal=[0, 1, 2, 3],
            depth=2,
        ),
    ]

    for t in depth2_examples:
        run_depth_d_test(t)

    # ---------------- Example: measurement sampling for original case ----------------
    print("\n[SAMPLING] Original demo simulated shots")
    _, prob, _ = grover_over_rules_one_move(original_demo.start, original_demo.goal)
    counts = simulate_measurements(prob, shots=2048)
    pretty = {RULE_BITS[k]: v for k, v in counts.items()}
    print(" counts (r1r0):", pretty)
    print(" expected winning rule: 10 (LEFT)")



[ONE‑MOVE] Original one‑move demo (LEFT)
 start=[1, 0, 2, 3]  goal=[0, 1, 2, 3]
  P(rule=00 (UP)) = 0.000000
  P(rule=01 (DOWN)) = 0.000000
  P(rule=10 (LEFT)) = 1.000000
  P(rule=11 (RIGHT)) = 0.000000
 marked rules: ['10 (LEFT)']
 predicted winner: 10 (LEFT)
 ✅ PASS

[ONE‑MOVE] DOWN from top‑left
 start=[0, 1, 2, 3]  goal=[2, 1, 0, 3]
  P(rule=00 (UP)) = 0.000000
  P(rule=01 (DOWN)) = 1.000000
  P(rule=10 (LEFT)) = 0.000000
  P(rule=11 (RIGHT)) = 0.000000
 marked rules: ['01 (DOWN)']
 predicted winner: 01 (DOWN)
 ✅ PASS

[ONE‑MOVE] UP from bottom‑left
 start=[2, 1, 0, 3]  goal=[0, 1, 2, 3]
  P(rule=00 (UP)) = 1.000000
  P(rule=01 (DOWN)) = 0.000000
  P(rule=10 (LEFT)) = 0.000000
  P(rule=11 (RIGHT)) = 0.000000
 marked rules: ['00 (UP)']
 predicted winner: 00 (UP)
 ✅ PASS

[ONE‑MOVE] RIGHT from top‑left
 start=[0, 1, 2, 3]  goal=[1, 0, 2, 3]
  P(rule=00 (UP)) = 0.000000
  P(rule=01 (DOWN)) = 0.000000
  P(rule=10 (LEFT)) = 0.000000
  P(rule=11 (RIGHT)) = 1.000000
 marked rules: ['11 (