In [None]:
import os
import json
import time
import pandas as pd
from pysat.solvers import Glucose4
from pysat.formula import CNF, IDPool
import concurrent.futures

In [3]:
# -------------------------------
# SAT 求解器部分（纯 SAT 模型）
# -------------------------------
def encode_multi_block_nonogram_to_cnf(m, n, row_constraints, col_constraints):
    """
    将“行多段 + 列多段”的 Nonogram 约束编码为 CNF。
    返回 (cnf, varmap)：
      - cnf: pysat.formula.CNF 对象，包含所有子句
      - varmap: 包含:
          varmap['x'][(i,j)]        -> int (变量ID)，表示 x[i][j]
          varmap['rblock'][(i,b,s)] -> int (变量ID)，表示第 i 行的第 b 段从列 s 开始
          varmap['cblock'][(j,b,s)] -> int (变量ID)，表示第 j 列的第 b 段从行 s 开始
    """
    pool = IDPool()
    cnf = CNF()
    varmap = {'x': {}, 'rblock': {}, 'cblock': {}}

    # 1) x[i][j] 变量
    for i in range(m):
        for j in range(n):
            varmap['x'][(i, j)] = pool.id(f"x_{i}_{j}")

    # 2) 行块变量
    for i in range(m):
        blocks = row_constraints[i]
        for b, length_b in enumerate(blocks):
            for start_col in range(n - length_b + 1):
                varmap['rblock'][(i, b, start_col)] = pool.id(f"rB_{i}_{b}_{start_col}")

    # 3) 列块变量
    for j in range(n):
        blocks = col_constraints[j]
        for b, length_b in enumerate(blocks):
            for start_row in range(m - length_b + 1):
                varmap['cblock'][(j, b, start_row)] = pool.id(f"cB_{j}_{b}_{start_row}")

    # -----------------------------
    # 添加行约束子句
    # -----------------------------
    for i in range(m):
        blocks = row_constraints[i]
        if len(blocks) == 0:
            for j in range(n):
                cnf.append([-varmap['x'][(i, j)]])
            continue
        # (A) 每段 exactly one 起始位置
        for b, length_b in enumerate(blocks):
            possible_ids = []
            for start_col in range(n - length_b + 1):
                possible_ids.append(varmap['rblock'][(i, b, start_col)])
            cnf.append(possible_ids)
            for a in range(len(possible_ids)):
                for c in range(a+1, len(possible_ids)):
                    cnf.append([-possible_ids[a], -possible_ids[c]])
        # (B) 相邻段至少隔 1 列空
        for b in range(len(blocks) - 1):
            length_b = blocks[b]
            for start_s in range(n - length_b + 1):
                rblock_s = varmap['rblock'][(i, b, start_s)]
                for start_t in range(n - blocks[b+1] + 1):
                    if start_t < start_s + length_b + 1:
                        rblock_t = varmap['rblock'][(i, b+1, start_t)]
                        cnf.append([-rblock_s, -rblock_t])
        # (C) 覆盖约束：x[i][col] = 1 当且仅当被某个行段覆盖
        for col in range(n):
            cover_ids = []
            for b, length_b in enumerate(blocks):
                for start_s in range(n - length_b + 1):
                    if start_s <= col < start_s + length_b:
                        cover_ids.append(varmap['rblock'][(i, b, start_s)])
            x_id = varmap['x'][(i, col)]
            if not cover_ids:
                cnf.append([-x_id])
            else:
                cnf.append([-x_id] + cover_ids)
                for cid in cover_ids:
                    cnf.append([-cid, x_id])

    # -----------------------------
    # 添加列约束子句
    # -----------------------------
    for j in range(n):
        blocks = col_constraints[j]
        if len(blocks) == 0:
            for i in range(m):
                cnf.append([-varmap['x'][(i, j)]])
            continue
        # (A) 每段 exactly one 起始位置
        for b, length_b in enumerate(blocks):
            possible_ids = []
            for start_i in range(m - length_b + 1):
                possible_ids.append(varmap['cblock'][(j, b, start_i)])
            cnf.append(possible_ids)
            for a in range(len(possible_ids)):
                for c in range(a+1, len(possible_ids)):
                    cnf.append([-possible_ids[a], -possible_ids[c]])
        # (B) 相邻段至少隔 1 行
        for b in range(len(blocks) - 1):
            length_b = blocks[b]
            for start_s in range(m - length_b + 1):
                cblock_s = varmap['cblock'][(j, b, start_s)]
                for start_t in range(m - blocks[b+1] + 1):
                    if start_t < start_s + length_b + 1:
                        cblock_t = varmap['cblock'][(j, b+1, start_t)]
                        cnf.append([-cblock_s, -cblock_t])
        # (C) 覆盖约束：x[i][j] 与列块变量之间的关系
        for i in range(m):
            cover_ids = []
            for b, length_b in enumerate(blocks):
                for start_s in range(m - length_b + 1):
                    if start_s <= i < start_s + length_b:
                        cover_ids.append(varmap['cblock'][(j, b, start_s)])
            x_id = varmap['x'][(i, j)]
            if not cover_ids:
                cnf.append([-x_id])
            else:
                cnf.append([-x_id] + cover_ids)
                for cid in cover_ids:
                    cnf.append([-cid, x_id])
    return cnf, varmap

In [5]:
def decode_solution(m, n, varmap, model_solution):
    sol_matrix = []
    for i in range(m):
        rowvals = []
        for j in range(n):
            x_id = varmap['x'][(i, j)]
            rowvals.append(1 if model_solution[x_id] else 0)
        sol_matrix.append(rowvals)
    return sol_matrix

