# Neuro-Symbolic Sudoku Solver

A pipeline that solves Sudoku puzzles from images using:
1. **Grid CNN** - Detect puzzle size (4×4 or 9×9)
2. **Character CNN** - Recognize digits in cells
3. **Z3 Solver** - Compute valid solution with Sudoku constraints

Key differences from KenKen:
- No cage detection needed (fixed box structure)
- No operators to recognize (digits only)
- Box constraints instead of cage arithmetic

In [None]:
!pip3 install z3-solver
!pip3 install opencv-python
!pip3 install torchvision

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
import cv2 as cv
import pandas as pd
from z3 import *
import os
import time
import json

from torchvision import transforms

In [None]:
BOARD_SIZE = 900
IMG_SIZE = 28
SCALE_FACTOR = 2

## Neural Network Models

We reuse the CNN architectures from KenKen but may need to retrain for Sudoku-specific data.

In [None]:
class Grid_CNN(nn.Module):
    """CNN to detect grid size (4x4 vs 9x9)."""
    def __init__(self, output_dim):
        super(Grid_CNN, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(262144, 128)
        self.fc2 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = F.relu(self.conv1(x))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.dropout(x)
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.fc2(x)
        return x

In [None]:
class CNN_v2(nn.Module):
    """CNN for character recognition (digits 0-9)."""
    def __init__(self, output_dim):
        super(CNN_v2, self).__init__()
        self.conv1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, padding=1)
        self.conv2 = nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1)
        self.pool = nn.MaxPool2d(2, 2)
        self.dropout = nn.Dropout(0.25)
        self.fc1 = nn.Linear(3136, 128)
        self.fc2 = nn.Linear(128, output_dim)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.fc2(x)
        return x

In [None]:
transform = transforms.Compose([
    transforms.Grayscale(num_output_channels=1),
    transforms.Resize(128),
    transforms.CenterCrop(128),
    transforms.ToTensor(),
])

## Load Pre-trained Models

We'll try to use the KenKen models initially. The character model recognizes 14 classes (0-9 plus operators), but we only need 0-9 for Sudoku.

In [None]:
# Try to load KenKen character model (14 classes, we'll use only 0-9)
kenken_model_path = '../KenKen/models/character_recognition_v2_model_weights.pth'

if os.path.exists(kenken_model_path):
    character_model = CNN_v2(output_dim=14)  # KenKen model has 14 classes
    state_dict = torch.load(kenken_model_path, weights_only=False)
    character_model.load_state_dict(state_dict)
    character_model.eval()
    print("Loaded KenKen character model (using classes 0-9 only)")
    USE_KENKEN_MODEL = True
else:
    print("KenKen model not found. Need to train Sudoku-specific model.")
    USE_KENKEN_MODEL = False

In [None]:
# For grid detection, we can use heuristics based on line detection
# since Sudoku only has 2 sizes (4x4 or 9x9)

