# Quantization
This notebook acts as an example of how to use the quantization techniques.

## Setup
* Import the necessary packages.
* Load a model.
* Load a dataset.
* Analyze performance of the model prior to quantization.

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
import torch
import importlib
import inspect
import sys
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn.functional as F

# Add thesis package to path
sys.path.append("../")
sys.path.append("../src")

import src.general as general
import src.metrics as metrics
import src.evaluation as eval
import src.compression.quantization as quant
import src.dataset_models as data
import src.plot as plot


In [None]:
# Get device
device = general.get_device()

# Load the dataset
dataset = data.supported_datasets["MNIST"]

In [None]:
model_state = "../models/mnist.pt"
model_class = "models.mnist"

# Load the model
model = torch.load(model_state, map_location=torch.device(device))

### Pre-Quantization Evaluation

In [None]:
# Evaluate model performance before quantization
original_results = eval.get_results(model, dataset)
plot.print_results(**original_results)

## Static Quantization
Post Training Static Quantization (PTQ) also pre-quantizes model weights but instead of calibrating activations on-the-fly, the clipping range is pre-calibrated and fixed (“static”) using validation data.

In [None]:
static_quantized_model = quant.static_quantization(model, dataset)

In [None]:
import torch.nn as nn

model = nn.Sequential(
    nn.Conv2d(3, 32, 3, padding=1),
    nn.BatchNorm2d(32),
    nn.ReLU(),
    nn.Conv2d(32, 64, 3, padding=1),
    nn.Linear(64, 128),
    nn.ReLU()
)

modules_to_fuse = quant.get_modules_to_fuse(model)

print(modules_to_fuse)

In [None]:
modules_to_fuse = quant.get_modules_to_fuse(model)

print(modules_to_fuse)

In [None]:
general.test(static_quantized_model, dataset)

In [None]:
# Evaluate model performance after static quantization
static_quantized_results = eval.get_results(static_quantized_model, dataset)
plot.print_before_after_results(original_results, static_quantized_results)

## Dynamic Quantization
Here the model’s weights are pre-quantized; the activations are quantized on-the-fly (“dynamic”) during inference. 

Currently only Linear and Recurrent (LSTM, GRU, RNN) layers are supported for dynamic quantization.

In [None]:
dynamic_quantized_model = quant.dynamic_quantization(model)

In [None]:
dynamic_quantized_results = eval.get_results(dynamic_quantized_model, dataset)
plot.print_before_after_results(original_results, dynamic_quantized_results)