In [None]:
#使用pytorch做训练后量化
import os
import sys
import time
import numpy as np

import torch
from torch.ao.quantization import get_default_qconfig, QConfigMapping
from torch.ao.quantization.quantize_fx import prepare_fx, convert_fx, fuse_fx
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
from torchvision import datasets
from torchvision.models.resnet import resnet18
import torchvision.transforms as transforms
from torch.quantization import MinMaxObserver
# Set up warnings
import warnings
warnings.filterwarnings(
    action='ignore',
    category=DeprecationWarning,
    module=r'.*'
)
warnings.filterwarnings(
    action='default',
    module=r'torch.ao.quantization'
)

# Specify random seed for repeatable results
_ = torch.manual_seed(191009)


class AverageMeter(object):
    """Computes and stores the average and current value"""
    def __init__(self, name, fmt=':f'):
        self.name = name
        self.fmt = fmt
        self.reset()

    def reset(self):
        self.val = 0
        self.avg = 0
        self.sum = 0
        self.count = 0

    def update(self, val, n=1):
        self.val = val
        self.sum += val * n
        self.count += n
        self.avg = self.sum / self.count

    def __str__(self):
        fmtstr = '{name} {val' + self.fmt + '} ({avg' + self.fmt + '})'
        return fmtstr.format(**self.__dict__)


def accuracy(output, target, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = target.size(0)

        _, pred = output.topk(maxk, 1, True, True)
        pred = pred.t()
        correct = pred.eq(target.view(1, -1).expand_as(pred))

        res = []
        for k in topk:
            correct_k = correct[:k].reshape(-1).float().sum(0, keepdim=True)
            res.append(correct_k.mul_(100.0 / batch_size))
        return res


def evaluate(model, criterion, data_loader,device):
    model.eval()
    model.to(device)
    top1 = AverageMeter('Acc@1', ':6.2f')
    top5 = AverageMeter('Acc@5', ':6.2f')
    cnt = 0
    with torch.no_grad():
        for image, target in data_loader:
            image=image.to(device)
            output = model(image)
            loss = criterion(output, target)
            cnt += 1
            acc1, acc5 = accuracy(output, target, topk=(1, 2))
            top1.update(acc1[0], image.size(0))
            top5.update(acc5[0], image.size(0))
            print('Val[',cnt,"]  [top1]:",acc1,"[top5]:",acc5)
    

    return top1, top5

def load_model(model_file):
    model = resnet18(pretrained=False)
    state_dict = torch.load(model_file)
    model.load_state_dict(state_dict)
    model.to("cpu")
    return model

def print_size_of_model(model):
    if isinstance(model, torch.jit.RecursiveScriptModule):
        torch.jit.save(model, "temp.p")
    else:
        torch.jit.save(torch.jit.script(model), "temp.p")
    print("Size (MB):", os.path.getsize("temp.p")/1e6)
    os.remove("temp.p")

def prepare_data_loaders(data_path):
    normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])
    dataset = torchvision.datasets.ImageFolder(
        data_path+'/train',transform=transforms.Compose([
            transforms.RandomResizedCrop(224),
            transforms.RandomHorizontalFlip(),
            transforms.ToTensor(),
            normalize,
        ]))
    dataset_test = torchvision.datasets.ImageFolder(
        data_path+'/val',  transform=transforms.Compose([
            transforms.Resize(256),
            transforms.CenterCrop(224),
            transforms.ToTensor(),
            normalize,
        ]))

    train_sampler = torch.utils.data.RandomSampler(dataset)
    test_sampler = torch.utils.data.SequentialSampler(dataset_test)

    data_loader = torch.utils.data.DataLoader(
        dataset, batch_size=train_batch_size,
        sampler=train_sampler)

    data_loader_test = torch.utils.data.DataLoader(
        dataset_test, batch_size=eval_batch_size,
        sampler=test_sampler)

    return data_loader, data_loader_test

data_path = 'E:/Transformer/DataSets/imagenet/Mini_Train'
saved_model_dir = 'Export/Ptq'
float_model_file = 'pretrained_float.pth'

train_batch_size = 8
eval_batch_size = 1

data_loader, data_loader_test = prepare_data_loaders(data_path)
example_inputs = (next(iter(data_loader))[0])
criterion = nn.CrossEntropyLoss()

