# Neuro-Symbolic HexaSudoku Solver

A pipeline that solves 16×16 Sudoku puzzles from images using:
1. **Character CNN** - Recognize digits (1-9) and letters (A-F) in cells
2. **Z3 Solver** - Compute valid solution with Sudoku constraints

Key features:
- 16×16 grid with 4×4 boxes
- Hexadecimal notation: 1-9 and A-F for values 10-15
- No size detection needed (always 16×16)
- 1600×1600 pixel images

In [None]:
!pip3 install z3-solver
!pip3 install opencv-python

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
import pandas as pd
from z3 import *
import os
import time
import json

## Constants

In [None]:
SIZE = 16
BOX_SIZE = 4
BOARD_PIXELS = 1600
CELL_SIZE = BOARD_PIXELS // SIZE  # 100px per cell
IMG_SIZE = 28  # CNN input size
NUM_CLASSES = 16  # 0-9 + A-F

## Helper Functions

In [None]:
def value_to_char(value):
    """Convert puzzle value (0-16) to display character."""
    if value == 0:
        return '.'
    if value <= 9:
        return str(value)
    return chr(ord('A') + value - 10)  # 10→A, 11→B, ..., 15→F

def char_to_value(char):
    """Convert display character to puzzle value."""
    if char == '.' or char == '0':
        return 0
    if char.isdigit():
        return int(char)
    return ord(char.upper()) - ord('A') + 10  # A→10, B→11, ..., F→15

## Neural Network Model

