In [1]:
import torch
import argparse
from torchsummary import summary
import os
from utils.dataloaders import get_dataloader, get_subnet_dataloader
from utils.train_eval import evaluate
from utils.functions import reconstruction_model
from utils.io import load_weights

from utils.train_eval import get_accuracy
from utils.utils import count_net_flops

from bn_fold import bn_fold, fuse_conv_bn
from torch import nn
from fxpmath import Fxp
import torchvision
from torchvision import transforms

In [2]:
model_ckpt = "./weights/lenet_cifar10_cmsis.pth"
image_size = 32
workers = 4
batch_size = 50
from models import LeNet

In [3]:
model = LeNet(num_channels=3, num_classes=10, model='cmsis')
print(model)
summary(model, (3, 32, 32), device='cpu')

LeNet(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU()
  (pool2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU()
  (pool3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=1024, out_features=10, bias=True)
)
------------------------

In [4]:
model = load_weights(model, model_ckpt)
print(model)


LeNet(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU()
  (pool2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU()
  (pool3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=1024, out_features=10, bias=True)
)


In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
flatten_model = reconstruction_model(model, device)
def get_cifar10_loader():
    print('=> loading cifar10 data...')
    normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])
    train_dataset = torchvision.datasets.CIFAR10(
            root='E:/2_Quantization/torch2cmsis/examples/cifar/data/data_cifar10',
            train=True,
            download=True,
            transform=transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
            ]))
    trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    test_dataset = torchvision.datasets.CIFAR10(
            root='E:/2_Quantization/torch2cmsis/examples/cifar/data/data_cifar10',
            train=False,
            download=True,
            transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
            ]))
    testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    return trainloader, testloader
trainloader, testloader = get_cifar10_loader()

=> loading cifar10 data...
Files already downloaded and verified
Files already downloaded and verified


In [6]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")    

print(f"Before accuracy: {get_accuracy(model.to(device), testloader):.2f}%",
        f"MAC+BN={count_net_flops(model, (1, 3, image_size, image_size), True):,}")
# model = fuse_conv_bn(model)

                                                       

Before accuracy: 85.03% MAC+BN=11,504,640




In [7]:
model

LeNet(
  (quant): QuantStub()
  (dequant): DeQuantStub()
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn2): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU()
  (pool2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (conv3): Conv2d(32, 64, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn3): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU()
  (pool3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=1024, out_features=10, bias=True)
)

In [8]:
from torch.quantization import get_default_qconfig, quantize_dynamic, QConfig, HistogramObserver, prepare, fuse_modules
# Custom quantization configuration 정의
import math
from torch.quantization import MinMaxObserver
import torch


class PowerOfTwoActivationObserver(HistogramObserver):
    def __init__(self, bits=8, *args, **kwargs):
        super(PowerOfTwoActivationObserver, self).__init__(*args, **kwargs)
        self.bits = bits
        self.quant_min = 0
        self.quant_max = 2 ** bits - 1
        self.dtype = torch.qint8  # 추가: dtype을 torch.quint8로 설정
        self.register_buffer('scale', torch.tensor(1.0))
        self.register_buffer('zero_point', torch.tensor(0))

    def calculate_qparams(self):
        min_val, max_val = self.min_val, self.max_val
        if min_val == max_val:
            scale = 1.0
            zero_point = 0
        else:
            max_val = max(abs(min_val), abs(max_val))
            scale = max_val / ((2 ** self.bits) - 1)
            scale_pow2 = 2 ** torch.floor(torch.log2(scale))
            zero_point = 0

        self.scale.copy_(scale_pow2)
        self.zero_point.copy_(zero_point)
        return self.scale, self.zero_point

    def forward(self, x):
        self.min_val = min(self.min_val, x.min())
        self.max_val = max(self.max_val, x.max())
        return x

    def extra_repr(self):
        return "min_val={}, max_val={}, scale={}, zero_point={}, bits={}".format(
            self.min_val, self.max_val, self.scale, self.zero_point, self.bits
        )