In [7]:
def add_blocking_clause(cnf, varmap, solution_matrix):
    clause = []
    m = len(solution_matrix)
    n = len(solution_matrix[0])
    for i in range(m):
        for j in range(n):
            x_id = varmap['x'][(i, j)]
            if solution_matrix[i][j] == 1:
                clause.append(-x_id)
            else:
                clause.append(x_id)
    cnf.append(clause)

In [9]:
def solve_all_solutions_pure_sat(m, n, row_constraints, col_constraints):
    cnf, varmap = encode_multi_block_nonogram_to_cnf(m, n, row_constraints, col_constraints)
    solutions = []
    while True:
        solver = Glucose4()
        for clause in cnf.clauses:
            solver.add_clause(clause)
        ok = solver.solve()
        if not ok:
            solver.delete()
            break
        model = solver.get_model()
        assignment = {}
        for lit in model:
            if lit > 0:
                assignment[lit] = True
            else:
                assignment[-lit] = False
        sol_matrix = decode_solution(m, n, varmap, assignment)
        solutions.append(sol_matrix)
        add_blocking_clause(cnf, varmap, sol_matrix)
        solver.delete()
    return solutions

In [11]:
# -------------------------------
# 批量处理及记录求解结果
# -------------------------------
def process_puzzle_file(file_path):
    """读取 JSON 文件，返回 puzzle 数据字典"""
    with open(file_path, 'r') as f:
        data = json.load(f)
    return data

In [13]:
def solve_puzzle(puzzle):
    """
    对一个 puzzle 求解，返回求解时间和解的数量。
    puzzle 中应包含 row_hints 或 row_constraints 以及 col_hints 或 col_constraints。
    """
    # 获取行提示和列提示，支持不同的键名
    row_hints = puzzle.get("row_hints", puzzle.get("row_constraints"))
    col_hints = puzzle.get("col_hints", puzzle.get("col_constraints"))
    m = len(row_hints)
    n = len(col_hints)
    start = time.time()
    solutions = solve_all_solutions_pure_sat(m, n, row_hints, col_hints)
    end = time.time()
    return end - start, len(solutions)

In [15]:
def process_and_solve(file_path):
    """
    读取谜题文件，对谜题求解，返回一个记录字典。
    """
    puzzle = process_puzzle_file(file_path)
    pid = puzzle.get("id", os.path.basename(file_path))
    label = puzzle.get("label", None)
    row_hints = puzzle.get("row_hints", puzzle.get("row_constraints"))
    col_hints = puzzle.get("col_hints", puzzle.get("col_constraints"))
    
    solve_time, num_sol = solve_puzzle(puzzle)
    # 返回一个字典记录
    return {
        "puzzle_id": pid,
        "label": label,
        "row_hints": str(row_hints),
        "col_hints": str(col_hints),
        "num_solutions": num_sol,
        "solving_time_sec": solve_time
    }

In [17]:
def main():
    folder = "./mnist_nonograms"  # 你的谜题文件夹
    records = []
    for filename in os.listdir(folder):
        if filename.lower().endswith(".json"):
            file_path = os.path.join(folder, filename)
            puzzle = process_puzzle_file(file_path)
            pid = puzzle.get("id", filename)
            label = puzzle.get("label", None)
            row_hints = puzzle.get("row_hints", puzzle.get("row_constraints"))
            col_hints = puzzle.get("col_hints", puzzle.get("col_constraints"))
            print(f"Processing puzzle {pid} ...")
            solve_time, num_sol = solve_puzzle(puzzle)
            print(f"Puzzle {pid}: {num_sol} solution(s) in {solve_time:.4f} sec")
            records.append({
                "puzzle_id": pid,
                "label": label,
                "row_hints": str(row_hints),
                "col_hints": str(col_hints),
                "num_solutions": num_sol,
                "solving_time_sec": solve_time
            })
    df = pd.DataFrame(records)
    output_file = "sat_solver_results.csv"
    df.to_csv(output_file, index=False)
    print(f"Results saved in {output_file}")

In [19]:
if __name__ == "__main__":
    main()

Processing puzzle 0 ...
Puzzle 0: 1 solution(s) in 0.0685 sec
Processing puzzle 1 ...
Puzzle 1: 1 solution(s) in 0.0589 sec
Processing puzzle 10 ...
Puzzle 10: 1 solution(s) in 0.0657 sec
Processing puzzle 100 ...
Puzzle 100: 1 solution(s) in 0.0272 sec
Processing puzzle 101 ...
Puzzle 101: 1 solution(s) in 0.0401 sec
Processing puzzle 102 ...
Puzzle 102: 1 solution(s) in 0.0152 sec
Processing puzzle 103 ...
Puzzle 103: 1 solution(s) in 0.0346 sec
Processing puzzle 104 ...
Puzzle 104: 1 solution(s) in 0.0150 sec
Processing puzzle 105 ...
Puzzle 105: 1 solution(s) in 0.0203 sec
Processing puzzle 106 ...
Puzzle 106: 1 solution(s) in 0.0721 sec
Processing puzzle 107 ...
Puzzle 107: 1 solution(s) in 0.0411 sec
Processing puzzle 108 ...
Puzzle 108: 1 solution(s) in 0.0826 sec
Processing puzzle 109 ...
Puzzle 109: 1 solution(s) in 0.0554 sec
Processing puzzle 11 ...
Puzzle 11: 2 solution(s) in 0.0786 sec
Processing puzzle 110 ...
Puzzle 110: 1 solution(s) in 0.0472 sec
Processing puzzle 111 

KeyboardInterrupt: 