<a href="https://colab.research.google.com/github/JuanZapa7a/semiotics/blob/main/mnist_with_larq.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Quantization Aware Training (QAT) using Larq for 4-bit quantization with the MNIST dataset

[Larq](https://larq.dev/) is a library designed to build and train binarised neural networks (BNNs) using TensorFlow and Keras. If you are interested in performing hardware-aware training (considering noise, quantization, etc.) for deep models using Larq, you can achieve this by taking advantage of Larq's specific functionalities for binarisation and compact model training.

Here is an outline of what we will cover:
 1. Installation of Larq and necessary dependencies.
 2. Data preparation (MNIST).
 3. Creation of a base model (without quantization).
 4. Training and evaluation of the base model.
 5. Creation of a quantified model with QAT.
 6. Training and evaluation of the quantized model.
 7. Performance and model size comparison.

Este NoteBook utiliza [Larq](https://larq.dev/) and the [Keras Sequential API](https://www.tensorflow.org/guide/keras).

The API of Larq is built on top of `tf.keras` and is designed to provide an easy to use, composable way to design and train BNNs (1 bit) and other types of Quantized Neural Networks (QNNs).

It provides tools specifically designed to aid in BNN development, such as specialized optimizers, training metrics, and profiling tools.

Note that efficient inference using a trained BNN requires the use of an optimized inference engine; we provide these for several platforms in [Larq Compute Engine] (https://docs.larq.dev/compute-engine).

To create a **Quantized Neural Network (QNN)**, Larq introduces two main components: **[quantized layers](https://docs.larq.dev/larq/api/layers/)** and **[quantizers](https://docs.larq.dev/larq/api/quantizers/)**.

1. **Quantizers**: A quantizer defines two critical aspects:
   - **Transformation of full-precision input to quantized output**: This involves converting high-precision values (usually 32-bit floating-point) to a lower-precision format (e.g., binary or integer). This reduces memory usage and computational load, which is helpful for efficiency.
   - **Pseudo-gradient method for backpropagation**: Since quantization can create non-differentiable points, Larq uses an approximate or "pseudo" gradient method for the backward pass during training. This allows the model to still update weights even if the quantized values don't support traditional gradient computation.

2. **Quantized Layers**: These layers use quantizers to handle activations and weights with reduced precision. Each quantized layer requires:
   - **input_quantizer**: Defines how to quantize the incoming activations for the layer. This allows the model to operate on low-precision activations instead of full-precision ones.
   - **kernel_quantizer**: Defines how to quantize the layer’s weights (often referred to as kernels in neural network layers).

If both `input_quantizer` and `kernel_quantizer` are set to `None`, then the layer behaves as a regular, full-precision layer, similar to standard TensorFlow/Keras layers.

3. **Integration with Models**: These quantized layers can be added to a Keras model just like other layers. Alternatively, you can use them with a custom training loop if you need more control over the training process.

Larq's QNN approach leverages quantizers to efficiently reduce precision while maintaining trainability through pseudo-gradients, which can then be integrated seamlessly into standard Keras workflows.

## 1. Installation of Larq and necessary dependencies.

In [None]:
!pip install tensorflow==2.10.0
!pip install larq==0.13.1

import tensorflow as tf
import larq as lq

### Download and prepare the MNIST dataset

In [None]:
(train_images, train_labels), (test_images, test_labels) = tf.keras.datasets.mnist.load_data()

train_images = train_images.reshape((60000, 28, 28, 1))
test_images = test_images.reshape((10000, 28, 28, 1))

# Normalize pixel values to be between -1 and 1
train_images, test_images = train_images / 127.5 - 1, test_images / 127.5 - 1

### Create the model

The following will create a simple binarized CNN.

The quantization function
$$
q(x) = \begin{cases}
    -1 & x < 0 \\\
    1 & x \geq 0
\end{cases}
$$
is used in the forward pass to binarize the activations and the latent full precision weights. The gradient of this function is zero almost everywhere which prevents the model from learning.

To be able to train the model the gradient is instead estimated using the Straight-Through Estimator (STE)
(the binarization is essentially replaced by a clipped identity on the backward pass):
$$
\frac{\partial q(x)}{\partial x} = \begin{cases}
    1 & \left|x\right| \leq 1 \\\
    0 & \left|x\right| > 1
\end{cases}
$$

In Larq this can be done by using `input_quantizer="ste_sign"` and `kernel_quantizer="ste_sign"`.
Additionally, the latent full precision weights are clipped to -1 and 1 using `kernel_constraint="weight_clip"`.

In [None]:
# All quantized layers except the first will use the same options
kwargs = dict(input_quantizer="ste_sign",
              kernel_quantizer="ste_sign",
              kernel_constraint="weight_clip")

model = tf.keras.models.Sequential()

# In the first layer we only quantize the weights and not the input
model.add(lq.layers.QuantConv2D(32, (3, 3),
                                kernel_quantizer="ste_sign",
                                kernel_constraint="weight_clip",
                                use_bias=False,
                                input_shape=(28, 28, 1)))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.BatchNormalization(scale=False))

model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs))
model.add(tf.keras.layers.MaxPooling2D((2, 2)))
model.add(tf.keras.layers.BatchNormalization(scale=False))

model.add(lq.layers.QuantConv2D(64, (3, 3), use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(tf.keras.layers.Flatten())

model.add(lq.layers.QuantDense(64, use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(lq.layers.QuantDense(10, use_bias=False, **kwargs))
model.add(tf.keras.layers.BatchNormalization(scale=False))
model.add(tf.keras.layers.Activation("softmax"))

Almost all parameters in the network are binarized, so either -1 or 1. This makes the network extremely fast if it would be deployed on custom BNN hardware.

 Here is the complete architecture of our model:

In [None]:
lq.models.summary(model)

### Compile and train the model

Note: This may take a few minutes depending on your system.

In [None]:
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])

model.fit(train_images, train_labels, batch_size=64, epochs=6)

test_loss, test_acc = model.evaluate(test_images, test_labels)

### Evaluate the model

In [None]:
print(f"Test accuracy {test_acc * 100:.2f} %")

As you can see, our simple binarized CNN has achieved a test accuracy of around 98 %. Not bad for a few lines of code!

For information on converting Larq models to an optimized format and using or benchmarking them on Android or ARM devices, have a look at [this guide](https://docs.larq.dev/compute-engine/end_to_end/).