In [1]:
import torch
import argparse
from torchsummary import summary
import os
from utils.dataloaders import get_dataloader, get_subnet_dataloader
from utils.train_eval import evaluate
from utils.functions import reconstruction_model
from utils.io import load_weights

from utils.train_eval import get_accuracy
from utils.utils import count_net_flops

from bn_fold import bn_fold
from torch import nn
from fxpmath import Fxp
import torchvision
from torchvision import transforms

In [2]:
# model_ckpt = "./weights/mcu_vggrepc1_vww.pth"
model_ckpt = "./weights/mcu_vggrepopt_cifar10.pth"
data_dir = "E:/1_TinyML/tiny/benchmark/training/visual_wake_words/vw_coco2014_96"
image_size = 32
workers = 4
batch_size = 50
from models.model_q import MCU_VGGRep, MCU_VGGRepC1

### Model Load

In [3]:
model = MCU_VGGRepC1(num_classes=10)
model = load_weights(model, model_ckpt)
print(model)

MCU_VGGRepC1(
  (quant): QuantStub()
  (STAGE0_CONV): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (STAGE0_RELU): ReLU()
  (STAGE1_0_CONV): Conv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (STAGE1_0_RELU): ReLU()
  (STAGE2_0_CONV): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (STAGE2_0_RELU): ReLU()
  (STAGE3_0_CONV): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (STAGE3_0_RELU): ReLU()
  (STAGE4_0_CONV): Conv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
  (STAGE4_0_RELU): ReLU()
  (GAP21): AdaptiveAvgPool2d(output_size=1)
  (FLATTEN22): Flatten(start_dim=1, end_dim=-1)
  (LINEAR): Linear(in_features=128, out_features=10, bias=True)
  (dequant): DeQuantStub()
)


### Dataset Load

In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
flatten_model = reconstruction_model(model, device)
def get_cifar10_loader():
    print('=> loading cifar10 data...')
    normalize = transforms.Normalize(mean=[0.491, 0.482, 0.447], std=[0.247, 0.243, 0.262])
    train_dataset = torchvision.datasets.CIFAR10(
            root='E:/2_Quantization/torch2cmsis/examples/cifar/data/data_cifar10',
            train=True,
            download=True,
            transform=transforms.Compose([
            transforms.RandomCrop(32, padding=4),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
            ]))
    trainloader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers=2)

    test_dataset = torchvision.datasets.CIFAR10(
            root='E:/2_Quantization/torch2cmsis/examples/cifar/data/data_cifar10',
            train=False,
            download=True,
            transform=transforms.Compose([
            transforms.ToTensor(),
            normalize,
            ]))
    testloader = torch.utils.data.DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=2)
    return trainloader, testloader
trainloader, testloader = get_cifar10_loader()

=> loading cifar10 data...
Files already downloaded and verified
Files already downloaded and verified


### Floating Point Evaluation

In [5]:
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")    

print(f"Before accuracy: {get_accuracy(model.to(device), testloader):.2f}%",
        f"MAC+BN={count_net_flops(model, (1, 3, image_size, image_size), False):,}")


                                                        

Before accuracy: 75.94% MAC+BN=480,640




### Cofing Quatization

In [6]:
from torch.quantization import get_default_qconfig, quantize_dynamic, QConfig, HistogramObserver, prepare, fuse_modules
# Custom quantization configuration 정의
import math
from torch.quantization import MinMaxObserver
import torch


class PowerOfTwoActivationObserver(HistogramObserver):
    def __init__(self, bits=8, *args, **kwargs):
        super(PowerOfTwoActivationObserver, self).__init__(*args, **kwargs)
        self.bits = bits
        self.quant_min = 0
        self.quant_max = 2 ** bits - 1
        self.dtype = torch.qint8  # 추가: dtype을 torch.quint8로 설정
        self.register_buffer('scale', torch.tensor(1.0))
        self.register_buffer('zero_point', torch.tensor(0))

    def calculate_qparams(self):
        min_val, max_val = self.min_val, self.max_val
        if min_val == max_val:
            scale = 1.0
            zero_point = 0
        else:
            max_val = max(abs(min_val), abs(max_val))
            scale = max_val / ((2 ** self.bits) - 1)
            scale_pow2 = 2 ** torch.floor(torch.log2(scale))
            zero_point = 0

        self.scale.copy_(scale_pow2)
        self.zero_point.copy_(zero_point)
        return self.scale, self.zero_point

    def forward(self, x):
        self.min_val = min(self.min_val, x.min())
        self.max_val = max(self.max_val, x.max())
        return x

    def extra_repr(self):
        return "min_val={}, max_val={}, scale={}, zero_point={}, bits={}".format(
            self.min_val, self.max_val, self.scale, self.zero_point, self.bits
        )
