## Get a sudoko puzzle

In [176]:
mission = "000300000000500031351496000940000260207904000000025400083240100406003050700000300"
mission = "000002670863000005007508000000080497070000080050030106001906000430050060020070010"
# solution = "195342678863197245247568931312685497674219583958734126781926354439851762526473819"

In [177]:
from typing import Iterable
from dataclasses import dataclass, field
from typing import List


def bit_of(num: int) -> int:
    return 1 << (num - 1) if num > 0 else 0

def val_of(mask: int) -> int:
    if mask == 0:
        return 0
    if mask <= 0 or (mask & (mask - 1)) != 0:
        raise ValueError("Mask must have exactly one bit set")
    return mask.bit_length()

def bits_iter(mask: int) -> Iterable[int]:
    """Yield concrete values encoded in mask (1..n)."""
    i = 0
    while mask:
        if mask & 1:
            yield i + 1
        mask >>= 1
        i += 1

def num_ones(mask: int) -> int:
    count = 0
    while mask:
        mask &= mask - 1
        count += 1
    return count

def print_mask(mask: int) -> None:
    print(f"{mask:09b}")

In [201]:
@dataclass(slots=True)
class SudokuState:
    """
    Immutable-size Sudoku state that tracks:
      - board: list[int] with 0 for empty, or fixed value 1..n
      - opts:  list[int] of n-bit masks (bit k => value k+1 allowed)
    """
    n: int                             # side length (e.g., 9)
    N: int                             # total cells (n*n)
    box: int                           # box side (e.g., 3 for 9×9)
    all_mask: int                      # (1<<n)-1
    board: List[int] = field(default_factory=list)
    opts: List[int] = field(default_factory=list)
    # Neighbours: indices that share row/col/box with a cell (excluding itself)
    neigh: List[frozenset[int]] = field(default_factory=list)

    @classmethod
    def from_string(cls, mission: str) -> "SudokuState":
        N = len(mission)
        n = int(N**.5)
        assert n * n == N, "The sudoku must be a square for this solver to work"

        box = int(n**.5)
        assert n * n == N, "Side length must be a perfect square (e.g. 9)"

        all_mask = (1 << n) - 1
        board: List[int] = [bit_of(int(c)) for c in mission]
        opts: List[int] = [all_mask if board[i] == 0 else 0 for i in range(N)]

        neigh = cls._build_neigh(N, n, box)

        return cls(
            n = n,
            N = N,
            box = box,
            all_mask = all_mask,
            board = board,
            opts = opts,
            neigh = neigh
        )
    
    @staticmethod
    def _build_neigh(N: int, n: int, box: int) -> List[frozenset[int]]:
        neigh: List[frozenset[int]] = []
        for i in range(N):
            # Determine the row and index for i (given the number of columns n)
            r, c = divmod(i, n)

            # Calculate the rows and columns every n steps (for the row and col constraint, respectively)
            row = {j * n + c for j in range(n)}
            col = {k + r * n for k in range(n)}

            # Calculate the box indices (for the box constraint)
            br, bc = (r // box) * box, (c // box) * box
            box_ids = {
                (br + dr) * n + (bc + dc)
                for dr in range(box) for dc in range(box)
            }

            # Add all except the index itself
            ngh = (row | col | box_ids) - {i}
            neigh.append(frozenset(ngh))

        return neigh
    
    # Inspection classes

    def options_mask(self, idx: int) -> int:
        """Raw options bitmask for a cell index."""
        return self.opts[idx]

    def options(self, idx: int) -> List[int]:
        """Concrete options for a cell index."""
        return list(bits_iter(self.opts[idx]))
    
    def is_fixed(self, idx: int) -> bool:
        """True if cell has a no options and a board value."""
        return self.opts[idx] == 0 and self.board[idx] != 0
    
    # Mutations (local, not solving)

    def assign(self, idx: int, mask: int) -> None:
        if num_ones(mask) != 1:
            raise ValueError("The mask contains none or more then 1 ones")
        
        # Update the board and remove all options for this cell
        self.board[idx] = mask
        self.opts[idx] = 0

    def eliminate(self, idx: int, mask: int) -> None:
        self.opts[idx] &= ~mask

    def clone(self) -> "SudokuState":
        "Creates a deep copy of the state"
        return SudokuState(
            n=self.n,
            N=self.N,
            box=self.box,
            all_mask=self.all_mask,
            board=self.board.copy(),
            opts=self.opts.copy(),
            neigh=self.neigh,
        )
    
    # Convenience
    def to_board_string(self) -> str:
        """Serialize the current board to a pretty string with box separators."""
        rows = []
        for i in range(self.n):
            # Build one row with vertical separators
            row_parts = []
            for j in range(self.n):
                row_parts.append(str(val_of(self.board[i * self.n + j])))
                # Add vertical line if we're at the end of a box
                if (j + 1) % self.box == 0 and j + 1 < self.n:
                    row_parts.append("|")
            rows.append(" ".join(row_parts))

            # Add horizontal line if we're at the end of a box
            if (i + 1) % self.box == 0 and i + 1 < self.n:
                rows.append("-" * (2 * self.n + self.box - 1))  # adjust length
        return "\n".join(rows)

    def to_board_with_opts(self) -> str:
        def bits_to_digits(mask: int) -> List[int]:
            return [i + 1 for i in range(self.n) if (mask >> i) & 1]

        def render_cell(idx: int) -> str:
            val = val_of(self.board[idx])
            if val:
                return str(val).center(self.n)
            digs = bits_to_digits(self.opts[idx])
            s = "".join(str(d) for d in digs) or "."
            return s.center(self.n)

        rows = []
        for i in range(self.n):
            row_parts = []
            for j in range(self.n):
                idx = i * self.n + j
                row_parts.append(render_cell(idx))
                if (j + 1) % self.box == 0 and j + 1 < self.n:
                    row_parts.append("|")
            rows.append(" ".join(row_parts))
            if (i + 1) % self.box == 0 and i + 1 < self.n:
                rows.append("-" * (self.n * 2))
        return "\n".join(rows)

In [203]:
state = SudokuState.from_string(mission)

In [204]:
print(state.to_board_string())

0 0 0 | 0 0 2 | 6 7 0
8 6 3 | 0 0 0 | 0 0 5
0 0 7 | 5 0 8 | 0 0 0
--------------------
0 0 0 | 0 8 0 | 4 9 7
0 7 0 | 0 0 0 | 0 8 0
0 5 0 | 0 3 0 | 1 0 6
--------------------
0 0 1 | 9 0 6 | 0 0 0
4 3 0 | 0 5 0 | 0 6 0
0 2 0 | 0 7 0 | 0 1 0


In [205]:
print(state.to_board_with_opts())

123456789 123456789 123456789 | 123456789 123456789     2     |     6         7     123456789
    8         6         3     | 123456789 123456789 123456789 | 123456789 123456789     5    
123456789 123456789     7     |     5     123456789     8     | 123456789 123456789 123456789
------------------
123456789 123456789 123456789 | 123456789     8     123456789 |     4         9         7    
123456789     7     123456789 | 123456789 123456789 123456789 | 123456789     8     123456789
123456789     5     123456789 | 123456789     3     123456789 |     1     123456789     6    
------------------
123456789 123456789     1     |     9     123456789     6     | 123456789 123456789 123456789
    4         3     123456789 | 123456789     5     123456789 | 123456789     6     123456789
123456789     2     123456789 | 123456789     7     123456789 | 123456789     1     123456789


In [215]:
from dataclasses import dataclass
from collections import deque
from typing import Tuple, List

# assumes your helpers exist: bit_of, val_of, bits_iter, num_ones

class SudokuSolver:
    def __init__(self, state: "SudokuState"):
        self.s = state
        self.units = self._units()

    # ------ Helpers ------
    def _allowed_mask(self, idx: int) -> int:
        if self.s.board[idx] != 0:
            return 0
        ban = 0
        for j in self.s.neigh[idx]:
            ban |= self.s.board[j]
        return self.s.all_mask & ~ban
    
    def _assign_and_enqueue(self, idx: int, mask: int, q: deque) -> None:
        self.s.assign(idx, mask)
        for j in self.s.neigh[idx]:
            if self.s.board[j] == 0:
                if j not in q:
                    q.append(j)

    def _units(self) -> Tuple[List[List[int]], List[List[int]], List[List[int]]]:
        n, box = self.s.n, self.s.box
        rows = [[r*n + c for c in range(n)] for r in range(n)]
        cols = [[r*n + c for r in range(n)] for c in range(n)]
        boxes = []
        for br in range(0, n, box):
            for bc in range(0, n, box):
                box_idxs = []
                for dr in range(box):
                    for dc in range(box):
                        box_idxs.append((br+dr)*n + (bc+dc))
                boxes.append(box_idxs)
        return rows, cols, boxes

    # ----- Core propagation loop -----
    def _deduce(self) -> Tuple[bool, bool]:
        """
        Do ONE pass of deduction.
        Returns (changed, contradiction_found).
        Applies: neighbor elimination, naked singles, hidden singles (once).
        Naked Single: Only has one option
        Hidden Single: In a given unit (one row, one column, or one box), if a digit can go in exactly one cell, that cell must take that digit
        """
        changed = False
        rows, cols, boxes = self.units
        q = deque()

        # Seed queue from already-fixed cells
        for i in range(self.s.N):
            if self.s.is_fixed(i):
                for j in self.s.neigh[i]:
                    if self.s.board[j] == 0:
                        q.append(j)

        # helper: tighten options from queue; assign naked singles
        def tighten_from_queue() -> Tuple[bool, bool]:
            nonlocal changed
            while q:
                j = q.popleft()
                if self.s.board[j] != 0:
                    continue
                new_mask = self._allowed_mask(j)
                if new_mask == 0:
                    return changed, True
                if new_mask != self.s.opts[j]:
                    self.s.opts[j] = new_mask
                    changed = True
                if num_ones(new_mask) == 1:
                    self._assign_and_enqueue(j, new_mask, q)
                    changed = True
            return changed, False

        # 1) neighbor eliminations + naked singles
        ch, bad = tighten_from_queue()
        if bad:
            return ch, True
        
        # 2) hidden singles (may enqueue and then tighten once more)
        for unit in (*rows, *cols, *boxes):
            # bit -> candidate indices
            pos = {}
            for idx in unit:
                if self.s.board[idx] == 0:
                    m = self.s.opts[idx]
                    b = m
                    while b:
                        bit = b & -b
                        pos.setdefault(bit, []).append(idx)
                        b &= b - 1

            for bit, where in pos.items():
                if len(where) == 1:
                    idx = where[0]
                    if self.s.board[idx] == 0:
                        self._assign_and_enqueue(idx, bit, q)
                        changed = True

            # Process consequences of any hidden-single assignments done in this unit
            if q:
                ch, bad = tighten_from_queue()
                if bad:
                    return ch, True

        return changed, False

    def _deduce_until_stable(self) -> Tuple[bool, bool]:
        """
        Keep calling `deduce()` until no more changes.
        Returns (progress_made, contradiction_found).
        """
        any_change = False
        while True:
            changed, bad = self._deduce()
            if bad:
                return any_change or changed, True
            if not changed:
                return any_change, False
            any_change = True

    def _search(self) -> bool:
        # run deduction first
        _, bad = self._deduce_until_stable()
        if bad:
            return False
        if self.is_solved():
            return True

        # choose cell with fewest options
        best_idx, best_mask, best_count = -1, 0, 999
        for i in range(self.s.N):
            if self.s.board[i] == 0:
                m = self.s.opts[i]
                c = num_ones(m)
                if c < best_count:
                    best_idx, best_mask, best_count = i, m, c
                    if c == 2:  # good heuristic: break early on 2
                        break

        # try each option (clone state for branch)
        b = best_mask
        while b:
            bit = b & -b
            b &= b - 1
            branch = self.s.clone()
            SudokuSolver(branch).s.assign(best_idx, bit)  # seed assignment
            solver = SudokuSolver(branch)
            if solver._search():
                # copy back winning board/opts
                self.s.board = branch.board
                self.s.opts = branch.opts
                return True
        return False
    
    # --- Convenience checks ---
    def is_solved(self) -> bool:
        s = self.s
        if any(s.board[i] == 0 for i in range(s.N)):
            return False
        rows, cols, boxes = self._units()
        for unit in (*rows, *cols, *boxes):
            seen = 0
            for idx in unit:
                m = s.board[idx]
                if m == 0 or (seen & m):
                    return False
                seen |= m
            if seen != s.all_mask:
                return False
        return True

    # ----- Public ------
    def solve_deductive(self) -> bool:
        _, bad = self._deduce_until_stable()
        if bad:
            raise ValueError("Contradiction during deduction.")
        return self.is_solved()
    
    def solve_search(self) -> bool:
        """Optional: deduction + DFS search. Keeps `SudokuState` clean."""
        ok = self._search()
        return ok

In [None]:
# Pure logic only
state = SudokuState.from_string(mission)
solver = SudokuSolver(state)

solver.solve_search()
print(state.to_board_string())

1 9 5 | 3 4 2 | 6 7 8
8 6 3 | 1 9 7 | 2 4 5
2 4 7 | 5 6 8 | 9 3 1
--------------------
3 1 2 | 6 8 5 | 4 9 7
6 7 4 | 2 1 9 | 5 8 3
9 5 8 | 7 3 4 | 1 2 6
--------------------
7 8 1 | 9 2 6 | 3 5 4
4 3 9 | 8 5 1 | 7 6 2
5 2 6 | 4 7 3 | 8 1 9


True

In [172]:
# Pure logic only
state = SudokuState.from_string(mission)
solver = SudokuSolver(state)

# print(state.to_board_with_opts())
# print("\n\n\n")

solver.deduce()
print(state.to_board_string())
# print(state.to_board_with_opts())

0 0 0 | 3 0 0 | 0 0 0
0 0 0 | 5 0 0 | 0 3 1
3 5 1 | 4 9 6 | 0 0 0
--------------------
9 4 8 | 0 0 0 | 2 6 0
2 0 7 | 9 0 4 | 0 0 0
0 0 0 | 0 2 5 | 4 0 0
--------------------
0 8 3 | 2 4 0 | 1 0 0
4 0 6 | 0 0 3 | 0 5 0
7 0 2 | 0 0 0 | 3 0 0


In [None]:
state = SudokuState.from_string(mission)
solver = SudokuSolver(state)
print(solver.s.to_board_with_opts())
print("\n\n\n")
solver._deduce_until_stable()

print(solver.s.to_board_with_opts())

123456789 123456789 123456789 |     3     123456789 123456789 | 123456789 123456789 123456789
123456789 123456789 123456789 |     5     123456789 123456789 | 123456789     3         1    
    3         5         1     |     4         9         6     | 123456789 123456789 123456789
------------------
    9         4     123456789 | 123456789 123456789 123456789 |     2         6     123456789
    2     123456789     7     |     9     123456789     4     | 123456789 123456789 123456789
123456789 123456789 123456789 | 123456789     2         5     |     4     123456789 123456789
------------------
123456789     8         3     |     2         4     123456789 |     1     123456789 123456789
    4     123456789     6     | 123456789 123456789     3     | 123456789     5     123456789
    7     123456789 123456789 | 123456789 123456789 123456789 |     3     123456789 123456789






AttributeError: 'SudokuSolver' object has no attribute '_deduce_until_stable'