In [6]:
%matplotlib inline

import numpy as np
import itertools
import random
import matplotlib
import matplotlib.pyplot as plt
import pickle
from tqdm import tqdm, tqdm_notebook
import math

import torch
import torch.nn as nn
import torch.optim as optim

import sys
sys.path.append('/home/ajhnam/sudoku/src/sudoku')

from board import Board
from solutions import Solutions
import utils

In [7]:
# set random seed to 0
np.random.seed(0)
torch.manual_seed(0)

torch.set_default_tensor_type('torch.DoubleTensor')

In [8]:
with open('solutions5.pickle', 'rb') as f:
    solutions = pickle.load(f)
puzzles = solutions.get_puzzles_by_hints()
for hints in sorted(puzzles):
    print(hints, len(puzzles[hints]))

4 64
5 357
6 883
7 1584
8 2384
9 3309
10 4149
11 4754
12 4841
13 3741
14 1391
15 192
16 12


In [9]:
def is_in_row(board, row, digit):
    return digit in board[row]

def is_in_col(board, col, digit):
    return digit in board[:,col]

def is_in_box(board, box, digit):
    return digit in board.get_box_by_index(box).box

def vectorize_cell(board, x, y):
    vector = torch.zeros(board.max_digit)
    if board[x][y] != 0:
        vector[board[x][y]-1] = 1
    return vector

def vectorize_cell_distribution(max_digit, cell_coordinates):
    """
    Creates a uniform distribution amongst cell_coordinates and zeros out other locations
    >>> vectorize_cell_distribution(4, [(1,1), (0, 3), (2, 0)])
    tensor([0.0000, 0.0000, 0.0000, 0.3333, 0.0000, 0.3333, 0.0000, 0.0000, 0.3333,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000])
    """
    cell_coordinates = np.array(list(set(cell_coordinates)))
    assert cell_coordinates.shape[0] < max_digit**2
    assert cell_coordinates.shape[1] == 2
    vector = torch.zeros(max_digit**2)
    for coord in cell_coordinates:
        vector[coord[0]*max_digit + coord[1]] = 1
    return vector/len(cell_coordinates)

In [10]:
class Query():
    ROW = "ROW"
    COLUMN = "COLUMN"
    BOX = "BOX"
    
    def __init__(self, board: Board, digit: int, house_type: int, house_index: int):
        assert house_type == Query.ROW or house_type == Query.COLUMN or house_type == Query.BOX
        assert 0 < digit <= board.max_digit
        assert 0 <= house_index < board.max_digit
        
        self.board = board
        self.digit = digit
        self.house_type = house_type
        self.house_index = house_index
        
    def __repr__(self):
        return "Is {0} in {1} {2}: {3}".format(self.digit,
                                               self.house_type.lower(),
                                               self.house_index,
                                               'Yes' if self.answer() else 'No')
        
    def vectorize(self):
        digit = torch.zeros(self.board.max_digit)
        house_type = torch.zeros(3)
        house_index = torch.zeros(self.board.max_digit)
        
        digit[self.digit-1] = 1
        if self.house_type == Query.ROW:
            house_type[0] = 1
        elif self.house_type == Query.COLUMN:
            house_type[1] = 1
        else:
            house_type[2] = 1
        house_index[self.house_index] = 1
        
        return torch.cat([digit, house_type, house_index])
    
    def answer(self):
        if self.house_type is Query.ROW:
            return is_in_row(self.board, self.house_index, self.digit)
        if self.house_type is Query.COLUMN:
            return is_in_col(self.board, self.house_index, self.digit)
        else:
            return is_in_box(self.board, self.house_index, self.digit)
    
    def relevant_cells(self):
        if self.house_type == Query.ROW:
            return {(self.house_index, i) for i in range(self.board.max_digit)}
        if self.house_type == Query.COLUMN:
            return {(i, self.house_index) for i in range(self.board.max_digit)}
        return {tuple(t) for t in self.board.get_box_by_index(self.house_index).get_coordinates()}
    
    def is_answerable(self, seen_cells):
        """
        Checks if the query could be answered with just the seen cells
        """
        if self.answer():
            if self.house_type == Query.ROW:
                cell = self.board.find_digit_in_row(self.house_index, self.digit)
            elif self.house_type == Query.COLUMN:
                cell = self.board.find_digit_in_column(self.house_index, self.digit)
            else:
                cell = self.board.find_digit_in_box_by_index(self.house_index, self.digit)
        return not bool(self.relevant_cells() - set(seen_cells))
    
    def vector_dim(max_digit):
        return 3 + 2*max_digit

