Reminder: make sure your environment is setup correctly (see [installation instructions](README.md))

# Training a Neural Network for Perception based solving

In [None]:
%matplotlib inline
import torch 
import torch.nn as nn
import numpy as np
import lightning as pl
from pathlib import Path
import os
import matplotlib.pyplot as plt
import seaborn as sns

For convenience, we will group all hyperparameters invovled with building and training our neural network into a single dict. 

These hyperparameters will appears in different cells below. Later you will have the opportunity to play around with them!

In [None]:
# hyperparameters for our neural network
hyperparams = {
    # preprocessing
    'size':(300,300),
    # architecture
    'dropout':0.25,
    'batchnorm':True,
    'pooling_kernel_size':3,
    'pooling_stride':3,
    # training
    'train_batch_size':1,
    'learning_rate':0.001,
    'max_total_epochs':5,
    'early_stopping_patience':5,
}

## Preprocessing

In [None]:
import torchvision.transforms as T 
from torchvision.io import read_image

base_tsfm = T.Compose([
    T.Lambda(lambda img_str:read_image(img_str)),
    T.ToPILImage(),
    T.Resize(hyperparams['size']), # scale down image here to speed up training
    T.ToTensor(),
    T.Normalize( # Normalize the data (all values between -1 and 1) to improve convergence
        mean=[0.5320, 0.5209, 0.5204],
        std=[0.1862917, 0.19031495, 0.18998064]
    ), 
    #T.Grayscale(1) # grayscale the image
])



In [None]:
from PIL import Image
img_path = os.path.join('data', 'visual_sudoku', 'img', '059.jpg')
original = Image.open(img_path)

see_torch_image = T.ToPILImage()
image_preprocessed = base_tsfm(img_path)

fig, axes = plt.subplots(1,2, figsize=(6,9))
axes[0].imshow(original)
axes[0].set_title('original image')
# This is what the neural network actually sees
axes[1].imshow(see_torch_image(image_preprocessed))
axes[1].set_title('preprocessed image')
for ax in axes:
    ax.set_axis_off()
print()

## Neural Network architecture


To build our neural network, we have to define its architecture. 

We want to classify multiple cells in an image. Therefore, we will use a convolutional neural network (CNN). 

More specifically, we take inspiration from previous work to build a 5-layers CNN which takes a full sudoku image as input and provide an 81x10 probability distribution matrix as output. 


<img src="assets/cnn.png" alt="CNN" style="width: 800px;"/>


Following best practicies regarding building deep neural network architecutre, we first define helpers function to build this CNN more easily. 

In [None]:
from utils import make_conv_layers, make_fc_layers_global_pooling

These functions helps us to create dense and convolutional layers, by providing a configuration as a list. 

Our CNN has two components:
- Feature extractor: mainly composed of convolutional layers. Extract features from the image by mapping the input to a latent space. 
- Classifier: global average pooling layer, followed by a softmax layer to output a probability distribution matrix of the desired shape (81 x 10)

In [None]:
class FullImageCNN(nn.Module):
    """Generalized 5-layers CNN, similar to https://github.com/Kyubyong/sudoku or
    SudokuNet used in NeurASP.
    """

    def __init__(self, grid_shape=(9, 9), n_classes=10) -> None:
        super().__init__()
        self.grid_shape = grid_shape
        self.n_classes = n_classes
        
        # backbone with 5 convolutional layers
        full_image_backbone_config = [
            (32, 4, 2, 0),
            (64, 3, 2, 0), 
            (128, 3, 2, 0), 
            (256, 2, 2, 0), 
            (512, 2, 2, 0), 
        ]

        conv_layers = make_conv_layers(
            full_image_backbone_config, 
            in_channels=3, # number of channel in the input image (3 for RGB, 1 for grayscale)
            p=hyperparams['dropout'], # dropout rate
            pool_ks=hyperparams['pooling_kernel_size'], 
            pool_str=hyperparams['pooling_stride'], 
            batch_norm=hyperparams['batchnorm'] # controls whether to add batchnorm in-between convolutional layers
        )
        # because of the last convolutional layer, this backbone output a tensor whose first dimension is of size 512
        self.feat_extract = nn.Sequential( *conv_layers)

        # classifier
        out_layers = make_fc_layers_global_pooling(
            in_dim=512, # this should match with the size of first dimension of the previous layer (hence 512)
            out_shape=grid_shape, 
            num_classes=n_classes
        )
        self.classifier = nn.Sequential(*out_layers,  nn.Softmax(-1))
        # the output is of size 81 x 10 

    def forward(self, x):
        h = self.feat_extract(x)
        return {
            'predictions': self.classifier(h)
        }


