# 🌟 Understanding Quantization Awareness Training

Quantization Awareness Training (QAT) is an advanced technique aimed at improving the robustness of deep learning models during the quantization process. Unlike traditional post-training quantization, which involves converting a pre-trained model to lower precision after training, QAT integrates quantization into the training process itself. 

By simulating the effects of quantization during training, QAT allows the model to learn and adapt to the challenges posed by reduced precision, resulting in better performance and less accuracy degradation when deployed in quantized formats. 

## 🔄 Workflow of Quantization Awareness Training:

**Pre-trained Model** → **Quantization Simulation** → **Training with Quantization** → **Quantized Model**

---

## 📖 Detailed Explanation:

1. **Pre-trained Model:** 
   - This is the starting point, typically a model that has been trained using standard floating-point precision. The goal is to prepare this model for quantization through fine-tuning.

2. **Quantization Simulation:** 
   - During this phase, the training process incorporates quantization effects by simulating lower precision (e.g., int8) operations. This is achieved by introducing quantization layers that mimic the behavior of quantized weights and activations.

3. **Training with Quantization:** 
   - The model is trained with quantization-aware layers, enabling it to learn how to minimize the loss and maintain accuracy despite the constraints of reduced precision. This step is crucial for adapting the model to the quantization effects and ensures it remains effective when converted.

4. **Quantized Model:** 
   - At the end of the training process, the model is transformed into a quantized version, optimized to perform well under quantization. The result is a model that can efficiently operate in low-resource environments while maintaining a higher accuracy compared to models trained without awareness of quantization.

---

By implementing **Quantization Awareness Training**, practitioners can achieve more reliable quantized models, effectively balancing the trade-offs between efficiency and accuracy. 🎯


In [1]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

### 📊 Loading the Dataset

In this section, we prepare the dataset that will be critical for training our neural network. We utilize the **MNIST dataset**, a well-known collection of handwritten digits that serves as a benchmark for image classification tasks.

---

🔄 **Reproducibility**: We ensure consistent results across runs by setting a manual seed, which helps in achieving reproducible experiments.

🖼️ **Data Transformations**: The dataset undergoes essential transformations:
- **Tensor Conversion**: We convert images into tensor format to enable compatibility with PyTorch models.
- **Normalization**: Normalizing the dataset is vital for enhancing training efficiency and improving overall model performance.

🚀 **Data Loaders**: We create data loaders for both the training and testing datasets, which facilitate efficient mini-batch processing:
- **Mini-batch Processing**: This technique allows our model to learn from smaller subsets of data, speeding up training and enhancing generalization.

💻 **GPU Utilization**: Lastly, we configure our environment to leverage GPU resources, optimizing computational efficiency when available.

---

This foundational step prepares us for the subsequent stages of model training and quantization, enabling us to effectively implement Quantization Aware Training on our model.


In [2]:
_ = torch.manual_seed(433)

In [3]:
transform = transforms.Compose([
    transforms.ToTensor(), # converting to tensors
    transforms.Normalize((0.1307,), (0.3081,)) # performing normalization on the data which is optimal in ML or DL
])

# we would be using the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)

# creating batch norm
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# trying to leverage my baby GPU hahahaha ;)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


### 🧠 Simple Neural Network

In this section, we define a simple neural network that incorporates **Quantization Aware Training (QAT)**. This approach allows us to account for quantization effects during the training process, thereby improving the model's performance in a quantized state.

---

🔍 **Network Architecture**: The network comprises several layers:
- **Input Layer**: Flattens the input images (28x28) into a 784-dimensional vector.
- **Hidden Layers**: 
  - **First Hidden Layer**: 50 neurons
  - **Second Hidden Layer**: 80 neurons
  - **Third Hidden Layer**: 30 neurons
- **Output Layer**: 10 neurons corresponding to the digit classes (0-9).

📦 **Quantization Stubs**: 
- **QuantStub**: This layer simulates the quantization of the input data.
- **DeQuantStub**: This layer reverses the quantization after processing through the network.

