In [61]:
import torch
import torchvision.datasets as datasets 
import torchvision.transforms as transforms
import torch.nn as nn
import matplotlib.pyplot as plt
from tqdm import tqdm
from pathlib import Path
import os

In [62]:
# Make torch deterministic
_ = torch.manual_seed(0)

In [31]:
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

# Load the MNIST dataset
mnist_trainset = datasets.MNIST(root='./data', train=True, download=True, transform=transform)
# Create a dataloader for the training
train_loader = torch.utils.data.DataLoader(mnist_trainset, batch_size=10, shuffle=True)

# Load the MNIST test set
mnist_testset = datasets.MNIST(root='./data', train=False, download=True, transform=transform)
test_loader = torch.utils.data.DataLoader(mnist_testset, batch_size=10, shuffle=True)

# Define the device
device = "cpu"

In [42]:
class VerySimpleNet(nn.Module):
    def __init__(self, hidden_size_1=100, hidden_size_2=100):
        super(VerySimpleNet,self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.linear1 = nn.Linear(28*28, hidden_size_1) 
        self.linear2 = nn.Linear(hidden_size_1, hidden_size_2) 
        self.linear3 = nn.Linear(hidden_size_2, 10)
        self.relu = nn.ReLU()
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, img):
        x = img.view(-1, 28*28)
        x = self.quant(x)
        x = self.relu(self.linear1(x))
        x = self.relu(self.linear2(x))
        x = self.linear3(x)
        x = self.dequant(x)
        return x

net = VerySimpleNet().to(device)


In [43]:
net.qconfig = torch.ao.quantization.default_qconfig
net.train()
net_quantized = torch.ao.quantization.prepare_qat(net) # Insert observers
net_quantized

