## VPRTempo - Quantized Aware Training and Inferencing Tutorial

### By Adam D Hines (https://research.qut.edu.au/qcr/people/adam-hines/)

VPRTempo is based on the following paper, if you use or find this code helpful for your research please consider citing the source:
    
[Adam D Hines, Peter G Stratton, Michael Milford, & Tobias Fischer. "VPRTempo: A Fast Temporally Encoded Spiking Neural Network for Visual Place Recognition. arXiv September 2023](https://arxiv.org/abs/2309.10225)

### Introduction

In this tutorial, we are going to take the base VPRTempo model to train and inference a network with PyTorch's Quantized Aware Training ([QAT](https://pytorch.org/docs/stable/quantization.html)). Functionally, this tutorial is similar to the previous one but will be simplified. For a more detailed dive into how VPRTempo works, please see [Tutorial 1](https://github.com/AdamDHines/VPRTempo-quant/blob/main/tutorials/1_Introduction.ipynb)

**Note: it does not appear that Apple Silicon is currently a supported backend for QAT**

To get started, please ensure you have installed and currently have activated the `conda` environment for VPRTempo.

In [None]:
!conda activate vprtempo

## 1. Get the Nordland dataset

### 1.1 Download the dataset

Please [download the Nordland datasets](https://webdiis.unizar.es/~jmfacil/pr-nordland/#download-dataset) (Summer, Spring, Fall, & Winter). There are two datasets available, the full size and downsampled versions. Either will work fine but our paper details the full size dataset. If disk space is a concern, please use the downsampled version.

Save the data in the `./VPRTempo-quant/dataset/` subfolder.

### 1.2 Import modules

Once we have downloaded the dataset, we'll start by importing all the necessary modules.

For this tutorial, we use [Jupyter Dynamic Classes](https://alexhagen.github.io/jdc/) so if not already installed please install. 

In [None]:
!pip install jdc

In [None]:
import jdc
import os
import torch
import gc
import sys
sys.path.append('../')
sys.path.append('../src')
sys.path.append('../models')
sys.path.append('../output')
sys.path.append('../dataset')

import blitnet as bn
import numpy as np
import torch.nn as nn
import torch.quantization as quantization

from settings import configure, image_csv, model_logger
from dataset import CustomImageDataset, ProcessImage
from torch.utils.data import DataLoader
from torch.ao.quantization import QuantStub, DeQuantStub
from tqdm import tqdm

### 1.3 Prepare the dataset for the model (optional)

The datset seasons are downloaded in .zip format and need to be extracted into a single folder. The `nordland` function has been provided to automatically do this for you and to re-name the images to match those in the nordland.csv file.

If you have already done this from the previous tutorial, you can skip this step.

In [None]:
from os import walk
from nordland import nord_sort

# unzip, re-organise, and re-name the Nordland datasets
nord_sort()

## 2. Set up the network

### 2.1 Define and initialize the VPRTempo model class

We'll now import the main network model class `VPRTempo`. Please see [Tutorial 1](https://github.com/AdamDHines/VPRTempo-quant/blob/main/tutorials/1_Introduction.ipynb) for a more detailed look at what this includes.

In [None]:
from VPRTempo import VPRTempo
model = VPRTempo()

### 2.2 Generate unique model name

We will set up a unique model name to save and load for inferencing.

In [None]:
def generate_model_name(model):
    """
    Generate the model name based on its parameters.
    """
    return ("VPRTempo" +
            str(model.input) +
            str(model.feature) +
            str(model.output) +
            str(model.number_modules) +
            "Quantized"+
            '.pth')

model_name = generate_model_name(model)

print(model_name)

## 3. Define the DataLoader

### 3.1 Set the DataLoader

Now that we've defined the model, we will set up the DataLoaders. These utilise a PyTorch CustomImageDataset and ProcessImage to import images and process them for training or inference. In brief, images are loaded, gamma corrected, resized, and then patch-normalized before being converted into system spikes to be propagated throughout.

Since we present the network with one image at a time, the `batch_size` is kept to 1.

In [None]:
from dataset import CustomImageDataset, ProcessImage
from torch.utils.data import DataLoader

image_transform = ProcessImage(model.dims, model.patches)
train_dataset = CustomImageDataset(annotations_file=model.dataset_file, 
                                       img_dirs=model.training_dirs,
                                       transform=image_transform,
                                       skip=model.filter,
                                       max_samples=model.number_training_images,
                                       test=False)
# Initialize the data loader
train_loader = DataLoader(train_dataset, 
                          batch_size=1, 
                          shuffle=False,
                          num_workers=8,
                          persistent_workers=True)

## 4. Quantization

### 4.1 Model quantization

VPRTempoQuant makes use of Quantized Aware Training QAT and has a few simple steps to prepare the model to accomodate this. First, we will get the default quantization configuration for `fggbem`.

In [None]:
import torch.quantization as quantization

# Set the quantization configuration
qconfig = quantization.get_default_qat_qconfig('fbgemm')

Next, we will set the model to be configured for network training and add our quantization configuration.

In [None]:
# Set the model to training mode and move to device
model.train()
model.to('cpu')
model.qconfig = qconfig

Now we will convert the model over to QAT.

In [None]:
# Apply quantization configurations to the model
model = quantization.prepare_qat(model, inplace=False)

At this point, we are ready to start training our network!

## 5. Set up and run the training 

### 5.1 Define and run the training regime

The training will loop through each defined layer until every single one has trained. In order to propagate spikes throughout the system, trained layers are appended to a list so that they can be re-fed back into the network to calculate spikes based on learned weights.

Run the below cell to train our `feature_layer` and `output_layer`!

In [None]:
# Keep track of trained layers to pass data through them
trained_layers = [] 

# Training each layer
for layer_name, _ in sorted(model.layer_dict.items(), key=lambda item: item[1]):
    print(f"Training layer: {layer_name}")
    # Retrieve the layer object
    layer = getattr(model, layer_name)
    # Train the layer
    model.train_model(train_loader, layer, prev_layers=trained_layers)
    # After training the current layer, add it to the list of trained layers
    trained_layers.append(layer_name)
    
print('All layers trained succesfully')

### 5.2 Convert and save the model

Now that the training has been completed, we can convert the QAT model over to be fully quantized. As the layers were trained, scale and zero-point factors will learned for all the elements of the model and can now be applied to the layers. Once converted, we will save the model for use in inferencing.

In [None]:
# Convert the model to a quantized model
model = quantization.convert(model, inplace=False)
model.eval()
# Save the model
model.save_model(os.path.join('../models', model_name))  

### 6.3 Re-initialize the model class, convert to quantization, and load the model

Now we will re-initialize the VPRTempo class model, set to eval mode, and convert it over to quantized so that we can import our newly trained model.

In [None]:
# Set the model to evaluation mode and set configuration
model = VPRTempo()
model.model_logger()
model.eval()
model.qconfig = qconfig

# Apply quantization configurations to all layers in layer_dict
for layer_name, _ in model.layer_dict.items():
    getattr(model, layer_name).qconfig = qconfig
# Prepare and convert the model to a quantized model
model = quantization.prepare(model, inplace=False)
model = quantization.convert(model, inplace=False)
# Load the model
model.load_model(os.path.join('../models', model_name))

# Retrieve layer names for inference
layer_names = list(model.layer_dict.keys())

### 6.2 Define the inferencing DataLoader

The only difference between the training and testing DataLoader is the directory with which it will import images from.

In [None]:
# Initialize the image transforms and datasets
image_transform = ProcessImage(model.dims, model.patches)
test_dataset = CustomImageDataset(annotations_file=model.dataset_file, 
                                  img_dirs=model.testing_dirs,
                                  transform=image_transform,
                                  skip=model.filter,
                                  max_samples=model.number_testing_images)
# Initialize the data loader
test_loader = DataLoader(test_dataset, 
                         batch_size=1, 
                         shuffle=False,
                         num_workers=8,
                         persistent_workers=True)

### 6.4 Run the model inference

Now we are ready to inference the model!

In [None]:
# Use evaluate method for inference accuracy
model.evaluate(model, test_loader, layers=layer_names)

## 7. Conslusions



This tutorial covered how we can convert the VPRTempo model to perform Quantized Aware Training (QAT) to keep the model size more lightweight. You might notice that if you compare the system between FP32 to Int8, the model works equally as well with a reduced bit-depth with the added benefit of a reduced model size.

To read more about QAT and quantization in general, PyTorch provides many useful articles;
https://pytorch.org/docs/stable/quantization.html
https://pytorch.org/blog/quantization-in-practice/

The key benefit to this is being able to perform fast training and inferencing on CPU architecture, which for resource limited compute scenarios is critical.