class PowerOfTwoWeightObserver(HistogramObserver):
    def __init__(self, bits=8, *args, **kwargs):
        super(PowerOfTwoWeightObserver, self).__init__(*args, **kwargs)
        self.bits = bits
        self.dtype = torch.qint8  # 추가: dtype을 torch.quint8로 설정

        self.quant_min = -2 ** (bits - 1)
        self.quant_max = 2 ** (bits - 1) - 1
        self.register_buffer('scale', torch.tensor(1.0))
        self.register_buffer('zero_point', torch.tensor(0))

    def calculate_qparams(self):
        min_val, max_val = self.min_val, self.max_val
        if min_val == max_val:
            scale = 1.0
            zero_point = 0
        else:
            max_val = max(abs(min_val), abs(max_val))
            scale = max_val / (2 ** (self.bits - 1) - 1)
            scale_pow2 = 2 ** torch.floor(torch.log2(scale))
            zero_point = 0

        self.scale.copy_(scale_pow2)
        self.zero_point.copy_(zero_point)
        return self.scale, self.zero_point

    def forward(self, x):
        self.min_val = min(self.min_val, x.min())
        self.max_val = max(self.max_val, x.max())
        return x

    def extra_repr(self):
        return "min_val={}, max_val={}, scale={}, zero_point={}, bits={}".format(
            self.min_val, self.max_val, self.scale, self.zero_point, self.bits
        )

backend = "x86"
model.qconfig = QConfig(
            activation=PowerOfTwoWeightObserver.with_args(bits=8, 
                                                          qscheme=torch.per_tensor_symmetric,
                                                          dtype=torch.qint8),
            weight=PowerOfTwoWeightObserver.with_args(bits=8,
                                              qscheme=torch.per_tensor_symmetric,
                                              dtype=torch.qint8)
            )
fuse_modules(model, [['STAGE0_CONV', 'STAGE0_RELU'],
                     ['STAGE1_0_CONV', 'STAGE1_0_RELU'],
                     ['STAGE2_0_CONV', 'STAGE2_0_RELU'],
                     ['STAGE3_0_CONV', 'STAGE3_0_RELU'],
                     ['STAGE4_0_CONV', 'STAGE4_0_RELU']], inplace=True)
# model.qconfig = torch.quantization.QConfig(
#     activation=HistogramObserver.with_args(dtype=torch.qint8, 
#                                            qscheme=torch.per_tensor_symmetric),
#     weight=HistogramObserver.with_args(dtype=torch.qint8,
#                                         qscheme=torch.per_tensor_symmetric)
# )
qmodel = prepare(model, inplace=False)
# model.qconfig = torch.quantization.get_default_qconfig(backend)
# torch.backends.quantized.engine = backend
# model_static_quantized = torch.quantization.prepare(model, inplace=False)
# model_static_quantized = torch.quantization.convert(model_static_quantized, inplace=False)
# model.qconfig


### Calibration

In [7]:
cnt = 0
qmodel = qmodel.to(device)
with torch.inference_mode():
    for img, label in trainloader:
        img = img.to(device)
        label = label.to(device)
        if cnt > 10:
            break
        qmodel(img)
    
qmodel = torch.quantization.convert(qmodel, inplace=True)
qmodel

