In [1]:
%matplotlib inline

import numpy as np
import itertools
import random
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm, tqdm_notebook
import math

import torch
import torch.nn as nn
import torch.optim as optim

import sys
sys.path.append('/Users/andrew/Desktop/sudoku/src/sudoku')

from board import Board
from solutions import Solutions
import andrew_utils as utils

In [2]:
# set random seed to 0
np.random.seed(0)
torch.manual_seed(0)
torch.set_default_tensor_type('torch.DoubleTensor')

In [3]:
solutions = Solutions('/Users/andrew/Desktop/sudoku/data/solutions5.txt')
puzzles = solutions.get_puzzles_by_hints()
for hints in sorted(puzzles):
    print(hints, len(puzzles[hints]))

4 64
5 357
6 883
7 1584
8 2384
9 3309
10 4149
11 4754
12 4841
13 3741
14 1391
15 192
16 12


In [4]:
def is_in_row(board, row, digit):
    return digit in board[row]

def is_in_col(board, col, digit):
    return digit in board[:,col]

def is_in_box(board, box, digit):
    return digit in board.get_box_by_index(box).box

def vectorize_cell(board, x, y):
    vector = torch.zeros(board.max_digit)
    if board[x][y] != 0:
        vector[board[x][y]-1] = 1
    return vector

def vectorize_cell_distribution(max_digit, cell_coordinates):
    """
    Creates a uniform distribution amongst cell_coordinates and zeros out other locations
    >>> vectorize_cell_distribution(4, [(1,1), (0, 3), (2, 0)])
    tensor([0.0000, 0.0000, 0.0000, 0.3333, 0.0000, 0.3333, 0.0000, 0.0000, 0.3333,
        0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000, 0.0000])
    """
    cell_coordinates = np.array(list(set(cell_coordinates)))
    assert cell_coordinates.shape[0] < max_digit**2
    assert cell_coordinates.shape[1] == 2
    vector = torch.zeros(max_digit**2)
    for coord in cell_coordinates:
        vector[coord[0]*max_digit + coord[1]] = 1
    return vector/len(cell_coordinates)

In [5]:
class Query():
    ROW = "ROW"
    COLUMN = "COLUMN"
    BOX = "BOX"
    
    def __init__(self, board: Board, digit: int, house_type: int, house_index: int):
        assert house_type == Query.ROW or house_type == Query.COLUMN or house_type == Query.BOX
        assert 0 < digit <= board.max_digit
        assert 0 <= house_index < board.max_digit
        
        self.board = board
        self.digit = digit
        self.house_type = house_type
        self.house_index = house_index
        
    def __repr__(self):
        return "Is {0} in {1} {2}: {3}".format(self.digit,
                                               self.house_type.lower(),
                                               self.house_index,
                                               'Yes' if self.answer() else 'No')
        
    def vectorize(self):
        digit = torch.zeros(self.board.max_digit)
        house_type = torch.zeros(3)
        house_index = torch.zeros(self.board.max_digit)
        
        digit[self.digit-1] = 1
        if self.house_type == Query.ROW:
            house_type[0] = 1
        elif self.house_type == Query.COLUMN:
            house_type[1] = 1
        else:
            house_type[2] = 1
        house_index[self.house_index] = 1
        
        return torch.cat([digit, house_type, house_index])
    
    def answer(self):
        if self.house_type is Query.ROW:
            return is_in_row(self.board, self.house_index, self.digit)
        if self.house_type is Query.COLUMN:
            return is_in_col(self.board, self.house_index, self.digit)
        else:
            return is_in_box(self.board, self.house_index, self.digit)
    
    def relevant_cells(self):
        if self.house_type == Query.ROW:
            return {(self.house_index, i) for i in range(self.board.max_digit)}
        if self.house_type == Query.COLUMN:
            return {(i, self.house_index) for i in range(self.board.max_digit)}
        return {tuple(t) for t in self.board.get_box_by_index(self.house_index).get_coordinates()}
    
    def is_answerable(self, seen_cells):
        """
        Checks if the query could be answered with just the seen cells
        """
        if self.answer():
            if self.house_type == Query.ROW:
                cell = self.board.find_digit_in_row(self.house_index, self.digit)
            elif self.house_type == Query.COLUMN:
                cell = self.board.find_digit_in_column(self.house_index, self.digit)
            else:
                cell = self.board.find_digit_in_box_by_index(self.house_index, self.digit)
        return not bool(self.relevant_cells() - set(seen_cells))
    
    def vector_dim(max_digit):
        return 3 + 2*max_digit

