Notebook created to build and evaluate a RNN model for solving sudoku problems.
Be sure to get the dataset from https://www.kaggle.com/datasets/radcliffe/3-million-sudoku-puzzles-with-ratings

In [17]:
# Different Imports needed
import torch.utils.data as data
import torch
import pandas as pd
import torch.nn as nn
import torch.optim as optim

In [18]:
# Function to create Sudoku datasets from a DataFrame
def create_sudoku_tensors(df, train_split=0.5):
    # Total number of samples in the DataFrame
    s = df.shape[0]

    # Function to one-hot encode a Sudoku puzzle
    def one_hot_encode(s):
        zeros = torch.zeros((1, 81, 9), dtype=torch.float)
        for a in range(81):
            digit = s[a]
            # Check if the character is a digit
            if digit.isdigit():
                # One-hot encode the digit
                zeros[0, a, int(digit) - 1] = 1 if int(digit) > 0 else 0
        return zeros

    # Apply one-hot encoding to puzzle and solution columns
    quizzes_t = df.puzzle.apply(one_hot_encode)
    solutions_t = df.solution.apply(one_hot_encode)

    # Concatenate tensors and split into training and test sets
    quizzes_t = torch.cat(quizzes_t.values.tolist())
    solutions_t = torch.cat(solutions_t.values.tolist())
    randperm = torch.randperm(s)
    train = randperm[:int(train_split * s)]
    test = randperm[int(train_split * s):]

    return data.TensorDataset(quizzes_t[train], solutions_t[train]), \
           data.TensorDataset(quizzes_t[test], solutions_t[test])


