In [10]:
import numpy as np
import random 

In [11]:
# Load in all clauses
with open('sudoku-rules.txt', 'r') as f:
    clauses = [lines.split()[:-1] for lines in f][1:]
    clauses = [np.array(clause, dtype=int) for clause in clauses]

literals = [[x+y+z] for x in range(100,1000, 100) for y in range(10,100,10) for z in range(1,10)]
start_state = [168 ,175, 225, 231, 318, 419, 444, 465, 493, 689, 692, 727, 732, 828, 886, 956, 961, 973]
start_state = [[x] for x in start_state]

In [12]:
def propagate(f, unit_clause):
    modified = []
    for clause in f:
        if unit_clause in clause:
            continue
        if -unit_clause in clause:
            new_clause = [x for x in clause if x != -unit_clause]
            if not new_clause:
                return -1
            modified.append(new_clause)
        else:
            modified.append(clause)
    return modified

def unit_propagation(f):
    assignment = []
    unit_clauses = [clause for clause in f if len(clause) == 1]
    while len(unit_clauses) > 0:
        unit_clause = unit_clauses[0]
        f = propagate(f, unit_clause[0])
        assignment += [[unit_clause[0]]]
        if f == -1:
            return -1, []
        if not f:
            return f, assignment
        unit_clauses = [clause for clause in f if len(clause) == 1]
    return f, assignment
    
def backtracking(formula, assignment, literals):
    formula, unit_assignment = unit_propagation(formula)
    assignment = assignment + unit_assignment 
    if formula == - 1:
        return []
    if not formula:
        return assignment
    
    variable = random.choice(np.setdiff1d(literals, [[abs(x[0])] for x in assignment]))
    solution = backtracking(propagate(formula, variable), assignment + [[variable]], literals)
    if not solution:
        solution = backtracking(propagate(formula, -variable), assignment + [[-variable]], literals)

    return solution

In [13]:
solution = backtracking(clauses + start_state, [], literals)

In [14]:
# flatten solution list so we can easily display it
flatten = lambda *n: (e for a in n
    for e in (flatten(*a) if isinstance(a, (tuple, list)) else (a,)))
solution = list(flatten(solution))

def visualize_solution(state):
    sudoku_grid = np.zeros((9,9))
    for x in state:
        if x > 0:
            sudoku_grid[int(str(x)[0])-1,int(str(x)[1])-1] = int(str(x)[2]) 
    return sudoku_grid

visualize_solution(solution)

array([[2., 6., 9., 1., 4., 8., 5., 3., 7.],
       [4., 5., 1., 3., 7., 2., 9., 8., 6.],
       [8., 3., 7., 9., 5., 6., 4., 2., 1.],
       [9., 2., 6., 4., 8., 5., 7., 1., 3.],
       [3., 1., 8., 7., 2., 9., 6., 4., 5.],
       [7., 4., 5., 6., 1., 3., 8., 9., 2.],
       [6., 7., 2., 8., 3., 4., 1., 5., 9.],
       [1., 8., 3., 5., 9., 7., 2., 6., 4.],
       [5., 9., 4., 2., 6., 1., 3., 7., 8.]])