In [13]:
class Response():
    MAYBE = "MAYBE"
    TRUE = "TRUE"
    FALSE = "FALSE"
    
    def __init__(self, max_digit, attention_dist, answer_dist):
        self.attention_dist = attention_dist
        self.answer_dist = answer_dist
        
        cell_index = torch.argmax(attention_dist)
        self.x = int(cell_index // max_digit)
        self.y = int(cell_index % max_digit)
        
        answer_index = torch.argmax(answer_dist)
        self.answer_index = answer_index
        if answer_index == 0:
            self.answer = Response.TRUE
        elif answer_index == 1:
            self.answer = Response.FALSE
        elif answer_index == 2:
            self.answer = Response.MAYBE
        assert len(answer_dist) == 3
        
    def __repr__(self):
        return str((self.answer, self.x, self.y))
            
    @staticmethod
    def answer_vector(answer):
        if answer == Response.TRUE or answer == True:
            return torch.tensor([1., 0, 0])
        if answer == Response.FALSE or answer == False:
            return torch.tensor([0, 1., 0])
        if answer == Response.MAYBE:
            return torch.tensor([0, 0., 1])
        assert answer in (Response.TRUE, Response.FALSE, Response.MAYBE, True, False)

In [14]:
class Model1(nn.Module):
    def __init__(self, input_size, hidden_layer_size, output_sizes):
        super(Model1, self).__init__()
        self.input_size = input_size
        self.hidden_layer_size = hidden_layer_size
        self.output_sizes = output_sizes
        
        self.lstm = nn.LSTMCell(input_size, hidden_layer_size)
        self.lstm_h = torch.zeros(1, self.hidden_layer_size, dtype=torch.double)
        self.lstm_c = torch.zeros(1, self.hidden_layer_size, dtype=torch.double)
        
        self.output_layers = nn.ModuleList() 
        for output_size in output_sizes:
            self.output_layers.append(nn.Linear(hidden_layer_size, output_size))
        self.softmax = nn.Softmax()
        
    def reset(self):
        self.lstm_h = torch.zeros(1, self.hidden_layer_size, dtype=torch.double)
        self.lstm_c = torch.zeros(1, self.hidden_layer_size, dtype=torch.double)
        
    def forward(self, x):
        input_layer = x.reshape(1, self.input_size)
        self.lstm_h, self.lstm_c = self.lstm(input_layer, (self.lstm_h, self.lstm_c))
        return (self.softmax(layer(self.lstm_h).squeeze()) for layer in self.output_layers)

In [15]:
class BatchLSTM(nn.Module):
    def __init__(self, input_size, hidden_layer_size, output_sizes):
        super(BatchLSTM, self).__init__()
        self.input_size = input_size
        self.hidden_layer_size = hidden_layer_size
        self.output_sizes = output_sizes
        
        self.lstm = nn.LSTMCell(input_size, hidden_layer_size)
        self.lstm_h = None
        self.lstm_c = None
        
        self.output_layers = nn.ModuleList() 
        for output_size in output_sizes:
            self.output_layers.append(nn.Linear(hidden_layer_size, output_size))
        self.softmax = nn.Softmax(dim=0)
        
    def reset(self, n):
        self.lstm_h = torch.zeros(n, self.hidden_layer_size, dtype=torch.double)
        self.lstm_c = torch.zeros(n, self.hidden_layer_size, dtype=torch.double)
        
    def drop_memory(self, indices):
        keep = list(set(range(len(self.lstm_h))) - set(indices))
        self.lstm_h = self.lstm_h[keep]
        self.lstm_c = self.lstm_c[keep]


    def forward(self, x):
        self.lstm_h, self.lstm_c = self.lstm(x, (self.lstm_h, self.lstm_c))
        return tuple(self.softmax(layer(self.lstm_h).reshape(x.shape[0], size)) for size,layer in zip(self.output_sizes, self.output_layers))
    
    
    
    
    

In [52]:
class QueryModel():
    
    def __init__(self, torch_model):
        self.torch_model = torch_model
        self.optimizer = optim.Adam(model.parameters())
        
    def predict(self, queries):
        n = len(queries)
        self.torch_model.reset(n)
        
        all_query_vectors = [q.vectorize() for q in queries]
        
        continues = np.ones(n)
        responses = [[] for i in range(n)]
        visiting_cell = [None for i in range(n)]
        visited = [set() for i in range(n)]
        relevant_unvisited = [q.relevant_cells() for q in queries]
        
        query_vectors = torch.stack(all_query_vectors)
        cell_vectors = torch.zeros(n, queries[0].board.max_digit)
        iteration = 0
        while np.any(continues):
            cont_indices = np.nonzero(continues)[0]
            index_map = {cont_indices[i]: i for i in range(len(cont_indices))}
            
            inputs = torch.cat((query_vectors, cell_vectors), dim=1)
            attentions, answers = self.torch_model(inputs)
            
            query_vectors = []
            cell_vectors = []
            lstm_mem_drop_indices = set()
            for index in range(n):
                if continues[index]:
                    if visiting_cell[index]: # i.e. not first iteration
                        visited[index].add(visiting_cell[index])
                        relevant_unvisited[index].discard(visiting_cell[index])
                    i = index_map[index]
                    response = Response(queries[i].board.max_digit, attentions[i], answers[i])
                    next_cell = (response.x, response.y)
                    
                    cont = (next_cell not in visited[index]) \
                                        and (response.answer == Response.MAYBE) \
                                        and bool(relevant_unvisited[index])
                    
                    if cont:
                        query_vectors.append(all_query_vectors[index])
                        cell_vectors.append(vectorize_cell(queries[index].board, response.x, response.y))
                    else:
                        lstm_mem_drop_indices.add(i)
                    
                    continues[index] = cont
                    visiting_cell[index] = next_cell
                    responses[index].append(response)
                    
            if query_vectors:
                query_vectors = torch.stack(query_vectors)
            if cell_vectors:
                cell_vectors = torch.stack(cell_vectors)
            self.torch_model.drop_memory(lstm_mem_drop_indices)
            iteration += 1
                    
        return responses

    
    def get_correct_responses(self, query, predicted_responses):
        seen_cells = set()
        relevant_unseen = set(query.relevant_cells())
        correct_responses = []
        
        for response in predicted_responses:
            next_cell = (response.x, response.y)
            
            correct_attention_dist = vectorize_cell_distribution(query.board.max_digit, relevant_unseen)
            correct_answer = query.answer() if query.is_answerable(seen_cells) else Response.MAYBE
            correct_answer_dist = Response.answer_vector(correct_answer)
            correct_responses.append(Response(query.board.max_digit, correct_attention_dist, correct_answer_dist))
            
            seen_cells.add(next_cell)
            relevant_unseen.discard(next_cell)
        return correct_responses
    
    def train(self, queries):
        total_loss = 0
        def closure():
            self.optimizer.zero_grad()
            all_responses = self.predict(queries)
            total_loss = 0
            
            pred_attention_dists = []
            corr_attention_dists = []
            pred_answer_dists = []
            corr_answer_labels = []
            for responses, query in zip(all_responses, queries):
                for response, correct_response in zip(responses, self.get_correct_responses(query, responses)):
                    pred_attention_dists.append(response.attention_dist)
                    corr_attention_dists.append(correct_response.attention_dist)
                    pred_answer_dists.append(response.answer_dist)
                    corr_answer_labels.append(correct_response.answer_index)
            pred_attention_dists = torch.stack(pred_attention_dists)
            corr_attention_dists = torch.stack(corr_attention_dists)
            pred_answer_dists = torch.stack(pred_answer_dists)
            corr_answer_labels = torch.stack(corr_answer_labels)
            attention_loss = nn.functional.kl_div(pred_attention_dists, corr_attention_dists)
            answer_loss = nn.functional.cross_entropy(pred_answer_dists, corr_answer_labels)
            total_loss = attention_loss + answer_loss
            total_loss.backward()
            return total_loss
        
        for query in queries:
            total_loss += self.optimizer.step(closure)
        return total_loss

In [58]:
max_digit = 4
model = BatchLSTM(Query.vector_dim(max_digit) + max_digit, 32, [(max_digit**2), 3]).double()
query_model = QueryModel(model)

boards = puzzles[8]
np.random.shuffle(boards)
losses = []
tqdm_notebook?
for epoch in tqdm_notebook(range(100)):
    epoch_losses = []
    for board in boards[:1000]:
        xs, ys = np.nonzero(board.board)
        queries = [Query(board, board[x][y], Query.ROW, x) for x,y in zip(xs, ys) if board[x][y] != 0]
        loss = float(query_model.train(queries))
        epoch_losses.append(loss)
    losses.append(epoch_losses)
    print(sum(epoch_losses))
        
        

7882.35866127842


IndexError: tuple index out of range

In [59]:
board

array([[1, 2, 0, 4],
       [0, 3, 0, 0],
       [0, 1, 0, 3],
       [3, 0, 0, 2]], dtype=int8)

In [63]:
all_responses = query_model.predict(queries)

In [64]:
all_responses

[[('MAYBE', 0, 3),
  ('MAYBE', 0, 1),
  ('MAYBE', 0, 2),
  ('MAYBE', 0, 0),
  ('TRUE', 0, 0)],
 [('MAYBE', 0, 2), ('TRUE', 3, 0)],
 [('MAYBE', 0, 1), ('MAYBE', 0, 2), ('MAYBE', 0, 0), ('TRUE', 3, 0)],
 [('FALSE', 1, 2)],
 [('MAYBE', 2, 3), ('TRUE', 2, 2)],
 [('MAYBE', 2, 0), ('MAYBE', 2, 3), ('TRUE', 2, 2)],
 [('FALSE', 3, 0)],
 [('FALSE', 3, 1)]]

In [38]:
max_digit = 4
model = BatchLSTM(Query.vector_dim(max_digit) + max_digit, 32, [(max_digit**2), 3]).double()
query_model = QueryModel(model)

# boards = puzzles[8]
# np.random.shuffle(boards)
# losses = []

# queries = []
# for board in tqdm_notebook(boards[:1000]):
#     xs, ys = np.nonzero(board.board)
#     queries += [Query(board, board[x][y], Query.ROW, x) for x,y in zip(xs, ys) if board[x][y] != 0]

# for epoch in tqdm_notebook(range(10)):
#     attention_losses, answer_losses = query_model.train(queries)

board = puzzles[8][0]
xs, ys = np.nonzero(board.board)
queries = [Query(board, board[x][y], Query.ROW, x) for x,y in zip(xs, ys) if board[x][y] != 0]
print(queries)
print(float(query_model.train(queries)))
# attention_losses, answer_losses = query_model.train(queries)

# query_model.predict(queries)

[Is 1 in row 0: Yes, Is 4 in row 0: Yes, Is 2 in row 1: Yes, Is 2 in row 2: Yes, Is 4 in row 2: Yes, Is 4 in row 3: Yes, Is 3 in row 3: Yes, Is 1 in row 3: Yes]
8.020066023877671


In [42]:
input = torch.randn(3, 5, requires_grad=True)
print(input)
target = torch.tensor([1, 0, 4])
print(target)
output = nn.functional.nll_loss(nn.functional.log_softmax(input), target)
output.backward()

tensor([[-0.3223, -0.1777, -0.9039, -2.5052,  0.5468],
        [-0.0024,  0.7793, -1.2069,  0.9433,  0.0334],
        [-1.1096, -0.0093, -0.0123, -0.8477,  0.1106]], requires_grad=True)
tensor([1, 0, 4])




In [38]:
input = torch.randn(3, 5, requires_grad=True)
print(input)
target = torch.randint(5, (3,), dtype=torch.int64)
print(target)
loss = nn.functional.cross_entropy(input, target)
loss.backward()

tensor([[-0.0657,  0.2982, -0.3822,  0.7871, -0.3977],
        [-1.5977, -0.7959,  0.5136, -0.3415, -0.9693],
        [-1.1837,  1.4853, -1.4752, -0.6943, -0.8421]], requires_grad=True)
tensor([1, 3, 0])


In [39]:
max_digit = 4
model = Model1(Query.vector_dim(max_digit) + max_digit, 128, [(max_digit**2), 3]).double()
query_model = QueryModel(model)

boards = puzzles[8]
np.random.shuffle(boards)
losses = []
for epoch in tqdm_notebook(range(10)):
    for board in tqdm_notebook(boards[:1000]):
        xs, ys = np.nonzero(board.board)
        queries = [Query(board, board[x][y], Query.ROW, x) for x,y in zip(xs, ys) if board[x][y] != 0]
        attention_losses, answer_losses = query_model.train(queries)
#     print(board)
#     for query in queries:
        
        
#         print(query)
#         predicted_responses = query_model.predict(query)
#         print("Predicted: {0}".format(predicted_responses))
#         correct_responses = query_model.get_correct_responses(query, predicted_responses)
#         print("Correct: {0}".format(correct_responses))
        
#     for query in queries:
#         losses.append(train(model, board, query))




TypeError: reset() takes 1 positional argument but 2 were given

In [84]:
for query in queries:
    print(query)
    predicted_responses = query_model.predict(query)
    print("Predicted: {0}".format(predicted_responses))
    correct_responses = query_model.get_correct_responses(query, predicted_responses)
    print("Correct: {0}".format(correct_responses))

Is 3 in row 0: Yes
Predicted: [('MAYBE', 0, 2), ('MAYBE', 0, 0), ('MAYBE', 0, 1), ('MAYBE', 0, 3)]
Correct: [('MAYBE', 0, 0), ('MAYBE', 0, 0), ('MAYBE', 0, 1), ('MAYBE', 0, 3)]
Is 4 in row 0: Yes
Predicted: [('MAYBE', 0, 3), ('MAYBE', 0, 0), ('MAYBE', 0, 1), ('MAYBE', 0, 2)]
Correct: [('MAYBE', 0, 0), ('MAYBE', 0, 0), ('MAYBE', 0, 1), ('MAYBE', 0, 2)]
Is 4 in row 1: Yes
Predicted: [('MAYBE', 1, 1), ('MAYBE', 1, 2), ('MAYBE', 1, 3), ('MAYBE', 1, 0)]
Correct: [('MAYBE', 1, 0), ('MAYBE', 1, 0), ('MAYBE', 1, 0), ('MAYBE', 1, 0)]
Is 1 in row 1: Yes
Predicted: [('MAYBE', 1, 1), ('MAYBE', 1, 2), ('MAYBE', 1, 3), ('MAYBE', 1, 0)]
Correct: [('MAYBE', 1, 0), ('MAYBE', 1, 0), ('MAYBE', 1, 0), ('MAYBE', 1, 0)]
Is 1 in row 2: Yes
Predicted: [('MAYBE', 2, 1), ('MAYBE', 2, 1)]
Correct: [('MAYBE', 2, 0), ('MAYBE', 2, 0)]
Is 2 in row 2: Yes
Predicted: [('MAYBE', 2, 3), ('MAYBE', 2, 1), ('MAYBE', 2, 0), ('MAYBE', 2, 3)]
Correct: [('MAYBE', 2, 0), ('MAYBE', 2, 0), ('MAYBE', 2, 0), ('MAYBE', 2, 2)]
Is 2 i



In [100]:
for response in predicted_responses:
    print(np.round(response.attention_dist.detach().numpy(), 3))

[0.    0.001 0.001 0.001 0.    0.    0.    0.    0.    0.    0.001 0.
 0.297 0.202 0.194 0.302]
[0.001 0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
 0.001 0.096 0.901 0.001]
[0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
 0.    0.999 0.    0.001]
[0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.    0.
 0.943 0.    0.    0.055]


In [86]:
query

Is 4 in row 3: Yes

In [61]:
max_digit = 4
model = Model1(Query.vector_dim(max_digit) + max_digit, 128, [(max_digit**2), 3]).double()
query_model = QueryModel(model)

boards = puzzles[8]
np.random.shuffle(boards)
losses = []
for board in boards[:1]:
    xs, ys = np.nonzero(board.board)
    queries = [Query(board, board[x][y], Query.ROW, x) for x,y in zip(xs, ys) if board[x][y] != 0]
    print(board)
    for query in queries:
        
        
        print(query)
        predicted_responses = query_model.predict(query)
        print("Predicted: {0}".format(predicted_responses))
        correct_responses = query_model.get_correct_responses(query, predicted_responses)
        print("Correct: {0}".format(correct_responses))
        
#     for query in queries:
#         losses.append(train(model, board, query))

array([[1, 2, 3, 0],
       [4, 0, 0, 2],
       [3, 4, 0, 0],
       [0, 1, 0, 0]], dtype=int8)
Is 1 in row 0: Yes
Predicted: [('MAYBE', 2, 3), ('MAYBE', 2, 3)]
Correct: [('MAYBE', 0, 0), ('MAYBE', 0, 0)]
Is 2 in row 0: Yes
Predicted: [('MAYBE', 2, 3), ('MAYBE', 3, 3), ('MAYBE', 3, 3)]
Correct: [('MAYBE', 0, 0), ('MAYBE', 0, 0), ('MAYBE', 0, 0)]
Is 3 in row 0: Yes
Predicted: [('MAYBE', 2, 3), ('MAYBE', 2, 3)]
Correct: [('MAYBE', 0, 0), ('MAYBE', 0, 0)]
Is 4 in row 1: Yes
Predicted: [('MAYBE', 2, 3), ('MAYBE', 2, 3)]
Correct: [('MAYBE', 1, 0), ('MAYBE', 1, 0)]
Is 2 in row 1: Yes
Predicted: [('MAYBE', 3, 3), ('MAYBE', 3, 3)]
Correct: [('MAYBE', 1, 0), ('MAYBE', 1, 0)]
Is 3 in row 2: Yes
Predicted: [('MAYBE', 2, 3), ('MAYBE', 2, 3)]
Correct: [('MAYBE', 2, 0), ('MAYBE', 2, 0)]
Is 4 in row 2: Yes
Predicted: [('MAYBE', 2, 3), ('MAYBE', 2, 3)]
Correct: [('MAYBE', 2, 0), ('MAYBE', 2, 0)]
Is 1 in row 3: Yes
Predicted: [('MAYBE', 2, 3), ('MAYBE', 2, 3)]
Correct: [('MAYBE', 3, 0), ('MAYBE', 3, 0

In [10]:
board = puzzles[8][0]
query = Query(board, 1, Query.BOX, 0)

model = Model1(Query.vector_dim(board.max_digit) + board.max_digit, 128, [(board.max_digit**2), 3])
model.double()
train(model, board, query)

([-0.08874258567773603], [-0.024147171330729075])

In [16]:
max_digit = 4
model = Model1(Query.vector_dim(max_digit) + max_digit, 128, [(max_digit**2), 3])
model.double()

boards = puzzles[8]
np.random.shuffle(boards)
losses = []
for board in boards[:1000]:
    xs, ys = np.nonzero(board.board)
    queries = [Query(board, board[x][y], Query.ROW, x) for x,y in zip(xs, ys) if board[x][y] != 0]
    for query in queries:
        losses.append(train(model, board, query))

In [17]:
for a in losses:
    print(a)

([-0.09129058549729697, -0.07319013467700689], [-0.0024329090342329837, -0.00458922210786369])
([-0.08627286265817455, -0.08692785817406051], [-0.008047760423071459, -0.013029441903115488])
([-0.08621655694631328], [-0.012394654422669354])
([-0.08855192160191595], [-0.01931159196121497])
([-0.08897305571127292], [-0.016646359291392937])
([-0.08888267981650408], [-0.023762617030497202])
([-0.08644921569279693], [-0.02580570749783241])
([-0.0866652866934165], [-0.030246554967065662])
([-0.09185245338852221], [-0.029446135647696953])
([-0.09168467684218642], [-0.0278455574469386])
([-0.08632038205457303], [-0.03342139625147553])
([-0.08643589965278273], [-0.03543086419484732])
([-0.08910729234254239], [-0.04254214315781272])
([-0.08687047607438615], [-0.042078847086437686])
([-0.0871015324561315], [-0.04924344769655265])
([-0.08731847397012472], [-0.05353736803740531])
([-0.09166061113919384], [-0.0461673607015183])
([-0.0923246557805334], [-0.05527077410371876])
([-0.09257183064907482], 

In [150]:
model[]

array([[0, 2, 0, 0],
       [3, 4, 0, 0],
       [0, 3, 2, 1],
       [2, 0, 0, 3]], dtype=int8)

In [146]:
queries[3].board.find_digit_in_row(2, 3)

In [144]:
queries[3].house_index

2

In [145]:
queries[3].digit

3

In [149]:
np.where(queries[3].board.board[2] == 3)[0]

array([1])

In [151]:
queries[3].board.board[2]

array([0, 3, 2, 1], dtype=int8)

In [152]:
queries[3].board.find_digit_in_column(2, 3)

(2, 1, 3)

In [153]:
board[2, 1]

3

In [None]:
# def train(model, board, query):
    
    
#     attention, answer = None, None
#     attention_losses, answer_losses = [], []
    
#     def closure():
#         nonlocal attention, answer
#         optimizer.zero_grad()
#         attention, answer = model(input)
#         attention_loss = nn.functional.kl_div(attention, attention_actual)
#         answer_loss = nn.functional.kl_div(answer, answer_actual)
#         loss = attention_loss + answer_loss
#         loss.backward(retain_graph=True)
        
#         attention_losses.append(float(attention_loss))
#         answer_losses.append(float(answer_loss))
#         return loss
    
#     model.reset()

#     relevant_unseen = set(query.relevant_cells())
#     seen_cells = set()
#     query_vector = query.vectorize()
#     cell_vector = torch.zeros(board.max_digit)
#     last_answer = Response.MAYBE

#     while last_answer == Response.MAYBE and relevant_unseen:
#         input = torch.cat([query_vector, cell_vector])

#         attention_actual = vectorize_cell_distribution(board.max_digit, relevant_unseen)
#         answer_actual = query.answer() if query.is_answerable(seen_cells) else Response.MAYBE
#         answer_actual = Response.answer_vector(answer_actual)

#         optimizer.step(closure)
#         response = Response(board.max_digit, attention, answer)

#         next_cell = (response.x, response.y)
#         if next_cell in seen_cells: # Should never re-visit a cell in this task
#             break

#         seen_cells.add(next_cell)
#         relevant_unseen.discard(next_cell)
#         last_answer = response.answer
#         cell_vector = vectorize_cell(board, response.x, response.y)
    
#     return attention_losses, answer_losses