💡 **Forward Pass**: The forward method includes:
1. **Flattening**: Reshaping the input image to a vector.
2. **Quantization**: Applying the `QuantStub` to the input.
3. **Activation Functions**: Using ReLU activation after each hidden layer to introduce non-linearity.
4. **Final Output**: Applying the `DeQuantStub` before returning the output.

---

### Key Concept:
Unlike Post-Training Quantization (PTQ), where weights are copied and quantized after training, **QAT** involves adding fake quantization layers to the model during training. This allows the model to learn how to compensate for the effects of quantization, resulting in better accuracy once deployed.

This structure sets the stage for effectively implementing QAT and ensuring that our model retains performance even in its quantized form.


In [4]:
class NeuralNetwork(nn.Module):
    def __init__(self, hidden_layer_1 = 50,hidden_layer_2 = 80, hidden_layer_3 = 30):
        super(NeuralNetwork,self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_layer_1)
        self.linear2 = nn.Linear(hidden_layer_1, hidden_layer_2)
        self.linear3 = nn.Linear(hidden_layer_2, hidden_layer_3)
        self.linear4 = nn.Linear(hidden_layer_3, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()
        
    def forward(self,img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.relu(self.linear3(x))
        x = self.linear4(x)
        x = self.dequant(x)
        return x

model = NeuralNetwork().to(device)

On like the case of PTQ, after training the weights are copied and then quantized based of the trained weight, in QAT we add fake quantized layers and then train. 

In [5]:
model.qconfig = torch.ao.quantization.default_qconfig
model.train() # we aren't doing inferencing
model_qat = torch.ao.quantization.prepare_qat(model) # Insert observers
model_qat

NeuralNetwork(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=50, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=50, out_features=80, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=80, out_features=30, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear4): Linear(
    in_features=30, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuant

### ⚙️ Model Training

In this section, we focus on training our neural network using the **MNIST dataset**. The training process is essential for the model to learn and improve its performance on the classification task.

---

🛠️ **Components of Training**:
1. **Optimizer**: We utilize the **Adam** optimizer, known for its efficiency in adjusting learning rates during training.
2. **Loss Function**: The **Cross-Entropy Loss** is employed, suitable for multi-class classification problems like digit recognition.

---

🕒 **Training Loop**:
- **Epochs**: The model is trained for a specified number of epochs, allowing it to learn from the data iteratively.
- **Total Iterations**: We keep track of the total iterations across epochs to manage training duration and avoid overfitting.

### Training Steps:
1. **Model in Training Mode**: The model is set to training mode to enable features like dropout and batch normalization.
2. **Loss Calculation**: For each batch of data:
   - Inputs (images) and targets (labels) are moved to the appropriate device (CPU/GPU).
   - The model outputs predictions for the input data.
   - The loss is calculated using the defined loss function.
3. **Backpropagation**: The optimizer updates the model parameters based on the computed gradients to minimize the loss.
4. **Progress Tracking**: The average loss is updated and displayed, providing insight into the model's learning process.

---

💡 **Iteration Limit**: An optional limit on total iterations can be set, enabling early stopping if desired, thus preventing unnecessary training and resource usage.

This structured training approach is crucial for optimizing our neural network and preparing it for effective **Quantization Aware Training (QAT)**.


In [6]:
def train(train_loader, model, epochs = None, total_iterations_limit = None):
    # optimizer and loss function
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
    loss_function = nn.CrossEntropyLoss() # since this is a classification problem.

    total_iterations = 0  # Keep track of how many total iterations we've done

    for epoch in range(epochs):
        model.train()

        loss_sum = 0  # Sum of all the losses to calculate the average loss
        num_iterations = 0  # Keep track of the iterations in this epoch
        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')

        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data # 'data' is a batch (x, y), where x is the input (image), and y is the label (digit)
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = model(x.view(-1, 28*28))
            loss = loss_function(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            # If a total iteration limit is set, stop training once the limit is reached
            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return

### 📏 Function to Print the Size of the Model

Understanding the size of our model is crucial for evaluating its deployment feasibility, especially in resource-constrained environments. This section introduces a function that measures the size of the model after training.

---

📝 **Function Overview**:
- **Model Size Calculation**: The function saves the model’s state dictionary temporarily to compute its file size. This provides a direct indication of the model's memory footprint in kilobytes (KB).

#### Key Steps:
1. **Saving Model State**: The model's state dictionary is saved to a temporary file.
2. **Size Retrieval**: The file size is retrieved and printed in kilobytes for clarity.
3. **Cleanup**: The temporary file is deleted to maintain a clean workspace.

---

📂 **Model Loading**:
- If a pre-trained model exists, it is loaded from disk, ensuring we can resume from previous training without starting from scratch.
- If not found, the model undergoes training for a specified number of epochs, followed by saving the newly trained model.

#### Example Output:
During training, the model's performance can be tracked, showcasing the loss over iterations. For instance, you may observe a progress bar indicating the completion of an epoch, reflecting the loss value.

---

This process not only facilitates model evaluation but also prepares us for the next steps in **Quantization Aware Training (QAT)**, ensuring we can optimize the model for efficient deployment.


In [12]:
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

MODEL_FILENAME = 'simpleNN_qat.pt'

if Path(MODEL_FILENAME).exists():
    model.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(train_loader, model_qat, epochs=1)
    # Save the model to disk
    torch.save(model.state_dict(), MODEL_FILENAME)

Epoch 1: 100%|████████████████████████████████████████████████████████| 6000/6000 [00:38<00:00, 157.61it/s, loss=0.263]


### 🧪 Time to Test Our Neural Network

After training our model, it's essential to evaluate its performance using a testing phase. This section outlines the process of testing our neural network and calculating its accuracy.

---

🔍 **Testing Overview**:
- **Model Evaluation**: We switch the model to evaluation mode, which is crucial as it disables layers like dropout that are only used during training.
- **No Gradient Calculation**: Using `torch.no_grad()` ensures that we don't compute gradients, thus saving memory and computation time during testing.

#### Testing Steps:
1. **Initialization**: Set counters for correct predictions and total predictions.
2. **Data Iteration**: Iterate over the test dataset using a progress bar to visualize progress.
3. **Model Prediction**:
   - Input test images into the model.
   - Calculate the predicted outputs.
4. **Accuracy Calculation**:
   - Compare predicted labels with actual labels to count correct predictions.
   - Update total predictions accordingly.

---

✅ **Final Accuracy**: At the end of the testing phase, the accuracy is printed, providing insight into how well our model has generalized from the training data.

#### Example Output:
Upon completion of the testing process, you will see the accuracy printed in a user-friendly format, indicating the model's performance.

---

This testing phase is crucial for understanding the model's effectiveness and readiness for deployment, paving the way for further optimizations like quantization.


In [19]:
def test(model, total_iterations):
    correct,total, iterations = 0,0,0

    model.eval()
    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 8)}')

### 📊 Awareness Statistics

In this section, we delve into the statistics of the various layers within our neural network model. Understanding these statistics is crucial for monitoring the performance and effectiveness of our quantization-aware training (QAT).

---


In [14]:
print(f'Check statistics of the various layers')
model_qat

Check statistics of the various layers


NeuralNetwork(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=50, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.5536828637123108, max_val=0.34701329469680786)
    (activation_post_process): MinMaxObserver(min_val=-43.13856506347656, max_val=33.60527420043945)
  )
  (linear2): Linear(
    in_features=50, out_features=80, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.43600958585739136, max_val=0.4428870975971222)
    (activation_post_process): MinMaxObserver(min_val=-31.863323211669922, max_val=27.115581512451172)
  )
  (linear3): Linear(
    in_features=80, out_features=30, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=-0.3675735592842102, max_val=0.3865065574645996)
    (activation_post_process): MinMaxObserver(min_val=-20.646923065185547, max_val=30.78984832763672)
  )
  (linear4): Linear(
    in_features=30, 

---

🔍 **Layer Statistics Overview**:
- Each layer's statistics provide insights into the distribution of activations and weights, essential for evaluating the impact of quantization.
- We use **MinMaxObserver** to track the minimum and maximum values of activations and weights across layers.

#### Layer Breakdown:
- **Quantization Layer**:
  - **QuantStub**: Prepares inputs for quantization and captures activation statistics.
  
- **Linear Layers**:
  - Each linear layer (`linear1`, `linear2`, `linear3`, `linear4`) reports:
    - **Input Features**: The number of features each layer receives.
    - **Output Features**: The number of features each layer produces.
    - **Weight Fake Quantization**: Minimum and maximum observed values to understand the scale of weights.
    - **Activation Post-Process**: Minimum and maximum values for the activations after applying the activation function.

- **Activation Function**:
  - **ReLU**: Applies the rectified linear unit activation, which is widely used for hidden layers.

---

✅ **Model Structure**: The printed model summary provides a comprehensive overview of the layer statistics, helping identify any potential issues that could arise from quantization.

#### Example Output:
You can observe the statistical summary of each layer printed below, detailing the behavior of the model's components post-training.

---

Understanding these statistics aids in fine-tuning our model further, ensuring that it maintains performance after quantization.


### 🔍 Quantization of the Model

With our quantization-aware training (QAT) complete, it's time to convert the model to a quantized version using the statistics gathered during training. This process optimizes the model for efficiency without significantly sacrificing accuracy.

#### **Quantization Process:**
- The model is evaluated and converted, which adjusts the weights and activations based on the quantization parameters derived during training.

---


In [15]:
model_qat.eval()
model_qat = torch.ao.quantization.convert(model_qat)

print(f'Check statistics of the various layers')
model_qat

Check statistics of the various layers


NeuralNetwork(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=50, scale=0.6042821407318115, zero_point=71, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=50, out_features=80, scale=0.4644008278846741, zero_point=69, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=80, out_features=30, scale=0.40501394867897034, zero_point=51, qscheme=torch.per_tensor_affine)
  (linear4): QuantizedLinear(in_features=30, out_features=10, scale=0.32742488384246826, zero_point=72, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

In [16]:
# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(torch.int_repr(model_qat.linear1.weight()))

Weights before quantization
tensor([[ -3,  -4,  -9,  ...,  -8,   0,  -8],
        [  1,  10,   5,  ...,  -1,   6,  -3],
        [  5,  -7,  -5,  ..., -10,  -7,  -2],
        ...,
        [  1,   7,   5,  ...,   4,   2,   6],
        [ -4,   7,   0,  ...,   4,   3,   5],
        [ -5,   7,   2,  ...,   4,   3,   6]], dtype=torch.int8)


In [17]:
print('Testing the model after quantization')
test(model_qat,None)

Testing the model after quantization


Testing: 100%|████████████████████████████████████████████████████████████████████| 1000/1000 [00:02<00:00, 392.91it/s]

Accuracy: 0.9502





In [18]:
print('Size of the model after quantization')
print_size_of_model(model_qat)

Size of the model after quantization
Size (KB): 52.77




---

**Layer Statistics Post-Quantization:**
After conversion, we check the statistics of the various layers in our quantized model:

- **Quantization Parameters**:
  - Each layer now has a scale and zero point, allowing for efficient computation with reduced precision.

#### **Model Summary**:
The model structure includes:
- Quantized layers with `QuantizedLinear`, displaying input and output features along with their respective scales and zero points.
- The activation function remains the same with ReLU.
- DeQuantize layer for converting the quantized outputs back to floating-point format when needed.

---

#### **Model Weights Before Quantization:**
To understand the transformation, we also print the weights of the first linear layer before quantization, showcasing their original representation.

---

### ✅ **Model Performance After Quantization**:
After quantization, we tested the model to evaluate its accuracy on the test dataset. Remarkably, the model achieved an accuracy of **95.02%**!

#### **Model Size After Quantization**:
Finally, we assess the size of the quantized model, which significantly reduces from its original size to approximately **52.77 KB**. This reduction in model size is beneficial for deployment in resource-constrained environments.

---

### 🎉 Conclusion
The quantization process successfully optimized the model, maintaining high accuracy while reducing its size, demonstrating the efficacy of QAT in deep learning model deployment.
