# QAT for MLP
### Quantization aware training

In [17]:

import torch
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from src.models import MLP
from torchvision.utils import make_grid
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline

In [18]:
class QuantizedMLP(nn.Module):
    def __init__(self, model_fp32):
        super(QuantizedMLP, self).__init__()
        # QuantStub converts tensors from floating point to quantized.
        # This will only be used for inputs.
        self.quant = torch.quantization.QuantStub()
        # DeQuantStub converts tensors from quantized to floating point.
        # This will only be used for outputs.
        self.dequant = torch.quantization.DeQuantStub()
        # FP32 model
        self.model_fp32 = model_fp32

    def forward(self, x):
        # manually specify where tensors will be converted from floating
        # point to quantized in the quantized model
        x = self.quant(x)
        x = self.model_fp32(x)
        # manually specify where tensors will be converted from quantized
        # to floating point in the quantized model
        x = self.dequant(x)
        return x
    
def evaluate_model(model, test_loader, device, criterion=None):

    model.eval()
    model.to(device)

    running_loss = 0
    running_corrects = 0

    for inputs, labels in test_loader:

        inputs = inputs.to(device)
        labels = labels.to(device)

        outputs = model(inputs)
        _, preds = torch.max(outputs, 1)

        if criterion is not None:
            loss = criterion(outputs, labels).item()
        else:
            loss = 0

        # statistics
        running_loss += loss * inputs.size(0)
        running_corrects += torch.sum(preds == labels.data)

    eval_loss = running_loss / len(test_loader.dataset)
    eval_accuracy = running_corrects / len(test_loader.dataset)

    return eval_loss, eval_accuracy

def train_model(model, train_loader, test_loader, device):

    # The training configurations were not carefully selected.
    learning_rate = 1e-2
    num_epochs = 20

    criterion = nn.CrossEntropyLoss()

    model.to(device)

    # It seems that SGD optimizer is better than Adam optimizer for ResNet18 training on CIFAR10.
    optimizer = optim.SGD(model.parameters(), lr=learning_rate, momentum=0.9, weight_decay=1e-5)
    # optimizer = optim.Adam(model.parameters(), lr=learning_rate, betas=(0.9, 0.999), eps=1e-08, weight_decay=0, amsgrad=False)

    for epoch in range(num_epochs):

        # Training
        model.train()

        running_loss = 0
        running_corrects = 0

        for inputs, labels in train_loader:

            inputs = inputs.to(device)
            labels = labels.to(device)

            # zero the parameter gradients
            optimizer.zero_grad()

            # forward + backward + optimize
            outputs = model(inputs)
            _, preds = torch.max(outputs, 1)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()

            # statistics
            running_loss += loss.item() * inputs.size(0)
            running_corrects += torch.sum(preds == labels.data)

        train_loss = running_loss / len(train_loader.dataset)
        train_accuracy = running_corrects / len(train_loader.dataset)

        # Evaluation
        model.eval()
        eval_loss, eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=device, criterion=criterion)

        print("Epoch: {:02d} Train Loss: {:.3f} Train Acc: {:.3f} Eval Loss: {:.3f} Eval Acc: {:.3f}".format(epoch, train_loss, train_accuracy, eval_loss, eval_accuracy))

    return model

In [19]:
# define transformation to apply to each image in the dataset
transform = transforms.Compose([
    transforms.ToTensor(), # convert the image to a PyTorch tensor
    transforms.Normalize((0.5,), (0.5,)) # normalize the image with mean=0.5 and std=0.5
])

# load the MNIST training and testing datasets
train_dataset = datasets.MNIST(root='data/', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='data/', train=False, transform=transform, download=True)

# create data loaders to load the datasets in batches during training and testing
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

In [20]:
parameters = [ 28 * 28, # input
                512, 256, 128, 64,
                10 ] #output
# instantiate the model

model = MLP(parameters)
# define the loss function and optimizer
criterion = nn.CrossEntropyLoss() # computes the cross-entropy loss between the predicted and true labels
#criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001) # performs stochastic gradient descent with adaptive learning rate