In [None]:
full_image_cnn = FullImageCNN((9,9), n_classes=10)
full_image_cnn  

## Train - Validate - Test

We now split our dataset into a training, a validation and a test set. For that, we define a LightningDataModule. 

In [None]:
# Wrapper class to handle our data more easily
class ImgDataset:
    def __init__(self, dirpath_images:str, img_transform:torch.nn.Module, labels, ) -> None:
        self.imgs_path = dirpath_images
        self.tf = img_transform
        self.labels = labels
    
    def __len__(self):
        return len(self.imgs_path)
    
    def __getitem__(self, index):
        return self.tf(self.imgs_path[index]), {   
            'id':int(os.path.basename(self.imgs_path[index]).split('.')[0]),
            'label':self.labels[index],
        }

In [None]:
from lightning.pytorch.utilities.types import EVAL_DATALOADERS, TRAIN_DATALOADERS
from torch.utils.data import Subset, DataLoader
torch.manual_seed(243)

class FullImageSudokuDataModule(pl.LightningDataModule):
    """Data module for Visual Sudoku Solver

    Expect data to be in the following format: 
        data_dir

        ├── img

        │   ├── 001.png

        │   └── 002.png

        ├── label

        │   ├── 001.npy

        │   └── 002.npy

        where 

        Args:
            data_dir (str, optional): path to data. Defaults to '.'.
            img_transform (nn.Module, optional): preprocessing pipeline. Defaults to the Identity function.
            train_batch_size (int, optional): batch size during training. Defaults to 1.
    """
    def __init__(self, data_dir:str = '.', img_transform:torch.nn.Module=torch.nn.Identity(), train_batch_size=1) -> None:
        super().__init__()
        self.data_dir = Path(data_dir).resolve()
        self.train_batch_size = train_batch_size
        assert (self.data_dir / 'img').is_dir(), f'bad data format, {data_dir}/img'
        assert (self.data_dir / 'label').is_dir(), f'bad data format, {data_dir}/label'
        self.img_transforms = img_transform
        self.imgs_fname = [os.path.join(self.data_dir,'img',n) for n in sorted(os.listdir(self.data_dir / 'img'), key=lambda n:int(n.split('.')[0]))]
        self.labels = np.array([np.load(os.path.join(self.data_dir,'label',l)) for l in sorted(os.listdir(self.data_dir / 'label'), key=lambda n:int(n.split('.')[0]))])
        
        # train val test split
        n_train = 72
        n_val = 11
        n_test = len(self.imgs_fname) - n_train - n_val
        self.train_subset, self.val_subset, self.test_subset = torch.utils.data.random_split(np.arange(len(self.labels)), [n_train, n_val, n_test])
        self.img_dataset = ImgDataset(self.imgs_fname, self.img_transforms, torch.from_numpy(self.labels))
    
    def test_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(Subset(self.img_dataset, self.train_subset.indices), batch_size=1, shuffle=False)

    def train_dataloader(self) -> TRAIN_DATALOADERS:
        # batch size could be increase here
        return DataLoader(Subset(self.img_dataset, self.train_subset.indices), batch_size=self.train_batch_size, shuffle=True)

    def val_dataloader(self) -> EVAL_DATALOADERS:
        return DataLoader(Subset(self.img_dataset, self.train_subset.indices), batch_size=1, shuffle=False)

## Learning and evaluation

We use lightning to train and evaluate our architecture. 

Wrapping our CNN into a LightningModule enriches it with useful features, thus reducing the amount of glue code.

In [None]:
from collections import Counter
from functools import wraps
import time

