In [1]:
%matplotlib inline

import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
import numpy as np
import itertools
import random
import matplotlib
import matplotlib.pyplot as plt
import pickle
from tqdm import tqdm, tqdm_notebook

import sys
sys.path.append('/home/ajhnam/sudoku/src/sudoku')

from board import Board
from solutions import Solutions
import solvertools
import utils

In [2]:
# set random seed to 0
np.random.seed(0)
torch.manual_seed(0)

torch.set_default_tensor_type('torch.DoubleTensor')

In [3]:
def roll(seq, n):
    return torch.cat((seq[-n:], seq[:-n]))

def hstack(a, b):
    return torch.cat([a, b], dim=1)

def round_robin_concat(X):
    """
    Concatenates each row i of X with another row j where i != j. e.g.
    >>> mat, indices = round_robin_concat(torch.stack([torch.ones(4)*i for i in range(4)]))
    >>> mat
    tensor([[0., 0., 0., 0., 3., 3., 3., 3.],
            [1., 1., 1., 1., 0., 0., 0., 0.],
            [2., 2., 2., 2., 1., 1., 1., 1.],
            [3., 3., 3., 3., 2., 2., 2., 2.],
            [0., 0., 0., 0., 2., 2., 2., 2.],
            [1., 1., 1., 1., 3., 3., 3., 3.],
            [2., 2., 2., 2., 0., 0., 0., 0.],
            [3., 3., 3., 3., 1., 1., 1., 1.],
            [0., 0., 0., 0., 1., 1., 1., 1.],
            [1., 1., 1., 1., 2., 2., 2., 2.],
            [2., 2., 2., 2., 3., 3., 3., 3.],
            [3., 3., 3., 3., 0., 0., 0., 0.]])
    >>> mat[indices[(1, 3)]]
    tensor([1., 1., 1., 1., 3., 3., 3., 3.])
    >>> mat[indices[(3, 1)]]
    tensor([3., 3., 3., 3., 1., 1., 1., 1.])
    """
    concat_index = {}
    concat_matrix = []
    n = X.shape[0]
    dim = X.shape[1]
    
    rolled = 1
    last_index = 0
    while len(concat_index) < n*(n-1):
        concat_matrix.append(torch.stack((X, roll(X, rolled)), dim=1).reshape(n, dim*2))
        for i in range(n):
            concat_index[(i, (i-rolled)%n)] = last_index
            last_index += 1
        rolled += 1
    
    concat_matrix = torch.stack(concat_matrix)
    concat_matrix = concat_matrix.reshape(concat_matrix.shape[0]*concat_matrix.shape[1], concat_matrix.shape[2])
    return concat_matrix, concat_index

def board_vec2mat(vectors):
    n, d = vectors.shape[0], int(np.cbrt(vectors.shape[1]))
    return vectors.reshape(n, d**2, d)

In [4]:
with open('solutions5.pickle', 'rb') as f:
    solutions = pickle.load(f)
puzzles = solutions.get_puzzles_by_hints()
for hints in sorted(puzzles):
    print(hints, len(puzzles[hints]))

4 64
5 357
6 883
7 1584
8 2384
9 3309
10 4149
11 4754
12 4841
13 3741
14 1391
15 192
16 12


In [5]:
class MLP(nn.Module):
    def __init__(self, input_size, output_size, hidden_layer_sizes, activation_func):
        super(MLP, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.hidden_layer_sizes = hidden_layer_sizes
        self.activation_func = activation_func
        
        last_layer_size = input_size
        self.layers = nn.ModuleList() 
        for size in hidden_layer_sizes:
            self.layers.append(nn.Linear(last_layer_size, size).cuda())
            last_layer_size = size
        self.layers.append(nn.Linear(last_layer_size, output_size).cuda())

    def forward(self, X):
        vector = X
        for layer in self.layers:
            vector = self.activation_func(layer(vector))
        return vector

In [6]:
class FeedForwardRRN(nn.Module):
    def __init__(self, max_digit, message_size, hidden_layer_sizes, activation_func):
        super(FeedForwardRRN, self).__init__()
        assert len(hidden_layer_sizes) > 0
        self.max_digit = max_digit
        self.message_size = message_size
        self.hidden_layer_sizes = hidden_layer_sizes
        self.activation_func = activation_func
        
        self.f = MLP(2*max_digit, message_size, hidden_layer_sizes, activation_func).cuda()
        self.g = MLP(max_digit + message_size, max_digit, hidden_layer_sizes, activation_func).cuda()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, X):
        """
        Assume that X is an (n, d) matrix where
            n: number of cells in the board (self.max_digit**2)
            d: max_digit
        """
        assert X.shape == (self.max_digit**2, self.max_digit)
        X_paired, pair_indices = round_robin_concat(X)
        pair_messages = self.f(X_paired.cuda())
        messages = []
        for i in range(len(X)):
            message = torch.zeros(self.message_size).cuda()
            for j in range(self.max_digit):
                if i != j:
                    message += pair_messages[pair_indices[(i, j)]].cuda()
            messages.append(message)
        messages = torch.stack(messages)
        pre_output = self.g(hstack(X, messages))
        output = self.softmax(pre_output)
        return output

In [None]:
epochs = 200
rounds = 12

max_digit = 4
model = FeedForwardRRN(max_digit, 64, [32, 32], F.relu)
model = model.double().cuda()
optimizer = optim.Adam(model.parameters())

train, valid_deriv, test_deriv, valid_nonderiv, test_nonderiv = solvertools.generate_dataset(puzzles[4], solutions, [.7, .8])
X, Y = solvertools.generate_XY(train)
X = board_vec2mat(X).cuda()
Y = Y.cuda()

def closure():
    losses = []
    for i in tqdm_notebook(range(len(X)), leave=False):
        total_loss = 0
        x, y = X[i], Y[i]
        optimizer.zero_grad()
        for j in range(rounds):#tqdm_notebook(range(rounds), leave=False):
            prediction = model(x)
            x = prediction
            loss = nn.functional.nll_loss(prediction, y)
            loss.backward(retain_graph=True)
            total_loss += float(loss)
        losses.append(total_loss)
    print(sum(losses))
    return losses

# train_losses = []


for epoch in tqdm_notebook(range(epochs), leave=False):
    train_losses = optimizer.step(closure)
    
plt.plot(range(epochs), train_losses, '-b', label='training')
plt.title("Losses")
plt.legend(loc='upper right')
plt.show()

HBox(children=(IntProgress(value=0, max=200), HTML(value='')))

HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

-2217.104336371031


HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

-2217.2474832557405


HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

-2217.354013356788


HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

-2217.456539914503


HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

-2217.5721468116476


HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

-2217.7005952713753


HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

-2217.8338656999913


HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

-2217.966403142499


HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

-2218.1054942260866


HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

-2218.2436132356843


HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

-2218.378737474827


HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

-2218.5088465970616


HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

-2218.64290922255


HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)



-2219.1114490511195


HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

-2219.303018916415


HBox(children=(IntProgress(value=0, max=739), HTML(value='')))

In [None]:
print(hi)