MCU_VGGRepC1(
  (quant): Quantize(scale=tensor([0.0156], device='cuda:0'), zero_point=tensor([0], device='cuda:0'), dtype=torch.qint8)
  (STAGE0_CONV): QuantizedConvReLU2d(3, 16, kernel_size=(3, 3), stride=(2, 2), scale=0.5, zero_point=0, padding=(1, 1))
  (STAGE0_RELU): Identity()
  (STAGE1_0_CONV): QuantizedConvReLU2d(16, 16, kernel_size=(3, 3), stride=(2, 2), scale=0.5, zero_point=0, padding=(1, 1))
  (STAGE1_0_RELU): Identity()
  (STAGE2_0_CONV): QuantizedConvReLU2d(16, 32, kernel_size=(3, 3), stride=(2, 2), scale=0.5, zero_point=0, padding=(1, 1))
  (STAGE2_0_RELU): Identity()
  (STAGE3_0_CONV): QuantizedConvReLU2d(32, 64, kernel_size=(3, 3), stride=(2, 2), scale=0.25, zero_point=0, padding=(1, 1))
  (STAGE3_0_RELU): Identity()
  (STAGE4_0_CONV): QuantizedConvReLU2d(64, 128, kernel_size=(3, 3), stride=(2, 2), scale=0.03125, zero_point=0, padding=(1, 1))
  (STAGE4_0_RELU): Identity()
  (GAP21): AdaptiveAvgPool2d(output_size=1)
  (FLATTEN22): Flatten(start_dim=1, end_dim=-1)
  (LINE

### Quantization Result

In [8]:
get_accuracy(qmodel, testloader)

                                                       

75.25

In [9]:
imgs = iter(testloader).__next__()[0]
print(imgs.shape)
_ = qmodel(imgs.to(device))


torch.Size([50, 3, 32, 32])


In [10]:
img = imgs[0]
img.shape
qmodel(img.unsqueeze(0).to(device))

tensor([[-0.0625, -0.9688, -0.0625,  1.9375, -0.9688,  1.2188,  0.1250, -1.5000,
         -0.9062, -0.6875]], device='cuda:0')

In [11]:
from typing import Tuple
from fxpmath import Fxp
import numpy as np

def qfmt_quanize(x, n_bits=8, signed=True):
    range_min, range_max = torch.min(x), torch.max(x)
    range_abs = torch.max(torch.abs(range_min), torch.abs(range_max))
    int_bits = torch.ceil(torch.log2(range_abs)).type(torch.int8)
    frac_bits = n_bits - int_bits
    if signed:
        range_int_min = -(2 ** n_bits)
        range_int_max = (2 ** n_bits) - 1
        
        # frac_bits = 7 if frac_bits >= 8 else frac_bits - 1
        frac_bits -= 1
    else:
        range_int_min = 0
        range_int_max = (2 ** n_bits)
    # Quantization the input
    
    x_int = torch.round(x * (2 ** (frac_bits))).to(torch.int8)
    x_float = torch.clamp(x_int * (1/(2 ** (frac_bits))), range_int_min, range_int_max)
    # quant_error = torch.mean((x - x_float) ** 2)
    frac_bits = frac_bits if isinstance(frac_bits, int) else frac_bits.item()
    return x_int, frac_bits
class HookRecorder:
    def __init__(self):
        self.recorder = dict() # Get intermediate tensor from the recorder
        self.handlers = list()
    
    def _register_hooker(self, name):
        self.recorder[name] = dict()
        def named_hooker(module, input: Tuple[torch.Tensor], output: torch.Tensor):
            input = input[0].dequantize().detach().cpu().numpy()
            input = Fxp(input, signed=True, n_word=8, overflow='saturate')
            x_frac = input.n_frac
            x_int = torch.tensor(np.array(input << input.n_frac).astype(np.int8))
            # x_int, x_frac = qfmt_quanize(input[0], 8, True)
            
            self.recorder[name]["input"] = x_int
            self.recorder[name]["i_f"] = x_frac
            self.recorder[name]['input_shape'] = x_int.shape
            
            output = output.dequantize().detach().cpu().numpy()
            output = Fxp(output, signed=True, n_word=8, overflow='saturate')
            y_frac = output.n_frac
            y_int = torch.tensor(np.array(output << output.n_frac).astype(np.int8))
            # y_int, y_frac = qfmt_quanize(output, 8, True)
            self.recorder[name]["output"] = y_int
            self.recorder[name]["o_f"] = y_frac
            self.recorder[name]['output_shape'] = y_int.shape
            
            
        return named_hooker
    
    def register_hookers(self, target_sub_modules, layer_names):
        for i in range(len(layer_names)):
            module = target_sub_modules[i]
            layer_name = layer_names[i]
            handler = module.register_forward_hook(self._register_hooker(layer_name))
        self.handlers.append(handler)
        
    def remove_handlers(self):
        for i in self.handlers:
            i.remove()
        self.handlers.clear()
        
    def __del__(self):
        self.remove_handlers()

hook = HookRecorder()
hook.register_hookers([qmodel.quant, qmodel.STAGE0_CONV, qmodel.STAGE1_0_CONV, 
                       qmodel.STAGE2_0_CONV, qmodel.STAGE3_0_CONV, 
                       qmodel.STAGE4_0_CONV, qmodel.LINEAR], 
                      ["quant", "STAGE0_CONV", "STAGE1_0_CONV", 
                       "STAGE2_0_CONV", "STAGE3_0_CONV",
                       "STAGE4_0_CONV", "LINEAR"])

qmodel(img.unsqueeze(0).to(device))
# print(hook.recorder)
# hook.remove_handlers()

{'quant': {'input': tensor([[[[ 16,  17,  20,  ...,   5,   0,  -4],
          [ 13,  13,  17,  ...,   5,   0,  -3],
          [ 13,  13,  16,  ...,   7,   2,  -2],
          ...,
          [-29, -42, -47,  ..., -44, -57, -43],
          [-32, -38, -45,  ..., -50, -48, -53],
          [-36, -35, -40,  ..., -51, -46, -52]],

         [[ -5,  -6,  -3,  ..., -14, -16, -19],
          [ -5,  -6,  -4,  ..., -14, -16, -18],
          [ -6,  -7,  -6,  ..., -12, -14, -17],
          ...,
          [  0, -11, -18,  ..., -13, -30, -19],
          [ -3, -10, -19,  ..., -21, -21, -30],
          [ -8,  -9, -17,  ..., -23, -20, -28]],

         [[-31, -32, -30,  ..., -37, -37, -38],
          [-30, -35, -33,  ..., -39, -39, -38],
          [-32, -38, -37,  ..., -38, -38, -38],
          ...,
          [ 30,  16,  11,  ...,  15,  -2,   6],
          [ 25,  16,   8,  ...,   7,   5,  -3],
          [ 22,  16,   8,  ...,   4,   7,  -1]]]], dtype=torch.int8), 'i_f': 5, 'input_shape': torch.Size([1, 3, 32

In [21]:
print(hook.recorder.keys())
print(hook.recorder['STAGE0_CONV']['input_shape'])
print(hook.recorder['STAGE0_CONV']['output_shape'])
print(hook.recorder['STAGE0_CONV']['i_f'])
print(hook.recorder['STAGE0_CONV']['o_f'])
print(hook.recorder['STAGE0_CONV'].keys())

dict_keys(['quant', 'STAGE0_CONV', 'STAGE1_0_CONV', 'STAGE2_0_CONV', 'STAGE3_0_CONV', 'STAGE4_0_CONV', 'LINEAR'])
torch.Size([1, 3, 32, 32])
torch.Size([1, 16, 16, 16])
6
1
dict_keys(['input', 'i_f', 'input_shape', 'output', 'o_f', 'output_shape'])


In [None]:
import numpy as np
from fxpmath import Fxp

def simulate(img, qmodel, verbose=False):
    quant = qmodel.quant.scale.item()
    inp = torch.round(img*(1/quant)).int()
    
    """
    (quant): Quantize(scale=tensor([0.0156], device='cuda:0'), zero_point=tensor([0], device='cuda:0'), dtype=torch.qint8)
    (STAGE0_CONV): QuantizedConv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), scale=0.5, zero_point=0, padding=(1, 1))
    (STAGE1_0_CONV): QuantizedConv2d(16, 16, kernel_size=(3, 3), stride=(2, 2), scale=1.0, zero_point=0, padding=(1, 1))
    (STAGE2_0_CONV): QuantizedConv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), scale=0.5, zero_point=0, padding=(1, 1))
    (STAGE3_0_CONV): QuantizedConv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), scale=1.0, zero_point=0, padding=(1, 1))
    (STAGE4_0_CONV): QuantizedConv2d(64, 128, kernel_size=(3, 3), stride=(2, 2), scale=0.0625, zero_point=0, padding=(1, 1))
    (GAP21): AdaptiveAvgPool2d(output_size=1)
    (FLATTEN22): Flatten(start_dim=1, end_dim=-1)
    (LINEAR): QuantizedLinear(in_features=128, out_features=10, scale=0.03125, zero_point=0, qscheme=torch.per_tensor_affine)
    (dequant): DeQuantize()
    """
    if verbose:
        print(inp)
    qmodel.eval()
    with torch.no_grad():
        qmodel = qmodel.to(device)
        
        w = qmodel.STAGE0_CONV.weight().int_repr()[:3, 0, 0, :]
        fw = qmodel.STAGE0_CONV.weight().dequantize()[:3, 0, 0, :].detach().cpu().numpy()
        print(w)
        # fw = fw*(1/scale)
        # print(fw)
        fw = Fxp(fw, signed=True, n_word=8)
        print(fw<<fw.n_frac)
        # scale_pow2 = 2 ** torch.floor(torch.log2(scale))
        
    #     # print(model.STAGE0_CONV.weight().dequantize()[:, :3, :,:])
    #     print(w, fw.raw())
    #     # print(model.STAGE0_CONV.bias().dequantize()*(1/scale))
    # return pow2s
        
qmodel  
pow2s = simulate(img, qmodel)
# print(pow2s)


In [None]:
a

In [None]:
# model.qconfig.activation().calculate_qparams()
qmodel

### Deploy

In [None]:

def extra_preprocess(x:torch.Tensor):
    # hint: you need to convert the original fp32 input of range (0, 1)
    #  into int8 format of range (-128, 127)
    ############### YOUR CODE STARTS HERE ###############
    from fxpmath import Fxp
    import numpy as np
    np_x = x.numpy()
    x = torch.tensor(np.array(Fxp(np_x, signed=True, n_word=8).raw(), dtype=np.int8))
    return x.to(torch.int8)
    ############### YOUR CODE ENDS HERE #################

In [None]:
plain_model = MCU_VGGRepC1(num_classes=10, quant=False)
qmodel = qmodel.cpu()
plain_model = plain_model.cpu()
quantized_state_dict = qmodel.state_dict()
state_dict = plain_model.state_dict()

def qfmt_quanize(x, n_bits=8, signed=True):
    range_min, range_max = torch.min(x), torch.max(x)
    range_abs = torch.max(torch.abs(range_min), torch.abs(range_max))
    int_bits = torch.ceil(torch.log2(range_abs)).type(torch.int8)
    frac_bits = n_bits - int_bits
    if signed:
        range_int_min = -(2 ** n_bits)
        range_int_max = (2 ** n_bits) - 1
        
        # frac_bits = 7 if frac_bits >= 8 else frac_bits - 1
        frac_bits -= 1
    else:
        range_int_min = 0
        range_int_max = (2 ** n_bits)
    # Quantization the input
    
    x_int = torch.round(x * (2 ** (frac_bits))).to(torch.int8)
    x_float = torch.clamp(x_int * (1/(2 ** (frac_bits))), range_int_min, range_int_max)
    # quant_error = torch.mean((x - x_float) ** 2)
    frac_bits = frac_bits if isinstance(frac_bits, int) else frac_bits.item()
    return x_float, frac_bits

def input_process(inputs):
    
    return qfmt_quanize(inputs, 8, True)[0]
    
# 가중치와 바이어스 복사
for name, param in state_dict.items():
    if name in quantized_state_dict:
        if "weight" in name or "bias" in name:
            # Quantization된 모델의 가중치/바이어스 텐서 가져오기
            quantized_param = quantized_state_dict[name]
            
            dequantized_param = quantized_param.dequantize()
            # print(dequantized_param.size(), param.size())
            # 첫 번째 차원 크기 비교
            if dequantized_param.dim()>1 and dequantized_param.size(1) != param.size(1):
                if dequantized_param.size(1) == param.size(1) + 1:
                    
                    dequantized_param = dequantized_param[:, :-1]  # 첫 번째 차원의 크기가 1 더 크면 첫 번째 채널 제거
                else:
                    raise ValueError(f"Unexpected size mismatch in {name}: {dequantized_param.size()} vs {param.size()}")
            # 크기 조정
            if dequantized_param.size() != param.size():
                dequantized_param = dequantized_param.view(param.size())
            
            # 가중치/바이어스 복사
            param.data.copy_(dequantized_param)
qmodel = qmodel.to(device)
plain_model = plain_model.to(device)
print(get_accuracy(plain_model, testloader, extra_preprocess=input_process), get_accuracy(qmodel, testloader))


In [None]:
from tqdm import tqdm
import numpy as np
from torch.utils.data import DataLoader
from utils.train_eval import AverageMeter, ProgressMeter, accuracy
import time


@torch.no_grad()
def get_Int_accuracy(model: nn.Module,
                    dataloader: DataLoader,
                    extra_preprocess = None,
                    device:str = 'cuda:0') -> float:
    model.eval()
    
    num_samples = 0
    num_correct = 0

    for inputs, targets in tqdm(dataloader, desc="eval", leave=False):
        # Move the data from CPU to GPU
        inputs = inputs.cpu()    
        if extra_preprocess is not None:
            for preprocess in extra_preprocess:
                inputs = preprocess(inputs)

    targets = targets.to(device)

    # Inference
    outputs = model(inputs)

    # Convert logits to class indices
    outputs = outputs.argmax(dim=1)

    # Update metrics
    num_samples += targets.size(0)
    num_correct += (outputs == targets).sum()

    return (num_correct / num_samples * 100).item()

def evaluate(dataloader, model, criterion):
    batch_time = AverageMeter('Time', ':6.3f')
    losses = AverageMeter('Loss', ':.4e')
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    progress = ProgressMeter(
        len(dataloader), [batch_time, losses, top1, top5], prefix='Test: ')

    # switch to evaluate mode
    model.eval()

    with torch.no_grad():
        end = time.time()
        for i, (images, target) in enumerate(dataloader):
            model = model.cuda()
            images = images.cuda(non_blocking=True)
            target = target.cuda(non_blocking=True)

            # compute output
            output = model(images)
            loss = criterion(output, target)

            # measure accuracy and record loss
            acc1, acc5 = accuracy(output, target, topk=(1, 2))
            losses.update(loss.item(), images.size(0))
            top1.update(acc1[0], images.size(0))
            top5.update(acc5[0], images.size(0))

            # measure elapsed time
            batch_time.update(time.time() - end)
            end = time.time()

            if i % 50 == 0:
                progress.display(i)

        # TODO: this should also be done with the ProgressMeter
        print(' * Acc@1 {top1.avg:.3f} Acc@5 {top5.avg:.3f}'.format(
            top1=top1, top5=top5))

    return top1.avg, top5.avg
    
                    


In [None]:
qmodel = qmodel.cuda()
qmodel.STAGE0_CONV

In [None]:
model.qconfig

In [None]:
qmodel

In [None]:
# qmodel.STAGE0_CONV.weight().int_repr()[:, :3, :, :]
# print(qmodel.LINEAR.weight().detach().cpu())
# w = qmodel.LINEAR.weight().detach().cpu()
# (1/qmodel.LINEAR.scale)
# print(w.data.int_repr())
import math
from torch.quantization import get_observer_state_dict
# qmodel(next(iter(trainloader))[0])
# w.dequantize()
observer_dict = get_observer_state_dict(qmodel)
print(observer_dict)

In [None]:
qmodel

In [None]:
from torch.ao.nn.quantized.modules.conv import Conv2d
from torch.ao.nn.quantized.modules.linear import Linear
qmodel = qmodel.to(device)
import numpy as np
from collections import defaultdict
def ddict():
    return defaultdict(ddict)
GRAPH = ddict()
for name,  modules in qmodel.named_modules():
    print(name, type(modules))
    if isinstance(modules, Conv2d):
        # print(f'{name} weight: {modules.weight()}')
        # print(modules.weight().element_size())
        # print(modules.scale, modules.zero_point)
        # print(modules.weight().int_repr())
        GRAPH[name]['scale'] = modules.scale
        GRAPH[name]['zero_point'] = modules.zero_point
        GRAPH[name]['weight']['float'] = model.state_dict()[f'{name}.weight'].detach().cpu().numpy()
        GRAPH[name]['weight']['int'] = modules.weight().detach().cpu()#.int_repr()
        
        if modules.bias is not None:
            GRAPH[name]['bias']['float'] = model.state_dict()[f'{name}.bias'].detach().cpu().numpy()
            GRAPH[name]['bias']['qfloat'] = modules.bias().detach().cpu().numpy()
            # GRAPH[name]['bias']['int'] = modules.bias().detach().cpu().int_repr().numpy()
            GRAPH[name]['bias_scale'] = modules.scale
            GRAPH[name]['bias_zero_point'] = modules.zero_point
            # print(f'{name} bias: {modules.bias()}')
            
    elif isinstance(modules, Linear):
        # print(modules.weight().element_size())
        GRAPH[name]['weight']['int'] = modules.weight().detach().cpu()#.int_repr()
        GRAPH[name]['weight']['float'] = model.state_dict()[f'{name}.weight'].detach().cpu().numpy()
        GRAPH[name]['scale'] = modules.scale
        GRAPH[name]['zero_point'] = modules.zero_point
        # print(f'{name} weight: {modules.weight().int_repr()}')
        if modules.bias is not None:
            GRAPH[name]['bias']['float'] = model.state_dict()[f'{name}.bias'].detach().cpu().numpy()
            GRAPH[name]['bias']['qfloat'] = modules.bias().detach().cpu().numpy()
            GRAPH[name]['bias_scale'] = modules.scale
            GRAPH[name]['bias_zero_point'] = modules.zero_point
        #     print(f'{name} bias: {modules.bias()}')
# print(qmodel)
# print(get_accuracy(qmodel, val_loader))

In [None]:
from pprint import pprint
from collections import OrderedDict
import numpy as np
import matplotlib.pyplot as plt
import copy
%matplotlib inline
# qweight = GRAPH['STAGE0_CONV']['weight']['int'].dequantize()[:,:3,:,:].numpy().reshape(-1)
qweight = GRAPH['STAGE0_CONV']['weight']['int'].int_repr()[:,:3,:,:].numpy().reshape(-1)
# print(qweight.shape)
from fxpmath import Fxp
weight = GRAPH['STAGE0_CONV']['weight']['float'].reshape(-1)
# weight = np.array(Fxp(weight, signed=True, n_word=8, overflow='saturate').raw())

# # print(weight.n_frac)
scale = GRAPH['STAGE0_CONV']['scale']
zero = GRAPH['STAGE0_CONV']['zero_point']
print(scale, zero)
print(weight.min(), weight.max())
weight = (qweight)*(2**scale)
# qweight = (qweight*scale) + zero


plt.figure(figsize=(18,6))
# plt.hist(qweight, bins=100, lpha=0.5, label=f'Scale:{scale}, Zero:{zero}')
# plt.hist(weight, bins=100, color='b', alpha=0.5, label=f'Original')
# plt.hist(qweight, bins=128, color='red')
# plt.hist(weight, bins=128, alpha=0.5, color='blue')
plt.scatter(np.arange(len(qweight)), qweight, c='r', label=f'Scale:{scale}, Zero:{zero}')
plt.scatter(np.arange(len(weight)), weight, c='b', label=f'Original')
# plt.legend()
plt.show()
plt.close()

# pprint(GRAPH['LINEAR'])
# test_qint = (GRAPH['STAGE0_CONV']['weight']['int'] - GRAPH['STAGE0_CONV']['zero_point'])*(GRAPH['STAGE0_CONV']['scale'])
# print(test_qint)
# test_qint2 = (GRAPH['STAGE0_CONV']['weight']['float'])
# test_qint[:,:3, :,:] - test_qint2
# print(GRAPH['STAGE0_CONV']['bias']['float'], GRAPH['STAGE0_CONV']['bias']['qfloat'])


In [None]:
def print_model_size(mdl):
    torch.save(mdl.state_dict(), "tmp.pt")
    print("%.2f MB" %(os.path.getsize("tmp.pt")/1e6))
    os.remove('tmp.pt')

print_model_size(model)
print_model_size(qmodel)

In [None]:
from torch.ao.nn.quantized.modules.conv import Conv2d as QConv2d
from torch.ao.nn.quantized.modules.linear import Linear as QLinear
from torch import nn
def hook_save_params(module, input, output):
    setattr(module, "input_shape", input[0].shape)
    setattr(module, "output_shape", output[0].shape)
    setattr(module, "input", input[0][0])
    setattr(module, "output", output[0])


def register_hooks(model:nn.Module):
    for name, module in model.named_modules():
        if isinstance(module, (QConv2d, QLinear, nn.Conv2d, nn.Linear, nn.MaxPool2d, nn.AvgPool2d, nn.AdaptiveAvgPool2d)):
            module.register_forward_hook(hook_save_params)


In [None]:
register_hooks(qmodel)

In [None]:
qmodel = qmodel.to(device)
with torch.inference_mode():
  for img, label in testloader:
    img = img.to(device)
    label = label.to(device)
    if cnt > 10:
        break
    qmodel(img)

In [None]:
from collections import OrderedDict
import numpy as np
import matplotlib.pyplot as plt
import copy
%matplotlib inline
q_param = OrderedDict()
q_model = copy.deepcopy(model).cpu()
for name, modules in model.named_modules():
    fxp_ref = Fxp(None, signed=True, n_word=8, overflow='saturate')
    fxp_ref.config.dtype_notation = 'Q'
    fxp_ref.config.op_method = 'repr'
    fxp_ref.config.op_out = Fxp(None, True, n_word=8, overflow='saturate')
    fxp_ref.config.array_output_type = 'array'
    if isinstance(modules, nn.Conv2d):
        weight = modules.weight.detach().cpu().numpy()
        q_weight = Fxp(weight, like = fxp_ref)
        n_int = q_weight.n_int
        n_frac = q_weight.n_frac
        print(f'{name} n_frac: {n_frac}')
        q_weight = q_weight << n_frac # Interger convert
        q_weight = q_weight >> n_frac # Fixed point convert
        
        # q_weight = q_weight.ravel()
        # weight = weight.ravel()
        # plt.figure(figsize=(18,6))
        # plt.scatter(np.arange(len(q_weight)), q_weight, c='r', label=f'Q{n_int}.{n_frac}')
        # plt.scatter(np.arange(len(weight)), weight, c='b', label=f'Original')
        # plt.legend()
        # plt.show()
        print((q_weight - weight).sum())
        # break
        q_model.state_dict()[f'{name}.weight'].copy_(torch.Tensor((np.array(q_weight))))
        # print(q_model.state_dict()[f'{name}.weight'])
        q_param[f'{name}.weight'] = (n_int, n_frac)
        # print(torch.IntTensor(np.array(q_weight)))
        
        # print(q_model.state_dict()[f'{name}.weight'])
        # model[name].weight = torch.Tensor(np.array(q_weight))
        print(f'{name}.weight error: {(q_weight - weight).sum()}')
        if modules.bias is not None:
            bias = modules.bias.detach().cpu().numpy()
            q_bias = Fxp(bias, like = fxp_ref)
            n_int = q_bias.n_int
            n_frac = q_bias.n_frac
            q_bias = q_bias << n_frac
            q_bias = q_bias >> n_frac
            
            q_model.state_dict()[f'{name}.bias'].copy_(torch.Tensor(np.array(q_bias)))
            q_param[f'{name}.bias'] = (n_frac)
            print(f'{name}.bias error: {(q_bias - bias).sum()}')
            
        
    elif isinstance(modules, nn.Linear):
        weight = modules.weight.detach().cpu().numpy()
        q_weight = Fxp(weight, like = fxp_ref)
        n_int = q_weight.n_int
        n_frac = q_weight.n_frac
        q_weight = q_weight << n_frac
        q_weight = q_weight >> n_frac
        
        q_model.state_dict()[f'{name}.weight'].copy_(torch.Tensor(np.array(q_weight)))
        q_param[f'{name}.weight'] = (n_frac)
        print(f'{name}.weight error: {(q_weight - weight).sum()}')
        if modules.bias is not None:
            bias = modules.bias.detach().cpu().numpy()
            
            q_bias = Fxp(bias, like = fxp_ref)
            n_int = q_bias.n_int
            n_frac = q_bias.n_frac
            
            
            q_bias = q_bias << n_frac
            q_bias = q_bias >> n_frac
            q_model.state_dict()[f'{name}.bias'].copy_(torch.Tensor(np.array(q_bias)))
            q_param[f'{name}.bias'] = (n_frac)
            print(f'{name}.bias error: {(q_bias - bias).sum()}')


In [None]:
from fxpmath import Fxp
for name, module in qmodel.named_modules():
    if isinstance(module, (QConv2d, QLinear, nn.Conv2d, nn.Linear)):
        print(module.scale)
        fxp_scale = Fxp(module.scale, signed=True, n_word=8, overflow='saturate')
        print(f'{name} scale: {fxp_scale.n_frac}, INT8={fxp_scale<<fxp_scale.n_frac}')
        print(name, module.input_shape, module.output_shape, module.input.shape, module.output.shape)
    if isinstance(module, (nn.AdaptiveAvgPool2d, nn.AvgPool2d, nn.MaxPool2d)):
        print(name, module.input_shape, module.output_shape, module.input.shape, module.output.shape)

In [None]:
qmodel

In [None]:
qmodel.STAGE0_CONV