# Can Graph Neural Network solve sudoku?

The n-Sudoku graph is a graph with $n^4$ vertices, corresponding to the cells of an $n^2$ by $n^2$ grid. Two distinct vertices are adjacent if and only if they belong to the same row, column, or n-by-n box. In our case $n = 3$.

Solving a sudoku can be seen as node classification problem over graph: each blank cell (node) has to be mapped into one of the following class C={1,...,9}

## Dataset and preprocessing

To train and validate the model, it is used only a subset of size 100000 of the *9 milion sudoku dataset*.
Each puzzle is first flattened into a vector and then each cell is one-hot-encoded.

To show the power of GNN it is only used 20% of the subset dataset to train the model.

## Architecture 

A graph neural network is defined by stacking various "aggregation layer" that transform the representation of a given node into a new representation based on the node itsel and its neighbours defined by the graph.
As aggregation layer is used the classic self-attention layer but every node is allowed to attend only to its neighbours using a mask.

## Training 

The model is trained to classify all the node in the sudoku graph and not only the blank cell, using cross-entropy loss, for a maximum of 20 epochs.

The model achieve over 90% accuracy over the validation set after only 2 epochs.

## Conclusion

After 20 epochs the model achieve over 97% accuracy and there is still room for improvement.

In [None]:
import networkx as nx
import torch
from torch import nn
import numpy as np
import pandas as pd
from random import shuffle, sample
import pytorch_lightning as pl

In [None]:
# 9 milion csv
num_examples = 100000
train_split = 0.5

def encode(list_x):
    x = torch.LongTensor(list_x)
    x_encoded = torch.zeros((81, 10), dtype=torch.float)
    x_encoded[torch.arange(81), x.reshape(-1)] = 1

    return x_encoded

print("loading")
df = next(
            pd.read_csv('../input/sudoku/sudoku.csv', chunksize=(num_examples))
        )

puzzle, solution = df[["puzzle", "solution"]].values.T
print(f"total row: {df.shape[0]}")

print("creating graphs ...")
graphs = [(encode([int(d) for d in p]), 
           torch.LongTensor([int(d)-1 for d in s])) for p,s in zip(puzzle, solution)]

print("shuffling")
shuffle(graphs)

print("splitting")
train_9m = graphs[:int(num_examples*train_split)]
val_9m = graphs[int(num_examples*train_split):]
print(len(train_9m), len(val_9m))

In [None]:
class Dataset(torch.utils.data.Dataset):
  'Characterizes a dataset for PyTorch'
  def __init__(self, dataset):
        # dataset is a list of tuples (x, y)
        self.dataset = dataset

  def __len__(self):
        'Denotes the total number of samples'
        return len(self.dataset)

  def __getitem__(self, index):
        # Load data and get label
        data = self.dataset[index]
        X = data[0]
        y = data[1]

        return X, y

In [None]:
training_set = Dataset(train_9m)
valid_set = Dataset(val_9m)

In [None]:
class GNNTransformer(nn.Module):

    def __init__(self):
        super(GNNTransformer, self).__init__()

        G = nx.sudoku_graph()
        # If a BoolTensor is provided, positions with True is not allowed to attend while False values will be unchanged
        self.register_buffer("A", torch.BoolTensor(1 - nx.adjacency_matrix(G).todense()))

        hidden_dim = 128
        num_heads = 4
        num_layers = 8

        # node embedding to higher dimension
        self.embedding = nn.Linear(10, hidden_dim)

        # aggregation layers
        self.transformers = nn.ModuleList([nn.MultiheadAttention(hidden_dim, num_heads) for _ in range(num_layers)])
        self.pre_norm_layers = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)])
        self.post_norm_layers = nn.ModuleList([nn.LayerNorm(hidden_dim) for _ in range(num_layers)])

        # mlp
        self.mlp = nn.Sequential(*[
                                   nn.Linear(hidden_dim, hidden_dim*2),
                                   nn.ReLU(),
                                   nn.Linear(hidden_dim * 2, 9)
        ])

    def forward(self, batch):

        # batch is a matrix of size b x 81 x 10
        x = self.embedding(batch)
        residual = x

        for transformer, pre_norm_layer, post_norm_layer in zip(self.transformers, self.pre_norm_layers, self.post_norm_layers):
            x = pre_norm_layer(x)

            # input requirements: sequence length (81) x batch x hidden_dim
            x = x.transpose(0,1)
            x, _ = transformer(x, x, x, attn_mask=self.A)
            
            # input requirements: batch x sequence length (81) x hidden_dim
            x = x.transpose(0,1)
            x = residual + nn.functional.relu(post_norm_layer(x))

            residual = x

        logits = self.mlp(x)

        return logits


In [None]:
class PlModel(pl.LightningModule):

    def __init__(self):
        super().__init__()

        self.model = GNNTransformer()

        # metric to log
        self.metric = pl.metrics.Accuracy()

        # define loss
        self.loss = torch.nn.CrossEntropyLoss()

    def forward(self, batch):
        return self.model(batch)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = 1e-4, weight_decay=0.)
        return optimizer

    def training_step(self, batch, batch_idx):

        # x shape: bacth size x num nodes (81) x num features (10)
        x = batch[0]

        # y shape: batch size x num_nodes (81) 
        y = batch[1].reshape(-1)
        
        # logits: batch size x num_nodes (81) x num classes (9)
        logits = self(x).reshape(-1, 9)

        J = self.loss(logits, y)

        y_pred = torch.nn.Softmax(dim=1)(logits)

        # logs metrics for each training_step,
        # and the average across the epoch, to the progress bar and logger
        self.log('train_loss', J, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('train_acc', self.metric(y_pred, y), on_step=False, on_epoch=True, prog_bar=True, logger=True)

        return J

    def validation_step(self, batch, batch_idx):

        # x shape: bacth size x num nodes (81) x num features (10)
        x = batch[0]
        
        mask = (x.argmax(dim=2) == 0).reshape(-1)

        # y shape: batch size x num_nodes (81) 
        y = batch[1].reshape(-1)
        
        # logits: batch size x num_nodes (81) x num classes (9)
        logits = self(x).reshape(-1, 9)

        J = self.loss(logits, y)

        y_pred = torch.nn.Softmax(dim=1)(logits)

        # logs metrics for each training_step,
        # and the average across the epoch, to the progress bar and logger
        self.log('val_loss', J, on_step=False, on_epoch=True, prog_bar=True, logger=True)
        self.log('val_acc', self.metric(y_pred[mask], y[mask]), on_step=False, on_epoch=True, prog_bar=True, logger=True)


In [None]:
train_loader = torch.utils.data.DataLoader(training_set, batch_size=16, shuffle=True)
valid_loader = torch.utils.data.DataLoader(valid_set, batch_size=16, shuffle=False)

In [None]:
model = PlModel()

In [None]:
trainer = pl.Trainer(
    max_epochs=20, 
    progress_bar_refresh_rate=20, 
    gradient_clip_val=0.1, 
    gpus=1,
    )
trainer.fit(model, train_loader, valid_loader)