In [3]:
import numpy as np
import time

# Simulated Annealing

Solve the Sudoku problem with Simulated Annealing. You can design your own algorithm or simply refer to [Metaheuristics_can_solve_Sudoku_puzzles](https://www.researchgate.net/publication/220403361_Metaheuristics_can_solve_Sudoku_puzzles). 

The code provided below starts with making a problem instance and ends by visualizing the running process of SA.

In [4]:
# making a problem instance
def make_grid_python(n):
    grid = np.empty((n**2, n**2), int)
    x = 0
    for i in range(n):
        for j in range(n):
            for k in range(n**2):
                grid[n*i+j, k] = x%(n**2) + 1
                x += 1
            x += n
        x += 1
    return grid

def make_grid_numpy(n):
    return np.fromfunction(lambda i, j: (i*n+i//n+j)%(n**2)+1, (n**2, n**2), dtype=int)

# a comparison between native python and numpy
# vary n to see their performances
n = 10
%timeit make_grid_python(n)
%timeit make_grid_numpy(n)

# test
grid = make_grid_numpy(3)
grid

3.77 ms ± 10.4 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)
108 µs ± 74.5 ns per loop (mean ± std. dev. of 7 runs, 10000 loops each)


array([[1, 2, 3, 4, 5, 6, 7, 8, 9],
       [4, 5, 6, 7, 8, 9, 1, 2, 3],
       [7, 8, 9, 1, 2, 3, 4, 5, 6],
       [2, 3, 4, 5, 6, 7, 8, 9, 1],
       [5, 6, 7, 8, 9, 1, 2, 3, 4],
       [8, 9, 1, 2, 3, 4, 5, 6, 7],
       [3, 4, 5, 6, 7, 8, 9, 1, 2],
       [6, 7, 8, 9, 1, 2, 3, 4, 5],
       [9, 1, 2, 3, 4, 5, 6, 7, 8]], dtype=int32)

In [5]:
class Sudoku:
    @classmethod
    def create(cls, n, seed=303):
        rng = np.random.default_rng(seed)
        init_grid = make_grid_numpy(n)
        mask = rng.integers(0, 2, size=init_grid.shape)
        grid = init_grid * mask
        print(grid)

        return cls(n, mask, grid, seed)

    def __init__(self, n, mask, grid, seed) -> None:
        self.seed = seed
        self.mask = mask
        self.grid = grid
        self.n = n
        self.all = set(range(1, n**2+1))

    def value(self):
        final_value = 0
        for i in range(self.n**2):
            for num in range(1, self.n**2+1):
                cnt_col = 0
                cnt_row = 0
                for j in range(self.n**2):
                    if self.grid[i][j] == num:
                        cnt_row += 1
                    if self.grid[j][i] == num:
                        cnt_col += 1
                if cnt_col == 1:
                    final_value += 1
                if cnt_row == 1:
                    final_value += 1
        for i in range(self.n):
            for j in range(self.n):
                for num in range(1, self.n**2+1):
                    cnt = 0
                    for inc_col in range(self.n):
                        for inc_row in range(self.n):
                            if self.grid[3*i+inc_col][3*j+inc_row] == num:
                                cnt += 1
                    if cnt == 1:
                        final_value += 1
        return final_value
    
    def local_search(self):
        for i in range(self.n**2):
            cnt_num_row = np.zeros(self.n**2+1)
            cnt_num_col = np.zeros(self.n**2+1)
            for num in range(1, self.n**2+1):
                for j in range(self.n**2):
                    if self.grid[i][j] == num:
                        cnt_num_row[num] += 1
                    if self.grid[j][i] == num:
                        cnt_num_col[num] += 1
            target = -1
            source = -1
            for num in range(1, self.n**2+1):
                if cnt_num_row[num] < 1:
                    target = num
                if cnt_num_row[num] > 1:
                    source = num
                if target != -1 and source != -1:
                    for j in range(self.n**2):
                        if self.grid[i][j] == source:
                            self.grid[i][j] = target
                            return self
            target = -1
            source = -1
            for num in range(1, self.n**2+1):
                if cnt_num_col[num] < 1:
                    target = num
                if cnt_num_col[num] > 1:
                    source = num
                if target != -1 and source != -1:
                    for j in range(self.n**2):
                        if self.grid[j][i] == source:
                            self.grid[j][i] = target
                            return self
        for i in range(self.n):
            for j in range(self.n):
                cnt_num_cube = np.zeros(self.n**2+1)
                for num in range(1, self.n**2+1):
                    for inc_col in range(self.n):
                        for inc_row in range(self.n):
                            if self.grid[3*i+inc_col][3*j+inc_row] == num:
                                cnt_num_cube[num] += 1
                target = -1
                source = -1
                for num in range(1, self.n**2+1):
                    if cnt_num_cube[num] < 1:
                        target = num
                    if cnt_num_cube[num] > 1:
                        source = num
                    if target != -1 and source != -1:
                        for inc_col in range(self.n):
                            for inc_row in range(self.n):
                                if self.grid[3*i+inc_col][3*j+inc_row] == source:
                                    self.grid[3*i+inc_col][3*j+inc_row] = target
                                    return self

    def init_solution(self):
        rng = np.random.default_rng(self.seed)
        n = self.n
        grid = self.grid.reshape(n, n, n, n).transpose(0, 2, 1, 3)
        for I in np.ndindex(n, n):
            idx = grid[I]==0
            grid[I][idx] = rng.permutation(list(self.all-set(grid[I].flat)))
        return self
        
    def __repr__(self) -> str:
        return self.grid.__repr__()

# test
sudoku = Sudoku.create(3)
sudoku = sudoku.init_solution()

[[0 0 3 0 0 6 0 0 9]
 [4 5 0 7 0 9 1 0 0]
 [7 0 0 0 2 0 4 5 0]
 [2 3 4 0 6 7 0 0 0]
 [5 6 7 0 0 0 0 3 0]
 [0 9 0 0 0 4 0 0 0]
 [0 4 5 0 0 8 0 1 2]
 [6 7 8 0 0 0 3 0 5]
 [9 1 2 3 0 5 0 7 8]]
array([[9, 9, 9, 9, 9, 9, 9, 9, 9],
       [4, 5, 1, 7, 8, 9, 1, 2, 3],
       [7, 6, 8, 5, 2, 1, 4, 5, 7],
       [2, 3, 4, 3, 6, 7, 4, 9, 1],
       [5, 6, 7, 9, 1, 5, 5, 3, 8],
       [8, 9, 1, 8, 2, 4, 7, 2, 6],
       [3, 4, 5, 1, 9, 8, 4, 1, 2],
       [6, 7, 8, 7, 2, 4, 3, 6, 5],
       [9, 1, 2, 3, 6, 5, 9, 7, 8]])
157


In [12]:
def simulated_annealing(initial:Sudoku, schedule, halt, log_interval=200):
    state = initial.init_solution()
    t = 0           # time step
    T = schedule(t) # temperature
    f = [state.value()] # a recording of values
    while not halt(T):
        T = schedule(t)
        new_state = state.local_search()
        new_value = new_state.value()
        # TODO: implement the replacement here
        
        

        # update time and temperature
        if t % log_interval == 0:
            print(f"step {t}: T={T}, current_value={state.value()}")
        t += 1
        T = schedule(t)
    print(f"step {t}: T={T}, current_value={state.value()}")
    return state, f

In [13]:
import matplotlib.pyplot as plt

# define your own schedule and halt condition
# run the algorithm on different n with different settings
n = 4
solution, record = simulated_annealing(
    initial=Sudoku.create(n), 
    schedule=lambda t: 0.999**t, 
    halt=lambda T: T<1e-7
)
solution, solution.value()

[[ 0  0  3  0  0  6  0  0  9 10 11  0 13  0 15 16]
 [ 0  0  7  0  0  0 11  0 13 14  0 16  1  2  0  4]
 [ 9  0  0  0 13 14 15  0  0  0  0  4  0  0  7  0]
 [ 0  0 15  0  0  0  0  4  5  0  0  8  0 10 11 12]
 [ 2  3  0  0  0  7  0  9 10 11 12 13  0 15  0  1]
 [ 6  7  8  0 10 11 12  0 14 15 16  0  2  0  4  0]
 [ 0  0  0  0 14  0  0  0  2  0  0  0  6  0  0  0]
 [ 0 15 16  1  2  3  4  5  6  0  8  0 10  0  0 13]
 [ 3  4  0  0  0  0  9 10  0 12 13 14 15  0  1  0]
 [ 0  0  9  0  0  0 13  0  0 16  0  2  3  0  0  0]
 [11  0 13 14 15  0  1  0  0  4  5  6  7  8  9 10]
 [15 16  0  2  0  4  5  0  0  0  9 10 11  0  0 14]
 [ 4  0  6  0  8  0  0  0 12 13  0 15 16  0  0  0]
 [ 8  0  0 11  0 13 14  0 16  0  0  3  4  5  6  7]
 [ 0  0 14  0 16  0  2  3  0  0  6  7  8  9  0  0]
 [16  1  2  0  0  5  0  7  0  9 10 11 12 13  0  0]]
step 0: T=1.0, current_value=428
step 200: T=0.8186488294786356, current_value=428
step 400: T=0.6701859060067401, current_value=428
step 600: T=0.5486469074854967, current_value=428


KeyboardInterrupt: 

In [None]:
# visualize the curve
plt.plot(record)
plt.xlabel("time step")
plt.ylabel("value")