# Train a Quantized MLP on UNSW-NB15 with Brevitas

<font color="red">**Live FINN tutorial:** We recommend clicking **Cell -> Run All** when you start reading this notebook for "latency hiding".</font>

In this notebook, we will show how to create, train and export a quantized Multi Layer Perceptron (MLP) with quantized weights and activations with [Brevitas](https://github.com/Xilinx/brevitas).
Specifically, the task at hand will be to classify handwritten digits into their respective categories (0 through 9) by training on the MNIST dataset. The MNIST dataset consists of grayscale images of digits, and our goal is to build a quantized MLP model that can accurately recognize these digits. 

**You won't need a GPU to train the neural net.** This MLP will be small enough to train on a modern x86 CPU, so no GPU is required to follow this tutorial  Alternatively, we provide pre-trained parameters for the MLP if you want to skip the training entirely.


## A quick introduction to the task and the dataset

*The task:* The goal of digit classification is to accurately identify handwritten digits from images, a fundamental task in computer vision and machine learning. This is particularly useful in applications such as automated postal code recognition, bank check processing, and digitizing handwritten documents. Machine learning (ML) techniques, especially deep learning, have proven to be highly effective in solving this problem.

*The dataset:* The MNIST dataset is one of the most well-known datasets for training and testing machine learning models on the task of handwritten digit recognition. It contains 60,000 training images and 10,000 test images of digits ranging from 0 to 9. Each image is a grayscale 28x28 pixel image, making it relatively simple while still being a challenging task for learning algorithms. The MNIST dataset has become a standard benchmark for evaluating image classification algorithms and is widely used in academic research and practical applications. You can find more details about the dataset on [its homepage](http://yann.lecun.com/exdb/mnist/).


*Performance considerations:* FPGAs are commonly used for implementing high-performance packet processing systems that still provide a degree of programmability. To avoid introducing bottlenecks on the network, the DNN implementation must be capable of detecting malicious ones at line rate, which can be millions of packets per second, and is expected to increase further as next-generation networking solutions provide increased
throughput. This is a good reason to consider FPGA acceleration for this particular use-case.

## Outline
-------------

* [Load the MNIST Dataset](#load_dataset) 
* [Define a PyTorch Device](#define_pytorch_device)
* [Define the Quantized MLP Model](#define_quantized_mlp)
* [Define Train and Test  Methods](#train_test)
* [Train the QNN](#train_qnn)
* [Export to QONNX and Conversion to FINN-ONNX](#export_qonnx)

In [None]:
import os
import onnx
import torch
import ssl

model_dir = os.environ['FINN_ROOT'] + "/notebooks/end2end_example/mnist-dataset"
ssl._create_default_https_context = ssl._create_unverified_context

**This is important -- always import onnx before torch**. This is a workaround for a [known bug](https://github.com/onnx/onnx/issues/2394).

# Load the MNIST Dataset <a id='load_dataset'></a>

### Dataset Quantization <a id='dataset_qnt'></a>

The goal of this notebook is to train a Quantized Neural Network (QNN) to be later deployed as an FPGA accelerator generated by the FINN compiler. 

We will create a binarized representation for the dataset by following these steps:

* Original features consist of grayscale images, each with dimensions of 28x28 pixels. Each image is initially represented as a vector with 784 elements (28x28).
* Pixel values, which range from 0 to 255, are first normalized to scale them into the range [0, 1] by dividing each value by 255.
* The normalized pixel values are then binarized. Values greater than 0.5 are set to 1, while values less than or equal to 0.5 are set to 0.
* In the end, each image is transformed into a binary vector with 784 bits.
* All vectors are labeled according to the digit present in the image (0 through 9).

Following these steps, we prepare the MNIST dataset for use with a quantized neural network and for potential deployment using FPGA accelerators.

We can load the MNIST dataset from the downloaded files and apply preprocessing transformations to prepare it for training as follows:

In [None]:
import numpy as np
from torchvision import datasets,transforms

transform = transforms.Compose([
    transforms.ToTensor(), # Converts images to tensor and scales values to [0, 1]
    transforms.Normalize((0.1307,), (0.3081,)), #Normalize the input tensor
    transforms.Lambda(lambda x: x.view(-1))  # Flatten the input tensor
])

train_quantized_dataset = datasets.MNIST('.',train=True,transform=transform, download=True)
test_quantized_dataset = datasets.MNIST('.',train=False, transform= transform, download= True)


print("Samples in each set: train = %d, test = %s" % (len(train_quantized_dataset), len(test_quantized_dataset))) 
print("Shape of one input sample: " +  str(train_quantized_dataset[0][0].shape))

## Set up DataLoader

Following either option, we now have access to the quantized dataset. We will wrap the dataset in a PyTorch `DataLoader` for easier access in batches.

In [None]:
from torch.utils.data import DataLoader, Dataset

batch_size = 1000

# dataset loaders
train_quantized_loader = DataLoader(train_quantized_dataset, batch_size=batch_size, shuffle=True)
test_quantized_loader = DataLoader(test_quantized_dataset, batch_size=batch_size, shuffle=False) 

In [None]:
count = 0
for x,y in train_quantized_loader:
    print("Input shape for 1 batch: " + str(x.shape))
    print("Label shape for 1 batch: " + str(y.shape))
    count += 1
    if count == 1:
        break

# Define a PyTorch Device <a id='define_pytorch_device'></a> 

GPUs can significantly speed-up training of deep neural networks. We check for availability of a GPU and if so define it as target device.

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Target device: " + str(device))

# Define the Quantized MLP Model <a id='define_quantized_mlp'></a>

We will now define a Multi-Layer Perceptron (MLP) model to be trained for digit classification on the MNIST dataset, using quantized weights and activations.
For this, we'll use the quantization-aware training (QAT) capabilities offered by [Brevitas](https://github.com/Xilinx/brevitas).

Our MLP will consist of two fully-connected (FC) layers: one hidden layers with 512 neurons respectively, and a final output layer with 10 neurons for classification. All layers will use 2-bit quantized weights. We'll apply 2-bit quantized ReLU activation function to the hidden layer, and batch normalization will be applied between the FC layer and its activation.

In case you'd like to experiment with different quantization settings or topology parameters, we'll define all these topology settings as variables.

In [None]:
input_size = 28*28      
hidden = 512   
weight_bit_width = 2
act_bit_width = 2
num_classes = 10  

Now we can define our MLP using the layer primitives provided by Brevitas:

In [None]:
from brevitas.nn import QuantLinear, QuantReLU
import torch.nn as nn

# Setting seeds for reproducibility
torch.manual_seed(0)

model = nn.Sequential(
      QuantLinear(input_size, hidden, bias=True, weight_bit_width=weight_bit_width),
      nn.BatchNorm1d(hidden),
      nn.Dropout(0.5),
      QuantReLU(bit_width=act_bit_width),
      QuantLinear(hidden, num_classes, bias=True, weight_bit_width=weight_bit_width)
)

model.to(device)

Note that the MLP's output is not yet quantized. Even though we want the final output of our MLP to represent one of the ten possible digit classes (0-9) for classification, we've only defined a single-neuron FC layer as the output. While training the network we'll pass that output through a softmax function as part of the loss criterion.
Later on, after we're done training the network, we'll add a quantization node at the end before we export it to FINN.

# Define Train and Test  Methods  <a id='train_test'></a>
The train and test methods will use a `DataLoader`, which feeds the model with a new predefined batch of training data in each iteration, until the entire training data is fed to the model. Each repetition of this process is called an `epoch`.

In [None]:
def train(model, train_loader, optimizer, criterion):
    losses = []
    # ensure model is in training mode
    model.train()    
    
    for i, data in enumerate(train_loader, 0):        
        inputs, target = data
        inputs, target = inputs.to(device), target.to(device)
        optimizer.zero_grad()   
                
        # forward pass
        output = model(inputs.float())
        loss = criterion(output, target)
        
        # backward pass + run optimizer to update weights
        loss.backward()
        optimizer.step()
        
        # keep track of loss value
        losses.append(loss.data.cpu().numpy()) 
           
    return losses

In [None]:
import torch
from sklearn.metrics import accuracy_score

def test(model, test_loader):    
    # ensure model is in eval mode
    model.eval() 
    y_true = []
    y_pred = []
   
    with torch.no_grad():
        for data in test_loader:
            inputs, target = data
            inputs, target = inputs.to(device), target.to(device)
            output = model(inputs.float())
            pred = output.argmax(dim=1, keepdim= True).cpu().numpy()
            target = target.cpu().numpy()
            y_true.extend(target.tolist()) 
            y_pred.extend(pred.reshape(-1).tolist())
        
    return accuracy_score(y_true, y_pred)

# Train the QNN <a id="train_qnn"></a>

Before we start training our MLP we need to define some hyperparameters. Moreover, in order to monitor the loss function evolution over epochs, we need to define a method for it. For this task, we will use a loss criterion suitable for multi-class classification.

In our case, we'll utilize (`CrossEntropyLoss`) as the loss criterion. This function handles multi-class classification directly and combines the softmax activation and cross-entropy loss in one step. Therefore, we don't need to manually apply a sigmoid function or thresholding during the training phase. For testing, the output will be processed using the softmax function to obtain class probabilities, from which we will derive the predicted class labels.

In [None]:
num_epochs = 10
lr = 0.001 

def display_loss_plot(losses, title="Training loss", xlabel="Iterations", ylabel="Loss"):
    x_axis = [i for i in range(len(losses))]
    plt.plot(x_axis,losses)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.show()

In [None]:
# loss criterion and optimizer
criterion = nn.CrossEntropyLoss().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

In [None]:
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm import tqdm, trange

# Setting seeds for reproducibility
torch.manual_seed(0)
np.random.seed(0)

running_loss = []
running_test_acc = []
t = trange(num_epochs, desc="Training loss", leave=True)

for epoch in t:
        loss_epoch = train(model, train_quantized_loader, optimizer, criterion)
        test_acc = test(model, test_quantized_loader)
        t.set_description("Training loss = %f test accuracy = %f" % (np.mean(loss_epoch), test_acc))
        t.refresh() # to show immediately the update           
        running_loss.append(loss_epoch)
        running_test_acc.append(test_acc)

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt

loss_per_epoch = [np.mean(loss_per_epoch) for loss_per_epoch in running_loss]
display_loss_plot(loss_per_epoch)

In [None]:
acc_per_epoch = [np.mean(acc_per_epoch) for acc_per_epoch in running_test_acc]
display_loss_plot(acc_per_epoch, title="Test accuracy", ylabel="Accuracy [%]")

In [None]:
test(model, test_quantized_loader)

In [None]:
# Save the Brevitas model to disk
torch.save(model.state_dict(), "state_dict_self-trained.pth")

# Export to QONNX and Conversion to FINN-ONNX <a id="export_qonnx" ></a>


[ONNX](https://onnx.ai/) is an open format built to represent machine learning models, and the FINN compiler expects an ONNX model as input. We'll now export our network into ONNX to be imported and used in FINN for the next notebooks. Note that the particular ONNX representation used for FINN differs from standard ONNX, you can read more about this [here](https://finn.readthedocs.io/en/latest/internals.html#intermediate-representation-finn-onnx).

You can see below how we export a trained network in Brevitas into a FINN-compatible ONNX representation (QONNX). QONNX is the format we can export from Brevitas, to feed it into the FINN compiler, we will need to make a conversion to the FINN-ONNX format which is the intermediate representation the compiler works on. The conversion of the FINN-ONNX format is a FINN compiler transformation and to be able to apply it to our model, we will need to wrap it into [ModelWrapper](https://finn.readthedocs.io/en/latest/internals.html#modelwrapper). This is a wrapper around the ONNX model which provides several helper functions to make it easier to work with the model. Then we can call the conversion function to obtain the model in FINN-ONNX format.

In [None]:
from brevitas.export import export_qonnx
from qonnx.util.cleanup import cleanup as qonnx_cleanup
from qonnx.core.modelwrapper import ModelWrapper
from qonnx.core.datatype import DataType
from finn.transformation.qonnx.convert_qonnx_to_finn import ConvertQONNXtoFINN

ready_model_filename = model_dir + "/mnist-mlp-ready.onnx"
input_shape = (1, 784)

input_a = np.random.rand(*input_shape).astype(np.float32)  # Random float input in [0,1]
scale = 1.0
input_t = torch.from_numpy(input_a * scale)

#Move to CPU before export
model.cpu()

# Export to ONNX
export_qonnx(
    model, export_path=ready_model_filename, input_shape=input_shape
)

# clean-up
qonnx_cleanup(ready_model_filename, out_file=ready_model_filename)

# ModelWrapper
model = ModelWrapper(ready_model_filename)

# Set input datatype to 'float' for MNIST, not bipolar
model.set_tensor_datatype(model.graph.input[0].name, DataType["FLOAT32"])

model = model.transform(ConvertQONNXtoFINN())
model.save(ready_model_filename)

print("Model saved to %s" % ready_model_filename)

## View the Exported ONNX in Netron

Let's examine the exported ONNX model with [Netron](https://github.com/lutzroeder/netron), which is a visualizer for neural networks and allows interactive investigation of network properties. For example, you can click on the individual nodes and view the properties. Particular things of note:

* The input tensor "0" is annotated with `quantization: finn_datatype: FLOAT32`
* Brevitas `QuantLinear` layers are exported to ONNX as `MatMul`. The shape of the first MatMul node's weight parameter is 784x512
* The weight parameters (second inputs) for MatMul nodes are annotated with `quantization: finn_datatype: FLOAT32`
* The quantized activations are exported as `MultiThreshold` nodes with `domain=qonnx.custom_op.general`

In [None]:
from finn.util.visualization import showInNetron

showInNetron(ready_model_filename)

## That's it! <a id="thats_it" ></a>
You created, trained and tested a quantized MLP that is ready to be loaded into FINN, congratulations! You can now proceed to the next notebook.