def detect_grid_size(filename):
    """Detect if grid is 4x4 or 9x9 using line detection."""
    img = cv.imread(filename, cv.IMREAD_GRAYSCALE)
    if img is None:
        raise ValueError(f"Could not load image: {filename}")
    
    # Use Canny edge detection
    edges = cv.Canny(img, 50, 150)
    
    # Find horizontal lines
    lines = cv.HoughLinesP(edges, 1, np.pi/180, 100, minLineLength=100, maxLineGap=10)
    
    if lines is None:
        # Default to 9x9 if detection fails
        return 9
    
    # Count horizontal lines (filtering near-horizontal)
    h_lines = []
    for line in lines:
        x1, y1, x2, y2 = line[0]
        if abs(y2 - y1) < 10:  # Near horizontal
            h_lines.append((y1 + y2) // 2)
    
    # Remove duplicates (lines within 20 pixels of each other)
    h_lines = sorted(set(h_lines))
    unique_lines = [h_lines[0]] if h_lines else []
    for y in h_lines[1:]:
        if y - unique_lines[-1] > 20:
            unique_lines.append(y)
    
    # 4x4 has 5 horizontal lines, 9x9 has 10 horizontal lines
    if len(unique_lines) <= 6:
        return 4
    else:
        return 9

## Cell Extraction and Digit Recognition

In [None]:
def extract_cell(img_array, size, row, col, border=10):
    """
    Extract a single cell from the grid.
    
    Args:
        img_array: Grayscale image as numpy array
        size: Grid size (4 or 9)
        row, col: Cell position
        border: Pixels to trim from cell edges
    
    Returns:
        Cell image as numpy array
    """
    height, width = img_array.shape
    cell_h = height // size
    cell_w = width // size
    
    y1 = row * cell_h + border
    y2 = (row + 1) * cell_h - border
    x1 = col * cell_w + border
    x2 = (col + 1) * cell_w - border
    
    return img_array[y1:y2, x1:x2]

In [None]:
def is_cell_empty(cell_img, threshold=0.98):
    """
    Check if a cell is empty (mostly white).

    Args:
        cell_img: Grayscale cell image (0-255)
        threshold: Fraction of white pixels to consider empty
                   (0.98 works for both 4x4 and 9x9 - filled cells max at ~0.96)

    Returns:
        True if cell is empty
    """
    white_pixels = np.sum(cell_img > 200)
    total_pixels = cell_img.size
    return (white_pixels / total_pixels) > threshold

In [None]:
def preprocess_cell(cell_img):
    """
    Preprocess cell image for CNN input.
    
    Args:
        cell_img: Grayscale cell image
    
    Returns:
        28x28 normalized image ready for CNN
    """
    # Resize to 28x28
    cell_pil = Image.fromarray(cell_img)
    resized = cell_pil.resize((IMG_SIZE, IMG_SIZE), Image.LANCZOS)
    
    # Normalize to 0-1
    normalized = np.array(resized).astype(np.float32) / 255.0
    
    return normalized

In [None]:
def recognize_digit(cell_img, model):
    """
    Recognize digit in a cell using CNN.
    
    Args:
        cell_img: Preprocessed 28x28 cell image
        model: CNN model
    
    Returns:
        Predicted digit (1-9) or 0 if empty
    """
    with torch.no_grad():
        tensor = torch.tensor(cell_img, dtype=torch.float32).unsqueeze(0).unsqueeze(0)
        output = model(tensor)
        prediction = torch.argmax(output, dim=1).item()
        
        # For KenKen model, classes 0-9 are digits
        if prediction <= 9:
            return prediction
        else:
            return 0  # Treat operator predictions as empty

## Z3 Sudoku Solver

In [None]:
def solve_sudoku(size, given_cells):
    """
    Solve Sudoku using Z3 constraint solver.
    
    Args:
        size: Grid size (4 or 9)
        given_cells: dict {(row, col): digit} for given clues
    
    Returns:
        Solution grid or None if unsolvable
    """
    box_size = 2 if size == 4 else 3
    
    # Create integer variables for each cell
    X = [[Int(f"x_{i}_{j}") for j in range(size)] for i in range(size)]
    
    s = Solver()
    
    # Constraint 1: Each cell contains 1 to size
    for i in range(size):
        for j in range(size):
            s.add(And(X[i][j] >= 1, X[i][j] <= size))
    
    # Constraint 2: Given cells
    for (i, j), val in given_cells.items():
        s.add(X[i][j] == val)
    
    # Constraint 3: Row uniqueness
    for i in range(size):
        s.add(Distinct([X[i][j] for j in range(size)]))
    
    # Constraint 4: Column uniqueness
    for j in range(size):
        s.add(Distinct([X[i][j] for i in range(size)]))
    
    # Constraint 5: Box uniqueness
    for box_row in range(box_size):
        for box_col in range(box_size):
            cells = []
            for i in range(box_size):
                for j in range(box_size):
                    cells.append(X[box_row * box_size + i][box_col * box_size + j])
            s.add(Distinct(cells))
    
    # Solve
    if s.check() == sat:
        m = s.model()
        solution = [[m.evaluate(X[i][j]).as_long() for j in range(size)] for i in range(size)]
        return solution
    else:
        print("No solution found")
        return None

## Full Pipeline

In [None]:
def extract_puzzle_from_image(filename, model):
    """
    Extract Sudoku puzzle from image.
    
    Args:
        filename: Path to puzzle image
        model: Character recognition CNN
    
    Returns:
        (size, given_cells) tuple
    """
    # Load image
    img = cv.imread(filename, cv.IMREAD_GRAYSCALE)
    if img is None:
        raise ValueError(f"Could not load image: {filename}")
    
    # Detect grid size
    size = detect_grid_size(filename)
    
    # Extract given cells
    given_cells = {}
    
    for row in range(size):
        for col in range(size):
            cell = extract_cell(img, size, row, col)
            
            if not is_cell_empty(cell):
                processed = preprocess_cell(cell)
                digit = recognize_digit(processed, model)
                
                if digit > 0:  # Valid digit
                    given_cells[(row, col)] = digit
    
    return size, given_cells

In [None]:
def solve_from_image(filename, model):
    """
    Complete pipeline: image -> solution.
    
    Args:
        filename: Path to puzzle image
        model: Character recognition CNN
    
    Returns:
        Solution grid or None
    """
    size, given_cells = extract_puzzle_from_image(filename, model)
    print(f"Detected {size}x{size} puzzle with {len(given_cells)} given cells")
    
    solution = solve_sudoku(size, given_cells)
    return solution

## Alternative: Direct Solving from JSON

Since we generate puzzles programmatically, we can also solve directly from JSON without image processing.

In [None]:
def solve_from_json(puzzle):
    """
    Solve Sudoku from JSON puzzle data.
    
    Args:
        puzzle: 2D list where 0 = empty cell
    
    Returns:
        Solution grid or None
    """
    size = len(puzzle)
    given_cells = {}
    
    for i in range(size):
        for j in range(size):
            if puzzle[i][j] != 0:
                given_cells[(i, j)] = puzzle[i][j]
    
    return solve_sudoku(size, given_cells)

## Demo: Solve Single Puzzle

In [None]:
# Load puzzle data
puzzle_path = './puzzles/puzzles_dict.json'

if os.path.exists(puzzle_path):
    with open(puzzle_path, 'r') as f:
        puzzles_ds = json.load(f)
    print(f"Loaded puzzles: {list(puzzles_ds.keys())}")
else:
    print("No puzzle file found. Run SymbolicPuzzleGenerator first.")
    puzzles_ds = {}

In [None]:
# Demo: Solve from JSON
if '4' in puzzles_ds and len(puzzles_ds['4']) > 0:
    demo_puzzle = puzzles_ds['4'][0]['puzzle']
    demo_solution_expected = puzzles_ds['4'][0]['solution']
    
    print("Puzzle:")
    for row in demo_puzzle:
        print(row)
    
    solution = solve_from_json(demo_puzzle)
    
    print("\nSolution:")
    for row in solution:
        print(row)
    
    print("\nExpected:")
    for row in demo_solution_expected:
        print(row)
    
    print(f"\nMatch: {solution == demo_solution_expected}")

## Evaluate on Full Dataset

In [None]:
def evaluate_solver(puzzles_ds, use_images=False, model=None):
    """
    Evaluate solver on puzzle dataset.
    
    Args:
        puzzles_ds: Dictionary of puzzles by size
        use_images: Whether to solve from images (requires model)
        model: Character recognition model (if use_images=True)
    
    Returns:
        Dictionary of results by size
    """
    results = {}
    
    for size_str in puzzles_ds.keys():
        size = int(size_str)
        puzzles = puzzles_ds[size_str]
        
        correct = 0
        total_time = 0
        
        print(f"\nEvaluating {size}x{size} puzzles ({len(puzzles)} total)...")
        
        for i, puzzle_data in enumerate(puzzles):
            start = time.time()
            
            try:
                if use_images and model is not None:
                    filename = f'./board_images/board{size}_{i}.png'
                    solution = solve_from_image(filename, model)
                else:
                    solution = solve_from_json(puzzle_data['puzzle'])
                
                expected = puzzle_data['solution']
                
                if solution == expected:
                    correct += 1
                    
            except Exception as e:
                print(f"  Error on puzzle {i}: {e}")
            
            total_time += time.time() - start
            
            if (i + 1) % 25 == 0:
                print(f"  Progress: {i+1}/{len(puzzles)}, Accuracy: {correct}/{i+1}")
        
        accuracy = correct / len(puzzles) if puzzles else 0
        avg_time = total_time / len(puzzles) if puzzles else 0
        
        results[size] = {
            'accuracy': accuracy,
            'avg_time': avg_time,
            'correct': correct,
            'total': len(puzzles)
        }
        
        print(f"  Final: {correct}/{len(puzzles)} ({accuracy*100:.1f}%), Avg time: {avg_time:.3f}s")
    
    return results

In [None]:
# Evaluate using JSON (no image processing)
if puzzles_ds:
    results = evaluate_solver(puzzles_ds, use_images=False)

## Save Results

In [None]:
if puzzles_ds:
    # Convert results to DataFrame
    results_df = pd.DataFrame([
        {'size': size, 'accuracy': data['accuracy'], 'avg_time': data['avg_time']}
        for size, data in results.items()
    ])
    
    # Save
    os.makedirs('./results', exist_ok=True)
    results_df.to_csv('./results/neurosymbolic_solver.csv', index=False)
    
    print("\nResults saved to ./results/neurosymbolic_solver.csv")
    print(results_df)

## Visualization

In [None]:
def visualize_solution(puzzle, solution, size):
    """
    Visualize puzzle and solution side by side.
    """
    fig, axes = plt.subplots(1, 2, figsize=(12, 6))
    
    for ax, grid, title in [(axes[0], puzzle, 'Puzzle'), (axes[1], solution, 'Solution')]:
        ax.set_xlim(0, size)
        ax.set_ylim(0, size)
        ax.set_aspect('equal')
        ax.set_title(title, fontsize=14)
        ax.axis('off')
        
        box_size = 2 if size == 4 else 3
        
        # Draw grid lines
        for i in range(size + 1):
            lw = 2 if i % box_size == 0 else 0.5
            ax.axhline(y=i, color='black', linewidth=lw)
            ax.axvline(x=i, color='black', linewidth=lw)
        
        # Draw numbers
        for i in range(size):
            for j in range(size):
                val = grid[i][j]
                if val != 0:
                    color = 'black' if title == 'Puzzle' else ('black' if puzzle[i][j] != 0 else 'blue')
                    ax.text(j + 0.5, size - i - 0.5, str(val), 
                           ha='center', va='center', fontsize=16, color=color)
    
    plt.tight_layout()
    plt.show()

In [None]:
# Visualize a solution
if puzzles_ds and '9' in puzzles_ds and len(puzzles_ds['9']) > 0:
    puzzle = puzzles_ds['9'][0]['puzzle']
    solution = puzzles_ds['9'][0]['solution']
    visualize_solution(puzzle, solution, 9)

## Detailed Image-Based Evaluation

Evaluate the neurosymbolic solver on all board images with detailed metrics tracking.

In [None]:
def count_clues_in_puzzle(puzzle):
    """Count non-zero cells in ground truth puzzle."""
    return sum(1 for row in puzzle for cell in row if cell != 0)

def evaluate_from_images(puzzles_ds, model):
    """
    Evaluate solver on all board images with detailed metrics.
    
    Args:
        puzzles_ds: Dictionary of puzzles by size
        model: Character recognition CNN model
    
    Returns:
        List of result dictionaries (one per puzzle)
    """
    results = []
    
    for size_str in puzzles_ds.keys():
        size = int(size_str)
        puzzles = puzzles_ds[size_str]
        
        print(f"\nEvaluating {size}x{size} puzzles ({len(puzzles)} total)...")
        
        for i, puzzle_data in enumerate(puzzles):
            filename = f'./board_images/board{size}_{i}.png'
            result = {
                'size': size,
                'puzzle_index': i,
                'filename': filename,
                'size_detected': None,
                'size_correct': False,
                'num_clues_expected': count_clues_in_puzzle(puzzle_data['puzzle']),
                'num_clues_detected': 0,
                'clues_match': False,
                'solved': False,
                'solution_correct': False,
                'solve_time_ms': 0.0,
                'error_type': 'none',
                'error_message': ''
            }
            
            start_time = time.time()
            
            try:
                # Step 1: Load image and detect grid size
                img = cv.imread(filename, cv.IMREAD_GRAYSCALE)
                if img is None:
                    result['error_type'] = 'image_load_error'
                    result['error_message'] = f'Could not load image: {filename}'
                    results.append(result)
                    continue
                
                detected_size = detect_grid_size(filename)
                result['size_detected'] = detected_size
                result['size_correct'] = (detected_size == size)
                
                if not result['size_correct']:
                    result['error_type'] = 'size_detection'
                    result['error_message'] = f'Detected {detected_size}x{detected_size} instead of {size}x{size}'
                    result['solve_time_ms'] = (time.time() - start_time) * 1000
                    results.append(result)
                    continue
                
                # Step 2: Extract cells and recognize digits
                given_cells = {}
                for row in range(size):
                    for col in range(size):
                        cell = extract_cell(img, size, row, col)
                        if not is_cell_empty(cell):
                            processed = preprocess_cell(cell)
                            digit = recognize_digit(processed, model)
                            if digit > 0:
                                given_cells[(row, col)] = digit
                
                result['num_clues_detected'] = len(given_cells)
                result['clues_match'] = (result['num_clues_detected'] == result['num_clues_expected'])
                
                # Step 3: Solve with Z3
                solution = solve_sudoku(size, given_cells)
                
                if solution is None:
                    result['error_type'] = 'unsolvable'
                    result['error_message'] = 'Z3 could not find a valid solution'
                    result['solve_time_ms'] = (time.time() - start_time) * 1000
                    results.append(result)
                    continue
                
                result['solved'] = True
                
                # Step 4: Validate solution
                expected = puzzle_data['solution']
                result['solution_correct'] = (solution == expected)
                
                if not result['solution_correct']:
                    if not result['clues_match']:
                        result['error_type'] = 'digit_recognition'
                        result['error_message'] = f'Detected {result["num_clues_detected"]} clues, expected {result["num_clues_expected"]}'
                    else:
                        result['error_type'] = 'wrong_solution'
                        result['error_message'] = 'Solution does not match expected'
                
            except Exception as e:
                result['error_type'] = 'image_load_error'
                result['error_message'] = str(e)
            
            result['solve_time_ms'] = (time.time() - start_time) * 1000
            results.append(result)
            
            if (i + 1) % 25 == 0:
                correct = sum(1 for r in results if r['size'] == size and r['solution_correct'])
                print(f"  Progress: {i+1}/{len(puzzles)}, Correct: {correct}/{i+1}")
        
        # Print size summary
        size_results = [r for r in results if r['size'] == size]
        correct = sum(1 for r in size_results if r['solution_correct'])
        avg_time = sum(r['solve_time_ms'] for r in size_results) / len(size_results)
        print(f"  Final: {correct}/{len(size_results)} ({100*correct/len(size_results):.1f}%), Avg time: {avg_time:.1f}ms")
    
    return results

In [None]:
# Run detailed evaluation on all board images
if puzzles_ds and USE_KENKEN_MODEL:
    detailed_results = evaluate_from_images(puzzles_ds, character_model)
else:
    print("Cannot run image evaluation: missing puzzles or model")
    detailed_results = []

In [None]:
# Save detailed results to CSV
if detailed_results:
    detailed_df = pd.DataFrame(detailed_results)
    
    # Ensure results directory exists
    os.makedirs('./results', exist_ok=True)
    
    # Save detailed CSV
    detailed_df.to_csv('./results/detailed_evaluation.csv', index=False)
    print("Saved detailed results to ./results/detailed_evaluation.csv")
    
    # Print summary statistics
    print("\n" + "="*60)
    print("SUMMARY STATISTICS")
    print("="*60)
    
    for size in sorted(detailed_df['size'].unique()):
        size_df = detailed_df[detailed_df['size'] == size]
        total = len(size_df)
        correct = size_df['solution_correct'].sum()
        accuracy = 100 * correct / total
        avg_time = size_df['solve_time_ms'].mean()
        
        print(f"\n{size}x{size} Puzzles:")
        print(f"  Accuracy: {correct}/{total} ({accuracy:.1f}%)")
        print(f"  Avg Time: {avg_time:.1f}ms")
        print(f"  Size Detection: {size_df['size_correct'].sum()}/{total} correct")
        print(f"  Clue Match: {size_df['clues_match'].sum()}/{total}")
        
        # Error breakdown
        errors = size_df[size_df['error_type'] != 'none']['error_type'].value_counts()
        if len(errors) > 0:
            print(f"  Errors:")
            for error_type, count in errors.items():
                print(f"    - {error_type}: {count}")
    
    print("\n" + "="*60)
    print(f"OVERALL: {detailed_df['solution_correct'].sum()}/{len(detailed_df)} " +
          f"({100*detailed_df['solution_correct'].mean():.1f}%) correct")
    print("="*60)