# Function to create a constraint mask for Sudoku puzzles
def create_constraint_mask():
    constraint_mask = torch.zeros((81, 3, 81), dtype=torch.float)

    # Row constraints
    for a in range(81):
        r = 9 * (a // 9)
        for b in range(9):
            constraint_mask[a, 0, r + b] = 1

    # Column constraints
    for a in range(81):
        c = a % 9
        for b in range(9):
            constraint_mask[a, 1, c + 9 * b] = 1

    # Box constraints
    for a in range(81):
        r = a // 9
        c = a % 9
        br = 3 * 9 * (r // 3)
        bc = 3 * (c // 3)
        for b in range(9):
            r = b % 3
            c = 9 * (b // 3)
            constraint_mask[a, 2, br + bc + r + c] = 1

    return constraint_mask


# Function to load the dataset
def load_dataset(subsample=10000):
    dataset = pd.read_csv("sudoku-3m.csv", sep=',')
    my_sample = dataset.sample(subsample)
    train_set, test_set = create_sudoku_tensors(my_sample)
    return train_set, test_set

In [19]:
# Define the SudokuSolver model
class SudokuSolver(nn.Module):
    def __init__(self, constraint_mask, n=9, hidden1=100):
        super(SudokuSolver, self).__init__()
        self.constraint_mask = constraint_mask.view(1, n * n, 3, n * n, 1)
        self.n = n
        self.hidden1 = hidden1

        # Feature vector is the 3 constraints
        self.input_size = 3 * n

        # Define the neural network layers
        self.l1 = nn.Linear(self.input_size, self.hidden1, bias=False)
        self.a1 = nn.ReLU()
        self.l2 = nn.Linear(self.hidden1, n, bias=False)
        self.softmax = nn.Softmax(dim=1)

    # Forward pass of the model
    def forward(self, x):
        n = self.n
        bts = x.shape[0]
        c = self.constraint_mask
        min_empty = (x.sum(dim=2) == 0).sum(dim=1).max()
        x_pred = x.clone()
        for a in range(min_empty):
            # Score empty numbers
            constraints = (x.view(bts, 1, 1, n * n, n) * c).sum(dim=3)
            # Empty cells
            empty_mask = (x.sum(dim=2) == 0)

            f = constraints.reshape(bts, n * n, 3 * n)
            y_ = self.l2(self.a1(self.l1(f[empty_mask])))

            s_ = self.softmax(y_)

            # Score the rows
            x_pred[empty_mask] = s_

            s = torch.zeros_like(x_pred)
            s[empty_mask] = s_
            # Find most probable guess
            score, score_pos = s.max(dim=2)
            mmax = score.max(dim=1)[1]
            # Fill it in
            nz = empty_mask.sum(dim=1).nonzero().view(-1)
            mmax_ = mmax[nz]
            ones = torch.ones(nz.shape[0])
            x.index_put_((nz, mmax_, score_pos[nz, mmax_]), ones)
        return x_pred, x

In [20]:
# Train the model

# Batch size for training
batch_size = 20

# Load the dataset
train_set, test_set = load_dataset()

# Create constraint mask
constraint_mask = create_constraint_mask()

# Create data loaders for training and validation
dataloader_ = data.DataLoader(train_set, batch_size=batch_size, shuffle=True)
dataloader_val_ = data.DataLoader(test_set, batch_size=batch_size, shuffle=True)

# Define loss function
loss = nn.MSELoss()

# Initialize the SudokuSolver model
sudoku_solver = SudokuSolver(constraint_mask)

# Define optimizer
optimizer = optim.Adam(sudoku_solver.parameters(), lr=0.01, weight_decay=0.000)

# Number of epochs for training
epochs = 50

# Lists to store training and validation losses
loss_train = []
loss_val = []

# Training loop
for e in range(epochs):
    total_correct_cells = 0
    total_cells = 0
    for i_batch, ts_ in enumerate(dataloader_):
        sudoku_solver.train()
        optimizer.zero_grad()
        pred, mat = sudoku_solver(ts_[0])
        ls = loss(pred, ts_[1])
        ls.backward()
        optimizer.step()
        loss_train.append(ls.item())
        
        # Calculate accuracy during training
        correct_cells = (pred.argmax(dim=2) == ts_[1].argmax(dim=2)).sum().item()
        total_correct_cells += correct_cells
        total_cells += ts_[1].size(0) * ts_[1].size(1)
        
    # Calculate validation accuracy after each epoch
    sudoku_solver.eval()
    with torch.no_grad():
        total_val_correct_cells = 0
        total_val_cells = 0
        for i_batch, ts_ in enumerate(dataloader_val_):
            test_pred, test_fill = sudoku_solver(ts_[0])
            val_correct_cells = (test_pred.argmax(dim=2) == ts_[1].argmax(dim=2)).sum().item()
            total_val_correct_cells += val_correct_cells
            total_val_cells += ts_[1].size(0) * ts_[1].size(1)
            
        accuracy_val = total_val_correct_cells / total_val_cells * 100
        loss_val.append(accuracy_val)
        
        print(f"Epoch {e}: Validation Accuracy: {accuracy_val}%")


Epoch 0: Validation Accuracy: 58.403950617283954%
Epoch 1: Validation Accuracy: 59.58172839506173%
Epoch 2: Validation Accuracy: 58.8279012345679%
Epoch 3: Validation Accuracy: 58.821728395061726%
Epoch 4: Validation Accuracy: 58.98493827160494%
Epoch 5: Validation Accuracy: 59.77185185185186%
Epoch 6: Validation Accuracy: 59.181975308641974%
Epoch 7: Validation Accuracy: 60.026172839506174%
Epoch 8: Validation Accuracy: 60.52864197530864%
Epoch 9: Validation Accuracy: 60.309382716049384%
Epoch 10: Validation Accuracy: 59.7879012345679%
Epoch 11: Validation Accuracy: 59.55802469135803%
Epoch 12: Validation Accuracy: 60.46691358024692%
Epoch 13: Validation Accuracy: 60.67086419753086%
Epoch 14: Validation Accuracy: 59.95456790123457%
Epoch 15: Validation Accuracy: 60.4641975308642%
Epoch 16: Validation Accuracy: 61.01753086419753%
Epoch 17: Validation Accuracy: 60.8232098765432%
Epoch 18: Validation Accuracy: 60.95901234567901%
Epoch 19: Validation Accuracy: 61.86962962962963%
Epoch 20:

In [16]:
torch.save(sudoku_solver.state_dict(), "rnn_model")