In [12]:
import numpy as np
import random 
import time
import csv
from collections import defaultdict
import itertools
from math import ceil

In [13]:
def output_dimacs(n, solution):
    #with open('solution.out', 'w') as f:
    n = 0
    with open(f'{n}.out', 'w', newline='') as f:
        filewriter = csv.writer(f, delimiter=" ")
        filewriter.writerow(['p', 'cnf', len(solution), len(solution)])
        for literal in solution:
            n += 1
            filewriter.writerow([n, literal[0], 0])
    f.close()   
    return

def decode(state):
    coords = [str(i) + str(j) for i in range(1,10) for j in range(1,10)]
    new_state = []
    for x in range(len(state) - 1):
        if state[x] == '.':
            continue
        else:
            new_state.append([int(coords[x] + state[x])])
    return new_state

def propagate(f, unit_clause):
    modified = []
    for clause in f:
        if unit_clause in clause:
            continue
        if -unit_clause in clause:
            new_clause = [x for x in clause if x != -unit_clause]
            if not new_clause:
                return -1
            modified.append(new_clause)
        else:
            modified.append(clause)
    return modified

def unit_propagation(f):
    assignment = []
    unit_clauses = []  
    unit_clauses = [clause for clause in f if len(clause) == 1]
    while len(unit_clauses) > 0:
        unit_clause = unit_clauses[0]
        f = propagate(f, unit_clause[0])
        assignment += [[unit_clause[0]]]
        if f == -1:
            return -1, []
        if not f:
            return f, assignment
        unit_clauses = [clause for clause in f if len(clause) == 1]
    return f, assignment

def next_literal(f, assignment, literals, heuristic='random'):
    global nr_branches
    nr_branches += 1
    if heuristic == 'random':
        return random.choice(np.setdiff1d(literals, [[abs(x[0])] for x in assignment]))
    
    if heuristic == 'jw':                
        jw_values = []
        jw = 0
        free_literals = np.setdiff1d(literals, [[abs(x[0])] for x in assignment])
        for literal in free_literals: # loop over the literals
            # for each clause 
            for clause in f:
                if literal in clause:
                    jw += 2.**(-len(clause)) # append jw formula
            jw_values.append(jw)
            jw = 0
        # return the literal with heighest jw value
        return free_literals[np.argmax(np.array(jw_values))]
    
    if heuristic == 'rc':
        free_literals = np.setdiff1d(literals, [[abs(x[0])] for x in assignment])
        literal_dct = defaultdict(int)
        row_dct = defaultdict(int)
        col_dct = defaultdict(int)
        for literal in free_literals:
            row_dct[int(str(literal)[0])] += 1
            col_dct[int(str(literal)[1])] += 1
        for literal in free_literals:
            literal_dct[literal] = row_dct[int(str(literal)[0])] + col_dct[int(str(literal)[1])]

        return min(literal_dct, key=literal_dct.get)
    
def backtracking(formula, assignment, literals):
    formula, unit_assignment = unit_propagation(formula)
    assignment = assignment + unit_assignment 
    if formula == - 1:
        return []
    if not formula:
        return assignment
    variable = next_literal(formula, assignment, literals,'rc')
    solution = backtracking(propagate(formula, variable), assignment + [[variable]], literals)
    if not solution:
        solution = backtracking(propagate(formula, -variable), assignment + [[-variable]], literals)
        
    return solution

In [15]:
# Load in all clauses
with open('sudoku-rules.txt', 'r') as f:
    clauses = [lines.split()[:-1] for lines in f][1:]
    clauses = [ list(np.array(clause, dtype=int)) for clause in clauses]

literals = [[x+y+z] for x in range(100,1000, 100) for y in range(10,100,10) for z in range(1,10)]
runtimes = []
branches = []
with open('top91.sdk.txt', 'r') as f:
    n = 0
    for lines in f:
        nr_branches = 0
        n+=1
        start = time.time()
        solution = backtracking(clauses + decode(lines), [], literals)
        output_dimacs(n, solution)
        end = time.time()
        print('seconds',end - start)
        runtimes.append(end - start)
        branches.append(nr_branches)
        if n == 2:
            print('mean runtime: ', np.mean(runtimes))
            print('mean nr_of branches: ',  np.mean(branches))
            break

seconds 0.9550373554229736
seconds 1.0300450325012207
mean runtime:  0.9925411939620972
mean nr_of branches:  18.5


In [None]:
### flatten solution list so we can easily display it
flatten = lambda *n: (e for a in n
    for e in (flatten(*a) if isinstance(a, (tuple, list)) else (a,)))

solution = list(flatten(solutions[1]))

def visualize_solution(state):
    sudoku_grid = np.zeros((9,9))
    for x in state:
        if x > 0:
            sudoku_grid[int(str(x)[0])-1,int(str(x)[1])-1] = int(str(x)[2]) 
    return sudoku_grid

visualize_solution(solution)