In [11]:
class Response():
    MAYBE = "MAYBE"
    TRUE = "TRUE"
    FALSE = "FALSE"
    
    def __init__(self, max_digit, attention_dist, answer_dist):
        self.attention_dist = attention_dist
        self.answer_dist = answer_dist
        
        cell_index = torch.argmax(attention_dist)
        self.x = int(cell_index // max_digit)
        self.y = int(cell_index % max_digit)
        
        answer_index = torch.argmax(answer_dist)
        if answer_index == 0:
            self.answer = Response.TRUE
        elif answer_index == 1:
            self.answer = Response.FALSE
        elif answer_index == 2:
            self.answer = Response.MAYBE
        assert len(answer_dist) == 3
        
    def __repr__(self):
        return str((self.answer, self.x, self.y))
            
    @staticmethod
    def answer_vector(answer):
        if answer == Response.TRUE or answer == True:
            return torch.tensor([1., 0, 0])
        if answer == Response.FALSE or answer == False:
            return torch.tensor([0, 1., 0])
        if answer == Response.MAYBE:
            return torch.tensor([0, 0., 1])
        assert answer in (Response.TRUE, Response.FALSE, Response.MAYBE, True, False)

In [33]:
class Model1(nn.Module):
    def __init__(self, input_size, hidden_layer_size, output_sizes):
        super(Model1, self).__init__()
        self.input_size = input_size
        self.hidden_layer_size = hidden_layer_size
        self.output_sizes = output_sizes
        
        self.lstm = nn.LSTMCell(input_size, hidden_layer_size).cuda(1)
        self.reset()
        
        self.output_layers = nn.ModuleList() 
        for output_size in output_sizes:
            self.output_layers.append(nn.Linear(hidden_layer_size, output_size).cuda(1))
        self.softmax = nn.Softmax()
        
    def reset(self):
        self.lstm_h = torch.zeros(1, self.hidden_layer_size, dtype=torch.double).cuda(1)
        self.lstm_c = torch.zeros(1, self.hidden_layer_size, dtype=torch.double).cuda(1)
        
    def forward(self, x):
        input_layer = x.reshape(1, self.input_size).cuda(1)
        self.lstm_h, self.lstm_c = self.lstm(input_layer, (self.lstm_h, self.lstm_c))
        return (self.softmax(layer(self.lstm_h).squeeze()) for layer in self.output_layers)

In [34]:
class QueryModel():
    
    def __init__(self, torch_model):
        self.torch_model = torch_model
        self.optimizer = optim.Adam(model.parameters())
        
    def predict(self, query):
        self.torch_model.reset()
        
        responses = []
        relevant_unseen = set(query.relevant_cells())
        seen_cells = set()
        query_vector = query.vectorize()
        cell_vector = torch.zeros(query.board.max_digit)
        last_answer = Response.MAYBE
        
        while last_answer == Response.MAYBE and relevant_unseen:
            input = torch.cat([query_vector, cell_vector])
            attention, answer = self.torch_model(input)
            response = Response(query.board.max_digit, attention, answer)
            responses.append(response)

            next_cell = (response.x, response.y)
            if next_cell in seen_cells: # Should never re-visit a cell in this task
                break

            seen_cells.add(next_cell)
            relevant_unseen.discard(next_cell)
            last_answer = response.answer
            cell_vector = vectorize_cell(query.board, response.x, response.y)

        return responses
    
    def get_correct_responses(self, query, predicted_responses):
        seen_cells = set()
        relevant_unseen = set(query.relevant_cells())
        correct_responses = []
        
        for response in predicted_responses:
            next_cell = (response.x, response.y)
            
            correct_attention_dist = vectorize_cell_distribution(query.board.max_digit, relevant_unseen)
            correct_answer = query.answer() if query.is_answerable(seen_cells) else Response.MAYBE
            correct_answer_dist = Response.answer_vector(correct_answer)
            correct_responses.append(Response(query.board.max_digit, correct_attention_dist, correct_answer_dist))
            
            seen_cells.add(next_cell)
            relevant_unseen.discard(next_cell)
        return correct_responses
    
    def train(self, queries):
        def closure():
            self.optimizer.zero_grad()
            responses = self.predict(query)
            total_loss = 0
            
            for response, correct_response in zip(responses, self.get_correct_responses(query, responses)):
                attention_loss = nn.functional.kl_div(response.attention_dist, correct_response.attention_dist.cuda(1))
                answer_loss = nn.functional.kl_div(response.answer_dist, correct_response.answer_dist.cuda(1))
                loss = attention_loss + answer_loss
                loss.backward(retain_graph=True)

                attention_losses.append(float(attention_loss))
                answer_losses.append(float(answer_loss))
            return total_loss
        
        all_attention_losses, all_answer_losses = [], []
        for query in queries:
            attention_losses, answer_losses = [], []
            self.optimizer.step(closure)
            all_attention_losses.append(attention_losses)
            all_answer_losses.append(answer_losses)
        
        return all_attention_losses, all_answer_losses

In [36]:
max_digit = 4
model = Model1(Query.vector_dim(max_digit) + max_digit, 64, [(max_digit**2), 3]).double().cuda(1)
query_model = QueryModel(model)

boards = puzzles[8]
np.random.shuffle(boards)
losses = []
for epoch in tqdm_notebook(range(10)):
    for board in tqdm_notebook(boards[:1000]):
        xs, ys = np.nonzero(board.board)
        queries = [Query(board, board[x][y], Query.ROW, x) for x,y in zip(xs, ys) if board[x][y] != 0]
        attention_losses, answer_losses = query_model.train(queries)

HBox(children=(IntProgress(value=0, max=10), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))



HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))

HBox(children=(IntProgress(value=0, max=1000), HTML(value='')))