VerySimpleNet(
  (quant): QuantStub(
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear1): Linear(
    in_features=784, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear2): Linear(
    in_features=100, out_features=100, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (linear3): Linear(
    in_features=100, out_features=10, bias=True
    (weight_fake_quant): MinMaxObserver(min_val=inf, max_val=-inf)
    (activation_post_process): MinMaxObserver(min_val=inf, max_val=-inf)
  )
  (relu): ReLU()
  (dequant): DeQuantStub()
)

In [34]:
def train(train_loader, net, epochs=5, total_iterations_limit=None):
    cross_el = nn.CrossEntropyLoss()
    optimizer = torch.optim.Adam(net.parameters(), lr=0.001)

    total_iterations = 0

    for epoch in range(epochs):
        net.train()

        loss_sum = 0
        num_iterations = 0

        data_iterator = tqdm(train_loader, desc=f'Epoch {epoch+1}')
        if total_iterations_limit is not None:
            data_iterator.total = total_iterations_limit
        for data in data_iterator:
            num_iterations += 1
            total_iterations += 1
            x, y = data
            x = x.to(device)
            y = y.to(device)
            optimizer.zero_grad()
            output = net(x.view(-1, 28*28))
            loss = cross_el(output, y)
            loss_sum += loss.item()
            avg_loss = loss_sum / num_iterations
            data_iterator.set_postfix(loss=avg_loss)
            loss.backward()
            optimizer.step()

            if total_iterations_limit is not None and total_iterations >= total_iterations_limit:
                return
            
def print_size_of_model(model):
    torch.save(model.state_dict(), "temp_delme.p")
    print('Size (KB):', os.path.getsize("temp_delme.p")/1e3)
    os.remove('temp_delme.p')

train(train_loader, net_quantized, epochs=1)

Epoch 1: 100%|██████████| 6000/6000 [00:29<00:00, 200.53it/s, loss=0.224]


In [35]:
def test(model: nn.Module, total_iterations: int = None):
    correct = 0
    total = 0

    iterations = 0

    model.eval()

    with torch.no_grad():
        for data in tqdm(test_loader, desc='Testing'):
            x, y = data
            x = x.to(device)
            y = y.to(device)
            output = model(x.view(-1, 784))
            for idx, i in enumerate(output):
                if torch.argmax(i) == y[idx]:
                    correct +=1
                total +=1
            iterations += 1
            if total_iterations is not None and iterations >= total_iterations:
                break
    print(f'Accuracy: {round(correct/total, 3)}')


In [44]:
net_quantized.eval()
net_quantized = torch.ao.quantization.convert(net_quantized)



In [45]:
print_size_of_model(net)
print_size_of_model(net_quantized)

Size (KB): 361.062
Size (KB): 95.394


In [46]:
print('Weights before quantization')
print(torch.int_repr(net_quantized.linear1.weight()))

Weights before quantization
tensor([[  -1,   68, -105,  ...,   78,   13,    7],
        [ -71,  -53,  -37,  ...,  -72,  -21, -107],
        [ -72,   53, -119,  ...,  -73,    4,   29],
        ...,
        [  79,   92,  -31,  ...,  -50,   18, -114],
        [ -77,  -48,   66,  ...,   -4,   -4,   -6],
        [  51,   32,  -19,  ...,  111,  -65,    7]], dtype=torch.int8)


In [51]:
(net.linear1.weight)

Parameter containing:
tensor([[-0.0003,  0.0192, -0.0294,  ...,  0.0219,  0.0037,  0.0021],
        [-0.0198, -0.0150, -0.0104,  ..., -0.0203, -0.0060, -0.0299],
        [-0.0201,  0.0149, -0.0333,  ..., -0.0203,  0.0012,  0.0080],
        ...,
        [ 0.0221,  0.0258, -0.0088,  ..., -0.0141,  0.0051, -0.0318],
        [-0.0217, -0.0136,  0.0185,  ..., -0.0012, -0.0012, -0.0017],
        [ 0.0142,  0.0089, -0.0053,  ...,  0.0311, -0.0181,  0.0020]],
       requires_grad=True)

In [None]:
# Instanciar o modelo
import torch.ao.quantization


net = VerySimpleNet()

# Aplicar quantização dinâmica no modelo
net_quantized = torch.ao.quantization.quantize_dynamic(
    net,  # Modelo a ser quantizado
    {torch.nn.Linear},  # Camadas a serem quantizadas dinamicamente
    dtype=torch.qint8  # Tipo de quantização
)

# Exibir pesos antes e após a quantização
print("Pesos antes da quantização:")
print(net.linear1.weight)

print("\nPesos após a quantização:")
print(net_quantized.linear1.weight())

print_size_of_model(net)
print_size_of_model(net_quantized)

Pesos antes da quantização:
Parameter containing:
tensor([[ 0.0094, -0.0010, -0.0346,  ...,  0.0212,  0.0081,  0.0045],
        [ 0.0155,  0.0012, -0.0284,  ..., -0.0085,  0.0067, -0.0324],
        [-0.0255, -0.0122, -0.0105,  ..., -0.0346, -0.0156, -0.0003],
        ...,
        [-0.0252, -0.0017,  0.0309,  ...,  0.0141, -0.0328,  0.0246],
        [ 0.0121,  0.0155,  0.0333,  ..., -0.0195,  0.0311,  0.0292],
        [ 0.0343, -0.0196, -0.0254,  ..., -0.0004, -0.0308,  0.0105]],
       requires_grad=True)

Pesos após a quantização:
tensor([[ 0.0095, -0.0011, -0.0345,  ...,  0.0213,  0.0081,  0.0045],
        [ 0.0154,  0.0011, -0.0283,  ..., -0.0084,  0.0067, -0.0325],
        [-0.0255, -0.0120, -0.0106,  ..., -0.0347, -0.0157, -0.0003],
        ...,
        [-0.0252, -0.0017,  0.0308,  ...,  0.0140, -0.0328,  0.0246],
        [ 0.0120,  0.0154,  0.0333,  ..., -0.0196,  0.0311,  0.0291],
        [ 0.0342, -0.0196, -0.0255,  ..., -0.0003, -0.0308,  0.0106]],
       size=(100, 784), dtyp

In [60]:
torch.unique(net_quantized)

TypeError: _unique2(): argument 'input' (position 1) must be Tensor, not VerySimpleNet

In [59]:
# Dequantizar os pesos
weights_dequantized = net_quantized.linear1.weight().dequantize()

print("\nPesos após dequantização (retornando ao ponto flutuante):")
print(weights_dequantized)


Pesos após dequantização (retornando ao ponto flutuante):
tensor([[ 0.0095, -0.0011, -0.0345,  ...,  0.0213,  0.0081,  0.0045],
        [ 0.0154,  0.0011, -0.0283,  ..., -0.0084,  0.0067, -0.0325],
        [-0.0255, -0.0120, -0.0106,  ..., -0.0347, -0.0157, -0.0003],
        ...,
        [-0.0252, -0.0017,  0.0308,  ...,  0.0140, -0.0328,  0.0246],
        [ 0.0120,  0.0154,  0.0333,  ..., -0.0196,  0.0311,  0.0291],
        [ 0.0342, -0.0196, -0.0255,  ..., -0.0003, -0.0308,  0.0106]])


In [63]:
list(net_quantized.named_modules())

[('',
  VerySimpleNet(
    (quant): QuantStub()
    (linear1): DynamicQuantizedLinear(in_features=784, out_features=100, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    (linear2): DynamicQuantizedLinear(in_features=100, out_features=100, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    (linear3): DynamicQuantizedLinear(in_features=100, out_features=10, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
    (relu): ReLU()
    (dequant): DeQuantStub()
  )),
 ('quant', QuantStub()),
 ('linear1',
  DynamicQuantizedLinear(in_features=784, out_features=100, dtype=torch.qint8, qscheme=torch.per_tensor_affine)),
 ('linear1._packed_params',
  (tensor([[ 0.0095, -0.0011, -0.0345,  ...,  0.0213,  0.0081,  0.0045],
          [ 0.0154,  0.0011, -0.0283,  ..., -0.0084,  0.0067, -0.0325],
          [-0.0255, -0.0120, -0.0106,  ..., -0.0347, -0.0157, -0.0003],
          ...,
          [-0.0252, -0.0017,  0.0308,  ...,  0.0140, -0.0328,  0.0246],
          [ 0.0120,  0.0154,  0.0333,  .