# Solving with priority
In each call search the cell with the least legal values. Still only use the original cache.

In [1]:
#imports
import time
import copy
import json
import csv

In [2]:
#Helper functionss

#to check if a proposed number is valid for the sudoku solution
def is_valid(sudoku, row, col, num):
    # Check row
    for x in range(9):
        if sudoku[row][x] == num: #checks if number is in row
            return False
    # Check column
    for x in range(9):
        if sudoku[x][col] == num: #checks if number is in col
            return False
    # Check 3x3 subgrid
    start_row = 3 * (row // 3)
    start_col = 3 * (col // 3)
    for i in range(3):
        for j in range(3):
            if sudoku[i + start_row][j + start_col] == num:
                return False
    return True

def find_empty_location_priority(sudoku, cache):
    min_len = 10
    min_key = None
    for key in cache:
        if len(cache[key]) < min_len and sudoku[key[0]][key[1]] == 0:
            min_len = len(cache[key])
            min_key = key
    return min_key

def create_cache(sudoku):
    valid_values = {}
    for row in range(9):
        for col in range(9):
            if sudoku[row][col] == 0:
                valid_values[(row, col)] = []
                for num in range(1, 10):
                    if is_valid(sudoku, row, col, num):
                        valid_values[(row, col)].append(num)
    return valid_values

def update_cache(cache, row, col, num):
    changes = []
    for key in list(cache.keys()):
        if (key[0] == row and key[1] == col):
            continue
        if row == key[0] or col == key[1] or (row // 3 == key[0] // 3 and col // 3 == key[1] // 3):
            if num in cache[key]:
                cache[key].remove(num)
                changes.append((key, num))

   
    for x in range(9):
        start_row = 3 * (x // 3)
        start_col = 3 * (x // 3)
        cell_where_legal = None
        found_two = False
         # Check for new hidden singles in blocks
        for i in range(3):
            if found_two:
                break
            for j in range(3):
                if found_two:
                    break
                if (start_row + i, start_col + j) not in cache:
                    continue
                if num in cache[(start_row + i, start_col + j)]:
                    if cell_where_legal:
                        found_two = True
                    else:
                        cell_where_legal = (start_row + i, start_col + j)
        if cell_where_legal and not found_two:
            for n in cache[cell_where_legal]:
                if n != num:
                    cache[cell_where_legal].remove(n)
                    changes.append((cell_where_legal, n))

        # Check for new hidden singles in rows
        cell_where_legal = None
        found_two = False
        for i in range(9):
            if found_two:
                break
            if (i, x) not in cache:
                continue
            if num in cache[(i, x)]:
                if cell_where_legal:
                    found_two = True
                else:
                    cell_where_legal = (i, x)

        if cell_where_legal and not found_two:
            for n in cache[cell_where_legal]:
                if n != num:
                    cache[cell_where_legal].remove(n)
                    changes.append((cell_where_legal, n))

        # Check for new hidden singles in columns
        cell_where_legal = None
        found_two = False
        for i in range(9):
            if found_two:
                break
            if (x, i) not in cache:
                continue
            if num in cache[(x, i)]:
                if cell_where_legal:
                    found_two = True
                else:
                    cell_where_legal = (x, i)


        if cell_where_legal and not found_two:
            for n in cache[cell_where_legal]:
                if n != num:
                    cache[cell_where_legal].remove(n)
                    changes.append((cell_where_legal, n))

    # Check for naked pairs
    for key in cache.keys():
        if len(cache[key]) == 2:
            for key2 in cache.keys():
                if key == key2:
                    continue
                if cache[key] != cache[key2]:
                    continue

                # Check if they are in the same row
                if key[0] == key2[0]:
                    for i in range(9):
                        if i != key[1] and i != key2[1]:
                            if (key[0], i) in cache:
                                for n in cache[key[0], i]:
                                    if n in cache[key]:
                                        cache[key[0], i].remove(n)
                                        changes.append(((key[0], i), n)) 

                # Check if they are in the same column
                if key[1] == key2[1]:
                    for i in range(9):
                        if i != key[0] and i != key2[0]:
                            if (i, key[1]) in cache:
                                for n in cache[i, key[1]]:
                                    if n in cache[key]:
                                        cache[i, key[1]].remove(n)
                                        changes.append(((i, key[1]), n))

                # Check if they are in the same block
                if key[0] // 3 == key2[0] // 3 and key[1] // 3 == key2[1] // 3:
                    start_row = 3 * (key[0] // 3)
                    start_col = 3 * (key[1] // 3)
                    for i in range(3):
                        for j in range(3):
                            if (start_row + i, start_col + j) in cache and (start_row + i, start_col + j) != key and (start_row + i, start_col + j) != key2:
                                for n in cache[start_row + i, start_col + j]:
                                    if n in cache[key]:
                                        cache[start_row + i, start_col + j].remove(n)
                                        changes.append(((start_row + i, start_col + j), n))
    return changes


def revert_cache(cache, changes):
    for key, num in changes:
        if key in cache:
            cache[key].append(num)
        else:
            cache[key] = [num]

def solve_sudoku(sudoku, cache):
    solve_sudoku.counter += 1
    if solve_sudoku.counter > 500000:
        return False
    #print_sudoku(sudoku)
    #print()
    empty_location = find_empty_location_priority(sudoku, cache)
    if not empty_location:
        return True  # No empty cell left, puzzle solved
    row, col = empty_location

    # Only check values that are legaly allowed
    values = cache[(row, col)]
    for num in values: 
        sudoku[row][col] = num
        cache.pop((row, col))
        changes = update_cache(cache, row, col, num)
        #print(f"Placing {num} at ({row}, {col})")
        if solve_sudoku(sudoku, cache): 
            return True
        #print(f"Backtracking from ({row}, {col}), removing {num}")
        sudoku[row][col] = 0
        cache[(row, col)] = values
        revert_cache(cache, changes)  
    return False


def print_sudoku(sudoku):
    for row in sudoku:
        print(row)

def time_solve(sudoku):
    result = copy.deepcopy(sudoku)
    start_time = time.time()
    cache = create_cache(result)
    if solve_sudoku(result, cache):
        print("Sudoku solved successfully!")
        print_sudoku(result)
    else:
        print("No solution exists.")
    print(f"Time taken: {time.time() - start_time:.7f} seconds")

In [3]:
def time_solve(sudoku, solution):
    solve_sudoku.counter = 0
    result = copy.deepcopy(sudoku)
    start_time = time.time()
    correct = False
    cache = create_cache(result)
    if solve_sudoku(result, cache) and result == solution:
        #print("Sudoku solved successfully!")
        #print_sudoku(result)
        correct = True
    #else:
        #print("No solution exists or the solution is incorrect.")
    return correct, time.time() - start_time, solve_sudoku.counter

def string_to_2d_array(sudoku_string):
    sudoku_array = []
    row = []
    for i, char in enumerate(sudoku_string):
        if char.isdigit():
            row.append(int(char))
            if (i + 1) % 9 == 0:
                sudoku_array.append(row)
                row = []
    return sudoku_array

# Import from csv
raw = []
solved = []
with open('sudoku.csv', newline='') as csvfile:
    reader = csv.reader(csvfile)
    for row in reader:

        raw_sudoku = string_to_2d_array(row[0])
        solved_sudoku = string_to_2d_array(row[1])

        raw.append(raw_sudoku)
        solved.append(solved_sudoku)

times = []
call_counts = []
num_failed = 0
for i in range(len(raw)):
    print(f"Solving sudoku {i}", end="\r")
    correct, this_time, call_count = time_solve(raw[i], solved[i])
    if not correct:
        num_failed += 1
    else:
        times.append(this_time)
        call_counts.append(call_count)
        
# Output the data to a CSV file
with open('sudoku_stats.csv', 'a', newline='') as csvfile:
    writer = csv.writer(csvfile)
    writer.writerow(["Naked pair", sum(times) / len(times), sum(call_counts) / len(call_counts), min(times), max(times), min(call_counts), max(call_counts), len(times), num_failed])

Solving sudoku 299