# Импорт

In [1]:
import numpy as np
import pandas as pd

from matplotlib import pyplot as plt
from keras import Model, Sequential
from keras.callbacks import EarlyStopping
from keras.layers import Dense, Dropout, Flatten, Input
from tensorflow.keras.utils import to_categorical

# Методы:

### Загрузка данных игр в судоку из таблицы

In [2]:
def load_data(nb_train=50000, nb_test=320, full=False):

    if full:
        sudokus = pd.read_csv('sudoku.csv').values
    else:
        sudokus = next(
            pd.read_csv('sudoku.csv', chunksize=(nb_train + nb_test))
        ).values
        
    quizzes, solutions = sudokus.T
    flatX = np.array([np.reshape([int(d) for d in flatten_grid], (9, 9))
                      for flatten_grid in quizzes])
    flaty = np.array([np.reshape([int(d) for d in flatten_grid], (9, 9))
                      for flatten_grid in solutions])
    
    return (flatX[:nb_train], flaty[:nb_train]), (flatX[nb_train:], flaty[nb_train:])

### Рассчет точности решений

In [3]:
def diff(grids_true, grids_pred):
    return (grids_true != grids_pred).sum((1, 2))

### Удаление случайных ячеек в решенных играх

In [4]:
def delete_digits(X, n_delet=1):
    grids = X.argmax(3)
    for grid in grids:
        grid.flat[np.random.randint(0, 81, n_delet)] = 0
        
    return to_categorical(grids)

### Поиск решения

In [5]:
def batch_smart_solve(grids, solver):
    grids = grids.copy()
    for _ in range((grids == 0).sum((1, 2)).max()):
        preds = np.array(solver.predict(to_categorical(grids)))
        probs = preds.max(2).T
        values = preds.argmax(2).T + 1
        zeros = (grids == 0).reshape((grids.shape[0], 81))

        for grid, prob, value, zero in zip(grids, probs, values, zeros):
            if any(zero):
                where = np.where(zero)[0]
                confidence_position = where[prob[zero].argmax()]
                confidence_value = value[confidence_position]
                grid.flat[confidence_position] = confidence_value
    return grids

# Подготовка данных к обучению

In [17]:
load_data(nb_train=50000, nb_test=320, full=False)