In [None]:
class CNN_v2(nn.Module):
    """CNN for character recognition (0-9, A-F)."""
    def __init__(self, output_dim):
        super(CNN_v2, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(3136, 128)
        self.fc2 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

## Load Trained Model

In [None]:
# Load HexaSudoku character model
model_path = './models/hex_character_cnn.pth'

if os.path.exists(model_path):
    character_model = CNN_v2(output_dim=NUM_CLASSES)
    state_dict = torch.load(model_path, weights_only=False)
    character_model.load_state_dict(state_dict)
    character_model.eval()
    print(f"Loaded HexaSudoku character model ({NUM_CLASSES} classes)")
    MODEL_LOADED = True
else:
    print(f"Model not found at {model_path}")
    print("Run train_cnn.py first to train the model")
    MODEL_LOADED = False

## Cell Extraction and Character Recognition

In [None]:
def extract_cell(img_array, row, col, border=10):
    """
    Extract a single cell from the 16x16 grid.
    
    Args:
        img_array: Grayscale image as numpy array (1600x1600)
        row, col: Cell position (0-15)
        border: Pixels to trim from cell edges
    
    Returns:
        Cell image as numpy array
    """
    y1 = row * CELL_SIZE + border
    y2 = (row + 1) * CELL_SIZE - border
    x1 = col * CELL_SIZE + border
    x2 = (col + 1) * CELL_SIZE - border
    
    return img_array[y1:y2, x1:x2]

In [None]:
def is_cell_empty(cell_img, threshold=0.98):
    """
    Check if a cell is empty (mostly white).
    
    Args:
        cell_img: Grayscale cell image (0-255)
        threshold: Fraction of white pixels to consider empty
    
    Returns:
        True if cell is empty
    """
    white_pixels = np.sum(cell_img > 200)
    total_pixels = cell_img.size
    return (white_pixels / total_pixels) > threshold

In [None]:
def preprocess_cell(cell_img):
    """
    Preprocess cell image for CNN input.
    
    Args:
        cell_img: Grayscale cell image
    
    Returns:
        28x28 normalized image ready for CNN
    """
    # Resize to 28x28
    cell_pil = Image.fromarray(cell_img)
    resized = cell_pil.resize((IMG_SIZE, IMG_SIZE), Image.LANCZOS)
    
    # Normalize to 0-1 and invert (ink should be high values)
    normalized = np.array(resized).astype(np.float32) / 255.0
    inverted = 1.0 - normalized  # Invert so ink is high value
    
    return inverted

In [None]:
def recognize_character(cell_img, model):
    """
    Recognize character in a cell using CNN.
    
    Args:
        cell_img: Preprocessed 28x28 cell image
        model: CNN model
    
    Returns:
        Predicted value (1-16) or 0 if empty/unrecognized
    """
    with torch.no_grad():
        tensor = torch.tensor(cell_img, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        output = model(tensor)
        prediction = torch.argmax(output, dim=1).item()
        
        # Classes 0-15 map to values 0-15
        # (0 = empty, 1-9 = digits, 10-15 = A-F)
        if prediction >= 1 and prediction <= 15:
            return prediction
        else:
            return 0  # Treat as empty

## Z3 HexaSudoku Solver

In [None]:
def solve_hexasudoku(given_cells):
    """
    Solve 16x16 HexaSudoku using Z3 constraint solver.
    
    Args:
        given_cells: dict {(row, col): value} for given clues (values 1-16)
    
    Returns:
        Solution grid (16x16 list) or None if unsolvable
    """
    # Create integer variables for each cell
    X = [[Int(f"x_{i}_{j}") for j in range(SIZE)] for i in range(SIZE)]
    
    s = Solver()
    
    # Constraint 1: Each cell contains 1 to 16
    for i in range(SIZE):
        for j in range(SIZE):
            s.add(And(X[i][j] >= 1, X[i][j] <= SIZE))
    
    # Constraint 2: Given cells
    for (i, j), val in given_cells.items():
        s.add(X[i][j] == val)
    
    # Constraint 3: Row uniqueness
    for i in range(SIZE):
        s.add(Distinct([X[i][j] for j in range(SIZE)]))
    
    # Constraint 4: Column uniqueness
    for j in range(SIZE):
        s.add(Distinct([X[i][j] for i in range(SIZE)]))
    
    # Constraint 5: 4x4 Box uniqueness
    for box_row in range(BOX_SIZE):
        for box_col in range(BOX_SIZE):
            cells = []
            for i in range(BOX_SIZE):
                for j in range(BOX_SIZE):
                    cells.append(X[box_row * BOX_SIZE + i][box_col * BOX_SIZE + j])
            s.add(Distinct(cells))
    
    # Solve
    if s.check() == sat:
        m = s.model()
        solution = [[m.evaluate(X[i][j]).as_long() for j in range(SIZE)] for i in range(SIZE)]
        return solution
    else:
        return None

## Full Pipeline

In [None]:
def extract_puzzle_from_image(filename, model):
    """
    Extract HexaSudoku puzzle from image.
    
    Args:
        filename: Path to puzzle image (1600x1600)
        model: Character recognition CNN
    
    Returns:
        given_cells dict {(row, col): value}
    """
    # Load image
    img = cv.imread(filename, cv.IMREAD_GRAYSCALE)
    if img is None:
        raise ValueError(f"Could not load image: {filename}")
    
    # Extract given cells
    given_cells = {}
    
    for row in range(SIZE):
        for col in range(SIZE):
            cell = extract_cell(img, row, col)
            
            if not is_cell_empty(cell):
                processed = preprocess_cell(cell)
                value = recognize_character(processed, model)
                
                if value > 0:  # Valid value (1-16)
                    given_cells[(row, col)] = value
    
    return given_cells

In [None]:
def solve_from_image(filename, model):
    """
    Complete pipeline: image -> solution.
    
    Args:
        filename: Path to puzzle image
        model: Character recognition CNN
    
    Returns:
        Solution grid (16x16) or None
    """
    given_cells = extract_puzzle_from_image(filename, model)
    print(f"Detected {len(given_cells)} given clues")
    
    solution = solve_hexasudoku(given_cells)
    return solution

## Direct Solving from JSON

In [None]:
def solve_from_json(puzzle):
    """
    Solve HexaSudoku from JSON puzzle data.
    
    Args:
        puzzle: 16x16 2D list where 0 = empty cell, 1-16 = filled
    
    Returns:
        Solution grid or None
    """
    given_cells = {}
    
    for i in range(SIZE):
        for j in range(SIZE):
            if puzzle[i][j] != 0:
                given_cells[(i, j)] = puzzle[i][j]
    
    return solve_hexasudoku(given_cells)

## Load Puzzle Data

In [None]:
# Load puzzle data
puzzle_path = './puzzles/puzzles_dict.json'

if os.path.exists(puzzle_path):
    with open(puzzle_path, 'r') as f:
        puzzles_ds = json.load(f)
    print(f"Loaded puzzles: {list(puzzles_ds.keys())}")
    if '16' in puzzles_ds:
        print(f"Number of 16x16 puzzles: {len(puzzles_ds['16'])}")
else:
    print("No puzzle file found. Run SymbolicPuzzleGenerator first.")
    puzzles_ds = {}

## Demo: Solve Single Puzzle from JSON

In [None]:
def print_grid(grid):
    """Print a 16x16 grid with hex notation."""
    for i, row in enumerate(grid):
        if i > 0 and i % BOX_SIZE == 0:
            print('-' * 37)
        line = ''
        for j, val in enumerate(row):
            if j > 0 and j % BOX_SIZE == 0:
                line += '| '
            line += value_to_char(val) + ' '
        print(line)

# Demo: Solve from JSON
if '16' in puzzles_ds and len(puzzles_ds['16']) > 0:
    demo_puzzle = puzzles_ds['16'][0]['puzzle']
    demo_solution_expected = puzzles_ds['16'][0]['solution']
    
    clues = sum(1 for row in demo_puzzle for cell in row if cell != 0)
    print(f"Puzzle ({clues} clues):")
    print_grid(demo_puzzle)
    
    start = time.time()
    solution = solve_from_json(demo_puzzle)
    elapsed = time.time() - start
    
    print(f"\nSolution (computed in {elapsed:.2f}s):")
    print_grid(solution)
    
    print(f"\nMatch expected: {solution == demo_solution_expected}")
else:
    print("No puzzles loaded.")

## Evaluation Functions

In [None]:
def count_clues_in_puzzle(puzzle):
    """Count non-zero cells in ground truth puzzle."""
    return sum(1 for row in puzzle for cell in row if cell != 0)

def evaluate_json_solver(puzzles_ds):
    """
    Evaluate Z3 solver on all puzzles (no image processing).
    """
    if '16' not in puzzles_ds:
        print("No 16x16 puzzles found")
        return []
    
    puzzles = puzzles_ds['16']
    results = []
    
    print(f"Evaluating {len(puzzles)} puzzles from JSON...")
    
    for i, puzzle_data in enumerate(puzzles):
        start = time.time()
        solution = solve_from_json(puzzle_data['puzzle'])
        elapsed = time.time() - start
        
        correct = (solution == puzzle_data['solution'])
        clues = count_clues_in_puzzle(puzzle_data['puzzle'])
        
        results.append({
            'puzzle_index': i,
            'num_clues': clues,
            'solved': solution is not None,
            'correct': correct,
            'time_ms': elapsed * 1000
        })
        
        if (i + 1) % 25 == 0:
            correct_count = sum(1 for r in results if r['correct'])
            print(f"  Progress: {i+1}/{len(puzzles)}, Correct: {correct_count}/{i+1}")
    
    # Summary
    correct_count = sum(1 for r in results if r['correct'])
    avg_time = sum(r['time_ms'] for r in results) / len(results)
    print(f"\nFinal: {correct_count}/{len(results)} ({100*correct_count/len(results):.1f}%)")
    print(f"Average time: {avg_time:.1f}ms")
    
    return results

In [None]:
def evaluate_image_solver(puzzles_ds, model):
    """
    Evaluate neurosymbolic solver on all board images.
    """
    if '16' not in puzzles_ds:
        print("No 16x16 puzzles found")
        return []
    
    puzzles = puzzles_ds['16']
    results = []
    
    print(f"Evaluating {len(puzzles)} puzzles from images...")
    
    for i, puzzle_data in enumerate(puzzles):
        filename = f'./board_images/board16_{i}.png'
        result = {
            'puzzle_index': i,
            'filename': filename,
            'num_clues_expected': count_clues_in_puzzle(puzzle_data['puzzle']),
            'num_clues_detected': 0,
            'clues_match': False,
            'solved': False,
            'solution_correct': False,
            'solve_time_ms': 0.0,
            'error_type': 'none',
            'error_message': ''
        }
        
        start_time = time.time()
        
        try:
            # Check if image exists
            if not os.path.exists(filename):
                result['error_type'] = 'image_not_found'
                result['error_message'] = f'Image not found: {filename}'
                results.append(result)
                continue
            
            # Extract clues from image
            given_cells = extract_puzzle_from_image(filename, model)
            result['num_clues_detected'] = len(given_cells)
            result['clues_match'] = (result['num_clues_detected'] == result['num_clues_expected'])
            
            # Solve with Z3
            solution = solve_hexasudoku(given_cells)
            
            if solution is None:
                result['error_type'] = 'unsolvable'
                result['error_message'] = 'Z3 could not find a valid solution'
            else:
                result['solved'] = True
                expected = puzzle_data['solution']
                result['solution_correct'] = (solution == expected)
                
                if not result['solution_correct']:
                    if not result['clues_match']:
                        result['error_type'] = 'character_recognition'
                        result['error_message'] = f'Detected {result["num_clues_detected"]} clues, expected {result["num_clues_expected"]}'
                    else:
                        result['error_type'] = 'wrong_solution'
                        result['error_message'] = 'Solution does not match expected'
                        
        except Exception as e:
            result['error_type'] = 'exception'
            result['error_message'] = str(e)
        
        result['solve_time_ms'] = (time.time() - start_time) * 1000
        results.append(result)
        
        if (i + 1) % 25 == 0:
            correct = sum(1 for r in results if r['solution_correct'])
            print(f"  Progress: {i+1}/{len(puzzles)}, Correct: {correct}/{i+1}")
    
    # Summary
    correct = sum(1 for r in results if r['solution_correct'])
    avg_time = sum(r['solve_time_ms'] for r in results) / len(results) if results else 0
    print(f"\nFinal: {correct}/{len(results)} ({100*correct/len(results):.1f}%)")
    print(f"Average time: {avg_time:.1f}ms")
    
    return results

## Run Evaluation (JSON)

In [None]:
# Evaluate JSON solver (no image processing)
if puzzles_ds:
    json_results = evaluate_json_solver(puzzles_ds)
else:
    json_results = []

## Run Evaluation (Images)

In [None]:
# Evaluate image solver (requires trained model and board images)
if puzzles_ds and MODEL_LOADED:
    image_results = evaluate_image_solver(puzzles_ds, character_model)
else:
    print("Cannot run image evaluation: missing puzzles or model")
    image_results = []

## Save Results

In [None]:
# Save detailed results to CSV
if image_results:
    os.makedirs('./results', exist_ok=True)
    
    detailed_df = pd.DataFrame(image_results)
    detailed_df.to_csv('./results/detailed_evaluation.csv', index=False)
    print("Saved detailed results to ./results/detailed_evaluation.csv")
    
    # Print summary
    print("\n" + "="*60)
    print("SUMMARY STATISTICS")
    print("="*60)
    
    total = len(image_results)
    correct = sum(1 for r in image_results if r['solution_correct'])
    clue_match = sum(1 for r in image_results if r['clues_match'])
    avg_time = sum(r['solve_time_ms'] for r in image_results) / total
    
    print(f"\n16x16 HexaSudoku Puzzles:")
    print(f"  Accuracy: {correct}/{total} ({100*correct/total:.1f}%)")
    print(f"  Clue Match: {clue_match}/{total}")
    print(f"  Avg Time: {avg_time:.1f}ms")
    
    # Error breakdown
    errors = detailed_df[detailed_df['error_type'] != 'none']['error_type'].value_counts()
    if len(errors) > 0:
        print(f"  Errors:")
        for error_type, count in errors.items():
            print(f"    - {error_type}: {count}")
    
    print("\n" + "="*60)

## Visualization

In [None]:
def visualize_solution(puzzle, solution):
    """
    Visualize puzzle and solution side by side.
    """
    fig, axes = plt.subplots(1, 2, figsize=(16, 8))
    
    for ax, grid, title in [(axes[0], puzzle, 'Puzzle'), (axes[1], solution, 'Solution')]:
        ax.set_xlim(0, SIZE)
        ax.set_ylim(0, SIZE)
        ax.set_aspect('equal')
        ax.set_title(title, fontsize=14)
        ax.axis('off')
        
        # Draw grid lines
        for i in range(SIZE + 1):
            lw = 2 if i % BOX_SIZE == 0 else 0.5
            ax.axhline(y=i, color='black', linewidth=lw)
            ax.axvline(x=i, color='black', linewidth=lw)
        
        # Draw characters
        for i in range(SIZE):
            for j in range(SIZE):
                val = grid[i][j]
                if val != 0:
                    char = value_to_char(val)
                    color = 'black' if title == 'Puzzle' else ('black' if puzzle[i][j] != 0 else 'blue')
                    ax.text(j + 0.5, SIZE - i - 0.5, char, 
                           ha='center', va='center', fontsize=12, color=color)
    
    plt.tight_layout()
    plt.show()

# Visualize first puzzle
if '16' in puzzles_ds and len(puzzles_ds['16']) > 0:
    puzzle = puzzles_ds['16'][0]['puzzle']
    solution = puzzles_ds['16'][0]['solution']
    visualize_solution(puzzle, solution)