class PowerOfTwoWeightObserver(HistogramObserver):
    def __init__(self, bits=8, *args, **kwargs):
        super(PowerOfTwoWeightObserver, self).__init__(*args, **kwargs)
        self.bits = bits
        self.dtype = torch.qint8  # 추가: dtype을 torch.quint8로 설정

        self.quant_min = -2 ** (bits - 1)
        self.quant_max = 2 ** (bits - 1) - 1
        self.register_buffer('scale', torch.tensor(1.0))
        self.register_buffer('zero_point', torch.tensor(0))

    def calculate_qparams(self):
        min_val, max_val = self.min_val, self.max_val
        if min_val == max_val:
            scale = 1.0
            zero_point = 0
        else:
            max_val = max(abs(min_val), abs(max_val))
            scale = max_val / (2 ** (self.bits - 1) - 1)
            scale_pow2 = 2 ** torch.floor(torch.log2(scale))
            zero_point = 0

        self.scale.copy_(scale_pow2)
        self.zero_point.copy_(zero_point)
        return self.scale, self.zero_point

    def forward(self, x):
        self.min_val = min(self.min_val, x.min())
        self.max_val = max(self.max_val, x.max())
        return x

    def extra_repr(self):
        return "min_val={}, max_val={}, scale={}, zero_point={}, bits={}".format(
            self.min_val, self.max_val, self.scale, self.zero_point, self.bits
        )

backend = "x86"
model.qconfig = QConfig(
            activation=PowerOfTwoWeightObserver.with_args(bits=8, 
                                                          qscheme=torch.per_tensor_symmetric,
                                                          dtype=torch.qint8,
                                                          reduce_range=True),
            weight=PowerOfTwoWeightObserver.with_args(bits=8,
                                              qscheme=torch.per_tensor_symmetric,
                                              dtype=torch.qint8, 
                                              reduce_range=True)
            )

# model.qconfig = torch.quantization.QConfig(
#     activation=HistogramObserver.with_args(dtype=torch.qint8, 
#                                            qscheme=torch.per_tensor_symmetric),
#     weight=HistogramObserver.with_args(dtype=torch.qint8,
#                                         qscheme=torch.per_tensor_symmetric)
# )
"""
LeNet(
  (conv1): Conv2d(3, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn1): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu1): ReLU()
  (pool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): Conv2d(32, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn2): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu2): ReLU()
  (pool2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (conv3): Conv2d(16, 32, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
  (bn3): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu3): ReLU()
  (pool3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): Linear(in_features=512, out_features=10, bias=True)
)
"""

fuse_modules(model, [['conv1', 'bn1', 'relu1'],
                     ['conv2', 'bn2', 'relu2'],
                     ['conv3', 'bn3','relu3']], inplace=True)


qmodel = prepare(model, inplace=False)
# model.qconfig = torch.quantization.get_default_qconfig(backend)
# torch.backends.quantized.engine = backend
# model_static_quantized = torch.quantization.prepare(model, inplace=False)
# model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)
# model.qconfig




In [9]:
cnt = 0
qmodel = qmodel.to(device)
with torch.inference_mode():
    for img, label in trainloader:
        img = img.to(device)
        label = label.to(device)
        if cnt > 10:
            break
        qmodel(img)
    
qmodel = torch.quantization.convert(qmodel, inplace=True)
qmodel

LeNet(
  (quant): Quantize(scale=tensor([0.0156], device='cuda:0'), zero_point=tensor([0], device='cuda:0'), dtype=torch.qint8)
  (dequant): DeQuantize()
  (conv1): QuantizedConvReLU2d(3, 32, kernel_size=(5, 5), stride=(1, 1), scale=0.125, zero_point=0, padding=(2, 2))
  (bn1): Identity()
  (relu1): Identity()
  (pool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): QuantizedConvReLU2d(32, 32, kernel_size=(5, 5), stride=(1, 1), scale=0.0625, zero_point=0, padding=(2, 2))
  (bn2): Identity()
  (relu2): Identity()
  (pool2): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (conv3): QuantizedConvReLU2d(32, 64, kernel_size=(5, 5), stride=(1, 1), scale=0.0625, zero_point=0, padding=(2, 2))
  (bn3): Identity()
  (relu3): Identity()
  (pool3): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (flatten): Flatten(start_dim=1, end_dim=-1)
  (fc1): QuantizedLinear(in_features=1024, out_features=10, scal

In [10]:
# print(get_accuracy(qmodel, testloader), count_net_flops(model, (1, 3, image_size, image_size)))
print(get_accuracy(qmodel, testloader))

                                                       

82.79000091552734




In [11]:
def print_model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    print("%.2f KB" %(os.path.getsize("tmp.pt")/1e3))
    os.remove('tmp.pt')

print_model_size(model)
print_model_size(qmodel)

361.42 KB
96.39 KB