# set the device to run the model on
#device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
cpu_device = torch.device("cpu:0")
model.to(cpu_device)
# create a model instance
model_fp32 = QuantizedMLP(model)

# model must be set to eval for fusion to work
model_fp32.eval()



QuantizedMLP(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (model_fp32): MLP(
    (linears): ModuleList(
      (0): Linear(in_features=784, out_features=512, bias=True)
      (1): Linear(in_features=512, out_features=256, bias=True)
      (2): Linear(in_features=256, out_features=128, bias=True)
      (3): Linear(in_features=128, out_features=64, bias=True)
      (4): Linear(in_features=64, out_features=10, bias=True)
    )
    (relu): ReLU()
    (soft): Softmax(dim=None)
  )
)

In [21]:
for module_name, module in model_fp32.named_children():
        print(module_name)
        if "model" in module_name:
            for basic_block_name, basic_block in module.named_children():

                print(f"\t{basic_block_name}")
                if "linears" in basic_block_name:
                    for block_name, block in basic_block.named_children():
                        print(f"\t\t{block_name}")

quant
dequant
model_fp32
	linears
		0
		1
		2
		3
		4
	relu
	soft


In [22]:
# attach a global qconfig, which contains information about what kind
# of observers to attach. Use 'x86' for server inference and 'qnnpack'
# for mobile inference. Other quantization configurations such as selecting
# symmetric or asymmetric quantization and MinMax or L2Norm calibration techniques
# can be specified here.
# Note: the old 'fbgemm' is still available but 'x86' is the recommended default
# for server inference.
# model_fp32.qconfig = torch.ao.quantization.get_default_qconfig('fbgemm')
model_fp32.qconfig = torch.ao.quantization.get_default_qat_qconfig('x86')

# fuse the activations to preceding layers, where applicable
# this needs to be done manually depending on the model architecture
#model_fp32_fused = torch.ao.quantization.fuse_modules(model_fp32,
  #  [['conv', 'bn', 'relu']])

# Prepare the model for QAT. This inserts observers and fake_quants in
# the model needs to be set to train for QAT logic to work
# the model that will observe weight and activation tensors during calibration.
model_fp32_prepared = torch.ao.quantization.prepare_qat(model_fp32.train())

# run the training loop (not shown)
train_model(model_fp32_prepared, train_loader, test_loader, cpu_device)





Epoch: 00 Train Loss: 1.001 Train Acc: 0.797 Eval Loss: 0.276 Eval Acc: 0.909
Epoch: 01 Train Loss: 0.199 Train Acc: 0.938 Eval Loss: 0.234 Eval Acc: 0.923
Epoch: 02 Train Loss: 0.137 Train Acc: 0.957 Eval Loss: 0.121 Eval Acc: 0.963
Epoch: 03 Train Loss: 0.104 Train Acc: 0.967 Eval Loss: 0.098 Eval Acc: 0.969
Epoch: 04 Train Loss: 0.086 Train Acc: 0.973 Eval Loss: 0.090 Eval Acc: 0.970
Epoch: 05 Train Loss: 0.070 Train Acc: 0.978 Eval Loss: 0.091 Eval Acc: 0.972
Epoch: 06 Train Loss: 0.060 Train Acc: 0.981 Eval Loss: 0.080 Eval Acc: 0.975
Epoch: 07 Train Loss: 0.054 Train Acc: 0.982 Eval Loss: 0.077 Eval Acc: 0.977
Epoch: 08 Train Loss: 0.046 Train Acc: 0.985 Eval Loss: 0.084 Eval Acc: 0.975
Epoch: 09 Train Loss: 0.042 Train Acc: 0.986 Eval Loss: 0.081 Eval Acc: 0.977
Epoch: 10 Train Loss: 0.037 Train Acc: 0.988 Eval Loss: 0.071 Eval Acc: 0.979
Epoch: 11 Train Loss: 0.031 Train Acc: 0.990 Eval Loss: 0.076 Eval Acc: 0.976
Epoch: 12 Train Loss: 0.029 Train Acc: 0.990 Eval Loss: 0.078 Ev

