In [None]:
from __future__ import annotations

from pathlib import Path
from typing import Dict, Iterable, List, Optional, Sequence, Set, Tuple

# --- SAT / CNF primitives ---
# Internal representation:
# - Variables are 1-based integers: 1, 2, 3, ...
# - Literals are signed ints: +v means x_v, -v means Â¬x_v
Literal = int
Clause = List[Literal]
CNF = List[Clause]
Assignment = Dict[int, bool]


def parse_instances(text: str) -> List[CNF]:
    """Parse an instances file (see description.md).

    Format:
      - Each clause is one line of literals.
      - Literals are numbered starting at 0; negation is a leading '-'.
      - SAT problems are separated by a blank line.

    Note:
      Because Python ints cannot represent "+0" vs "-0" distinctly, we map
      external variable k (0-based) to internal variable (k+1).
    """
    problems: List[CNF] = []
    current: CNF = []

    for raw_line in text.splitlines():
        line = raw_line.strip()
        if not line:
            if current:
                problems.append(current)
                current = []
            continue

        tokens = line.split()
        clause: Clause = []
        for tok in tokens:
            neg = tok.startswith("-")
            num_str = tok[1:] if neg else tok
            if num_str == "":
                raise ValueError(f"Invalid literal token: {tok!r}")
            ext_var = int(num_str)  # 0-based variable index
            int_var = ext_var + 1   # internal 1-based variable index
            lit = -int_var if neg else int_var
            clause.append(lit)

        if not clause:
            raise ValueError(f"Empty clause line: {raw_line!r}")
        current.append(clause)

    if current:
        problems.append(current)

    return problems


def read_instances_file(path: str | Path) -> List[CNF]:
    p = Path(path)
    return parse_instances(p.read_text(encoding="utf-8"))


def _simplify(cnf: Sequence[Clause], assignment: Assignment) -> Optional[CNF]:
    """Simplify CNF under a partial assignment.

    Returns:
      - simplified CNF if consistent
      - [] if all clauses are satisfied (SAT)
      - None if a clause becomes empty (UNSAT)
    """
    simplified: CNF = []

    for clause in cnf:
        satisfied = False
        new_clause: Clause = []

        for lit in clause:
            var = abs(lit)
            if var in assignment:
                val = assignment[var]
                lit_is_true = (val and lit > 0) or ((not val) and lit < 0)
                if lit_is_true:
                    satisfied = True
                    break
                # otherwise literal is false; drop it
            else:
                new_clause.append(lit)

        if satisfied:
            continue
        if not new_clause:
            return None
        simplified.append(new_clause)

    return simplified


def _unit_literals(cnf: Sequence[Clause]) -> List[Literal]:
    return [clause[0] for clause in cnf if len(clause) == 1]


def _pure_literals(cnf: Sequence[Clause]) -> Set[Literal]:
    polarity: Dict[int, int] = {}
    for clause in cnf:
        for lit in clause:
            var = abs(lit)
            sign = 1 if lit > 0 else -1
            if var not in polarity:
                polarity[var] = sign
            elif polarity[var] != sign:
                polarity[var] = 0

    pures: Set[Literal] = set()
    for var, sign in polarity.items():
        if sign == 1:
            pures.add(var)
        elif sign == -1:
            pures.add(-var)
    return pures


def _choose_branch_literal(cnf: Sequence[Clause]) -> Literal:
    # Simple heuristic: pick a literal from the smallest clause.
    smallest = min(cnf, key=len)
    return smallest[0]


def model_satisfies(cnf: Sequence[Clause], model: Assignment) -> bool:
    for clause in cnf:
        clause_sat = False
        for lit in clause:
            var = abs(lit)
            val = model.get(var, False)
            if (val and lit > 0) or ((not val) and lit < 0):
                clause_sat = True
                break
        if not clause_sat:
            return False
    return True


def dpll(cnf: Sequence[Clause]) -> Optional[Assignment]:
    """Solve a CNF using DPLL.

    Input uses internal representation (see top of cell).
    Returns a satisfying assignment or None if UNSAT.
    """
    return _dpll_rec(list(cnf), {})