class FullImageCNNLightning(pl.LightningModule):
    """ Wrapper for Full Image Sudoku CNN module
    """
    def __init__(self, cnn: nn.Module, lr=1e-2, num_pred_classes=10, puzzle_shape=(9, 9), hparams=dict()):
        super().__init__()
        self.cnn = cnn
        self.lr = lr 
        self.num_pred_classes = num_pred_classes
        self.puzzle_shape = puzzle_shape
        self.save_hyperparameters("lr",  "num_pred_classes", hparams)

    def forward(self, x,):
        cnn_output = self.cnn(x)
        return cnn_output

    def training_step(self, batch, batch_idx):
        x, target = batch
        loss = 0
        cnn_output = self(x)
        loss_dict = self.compute_loss(cnn_output, target)
        loss += torch.stack([v for v in loss_dict.values()]).mean() 
        return loss

    def compute_loss(self, cnn_output:dict, target) -> dict:
        target_dim = target['label'].flatten().shape
        # cross entropy loss (binary form) 
        weighted_bce = torch.nn.BCELoss()
        cell_value_loss = weighted_bce(
            cnn_output['predictions'].view(*target_dim, -1), torch.eye(self.num_pred_classes)[target['label'].flatten()]
        )
        return {
            'cell_value_cross_entropy': cell_value_loss
        }

    def validation_step(self, batch, batch_idx):
        return self._shared_eval(batch, batch_idx, testing=False)

    def test_step(self, batch, batch_idx):
        return self._shared_eval(batch, batch_idx, testing=True)

    def _shared_eval(self, batch, batch_idx, log=True, testing=False):
        x, target = batch
        cnn_output = self(x)
        eval_output = dict()
        str_eval_type = 'test' if testing else 'val'
        
        target_shape = target['label'].flatten().shape
        labels_reduce = target['label'].reshape(*target_shape)
        pred = torch.argmax(
                cnn_output['predictions'].reshape(
                    *target_shape, -1), -1).long()
        
        eval_output[f'{str_eval_type}_cell_accuracy'] = (
            pred == labels_reduce).sum() / pred.numel()
        
        # per-label accuracy
        for l in torch.arange(10):
            idx = labels_reduce == l
            per_cell_acc = (pred[idx] == labels_reduce[idx]
                            ).sum() / labels_reduce[idx].numel()
            if labels_reduce[idx].numel() != 0:
                eval_output[f'{str_eval_type}_cell_accuracy_{l.item()}'] = per_cell_acc
            else:
                # lead to NaNs
                pass
        
        for k,v in eval_output.items():
            self.log(k, v, prog_bar= True,
                on_step=False, on_epoch=True, )
        return eval_output

    def configure_optimizers(self):
        # we will use a variant of the ADAM optimizer
        optimizer = torch.optim.AdamW(self.cnn.parameters(), lr=self.lr)
        return optimizer


The only remaining step is to use the Trainer, provided by lightning

In [None]:
from lightning import Trainer
from lightning.pytorch.loggers import CSVLogger, TensorBoardLogger
from lightning.pytorch.callbacks import EarlyStopping

# Early Stopping strategy: monitor the accuracy and stop the training if it does not increase over a number of epochs
es = EarlyStopping('val_cell_accuracy', mode='max', patience=hyperparams['early_stopping_patience'])

logger = CSVLogger(
        save_dir='log/',
        name='neural_network_p2',
)

trainer = Trainer(
    logger = logger,
    # maximum amount of epochs 
    max_epochs=hyperparams['max_total_epochs'], 
    log_every_n_steps=5,
    inference_mode=False,
    callbacks=[es],
    enable_progress_bar=True,
)

data_module = FullImageSudokuDataModule(
    os.path.join('data', 'visual_sudoku/'), 
    base_tsfm, 
    train_batch_size=hyperparams['train_batch_size'] # number of images per batch, during training
)

print('instances in test set:', data_module.test_subset.indices)

ml_model = FullImageCNNLightning(
    full_image_cnn,
    lr=hyperparams['learning_rate'], # learning rate for the gradient-based update
    num_pred_classes=10, 
    puzzle_shape=(9,9), 
    hparams=hyperparams
)

In [None]:

# use the trainer for Training, Validaiton and Testing
trainer.validate(ml_model,datamodule=data_module)
trainer.fit(ml_model, datamodule=data_module)
trainer.test(ml_model, datamodule=data_module)


In [None]:
saved_path = os.path.join(trainer.logger.root_dir, f'version_{trainer.logger.version}')
checkpoint_path = os.path.join(saved_path,'checkpoints')
print('your model was saved in ', os.path.join(saved_path, 'checkpoints', os.listdir(checkpoint_path)[0]))
print('Version number : ',trainer.logger.version )
print('see hyperparameters values in ', os.path.join(saved_path, 'hparams.yaml'))