QuantizedMLP(
  (quant): QuantStub(
    (activation_post_process): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([0.0157]), zero_point=tensor([64], dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=127, qscheme=torch.per_tensor_affine, reduce_range=True
      (activation_post_process): MovingAverageMinMaxObserver(min_val=-1.0, max_val=1.0)
    )
  )
  (dequant): DeQuantStub()
  (model_fp32): MLP(
    (linears): ModuleList(
      (0): Linear(
        in_features=784, out_features=512, bias=True
        (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
          fake_quant_enabled=tensor([1]), observer_enabled=tensor([1]), scale=tensor([0.0009, 0.0009, 0.0013, 0.0016, 0.0003, 0.0013, 0.0008, 0.0004, 0.0013,
                  0.0008, 0.0010, 0.0008, 0.0008, 0.0010, 0.0006, 0.0008, 0.0010, 0.0007,
                  0.0007, 0.0012, 0.0003, 0.0012, 0.0005, 0.0013, 0.0003, 0.0010, 0.0009,
                  0.0011

In [23]:
# Convert the observed model to a quantized model. This does several things:
# quantizes the weights, computes and stores the scale and bias value to be
# used with each activation tensor, fuses modules where appropriate,
# and replaces key operators with quantized implementations.
model_fp32_prepared.eval()
model_int8 = torch.ao.quantization.convert(model_fp32_prepared)

# run the model, relevant calculations will happen in int8
#res = model_int8(input_fp32)

In [24]:
#, fp32_eval_accuracy = evaluate_model(model=model, test_loader=test_loader, device=cpu_device, criterion=None)
_, int8_eval_accuracy = evaluate_model(model=model_int8, test_loader=test_loader, device=cpu_device, criterion=None)

# Skip this assertion since the values might deviate a lot.
# assert model_equivalence(model_1=model, model_2=quantized_jit_model, device=cpu_device, rtol=1e-01, atol=1e-02, num_tests=100, input_size=(1,3,32,32)), "Quantized model deviates from the original model too much!"

#print("FP32 evaluation accuracy: {:.3f}".format(fp32_eval_accuracy))
print("INT8 evaluation accuracy: {:.3f}".format(int8_eval_accuracy))

INT8 evaluation accuracy: 0.982


In [None]:
import os
import torch
import torchvision.datasets as datasets
from torch.quantization import quantize_dynamic
import torchvision.transforms as transforms
from src.models import MLP
from src import utils as u

param = [ 28 * 28, # input
                512, 256, 128, 64,
                10 ] #output

transform = transforms.Compose([
    transforms.ToTensor(), # convert the image to a PyTorch tensor
    transforms.Normalize((0.5,), (0.5,)) # normalize the image with mean=0.5 and std=0.5
])
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
test_dataset = datasets.MNIST(root='data/', train=False, transform=transform, download=True)
test_loader = torch.utils.data.DataLoader(dataset=test_dataset, batch_size=100, shuffle=False)

# Loading pretrained model
modeldict = torch.load('models/mlp.ckpt')
model = MLP(param)
model.load_state_dict(modeldict)
quantized_model = quantize_dynamic(model, {torch.nn.Linear}, dtype=torch.qint8)

model.to(device)
quantized_model.eval()
model.eval()
with torch.no_grad():
    correctq = 0
    totalq = 0
    total = 0
    correct = 0
    for images, labels in test_loader:
        images_cuda = images.to(device)
        labels_cuda = labels.to(device)

        outputsq = quantized_model(images)
        _, predictedq = torch.max(outputsq.data, 1)
        totalq += labels.size(0)
        correctq += (predictedq == labels).sum().item()
        
        outputs = model(images_cuda)
        _, predicted = torch.max(outputs.data, 1)
        total += labels_cuda.size(0)
        correct += (predicted == labels_cuda).sum().item()
        

    print('Accuracy of the quantized model on the test images: {} %'.format(100 * correctq / totalq))
    print('Accuracy of the model on the test images: {} %'.format(100 * correct / total))




u.print_model_size(model)
u.print_model_size(quantized_model)

Accuracy of the quantized model on the test images: 97.75 %
Accuracy of the model on the test images: 97.75 %
2.30 MB
0.58 MB