from Mymodels import Vit
In_Channels=3
Embed_Dim=384
Picture_Size=224
Patch_Size=16
Num_Class=3
Num_Heads=6
Encoder_Layers=6
float_model=Vit(In_Channels=In_Channels,Out_Channels=Embed_Dim,Picture_Size=Picture_Size,Patch_Size=Patch_Size
,Num_Class=Num_Class,Num_Heads=Num_Heads,Encoder_Layers=Encoder_Layers)
pretrained=True
if pretrained:
    state_dict = torch.load(r'E:\Transformer\Transformer_Main\MyTransformer\Export\float\FloatVit_93.3333333333333398.0.pth')
    float_model.load_state_dict(state_dict)


float_model.eval()
float_model.to('cuda')
# deepcopy the model since we need to keep the original model around
import copy
model_to_quantize = copy.deepcopy(float_model)
model_to_quantize.eval()



In [20]:
# qconfig_mapping = QConfigMapping.set_global(default_qconfig)
# qconfig_opt=None
# qconfig_mapping = (QConfigMapping()
#     .set_global(qconfig_opt)  # qconfig_opt is an optional qconfig, either a valid qconfig or None
#     .set_object_type(torch.nn.Conv2d, qconfig_opt) # can be a callable...
#     .set_object_type("torch.nn.functional.add", qconfig_opt) # ...or a string of the class name
#     .set_module_name_regex("foo.*bar.*conv[0-9]+", qconfig_opt) # matched in order, first match takes precedence
#     .set_module_name("foo.bar", qconfig_opt)
#     .set_module_name_object_type_order()
# )
#     # priority (in increasing order): global, object_type, module_name_regex, module_name
#     # qconfig == None means fusion and quantization should be skipped for anything
#     # matching the rule (unless a higher priority match is found)
from torch.ao.quantization.backend_config import DTypeConfig,BackendPatternConfig,ObservationType,BackendConfig
weighted_int8_dtype_config = DTypeConfig(
  input_dtype=torch.quint8,
  output_dtype=torch.quint8,
  weight_dtype=torch.qint8,
  bias_dtype=torch.float)

linear_pattern_config = BackendPatternConfig(torch.nn.Linear) \
   .set_observation_type(ObservationType.OUTPUT_USE_DIFFERENT_OBSERVER_AS_INPUT) \
   .add_dtype_config(weighted_int8_dtype_config) \

linear_backend_config = BackendConfig().set_backend_pattern_config(linear_pattern_config)
qconfig = get_default_qconfig("fbgemm")
qlinear_cfg=torch.quantization.QConfig(
   activation=MinMaxObserver.with_args(dtype=torch.qint8),
   weight=MinMaxObserver.with_args(dtype=torch.qint8))
qconfig_mapping = QConfigMapping().set_global(qconfig)#.set_object_type(torch.nn.Linear, qlinear_cfg)
prepared_model = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs)
prepared_model = prepare_fx(model_to_quantize, qconfig_mapping, example_inputs,backend_config=linear_backend_config)
# print(prepared_model)

print(prepared_model)