In [None]:
# Helper function to load weights of a pre-trained CNN
from collections import OrderedDict
def load_from_checkpoint(version_id=0):
    cnn = FullImageCNN()
    dir_chkp_path = os.path.join('log', 'neural_network_p1', f'version_{version_id}', 'checkpoints')
    chkp_path = os.path.join(dir_chkp_path, os.listdir(dir_chkp_path)[0])
    CKPT_state_dict = torch.load(chkp_path)
    layer_names = list(cnn.state_dict().keys())
    to_load = OrderedDict(**{k.split('cnn.')[1]:v for k,v in CKPT_state_dict['state_dict'].items() if k.split('cnn.')[1] in layer_names})
    cnn.load_state_dict(to_load)
    return cnn.eval()

### Let's visualize the output

In [None]:
def view_classify(probs, cell_idx, imgtitle=""):
    ''' Function for viewing an image and it's predicted classes.
    '''
    ps = probs.numpy().squeeze()[cell_idx]
    # fig, (ax1, ax2) = plt.subplots(figsize=(6,9), ncols=2)
    fig, ax2 = plt.subplots(1,1, figsize=(6,6))
    ax2.barh(np.arange(10), ps)
    ax2.set_aspect(0.1)
    ax2.set_yticks(np.arange(10))
    ax2.set_yticklabels(['empty'] + np.arange(1,10).tolist())
    ax2.set_title(f'Class Probability {cell_idx}')
    ax2.set_xlim(0, 1.1)
    plt.tight_layout()

@torch.no_grad() # disable gradients computation
def show_one_prediction(cnn, torch_image, cell_idx):
    output = cnn(torch_image.unsqueeze(0)) # add a batch dimension
    view_classify(output['predictions'].reshape(9,9,-1), cell_idx)


In [None]:
fig, axes = plt.subplots(1,2, figsize=(6,9))
axes[0].imshow(original)
axes[0].set_title('original image')
# This is what the neural network actually sees
axes[1].imshow(see_torch_image(image_preprocessed))
axes[1].set_title('preprocessed image')
for ax in axes:
    ax.set_axis_off()
show_one_prediction(full_image_cnn, image_preprocessed, (6,2)) # change indices of the cell you want to see

## Solving the Sudoku

In [None]:
from utils import get_sudoku_model, solve_sudoku

@torch.no_grad() # disable gradient computation
def get_predictions(cnn, image_preprocessed): 
    output = cnn(image_preprocessed.unsqueeze(0))['predictions']
    return output.detach().squeeze().numpy().reshape(9,9,-1)

Our CNN outputs a probability tensor of size 9 x 9 x 10, representing the distribution over possible values for each cell. By taking the `argmax` probability for each cell, we obtain the class assigned by the CNN for the cell. 

We can use these argmax classes as input for our basic sudoku solver:

In [None]:

ml_predictions = get_predictions(full_image_cnn, image_preprocessed)
ml_instance = ml_predictions.argmax(-1)
sudoku_problem = get_sudoku_model(ml_instance)
results = solve_sudoku(sudoku_problem['model'], sudoku_problem['variables'])
results

Even with a high accuracy (>95%), our CNN may still make prediction errors. Those errors may lead to an infeasible sudoku.
Let's see how many sudoku are: 
1. solved (solver finds a feasible solution)
2. solved **correctly** (the solution found corresponds to the true solution)


In [None]:
from cpmpy.solvers.solver_interface import ExitStatus
# here is a loop over test instances 
test_set = [6, 35, 91, 96, 10, 80, 11, 72, 51, 20, 75, 82, 21, 4, 41, 14, 88, 56, 79, 94]

count_unsat = 0 
count_solved = 0

for instance_id in test_set:
    img_path = os.path.join('data', 'visual_sudoku', 'img', f'{instance_id:03d}.jpg')
    label_path = os.path.join('data', 'visual_sudoku', 'label', f'{instance_id:03d}.npy')
    label = np.load(label_path)
    # preprocessing
    img_preprocessed = base_tsfm(img_path)
    # machine learning predictions
    ml_predictions = get_predictions(full_image_cnn, img_preprocessed)
    # argmax to get predicted puzzle 
    ml_instance = ml_predictions.argmax(-1)
    # solve sudoku 
    sudoku_problem = get_sudoku_model(ml_instance)
    results = solve_sudoku(sudoku_problem['model'], sudoku_problem['variables'])
    # evaluate the status 
    if results['status'] == ExitStatus.UNSATISFIABLE:
        count_unsat += 1
    # get the true solution using labels as starting clues
    sudoku_problem = get_sudoku_model(label)
    ground_truth = solve_sudoku(sudoku_problem['model'], sudoku_problem['variables'])
    # every cell in ground truth should match with cell in our result
    if np.all(ground_truth['solution'] == results['solution']):
        count_solved += 1
