# Quantization Post-Training


In [25]:
import torch
import torchvision.datasets as datasets
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os
import torch.ao.quantization as quantization



In [26]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = "cpu"

## Modelo Simple

In [27]:
class SimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(SimpleNet,self).__init__()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        return x
net = SimpleNet().to(device)


In [28]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0
        
        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        # Colocar límite de iteraciones
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
            
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

MODEL_FILENAME = 'simplenet_ptq.pt'

if Path(MODEL_FILENAME).exists():
    net.load_state_dict(torch.load(MODEL_FILENAME))
    print('Loaded model from disk')
else:
    train(train_loader, net, epochs=5)
    # Save the model to disk
    torch.save(net.state_dict(), MODEL_FILENAME)

Loaded model from disk


In [29]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')

### Observacion de los pesos

In [30]:
# Print the weights matrix of the model before quantization
print('Weights before quantization')
print(net.linear1.weight)
print(net.linear1.weight.dtype)

Weights before quantization
Parameter containing:
tensor([[-0.0213, -0.0199, -0.0160,  ..., -0.0369, -0.0306, -0.0087],
        [ 0.0122,  0.0231,  0.0552,  ...,  0.0227,  0.0452,  0.0321],
        [ 0.1020,  0.0692,  0.0715,  ...,  0.0470,  0.0504,  0.1005],
        ...,
        [ 0.0304, -0.0287,  0.0310,  ..., -0.0354,  0.0290,  0.0232],
        [ 0.0377,  0.0404,  0.0560,  ...,  0.0246,  0.0379,  0.0597],
        [ 0.0096,  0.0516,  0.0546,  ...,  0.0365,  0.0424,  0.0619]],
       requires_grad=True)
torch.float32


In [31]:
print('Size of the model before quantization')
print_size_of_model(net)

Size of the model before quantization
Size (KB): 360.559


In [32]:
print(f'Accuracy of the model before quantization: ')
test(net)

Accuracy of the model before quantization: 


Testing: 100%|██████████| 1000/1000 [00:01<00:00, 771.98it/s]

Accuracy: 0.973





# Quantized Post-Training Model

## Insert min-max observers in the model

In [33]:
class QuantizedSimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(QuantizedSimpleNet,self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

net_quantized = QuantizedSimpleNet()
net_quantized = net_quantized.to(device=device)

In [34]:
print(f'Accuracy of the model quantized before quantization: ')
test(net_quantized)

Accuracy of the model quantized before quantization: 


Testing: 100%|██████████| 1000/1000 [00:01<00:00, 798.33it/s]

Accuracy: 0.143





In [35]:

# Copy weights from unquantized model
net_quantized.load_state_dict(net.state_dict())
net_quantized.eval()

net_quantized.qconfig = torch.ao.quantization.default_qconfig # Which layers will be quantized
net_quantized = torch.ao.quantization.prepare(net_quantized) # Insert observers
net_quantized

QuantizedSimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

## Calibrar el modelo
Captura estadisticas del set de testing para ajustar los escaladores a utilizar en red quantizada

In [36]:
test(net_quantized)

Testing: 100%|██████████| 1000/1000 [00:01<00:00, 733.26it/s]

Accuracy: 0.973





In [37]:
net_quantized

QuantizedSimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=-0.4242129623889923, max_val=2.821486711502075)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-79.43531036376953, max_val=48.76089859008789)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (activation_post_process): MinMaxObserver(min_val=-63.34500503540039, max_val=48.45856857299805)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (activation_post_process): MinMaxObserver(min_val=-79.30042266845703, max_val=31.928213119506836)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [38]:
net_quantized = torch.ao.quantization.convert(net_quantized)

In [39]:
print(f'Check statistics of the various layers')
net_quantized

Check statistics of the various layers