GraphModule(
  (patch_embed): Module(
    (proj): Conv2d(3, 384, kernel_size=(16, 16), stride=(16, 16))
  )
  (Encoders): TransformerEncoder(
    (layers): ModuleList(
      (0): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True)
        )
        (linear1): Linear(in_features=384, out_features=2048, bias=True)
        (dropout): Dropout(p=0.1, inplace=False)
        (linear2): Linear(in_features=2048, out_features=384, bias=True)
        (norm1): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (norm2): LayerNorm((384,), eps=1e-05, elementwise_affine=True)
        (dropout1): Dropout(p=0.1, inplace=False)
        (dropout2): Dropout(p=0.1, inplace=False)
      )
      (1): TransformerEncoderLayer(
        (self_attn): MultiheadAttention(
          (out_proj): NonDynamicallyQuantizableLinear(in_features=384, out_features=384, bias=True)
        )
        (line



In [16]:
def calibrate(model, data_loader,device):
    model.eval()
    model.to(device)
    i=0
    with torch.no_grad():
        for image, target in data_loader:
            image=image.to(device)
            target=target.to(device)
            model(image)
            print("calibrate times",i)
            i+=1        
calibrate(prepared_model, data_loader_test,'cuda')  # run calibration on sample data
                #很奇怪，这里得用cpu校准，用gpu校准下一步就过不去了，，，，，，，，，，
print("===========calibrate end===========")
prepared_model.to('cpu')#这里得改成cpu，很奇怪
quantized_model = convert_fx(prepared_model)
print(quantized_model)

print("Size of model before quantization")
print_size_of_model(float_model)
print("Size of model after quantization")
print_size_of_model(quantized_model)
# print(quantized_model.parameters)

# test_input=torch.rand(1,3,224,224).to('cuda')
# quantized_model.to('cuda')
# out=quantized_model(test_input)



calibrate times 0
calibrate times 1
calibrate times 2
calibrate times 3
calibrate times 4
calibrate times 5
calibrate times 6
calibrate times 7
calibrate times 8
calibrate times 9
calibrate times 10
calibrate times 11
calibrate times 12
calibrate times 13
calibrate times 14
calibrate times 15
calibrate times 16
calibrate times 17
calibrate times 18
calibrate times 19
calibrate times 20
calibrate times 21
calibrate times 22
calibrate times 23
calibrate times 24
calibrate times 25
calibrate times 26
calibrate times 27
calibrate times 28
calibrate times 29
calibrate times 30
calibrate times 31
calibrate times 32
calibrate times 33
calibrate times 34
calibrate times 35
calibrate times 36
calibrate times 37
calibrate times 38
calibrate times 39
calibrate times 40
calibrate times 41
calibrate times 42
calibrate times 43
calibrate times 44
calibrate times 45
calibrate times 46
calibrate times 47
calibrate times 48
calibrate times 49
calibrate times 50
calibrate times 51
calibrate times 52
cal

In [7]:
top1, top2 = evaluate(quantized_model, criterion, data_loader_test,'cpu')#必须用cpu，用cuda会卡死
print("FX graph mode quantized model Evaluation accuracy on test dataset: %2.2f, %2.2f"%(top1.avg, top2.avg))
torch.save(quantized_model.state_dict(), "Export/PtqVit_"+str(top1.avg)+".pth")

Val[ 1 ]  [top1]: tensor([100.]) [top5]: tensor([100.])
Val[ 2 ]  [top1]: tensor([100.]) [top5]: tensor([100.])
Val[ 3 ]  [top1]: tensor([0.]) [top5]: tensor([100.])
Val[ 4 ]  [top1]: tensor([100.]) [top5]: tensor([100.])
Val[ 5 ]  [top1]: tensor([100.]) [top5]: tensor([100.])
Val[ 6 ]  [top1]: tensor([100.]) [top5]: tensor([100.])
Val[ 7 ]  [top1]: tensor([100.]) [top5]: tensor([100.])
Val[ 8 ]  [top1]: tensor([100.]) [top5]: tensor([100.])
Val[ 9 ]  [top1]: tensor([100.]) [top5]: tensor([100.])
Val[ 10 ]  [top1]: tensor([100.]) [top5]: tensor([100.])
Val[ 11 ]  [top1]: tensor([0.]) [top5]: tensor([100.])
Val[ 12 ]  [top1]: tensor([0.]) [top5]: tensor([100.])
Val[ 13 ]  [top1]: tensor([100.]) [top5]: tensor([100.])
Val[ 14 ]  [top1]: tensor([100.]) [top5]: tensor([100.])
Val[ 15 ]  [top1]: tensor([100.]) [top5]: tensor([100.])
Val[ 16 ]  [top1]: tensor([100.]) [top5]: tensor([100.])
Val[ 17 ]  [top1]: tensor([100.]) [top5]: tensor([100.])
Val[ 18 ]  [top1]: tensor([100.]) [top5]: tens

# 看看量化了个啥

In [None]:
print(quantized_model.patch_embed)


: 

In [None]:
print(quantized_model.graph)

: 

# 看看卷积的权重是个啥东东
卷积被量化了

In [None]:
print(quantized_model.patch_embed.proj.weight)
print(quantized_model.patch_embed.proj)
print(quantized_model.patch_embed.proj.weight().dtype)
# print(quantized_model.patch_embed.proj.weight())
print("=================================================")
print(quantized_model.patch_embed.proj.weight().dequantize())


: 

# 看看Linear是个什么东东
事实证明linear没有被量化

In [5]:
print(quantized_model.Encoders.layers[0].linear1.weight.dtype)

torch.float32


# 导出Onnx看看

In [None]:
def export_onnx(x:torch.tensor=None,model:torch.nn=None,export_path:str='exported_onnx.onnx'):

    if export_path is None:
        export_path='exported_onnx.onnx'
    torch.onnx.export(model,               # model being run
                x,                         # model input 
                export_path,   # where to save the model (can be a file or file-like object)                  
                opset_version=11,           # the ONNX version to export the model to                  
                input_names = ['input'],   # the model's input names
                output_names = ['output']  # the model's output names
                )
export_onnx(torch.rand(1,3,224,224),quantized_model)

: 

In [None]:
#阶段二
from torch.fx import symbolic_trace
symbolic_traced : torch.fx.GraphModule = symbolic_trace(quantized_model)
print(symbolic_traced.graph)

: 