In [None]:
%matplotlib inline

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import itertools
import random
import matplotlib
import matplotlib.pyplot as plt
from tqdm import tqdm, tqdm_notebook

import sys
sys.path.append('/Users/andrew/Desktop/sudoku/src/sudoku')

from board import Board
from grid_string import GridString, read_solutions_file
from shuffler import Shuffler
from shuffled_grid import ShuffledGrid
from solutions import Solutions
import utils

In [None]:
# set random seed to 0
np.random.seed(0)
torch.manual_seed(0)
torch.set_default_tensor_type('torch.DoubleTensor')

In [17]:
filename = '/Users/andrew/Desktop/sudoku/data/shuffled_puzzles.txt'
with open(filename) as f:
    lines = f.read().splitlines()
puzzles = {}
for line in lines:
    puzzle, solution = line.split(',')
    puzzles[GridString(puzzle)] = GridString(solution)

In [303]:
def determine_edges(dim_x, dim_y):
    """
    Returns a 2-d array of (max_digit**2, n) where the i_th entry is a list of
        other cells' indices that cell i shares a house with
    """
    max_digit = dim_x*dim_y
    edges = []
    for row in range(max_digit):
        row_edges = []
        for col in range(max_digit):
            # row & column
            col_edges = {(row, i) for i in range(max_digit)}
            col_edges |= {(i, col) for i in range(max_digit)}
            
            # box
            x_min = (row // dim_x) * dim_x
            y_min = (col // dim_y) * dim_y
            col_edges |= set(itertools.product(range(x_min, x_min+dim_x), range(y_min, y_min+dim_y)))
            
            # removing self
            col_edges -= {(row, col)}
            col_edges = [row*max_digit + col for row, col in col_edges]
            row_edges.append(sorted(col_edges))
        edges.append(row_edges)
    edges = torch.tensor(edges)
    shape = edges.shape
    return edges.reshape(max_digit**2, shape[2])

def encode_input(grid_string: GridString):
    return torch.tensor(list(grid_string.traverse_grid()))

def encode_output(grid_string: GridString):
    return torch.tensor(list(grid_string.traverse_grid())) - 1

In [238]:
train_n = 10
train_puzzles = list(puzzles.keys())[0:train_n]
train_solutions = [puzzles[p] for p in train_puzzles]

In [250]:
max_digit = train_puzzles[0].max_digit
num_cells = max_digit**2
cell_vec_dim = max_digit + 1
train_x = torch.cat([encode_input(p) for p in train_puzzles]).reshape(train_n, num_cells)
train_y = torch.cat([encode_output(p) for p in train_solutions]).reshape(train_n, num_cells)

In [460]:
class MLP(nn.Module):
    def __init__(self, layer_sizes):
        super(MLP, self).__init__()
        self.layer_sizes = layer_sizes
        
        self.layers = nn.ModuleList()
        
        prev_layer_size = self.layer_sizes[0]
        for size in self.layer_sizes[1:]:
            self.layers.append(nn.Linear(prev_layer_size, size))
            prev_layer_size = size

    def forward(self, X):
        vector = X
        for layer in self.layers:
            vector = layer(vector)
        return vector

class RRN(nn.Module):
    def __init__(self, dim_x, dim_y, embed_size=16, hidden_layer_size=96):
        super(RRN, self).__init__()
        self.max_digit = dim_x * dim_y
        self.embed_size = embed_size
        self.hidden_layer_size = hidden_layer_size
        
        self.edges = determine_edges(dim_x, dim_y)


        self.embed_layer = nn.Embedding(self.max_digit+1, self.embed_size)
        self.input_mlp = MLP([self.embed_size,
                              self.hidden_layer_size,
                              self.hidden_layer_size,
                              self.hidden_layer_size])
        
        self.f = MLP([2*self.hidden_layer_size,
                      self.hidden_layer_size,
                      self.hidden_layer_size,
                      self.hidden_layer_size])
        self.g_mlp = MLP([2*self.hidden_layer_size,
                      self.hidden_layer_size,
                      self.hidden_layer_size,
                      self.hidden_layer_size])
        self.g_lstm = nn.LSTM(self.hidden_layer_size, self.hidden_layer_size)
        self.r = MLP([self.hidden_layer_size,
                      self.hidden_layer_size,
                      self.hidden_layer_size,
                      self.max_digit])
    
    def compute_messages(self, H):
        messages = torch.zeros(H.shape)
        batch_size = H.shape[0]
        num_nodes = H.shape[1]
        for puzzle_index in range(batch_size): # for puzzle in batch
            messages[puzzle_index] = torch.tensor([torch.sum(H[puzzle_index][self.edges[n]]) for n in range(num_nodes)])
        return messages
                    

    def forward(self, grids, iters):
        batch_size = len(grids)
        num_nodes = self.max_digit**2
        edges_per_nodes = self.edges.shape[1]
        
        
        
        embeddings = self.embed_layer(grids)
        X = self.input_mlp(embeddings)
        H = torch.tensor(X)
        g_lstm_h = H.reshape(1, batch_size*num_nodes, self.hidden_layer_size)
        g_lstm_c = torch.randn(1, batch_size*num_nodes, self.hidden_layer_size)
#         g_lstm_h = torch.zeros(1, batch_size, 3)
#         g_lstm_c = torch.zeros(1, batch_size, 3)
        
        outputs = []
        for i in range(iters):
            M = torch.zeros(batch_size, self.max_digit**2, self.hidden_layer_size)
            for node in range(num_nodes):
                msgs = torch.cat([self.f(torch.cat([H[:,node,:], H[:,other,:]], dim=1)) for other in self.edges[node]])
                msgs = msgs.reshape(edges_per_nodes, batch_size, self.hidden_layer_size).permute(1,0,2)
                M[:,n,:] = torch.sum(msgs, dim=1)
            
            input_to_g_lstm = self.g_mlp(torch.cat([X, M], dim=2)).reshape(1, batch_size*num_nodes, self.hidden_layer_size)
            
            _, (g_lstm_h, g_lstm_c) = self.g_lstm(input_to_g_lstm, (g_lstm_h, g_lstm_c))
            H = g_lstm_h.reshape(H.shape)
            output = self.r(H)
            
            outputs.append(output)
                
        return outputs

In [463]:
model = RRN( dim_x=2, dim_y=2, embed_size=16, hidden_layer_size=96)
predictions = [p.permute(0,2,1) for p in model(train_x, 5)]

In [468]:
model = RRN( dim_x=2, dim_y=2, embed_size=4, hidden_layer_size=32)
optimizer = optim.Adam(model.parameters())

def closure():
    optimizer.zero_grad()
    predictions = [p.permute(0,2,1) for p in model(train_x, 32)]
#     print([F.cross_entropy(p, train_y) for p in predictions])
    loss = sum([F.cross_entropy(p, train_y) for p in predictions])
    loss.backward()
    return loss

for i in tqdm_notebook(range(100)):
    print(optimizer.step(closure))

tensor(44.4395, grad_fn=<ThAddBackward>)
tensor(44.4050, grad_fn=<ThAddBackward>)
tensor(44.3720, grad_fn=<ThAddBackward>)
tensor(44.3317, grad_fn=<ThAddBackward>)
tensor(44.2935, grad_fn=<ThAddBackward>)
tensor(44.2646, grad_fn=<ThAddBackward>)
tensor(44.2286, grad_fn=<ThAddBackward>)
tensor(44.1630, grad_fn=<ThAddBackward>)
tensor(44.1055, grad_fn=<ThAddBackward>)
tensor(44.0290, grad_fn=<ThAddBackward>)
tensor(43.9269, grad_fn=<ThAddBackward>)
tensor(43.7616, grad_fn=<ThAddBackward>)
tensor(43.6199, grad_fn=<ThAddBackward>)
tensor(43.3903, grad_fn=<ThAddBackward>)
tensor(43.1125, grad_fn=<ThAddBackward>)
tensor(42.7898, grad_fn=<ThAddBackward>)
tensor(42.3937, grad_fn=<ThAddBackward>)
tensor(41.8639, grad_fn=<ThAddBackward>)
tensor(41.3130, grad_fn=<ThAddBackward>)
tensor(40.6224, grad_fn=<ThAddBackward>)
tensor(39.8641, grad_fn=<ThAddBackward>)
tensor(39.0155, grad_fn=<ThAddBackward>)
tensor(38.1631, grad_fn=<ThAddBackward>)
tensor(37.2379, grad_fn=<ThAddBackward>)
tensor(36.2883, 

In [None]:
for cell in determine_edges(2, 2):
    print(cell)

In [272]:
set(itertools.product(range(3, 6), range(1, 4)))

{(3, 1), (3, 2), (3, 3), (4, 1), (4, 2), (4, 3), (5, 1), (5, 2), (5, 3)}

In [293]:
train_solutions[0]

1234341223414123

In [305]:
torch.tensor(range(2*3*4)).reshape(2,3,4)

tensor([[[ 0,  1,  2,  3],
         [ 4,  5,  6,  7],
         [ 8,  9, 10, 11]],

        [[12, 13, 14, 15],
         [16, 17, 18, 19],
         [20, 21, 22, 23]]])

In [354]:
m = torch.tensor(range(2*3*4)).reshape(2,3,4)
z = torch.zeros(m.shape)


# print(m[0])
# print(m[0][0])

# m = torch.cat([m]*2)
# m = m.reshape(2,2,3,4)
# m = m.permute(1,2,0,3)
# print(m.shape)

tensor([[ 0,  1,  2,  3],
        [12, 13, 14, 15]])

In [347]:
m[:,1,:]

tensor([[ 4,  5,  6,  7],
        [16, 17, 18, 19]])

In [350]:
torch.cat([m[:,0,:], m[:,1,:]], dim=1)

tensor([[ 0,  1,  2,  3,  4,  5,  6,  7],
        [12, 13, 14, 15, 16, 17, 18, 19]])

In [351]:
M = torch.zeros(batch_size, self.max_digit**2, self.hidden_layer_size)
for node in range(num_nodes):
    M[:,n,:] = torch.sum([self.f(torch.cat([H[:,node,:], H[:,other,:]], dim=1)) for other in self.edges[node]])


In [404]:
m = torch.tensor(range(3*4*5)).reshape(3,4,5)
e = torch.tensor([[0,1], [1,2], [2,3], [3,0]])
z = torch.zeros(3,4,5)
for n in range(4):
    # reshape()
    a = torch.cat([m[:,other] for other in e[n]]).reshape(2, 3, 5).permute(1,0,2)
    z[:,n,:] = torch.sum(a, dim=1)
    
m

tensor([[[ 0,  1,  2,  3,  4],
         [ 5,  6,  7,  8,  9],
         [10, 11, 12, 13, 14],
         [15, 16, 17, 18, 19]],

        [[20, 21, 22, 23, 24],
         [25, 26, 27, 28, 29],
         [30, 31, 32, 33, 34],
         [35, 36, 37, 38, 39]],

        [[40, 41, 42, 43, 44],
         [45, 46, 47, 48, 49],
         [50, 51, 52, 53, 54],
         [55, 56, 57, 58, 59]]])

In [402]:
z

tensor([[[  5.,   7.,   9.,  11.,  13.],
         [ 15.,  17.,  19.,  21.,  23.],
         [ 25.,  27.,  29.,  31.,  33.],
         [ 15.,  17.,  19.,  21.,  23.]],

        [[ 45.,  47.,  49.,  51.,  53.],
         [ 55.,  57.,  59.,  61.,  63.],
         [ 65.,  67.,  69.,  71.,  73.],
         [ 55.,  57.,  59.,  61.,  63.]],

        [[ 85.,  87.,  89.,  91.,  93.],
         [ 95.,  97.,  99., 101., 103.],
         [105., 107., 109., 111., 113.],
         [ 95.,  97.,  99., 101., 103.]]])

In [394]:
m[:,0,:]

tensor([[ 0,  1,  2,  3],
        [12, 13, 14, 15]])