print(f'Rate of infeasible puzzles: {count_unsat/len(test_set):2%}')
print(f'Rate of correctly solved puzzles: {count_solved/len(test_set):2%}')
    

# Play a bit!

**Try to improve the accuracy of your neural network by trying out different hyperparameter values!**

Change values in the ´hyperparams´ dictionary at the begining and just restart the notebook. Each of your trained CNNs is saved in the ´log´ folder. 

____________

The following section contains advanced challenges, feel free to try them out at the end of the tutorial

# [Challenge] Handling Data Imbalance

Our dataset contains 103 Visual Sudoku instances. Each Sudoku image is of size 300x300, in color, aligned, and well centered. Some are blurry, riddled with creases or other visual artifacts. Some contains handwritten digits, others are slightly shifted or rotated, etc.

<img src="data/visual_sudoku/img/073.jpg" alt="sudoku" style="width: 400px;"/>

Unlike numerical instances that solver can handle directly, these require to interpret the image first. Therefore, we want to train a neural network to infer the content of sudoku grid. This network should learn to predict values for all cells, whether they are empty or contain a starting clue. 

As such, this task can be viewed as a multioutput classification problem. From a given image, our machine learning should classify 81 cells. Each cell label $\in \{empty,1,\ldots,9\}$

Let's visualize the distribution of labels in the dataset.

<img src="assets/imbalance.png" alt="imbalance" style="width: 400px;"/>

The `empty` class is way more prevalent than others. This imbalance in the data can hinder the learning process if not handle carefully.

**Task: There exists many methods to handle data imbalance, but all of them may not be practical. One of them consists of assigning different weights to samples when computing the training loss, depending on the inverse popularity of their class in the current batch.**
Does it have any impact on the overall accuracy? 



________


# [Challenge] Printed and Handwritten (part 1)

As you may have notice, some images contains both printed and handwritten digits. 
We can assume that, as printed value genrally make up the starting clues of the puzzle, they are more reliable than handwritten values provided by a player.

<img src="data/visual_sudoku/img/030.jpg" alt="sudoku" style="width: 400px;"/>

A smart hybrid CP-ML solver could exploit this information to improve its rate of correctly solved instances. 
The fist step towards design such system is to build a machine learning architecture capable of predicting both the value of a cell and its font style (printed or handwritten). 

This can be framed as a *multitask classification problem*. A simple work around is to train an additional machine learning estimator solely for font style classification.
However, there exists multiple ways to change the current neural network architecture to handle such a problem. 

**Task: train a modified CNN that can predict both cell value and font style.**

*Tip: we provide additional labels about font style for each cell, in `data/visual_sudoku/style` folder. You can edit the `ImgDataset` class to also provides such labels during training* 

In [None]:
# Wrapper class to handle our data more easily
class ImgDatasetStyle:
    def __init__(self, dirpath_images:str, img_transform:torch.nn.Module, labels, styles) -> None:
        self.imgs_path = dirpath_images
        self.tf = img_transform
        self.labels = labels
        self.styles = styles
    
    def __len__(self):
        return len(self.imgs_path)
    
    def __getitem__(self, index):
        return self.tf(self.imgs_path[index]), {   
            'id':int(os.path.basename(self.imgs_path[index]).split('.')[0]),
            'label':self.labels[index],
            'label_style':self.styles[index],
        }

class FullImageSudokuStyleDataModule(FullImageSudokuDataModule):
    def __init__(self, data_dir:str = '.', img_transform:torch.nn.Module=torch.nn.Identity(), train_batch_size=1) -> None:
        super().__init__(datadir, img_transform, train_batch_size)
        self.labels_style = np.array([np.load(os.path.join(self.data_dir,'style',l)) for l in sorted(os.listdir(self.data_dir / 'style'), key=lambda n:int(n.split('.')[0]))])
        self.img_dataset = ImgDataset(self.imgs_fname, self.img_transforms, torch.from_numpy(self.labels), torch.from_numpy(self.labels_style))
    

Now, style labels are available in `compute_loss` function at `target['label_style]`