In [9]:
import numpy as np
from scipy.special import factorial
import re
import itertools
from enum import Enum
import random
import pickle

import sys
sys.path.append('/Users/andrew/Desktop/sudoku/src/sudoku')

from board import Board
from solutions import Solutions
import utils
from grid_string import GridString, read_solutions_file


In [5]:
# set random seed to 0
random.seed(0)
np.random.seed(0)

In [6]:
filename = '/Users/andrew/Desktop/sudoku/data/shuffled_puzzles.txt'
with open(filename) as f:
    lines = f.read().splitlines()
puzzles = {}
for line in lines:
    puzzle, solution = line.split(',')
    puzzles[GridString(puzzle)] = GridString(solution)

In [8]:
class Dataset:
    
    def __init__(self, data: dict, split_boundaries: tuple, split_data: None):
        """
        data: a non-nested dictionary of primitive key-value pairs where
            the key is the input and
            value is the target output
        """
        assert split_boundaries[0] > 0 and split_boundaries[-1] < 1 and split_boundaries == sorted(split_boundaries)
        assert len(data) > len(split_boundaries)
        
        self.data = data
        self.split_boundaries = split_boundaries
        if split_data:
            self.split_data = split_data
        else:
            self.split_data = self.create_split_data()
        
    def create_split_data(self):
        
        inputs = list(self.data)
        np.random.shuffle(inputs)
        
        split = []
        last_boundary = 0
        for boundary in self.split_boundaries + [1]:
            next_boundary = int(len(data) * boundary)
            split.append(data[last_boundary:next_boundary])
            last_boundary = next_boundary
        return split
    
    def get_input_data(self, index=None):
        if index is None:
            return list(self.split_data)
        else:
            return self.split_data[index]
    
    def get_output_data(self, index=None):
        if index is None:
            return [[self.data[k] for k in inputs] for inputs in self.split_data]
        else:
            return [self.data[k] for k in self.split_data[index]]
    
    def save(self, filename):
        with open(filename, 'wb') as f:
            pickle.dump({'data': self.data,
                         'split_boundaries': self.split_boundaries,
                         'split_data': self.split_data}, f)
    
    @staticmethod
    def load(filename):
        with open(filename, 'rb') as f:
            raw = pickle.load(f)
        return Dataset(raw['data'], raw['split_boundaries'], raw['split_data'])