def _dpll_rec(cnf: CNF, assignment: Assignment) -> Optional[Assignment]:
    cnf_s = _simplify(cnf, assignment)
    if cnf_s is None:
        return None
    if len(cnf_s) == 0:
        return dict(assignment)

    # Unit propagation
    while True:
        units = _unit_literals(cnf_s)
        if not units:
            break
        for lit in units:
            var = abs(lit)
            val = lit > 0
            if var in assignment and assignment[var] != val:
                return None
            assignment[var] = val
        cnf_s = _simplify(cnf_s, assignment)
        if cnf_s is None:
            return None
        if len(cnf_s) == 0:
            return dict(assignment)

    # Pure literal elimination
    pures = _pure_literals(cnf_s)
    if pures:
        for lit in pures:
            var = abs(lit)
            val = lit > 0
            if var in assignment and assignment[var] != val:
                return None
            assignment[var] = val
        return _dpll_rec(cnf_s, assignment)

    # Branch
    lit = _choose_branch_literal(cnf_s)
    var = abs(lit)
    for val in (True, False):
        new_assignment = dict(assignment)
        new_assignment[var] = val
        result = _dpll_rec(cnf_s, new_assignment)
        if result is not None:
            return result
    return None


def format_model_external(model: Assignment) -> str:
    """Pretty-print model in the external (0-based) literal numbering."""
    parts: List[str] = []
    for var in sorted(model.keys()):
        ext = var - 1
        parts.append(str(ext) if model[var] else f"-{ext}")
    return " ".join(parts)


def solve_instances_file(path: str | Path) -> List[Optional[Assignment]]:
    problems = read_instances_file(path)
    return [dpll(problem) for problem in problems]


In [None]:
import unittest
from itertools import product
from pathlib import Path
from typing import Optional


def _find_data_file(rel: str) -> Path:
    """Find data files whether CWD is repo root or dpll/ subdir."""
    candidates = [Path(rel), Path("dpll") / rel, Path(".") / rel]
    for p in candidates:
        if p.exists():
            return p
    raise FileNotFoundError(f"Couldn't find {rel!r}; tried: {candidates}")


def _num_vars(cnf: CNF) -> int:
    return max((abs(l) for clause in cnf for l in clause), default=0)


def _bruteforce_solve(cnf: CNF) -> Optional[Assignment]:
    n = _num_vars(cnf)
    if n == 0:
        return {}
    # Keep brute force bounded.
    if n > 12:
        raise ValueError(f"Too many vars for brute force: {n}")

    for bits in product([False, True], repeat=n):
        model = {i + 1: bits[i] for i in range(n)}
        if model_satisfies(cnf, model):
            return model
    return None


class TestDPLLFileFormat(unittest.TestCase):
    def test_parse_small_instances_has_multiple_problems(self) -> None:
        small_path = _find_data_file("small_instances.txt")
        problems = read_instances_file(small_path)
        self.assertGreaterEqual(len(problems), 2)

    def test_dpll_matches_bruteforce_on_small_instances(self) -> None:
        small_path = _find_data_file("small_instances.txt")
        problems = read_instances_file(small_path)

        for i, cnf in enumerate(problems):
            n = _num_vars(cnf)
            if n > 12:
                # small_instances should usually be small; but be safe.
                continue

            brute = _bruteforce_solve(cnf)
            dpll_model = dpll(cnf)

            self.assertEqual(brute is not None, dpll_model is not None, msg=f"problem #{i}")
            if dpll_model is not None:
                self.assertTrue(model_satisfies(cnf, dpll_model), msg=f"problem #{i}")

    def test_big_instances_first_problem_model_is_valid_if_sat(self) -> None:
        big_path = _find_data_file("big_instances.txt")
        problems = read_instances_file(big_path)
        self.assertGreaterEqual(len(problems), 1)

        cnf0 = problems[0]
        model = dpll(cnf0)
        if model is not None:
            self.assertTrue(model_satisfies(cnf0, model))


suite = unittest.defaultTestLoader.loadTestsFromTestCase(TestDPLLFileFormat)
runner = unittest.TextTestRunner(verbosity=2)
runner.run(suite)


test_pure_literal_elimination (__main__.TestDPLL.test_pure_literal_elimination) ... ok
test_sat_requires_branch (__main__.TestDPLL.test_sat_requires_branch) ... ok
test_sat_simple (__main__.TestDPLL.test_sat_simple) ... ok
test_sat_single_unit (__main__.TestDPLL.test_sat_single_unit) ... ok
test_unsat_contradiction (__main__.TestDPLL.test_unsat_contradiction) ... ok

----------------------------------------------------------------------
Ran 5 tests in 0.006s

OK


<unittest.runner.TextTestResult run=5 errors=0 failures=0>