# Train a Quantized MLP on UNSW-NB15 with Brevitas

In this notebook, we will show how to create, train and export a quantized Multi Layer Perceptron (MLP) with quantized weights and activations with [Brevitas](https://github.com/Xilinx/brevitas).
Specifically, the task at hand will be to label network packets as normal or suspicious (e.g. originating from an attacker, virus, malware or otherwise) by training on a quantized variant of the UNSW-NB15 dataset. 

**You won't need a GPU to train the neural net.** This MLP will be small enough to train on a modern x86 CPU, so no GPU is required to follow this tutorial 


## A quick introduction to the task and the dataset

*The task:* The goal of [*network intrusion detection*](https://ieeexplore.ieee.org/abstract/document/283931) is to identify, preferably in real time, unauthorized use, misuse, and abuse of computer systems by both system insiders and external penetrators. This may be achieved by a mix of techniques, and machine-learning (ML) based techniques are increasing in popularity. 

*The dataset:* Several datasets are available for use in ML-based methods for intrusion detection.
The [UNSW-NB15](https://www.unsw.adfa.edu.au/unsw-canberra-cyber/cybersecurity/ADFA-NB15-Datasets/) is one such dataset created by the Australian Centre for Cyber Security (ACCS) to provide a comprehensive network based data set which can reflect modern network traffic scenarios. You can find more details about the dataset on [its homepage](https://www.unsw.adfa.edu.au/unsw-canberra-cyber/cybersecurity/ADFA-NB15-Datasets/).

*Performance considerations:* FPGAs are commonly used for implementing high-performance packet processing systems that still provide a degree of programmability. To avoid introducing bottlenecks on the network, the DNN implementation must be capable of detecting malicious ones at line rate, which can be millions of packets per second, and is expected to increase further as next-generation networking solutions provide increased
throughput. This is a good reason to consider FPGA acceleration for this particular use-case.

## Outline
-------------

TODO reorder

* [Define Helper Functions and Parameters ](#auxilary_&_parameters)
* [Define the quantized MLP model class](#create_model)
* [Define the Train and Test methods](#train_test) 
* [Define the Loss and Optimizer](#loss_optimizer)
* [Load the UNSW_NB15 dataset](#load_dataset)
* [Train, Test and see the Loss](#train_test_loss)
* [Change the model structure after training](#change_model)
* [Brevitas export](#brevitas_export)

## Initial setup

Let's start by making sure we have all the Python packages we'll need for this notebook.

In [1]:
!pip install --user pandas
!pip install --user scikit-learn
!pip install --user tqdm



## Define the Quantized MLP model <a id='1create_model'></a>

We'll now define an MLP model that will be trained to perform inference with quantized weights and activations.
For this, we'll use the quantization-aware training (QAT) capabilities offered by[Brevitas](https://github.com/Xilinx/brevitas).

Our MLP will have four fully-connected (FC) layers in total: three hidden layers with 64 neurons, and a final output layer with a single output, all using 2-bit weights. We'll use 2-bit quantized ReLU activation functions, and apply batch normalization between each FC layer and its activation.

In case you'd like to experiment with different quantization settings or topology parameters, we'll keep define all these topology settings as variables.

In [2]:
input_size = 593      
hidden1 = 64      
hidden2 = 64
hidden3 = 64
weight_bit_width = 1
act_bit_width = 2
num_classes = 1    

Now let's define our quantization settings to be used by Brevitas. We'll use a [dependency injection](https://en.wikipedia.org/wiki/Dependency_injection)-based method for setting up the quantization properties, which is also used in the training scripts for the [Brevitas BNN-PYNQ examples](https://github.com/Xilinx/brevitas/tree/master/brevitas_examples/bnn_pynq). This will enable setting up good default settings depending on which bitwidth(s) we end up using for our MLP. 

In [3]:
from dependencies import value

from brevitas.inject import BaseInjector as Injector
from brevitas.core.bit_width import BitWidthImplType
from brevitas.core.quant import QuantType
from brevitas.core.restrict_val import RestrictValueType
from brevitas.core.scaling import ScalingImplType

class CommonQuant(Injector):
    bit_width_impl_type = BitWidthImplType.CONST
    scaling_impl_type = ScalingImplType.CONST
    restrict_scaling_type = RestrictValueType.FP
    scaling_per_output_channel = False
    narrow_range = True
    signed = True

    @value
    def quant_type(bit_width):
        if bit_width is None:
            return QuantType.FP
        elif bit_width == 1:
            return QuantType.BINARY
        else:
            return QuantType.INT

class CommonWeightQuant(CommonQuant):
    scaling_const = 1.0

class CommonActQuant(CommonQuant):
    min_val = -1.0
    max_val = 1.0

Now we can define our MLP using the layer primitives provided by Brevitas:

In [4]:
from brevitas.nn import QuantLinear, QuantIdentity
import torch.nn as nn

act_layer = QuantIdentity

class QuantMLP(nn.Module):
    def __init__(self, input_size,hidden1, hidden2, hidden3, num_classes):
        super(QuantMLP, self).__init__()
        self.fc1 = QuantLinear(input_size, hidden1, bias=False, weight_bit_width=weight_bit_width, weight_quant=CommonWeightQuant)
        self.batchnorm1 = nn.BatchNorm1d(hidden1)
        self.act1 = act_layer(act_quant=CommonActQuant, bit_width=act_bit_width)
        
        self.fc2 = QuantLinear(hidden1, hidden2, bias=False, weight_bit_width=weight_bit_width, weight_quant=CommonWeightQuant)
        self.batchnorm2 = nn.BatchNorm1d(hidden2)
        self.act2 = act_layer(act_quant=CommonActQuant, bit_width=act_bit_width)
        
        self.fc3 = QuantLinear(hidden2, hidden3, bias=False, weight_bit_width=weight_bit_width, weight_quant=CommonWeightQuant)
        self.batchnorm3 = nn.BatchNorm1d(hidden3)
        self.act3 = act_layer(act_quant=CommonActQuant, bit_width=act_bit_width)
        
        self.fc4 = QuantLinear(hidden3, num_classes, bias=True, weight_bit_width=weight_bit_width, weight_quant=CommonWeightQuant)


    def forward(self, x):
        fc1 = self.fc1(x)
        b1 = self.batchnorm1(fc1)
        act1 = self.act1(b1)
        
        fc2 = self.fc2(act1)
        b2 = self.batchnorm2(fc2)
        act2 = self.act2(b2)

        fc3 = self.fc3(act2)
        b3 = self.batchnorm3(fc3)
        act3 = self.act3(b3)
        
        fc4 = self.fc4(act3)
        return fc4

Note that the MLP's output is not yet quantized. Even though we want the final output of our MLP to be a binary (0/1) value indicating the classification, we've only defined a single-neuron FC layer as the output. While training the network we'll pass that output through a sigmoid function as part of the loss criterion, which [gives better numerical stability](https://pytorch.org/docs/stable/generated/torch.nn.BCEWithLogitsLoss.html). Later on, after we're done training the network, we'll add a quantization node at the end before we export it to FINN.

## Load the UNSW_NB15 dataset <a id='load_dataset'></a>

### Dataset quantization <a id='dataset_qnt'></a>

The goal of this notebook is to train a Quantized Neural Network (QNN) to be later deployed as an FPGA accelerator generated by the FINN compiler. Although we can choose a variety of different precisions for the input, [Murovic and Trost](https://ev.fe.uni-lj.si/1-2-2019/Murovic.pdf) have previously shown we can actually binarize the inputs and still get good accuracy.
Thus, we will create a binarized representation for the dataset by following the procedure defined by [Murovic and Trost](https://ev.fe.uni-lj.si/1-2-2019/Murovic.pdf), which we repeat briefly here:

* Original features have different formats ranging from integers, floating numbers to strings.
* Integers, which for example represent a packet lifetime, are binarized with as many bits as to include the maximum value. 
* Another case is with features formatted as strings (protocols), which are binarized by simply counting the number of all different strings for each feature and coding them in the appropriate number of bits.
* Floating-point numbers are reformatted into fixed-point representation.
* In the end, each sample is transformed into a 593-bit wide binary vector. 
* All vectors are labeled as bad (0) or normal (1)

Following their open-source implementation provided as a Matlab script [here](https://github.com/TadejMurovic/BNN_Deployment/blob/master/cybersecurity_dataset_unswb15.m), we've created a [Python version](dataloader_quantized.py).
This `UNSW_NB15_quantized` class implements a PyTorch `DataLoader`, which represents a Python iterable over a dataset. This is useful because enables access to data in batches.

### Download the training and test set from the [official website](https://www.unsw.adfa.edu.au/unsw-canberra-cyber/cybersecurity/ADFA-NB15-Datasets/)

In [5]:
! wget https://www.unsw.adfa.edu.au/unsw-canberra-cyber/cybersecurity/ADFA-NB15-Datasets/a%20part%20of%20training%20and%20testing%20set/UNSW_NB15_training-set.csv
! wget https://www.unsw.adfa.edu.au/unsw-canberra-cyber/cybersecurity/ADFA-NB15-Datasets/a%20part%20of%20training%20and%20testing%20set/UNSW_NB15_testing-set.csv

--2020-12-15 01:22:25--  https://www.unsw.adfa.edu.au/unsw-canberra-cyber/cybersecurity/ADFA-NB15-Datasets/a%20part%20of%20training%20and%20testing%20set/UNSW_NB15_training-set.csv
Resolving www.unsw.adfa.edu.au (www.unsw.adfa.edu.au)... 202.58.60.197
Connecting to www.unsw.adfa.edu.au (www.unsw.adfa.edu.au)|202.58.60.197|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 32293018 (31M) [text/csv]
Saving to: 'UNSW_NB15_training-set.csv.1'


2020-12-15 01:25:06 (198 KB/s) - 'UNSW_NB15_training-set.csv.1' saved [32293018/32293018]

--2020-12-15 01:25:06--  https://www.unsw.adfa.edu.au/unsw-canberra-cyber/cybersecurity/ADFA-NB15-Datasets/a%20part%20of%20training%20and%20testing%20set/UNSW_NB15_testing-set.csv
Resolving www.unsw.adfa.edu.au (www.unsw.adfa.edu.au)... 202.58.60.197
Connecting to www.unsw.adfa.edu.au (www.unsw.adfa.edu.au)|202.58.60.197|:443... connected.
HTTP request sent, awaiting response... 200 OK
Length: 15380800 (15M) [text/csv]
Saving to: 'UNSW_N

In [6]:
from torch.utils.data import DataLoader, Dataset
from dataloader_quantized import UNSW_NB15_quantized

file_path_train = "UNSW_NB15_training-set.csv"
file_path_test = "UNSW_NB15_testing-set.csv"

train_quantized_dataset = UNSW_NB15_quantized(file_path_train = file_path_train, \
                                              file_path_test = file_path_test, \
                                              train=True)

test_quantized_dataset = UNSW_NB15_quantized(file_path_train = file_path_train, \
                                              file_path_test = file_path_test, \
                                              train=False)

## Define Train and Test  methods  <a id='train_test'></a>
The train and test methods will use a `DataLoader`, which feeds the model with a new predefined batch of training data in each iteration, until the entire training data is fed to the model. Each repetition of this process is called an `epoch`.

In [7]:
def train(model, train_loader, optimizer, criterion):
    losses = []
    # ensure model is in training mode
    model.train()    
    
    for i, data in enumerate(train_loader, 0):        
        inputs, target = data
        optimizer.zero_grad()   
                
        # forward pass
        output = model(inputs.float())
        loss = criterion(output, target.unsqueeze(1))
        
        # backward pass + run optimizer to update weights
        loss.backward()
        optimizer.step()
        
        # keep track of loss value
        losses.append(loss.data.numpy()) 
           
    return losses

In [8]:
def test(model, test_loader):    
    # ensure model is in eval mode
    model.eval() 
    y_true = []
    y_pred = []
   
    with torch.no_grad():
        for data in test_loader:
            inputs, target = data
            output_orig = model(inputs.float())
            # run the output through sigmoid
            output = torch.sigmoid(output_orig)  
            # compare against a threshold of 0.5 to generate 0/1
            pred = (output.detach().numpy() > 0.5) * 1
            target = target.float()
            y_true.extend(target.tolist()) 
            y_pred.extend(pred.reshape(-1).tolist())
        
    return accuracy_score(y_true, y_pred)

## Define Helper Functions and Parameters <a id='auxilary_&_parameters'></a>

Before we start training our MLP we need to define some hyperparameters. Specifically the number of epochs, batch size and learning rate.
Moreover, in order to monitor the loss function evolution over epochs, we need to define a method for it.

In [9]:
num_epochs = 40
batch_size = 100
lr = 0.01 

def display_loss_plot(losses, title="Training loss", xlabel="Iterations", ylabel="Loss"):
    x_axis = [i for i in range(len(losses))]
    plt.plot(x_axis,losses)
    plt.title(title)
    plt.xlabel(xlabel)
    plt.ylabel(ylabel)
    plt.show()

## Define the Loss and Optimizer <a id="loss_optimizer"></a>

As mentioned earlier, we'll use a loss criterion which applies a sigmoid function during the training phase (`BCEWithLogitsLoss`). For the testing phase, we're manually computing the sigmoid and thresholding at 0.5 as seen above.

In [10]:
import onnx 
import torch

# dataset loaders
train_quantized_loader = DataLoader(train_quantized_dataset, batch_size=batch_size, shuffle=True)
test_quantized_loader = DataLoader(test_quantized_dataset, batch_size=batch_size, shuffle=True)

# model to train
model = QuantMLP(input_size, hidden1, hidden2, hidden3, num_classes)

# loss criterion and optimizer
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999))

## Train the model and Verify the Loss <a id="train_test_loss"></a>

Now that we have everything defined, we can finally train the quantized MLP on the quantized dataset. Then, test its accuracy and watch the loss function evolution over epochs.

In [11]:
import numpy as np
from sklearn.metrics import accuracy_score
from tqdm import tqdm, trange

running_loss = []
running_test_acc = []
t = trange(num_epochs, desc="Training loss", leave=True)

for epoch in t:
        %prun loss_epoch = train(model, train_quantized_loader, optimizer,criterion)
        test_acc = test(model, test_quantized_loader)
        t.set_description("Training loss = %f test accuracy = %f" % (np.mean(loss_epoch), test_acc))
        t.refresh() # to show immediately the update           
        running_loss.append(loss_epoch)
        running_test_acc.append(test_acc)

Training loss:   0%|          | 0/40 [00:00<?, ?it/s]

 

Training loss = 0.206721 test accuracy = 0.820483:   2%|▎         | 1/40 [00:11<07:16, 11.19s/it]

 

Training loss = 0.160921 test accuracy = 0.828171:   5%|▌         | 2/40 [00:22<07:05, 11.20s/it]

 

Training loss = 0.153925 test accuracy = 0.837244:   8%|▊         | 3/40 [00:33<06:55, 11.22s/it]

 

Training loss = 0.152099 test accuracy = 0.829690:  10%|█         | 4/40 [00:45<06:45, 11.26s/it]

 

Training loss = 0.154176 test accuracy = 0.844070:  12%|█▎        | 5/40 [00:56<06:35, 11.29s/it]

 

Training loss = 0.150484 test accuracy = 0.838678:  15%|█▌        | 6/40 [01:07<06:24, 11.30s/it]

 

Training loss = 0.150323 test accuracy = 0.818697:  18%|█▊        | 7/40 [01:18<06:11, 11.26s/it]

 

Training loss = 0.146891 test accuracy = 0.837937:  20%|██        | 8/40 [01:30<06:00, 11.28s/it]

 

Training loss = 0.144228 test accuracy = 0.813414:  22%|██▎       | 9/40 [01:41<05:50, 11.30s/it]

 

Training loss = 0.141844 test accuracy = 0.860285:  25%|██▌       | 10/40 [01:52<05:38, 11.30s/it]

 

Training loss = 0.139459 test accuracy = 0.831839:  28%|██▊       | 11/40 [02:03<05:26, 11.25s/it]

 

Training loss = 0.142604 test accuracy = 0.810839:  30%|███       | 12/40 [02:15<05:15, 11.28s/it]

 

Training loss = 0.140008 test accuracy = 0.826617:  32%|███▎      | 13/40 [02:26<05:06, 11.33s/it]

 

Training loss = 0.142259 test accuracy = 0.808410:  35%|███▌      | 14/40 [02:38<04:54, 11.34s/it]

 

Training loss = 0.140141 test accuracy = 0.811701:  38%|███▊      | 15/40 [02:49<04:44, 11.37s/it]

 

Training loss = 0.143105 test accuracy = 0.828973:  40%|████      | 16/40 [03:00<04:32, 11.37s/it]

 

Training loss = 0.139130 test accuracy = 0.863747:  42%|████▎     | 17/40 [03:12<04:22, 11.40s/it]

 

Training loss = 0.133940 test accuracy = 0.868156:  45%|████▌     | 18/40 [03:23<04:10, 11.39s/it]

 

Training loss = 0.135788 test accuracy = 0.830661:  48%|████▊     | 19/40 [03:35<03:59, 11.39s/it]

 

Training loss = 0.138048 test accuracy = 0.886533:  50%|█████     | 20/40 [03:46<03:46, 11.33s/it]

 

Training loss = 0.136244 test accuracy = 0.822196:  52%|█████▎    | 21/40 [03:57<03:35, 11.35s/it]

 

Training loss = 0.136308 test accuracy = 0.823070:  55%|█████▌    | 22/40 [04:09<03:24, 11.38s/it]

 

Training loss = 0.132924 test accuracy = 0.885609:  57%|█████▊    | 23/40 [04:20<03:14, 11.42s/it]

 

Training loss = 0.131081 test accuracy = 0.888525:  60%|██████    | 24/40 [04:32<03:02, 11.42s/it]

 

Training loss = 0.133813 test accuracy = 0.879294:  62%|██████▎   | 25/40 [04:43<02:51, 11.41s/it]

 

Training loss = 0.133536 test accuracy = 0.876816:  65%|██████▌   | 26/40 [04:54<02:40, 11.43s/it]

 

Training loss = 0.131778 test accuracy = 0.832362:  68%|██████▊   | 27/40 [05:06<02:28, 11.42s/it]

 

Training loss = 0.138220 test accuracy = 0.868168:  70%|███████   | 28/40 [05:17<02:17, 11.43s/it]

 

Training loss = 0.135467 test accuracy = 0.864451:  72%|███████▎  | 29/40 [05:29<02:05, 11.43s/it]

 

Training loss = 0.136560 test accuracy = 0.868836:  75%|███████▌  | 30/40 [05:40<01:53, 11.39s/it]

 

Training loss = 0.135325 test accuracy = 0.850289:  78%|███████▊  | 31/40 [05:52<01:43, 11.51s/it]

 

Training loss = 0.131879 test accuracy = 0.873682:  80%|████████  | 32/40 [06:03<01:32, 11.52s/it]

 

Training loss = 0.130507 test accuracy = 0.818442:  82%|████████▎ | 33/40 [06:15<01:20, 11.49s/it]

 

Training loss = 0.132760 test accuracy = 0.831852:  85%|████████▌ | 34/40 [06:26<01:08, 11.48s/it]

 

Training loss = 0.131189 test accuracy = 0.821698:  88%|████████▊ | 35/40 [06:38<00:57, 11.49s/it]

 

Training loss = 0.132866 test accuracy = 0.862119:  90%|█████████ | 36/40 [06:49<00:45, 11.45s/it]

 

Training loss = 0.130355 test accuracy = 0.872516:  92%|█████████▎| 37/40 [07:01<00:34, 11.45s/it]

 

Training loss = 0.131844 test accuracy = 0.876111:  95%|█████████▌| 38/40 [07:12<00:22, 11.43s/it]

 

Training loss = 0.131016 test accuracy = 0.872893:  98%|█████████▊| 39/40 [07:23<00:11, 11.42s/it]

 

Training loss = 0.130985 test accuracy = 0.816001: 100%|██████████| 40/40 [07:35<00:00, 11.43s/it]


In [12]:
test(model, test_quantized_loader)

0.8160010688432201

In [13]:
import matplotlib.pyplot as plt

loss_per_epoch = [np.mean(loss_per_epoch) for loss_per_epoch in running_loss]
display_loss_plot(loss_per_epoch)

<Figure size 640x480 with 1 Axes>

## Change  the model after training <a id="change_model"></a>

Now that the model is trained, its output ranges {0,+1}. However, for implementation purposes we need for the output to belong to {-1,+1}. Therefore we need to add a layer before its output, such as a Quantized Identity layer. 

Furthermore, the input will be changed, so instead of the {0,+1} range, the model will now accept {-1,1} as valid input values.

In [14]:
# save the state_dict
path = "quantized_mlp_unsw_nb15.pt"
torch.save(model.state_dict(), path)

# load the state_dict and create new model with aditional QuantIdentity layer
new_model_to_be = QuantMLP(input_size, hidden1, hidden2, hidden3, num_classes)
new_model_to_be.load_state_dict(torch.load(path))  
new_model_to_be.eval()

class extended_model(nn.Module):
    def __init__(self, my_pretrained_model):
        super(extended_model, self).__init__()
        self.pretrained = my_pretrained_model
        self.identity = QuantIdentity(act_quant=CommonActQuant, bit_width=1)
    
    def forward(self, x):
        x = (x + torch.tensor([1.0])) / 2.0  # shift from {-1,1} {0,1} 
        out_original = self.pretrained(x)
        out_final = self.identity(out_original)   # output as {-1,1}     
        return out_final

new_model = extended_model(my_pretrained_model=new_model_to_be)
new_model.eval()
new_model_output = new_model.forward(test_quantized_dataset.data[:,:-1] * 2.0 - torch.tensor([1.0])) # feed data as {-1,1}
new_model_output = new_model_output

## Brevitas export <a id="brevitas_export" ></a>

FINN expects an ONNX model as input. This can be a model trained with [Brevitas](https://github.com/Xilinx/brevitas). First a few things have to be imported. Then the model can be loaded with the pretrained weights. 

In [15]:
import brevitas.onnx as bo

export_onnx_path = "brevitas_w%d_a%-uNSW_NB15_model.onnx" % (weight_bit_width, act_bit_width)
input_shape = (1, 593)
bo.export_finn_onnx(new_model, input_shape, export_onnx_path)

new_model.eval
print("Model saved to %s" % export_onnx_path)

Model saved to brevitas_w1_a2NSW_NB15_model.onnx




## That's it!
You created, trained and tested a quantized MLP that is ready to be loaded into FINN! Congratulations!