((array([[[0, 0, 4, ..., 2, 0, 9],
          [0, 0, 5, ..., 0, 0, 1],
          [0, 7, 0, ..., 0, 4, 3],
          ...,
          [6, 0, 0, ..., 1, 0, 5],
          [0, 0, 3, ..., 6, 9, 0],
          [0, 4, 2, ..., 3, 0, 0]],
  
         [[0, 4, 0, ..., 0, 5, 0],
          [1, 0, 7, ..., 9, 6, 0],
          [5, 2, 0, ..., 0, 0, 0],
          ...,
          [0, 9, 0, ..., 5, 4, 3],
          [6, 0, 0, ..., 7, 0, 0],
          [2, 5, 0, ..., 1, 0, 0]],
  
         [[6, 0, 0, ..., 3, 8, 4],
          [0, 0, 8, ..., 0, 7, 2],
          [0, 0, 0, ..., 0, 0, 5],
          ...,
          [3, 1, 0, ..., 0, 5, 0],
          [0, 8, 9, ..., 0, 0, 0],
          [5, 0, 2, ..., 1, 9, 0]],
  
         ...,
  
         [[0, 0, 8, ..., 0, 0, 7],
          [9, 0, 0, ..., 1, 0, 0],
          [0, 1, 0, ..., 4, 0, 0],
          ...,
          [0, 0, 0, ..., 9, 3, 0],
          [5, 9, 0, ..., 0, 0, 0],
          [0, 0, 3, ..., 6, 2, 1]],
  
         [[0, 0, 9, ..., 0, 0, 0],
          [4, 0, 0, ..., 7, 0, 8

In [18]:
input_shape = (9, 9, 10)
(_, ytrain), (Xtest, ytest) = load_data()

Xtrain = to_categorical(ytrain).astype('float32')
Xtest = to_categorical(Xtest).astype('float32')

ytrain = to_categorical(ytrain-1).astype('float32')
ytest = to_categorical(ytest-1).astype('float32')

# Определить модель keras

In [19]:
model = Sequential()
model.add(Dense(64, activation='relu', input_shape=input_shape))
model.add(Dropout(0.4))
model.add(Dense(64, activation='relu'))
model.add(Dropout(0.4))
model.add(Flatten())

grid = Input(shape=input_shape)
features = model(grid)

digit_placeholders = [
    Dense(9, activation='softmax')(features)
    for i in range(81)
]

solver = Model(grid, digit_placeholders)
solver.compile(
    optimizer='adam',
    loss='categorical_crossentropy',
    metrics=['accuracy']
)

# Итерационное обучение модели на тестовых данных

In [20]:
history = solver.fit(
    delete_digits(Xtrain, 0),
    [ytrain[:, i, j, :] for i in range(9) for j in range(9)],
    batch_size=64,
    epochs=1,
    verbose=1,
)



## Решение судоку со случайно удаленными значениями

In [21]:
early_stop = EarlyStopping(patience=2, verbose=1)

i = 1
for nb_epochs, nb_delete in zip(
        #[1, 2, 3, 4, 6, 8, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10, 10],
        [1, 1, 2, 3, 5, 5, 5],
        [1, 2, 5, 10, 20, 40, 55]
        #[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        #[1, 2, 3, 4, 6, 8, 10, 12, 15, 20, 25, 30, 35, 40, 45, 50, 55]
):
    print('Pass n° {} ...'.format(i))
    i += 1
    
    history = solver.fit(
        delete_digits(Xtrain, nb_delete),
        [ytrain[:, i, j, :] for i in range(9) for j in range(9)],
        validation_data=(
            delete_digits(Xtrain, nb_delete),
            [ytrain[:, i, j, :] for i in range(9) for j in range(9)]),
        batch_size=64,
        epochs=nb_epochs,
        verbose=1,
        callbacks=[early_stop]
    )

Pass n° 1 ...


Pass n° 2 ...


Pass n° 3 ...
Epoch 1/2


Epoch 2/2


Pass n° 4 ...
Epoch 1/3


Epoch 2/3














Epoch 3/3
























Pass n° 5 ...
Epoch 1/5


Epoch 2/5


Epoch 3/5












Epoch 4/5


















Epoch 00004: early stopping
Pass n° 6 ...
Epoch 1/5










































Epoch 2/5
































































Epoch 3/5


 18/157 [==>...........................] - ETA: 27s - loss: 17.0519 - dense_85_loss: 0.1918 - dense_86_loss: 0.2425 - dense_87_loss: 0.2045 - dense_88_loss: 0.2155 - dense_89_loss: 0.1909 - dense_90_loss: 0.1867 - dense_91_loss: 0.2223 - dense_92_loss: 0.2033 - dense_93_loss: 0.2307 - dense_94_loss: 0.1927 - dense_95_loss: 0.1803 - dense_96_loss: 0.2224 - dense_97_loss: 0.2038 - dense_98_loss: 0.2196 - dense_99_loss: 0.2297 - dense_100_loss: 0.2168 - dense_101_loss: 0.1912 - dense_102_loss: 0.2056 - dense_103_loss: 0.1895 - dense_104_loss: 0.2005 - dense_105_loss: 0.2194 - dense_106_loss: 0.2105 - dense_107_loss: 0.2066 - dense_108_loss: 0.1951 - dense_109_loss: 0.2027 - dense_110_loss: 0.2143 - dense_111_loss: 0.1914 - dense_112_loss: 0.2109 - dense_113_loss: 0.1956 - dense_114_loss: 0.2043 - dense_115_loss: 0.2037 - dense_116_loss: 0.2211 - dense_117_loss: 0.2299 - dense_118_loss: 0.2083 - dense_119_loss: 0.2153 - dense_120_loss: 0.2135 - dense_121_loss: 0.2037 - dense_122_loss: 0.20





































Epoch 4/5


Epoch 00004: early stopping
Pass n° 7 ...
Epoch 1/5


Epoch 2/5


Epoch 3/5


















Epoch 4/5


Epoch 00004: early stopping


# Результаты

In [None]:
quizzes = Xtest.argmax(3)
true_grids = ytest.argmax(3) + 1
smart_guesses = batch_smart_solve(quizzes, solver)

deltas = diff(true_grids, smart_guesses)
accuracy = (deltas == 0).mean()

In [None]:
print(
"""
Решено:\t {}
Корректных:\t {}
Точность:\t {}
""".format(
deltas.shape[0], (deltas==0).sum(), accuracy
)
)

In [None]:
plt.plot(history.history['acc'])
plt.plot(history.history['val_acc'])
plt.title('model accuracy')
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.legend(['train', 'val'], loc='upper left')
plt.show()