QuantizedSimpleNet(
  (quant): Quantize(scale=tensor([0.0256]), zero_point=tensor([17]), dtype=torch.quint8)
  (linear1): QuantizedLinear(in_features=784, out_features=100, scale=1.0094189643859863, zero_point=79, qscheme=torch.per_tensor_affine)
  (linear2): QuantizedLinear(in_features=100, out_features=100, scale=0.8803430795669556, zero_point=72, qscheme=torch.per_tensor_affine)
  (linear3): QuantizedLinear(in_features=100, out_features=10, scale=0.8758160471916199, zero_point=91, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (dequant): DeQuantize()
)

In [40]:

# Print the weights matrix of the model after quantization
print('Weights after quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights after quantization
tensor([[-2, -2, -2,  ..., -4, -3, -1],
        [ 1,  2,  6,  ...,  2,  5,  3],
        [10,  7,  7,  ...,  5,  5, 10],
        ...,
        [ 3, -3,  3,  ..., -4,  3,  2],
        [ 4,  4,  6,  ...,  3,  4,  6],
        [ 1,  5,  6,  ...,  4,  4,  6]], dtype=torch.int8)


In [41]:
net_quantized.linear1.weight()

tensor([[-0.0195, -0.0195, -0.0195,  ..., -0.0389, -0.0292, -0.0097],
        [ 0.0097,  0.0195,  0.0584,  ...,  0.0195,  0.0487,  0.0292],
        [ 0.0973,  0.0681,  0.0681,  ...,  0.0487,  0.0487,  0.0973],
        ...,
        [ 0.0292, -0.0292,  0.0292,  ..., -0.0389,  0.0292,  0.0195],
        [ 0.0389,  0.0389,  0.0584,  ...,  0.0292,  0.0389,  0.0584],
        [ 0.0097,  0.0487,  0.0584,  ...,  0.0389,  0.0389,  0.0584]],
       size=(100, 784), dtype=torch.qint8,
       quantization_scheme=torch.per_tensor_affine, scale=0.009733841754496098,
       zero_point=0)

In [42]:
print('Original weights: ')
print(net.linear1.weight)
print('')
print(f'Dequantized weights: ')
print(torch.dequantize(net_quantized.linear1.weight()))
print('')

Original weights: 
Parameter containing:
tensor([[-0.0213, -0.0199, -0.0160,  ..., -0.0369, -0.0306, -0.0087],
        [ 0.0122,  0.0231,  0.0552,  ...,  0.0227,  0.0452,  0.0321],
        [ 0.1020,  0.0692,  0.0715,  ...,  0.0470,  0.0504,  0.1005],
        ...,
        [ 0.0304, -0.0287,  0.0310,  ..., -0.0354,  0.0290,  0.0232],
        [ 0.0377,  0.0404,  0.0560,  ...,  0.0246,  0.0379,  0.0597],
        [ 0.0096,  0.0516,  0.0546,  ...,  0.0365,  0.0424,  0.0619]],
       requires_grad=True)

Dequantized weights: 
tensor([[-0.0195, -0.0195, -0.0195,  ..., -0.0389, -0.0292, -0.0097],
        [ 0.0097,  0.0195,  0.0584,  ...,  0.0195,  0.0487,  0.0292],
        [ 0.0973,  0.0681,  0.0681,  ...,  0.0487,  0.0487,  0.0973],
        ...,
        [ 0.0292, -0.0292,  0.0292,  ..., -0.0389,  0.0292,  0.0195],
        [ 0.0389,  0.0389,  0.0584,  ...,  0.0292,  0.0389,  0.0584],
        [ 0.0097,  0.0487,  0.0584,  ...,  0.0389,  0.0389,  0.0584]])



In [43]:
print('Size of the model after quantization')
print_size_of_model(net_quantized)

Size of the model after quantization
Size (KB): 94.955


In [44]:
print('Testing the model after quantization')
test(net_quantized)

Testing the model after quantization


Testing: 100%|██████████| 1000/1000 [00:01<00:00, 749.22it/s]

Accuracy: 0.973



