In [9]:
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import numpy as np
import copy

from run_model import evaluate_model, train_one_epoch
from run_model import save_torchscript_model, load_torchscript_model
from datasets import tfds_data_loader, data_loader
import resmlp

import torch.optim as optim
from timm.models import create_model
from timm.data import Mixup
from timm.loss import SoftTargetCrossEntropy
from timm.utils import NativeScaler, get_state_dict, ModelEma

## Set Parameters

In [43]:
INPUT_SIZE = 224
DICT_PATH = 'E:/ResMLP_QAT/pytorch/fp32_weights/ResMLP_S24_ReLU_fp32_80.602.pth' 

DATA_NAME = 'imagenet2012'
DATA_DIR = 'E:\datasets'

BATCH_SIZE = 32
EPOCHS = 5
LR = 1e-4

WORKERS = 0 #8

## Speed Up Configs

In [11]:
# Data Parallel Training (DPT) settings

# device = CUDA
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# set seed
seed = 336 # args.seed + utils.get_rank()
torch.manual_seed(seed)
np.random.seed(seed)

# check for best cudnn ops before training starts.
cudnn.benchmark = True

## Quantize

In [50]:
class QuantizedResMLP(nn.Module):
    def __init__(self, module):
        super(QuantizedResMLP, self).__init__()
        self.quant = torch.quantization.QuantStub()
        self.dequant = torch.quantization.DeQuantStub()
        self.module = module

    def forward(self, x):
        x = self.quant(x)
        x = self.module(x)
        x = self.dequant(x)
        return x

# build train/val dataset
# create sampler (if dataset from tfds, can't apply sampler) (distributed ver. to be done)
# build up dataloader
data_loader_train, data_loader_val, NUM_CLASSES = tfds_data_loader(
    name=DATA_NAME,
    root=DATA_DIR,
    input_size=INPUT_SIZE, 
    batch_size=BATCH_SIZE,
    num_workers=WORKERS,
)

# additional data augmentation (mixup)
mixup_fn = Mixup(
    mixup_alpha=0.8, cutmix_alpha=1.0, cutmix_minmax=None,
    prob=1.0, switch_prob=0.5, mode='batch',
    label_smoothing=0.1, num_classes=NUM_CLASSES)  

# create model
float_model = create_model('resmlp_24', num_classes=NUM_CLASSES).to(device)
float_model = load_model(float_model, DICT_PATH, device)

# fuse
for basic_block_name, basic_block in float_model.blocks.named_children():
  for sub_block_name, sub_block in basic_block.named_children():
    if sub_block_name == "mlp":
      torch.quantization.fuse_modules(
        sub_block, [['fc1', 'act']],
        inplace=True)

# apply quant/dequant stabs
#float_model1 = torch.quantization.add_quant_dequant(float_model)
float_model = QuantizedResMLP(module=float_model)

# quantization configurations
float_model.qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
print(float_model.qconfig)

# train & save fp32 model on each epoch
print("Training Model with QAT...")
quantized_model = torch.quantization.prepare_qat(float_model, inplace=False)
quantized_model.train()



QConfig(activation=functools.partial(<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=0, quant_max=255, reduce_range=False){}, weight=functools.partial(<class 'torch.ao.quantization.fake_quantize.FusedMovingAvgObsFakeQuantize'>, observer=<class 'torch.ao.quantization.observer.MovingAverageMinMaxObserver'>, quant_min=-128, quant_max=127, dtype=torch.qint8, qscheme=torch.per_tensor_symmetric){})
Training Model with QAT...


QuantizedResMLP(
  (quant): QuantStub(
    (activation_post_process): FusedMovingAvgObsFakeQuantize(
      fake_quant_enabled=tensor([1], device='cuda:0'), observer_enabled=tensor([1], device='cuda:0'), scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32), dtype=torch.quint8, quant_min=0, quant_max=255, qscheme=torch.per_tensor_affine, reduce_range=False
      (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
    )
  )
  (dequant): DeQuantStub()
  (module): resmlp_models(
    (patch_embed): PatchEmbed(
      (proj): Conv2d(
        3, 384, kernel_size=(16, 16), stride=(16, 16)
        (weight_fake_quant): FusedMovingAvgObsFakeQuantize(
          fake_quant_enabled=tensor([1], device='cuda:0'), observer_enabled=tensor([1], device='cuda:0'), scale=tensor([1.], device='cuda:0'), zero_point=tensor([0], device='cuda:0', dtype=torch.int32), dtype=torch.qint8, quant_min=-128, quant_max=127, qscheme=torch.per_tensor_sym

In [49]:
# train
print("Training QAT Model...")
quantized_model.train()
torch.quantization.prepare_qat(quantized_model, inplace=True)

#criterion = nn.CrossEntropyLoss()
criterion = SoftTargetCrossEntropy()
optimizer = optim.SGD(quantized_model.parameters(),
                        lr=LR,
                        momentum=0.9,
                        weight_decay=1e-4)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer,
                                                    milestones=[100, 150],
                                                    gamma=0.1,
                                                    last_epoch=-1)                                                  

# training...                                                 
# train_one_epoch(model=quantized_model, criterion=criterion,
#                   data_loader=data_loader_train, optimizer=optimizer,
#                   device=device, epoch=1, max_norm=None,
#                   model_ema=None, mixup_fn=mixup_fn)

### Convert trained model to int8 ver.

In [51]:
# convert weight to int8, replace model to quantized ver.
quantized_model.cpu()
torch.quantization.convert(quantized_model, inplace=True)
quantized_model.eval()



QuantizedResMLP(
  (quant): Quantize(scale=tensor([1.]), zero_point=tensor([0]), dtype=torch.quint8)
  (dequant): DeQuantize()
  (module): resmlp_models(
    (patch_embed): PatchEmbed(
      (proj): QuantizedConv2d(3, 384, kernel_size=(16, 16), stride=(16, 16), scale=1.0, zero_point=0)
      (norm): Identity()
    )
    (blocks): ModuleList(
      (0): layers_scale_mlp_blocks(
        (norm1): QuantizedLinear(in_features=384, out_features=384, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
        (attn): QuantizedLinear(in_features=196, out_features=196, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
        (drop_path): Identity()
        (norm2): QuantizedLinear(in_features=384, out_features=384, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
        (mlp): Mlp(
          (fc1): QuantizedLinearReLU(in_features=384, out_features=1536, scale=1.0, zero_point=0, qscheme=torch.per_tensor_affine)
          (act): Identity()
          (drop1): QuantizedDrop

In [None]:
# convert weight to int8, replace model to quantized ver.
quantized_model.cpu()
torch.quantization.convert(quantized_model, inplace=True)
quantized_model.eval()

input_fp32 = torch.randn((1, 3, 224, 224), dtype=torch.float32, device="cpu")
quantized_model(input_fp32)

SAVE_PATH = 'modeltest.pth'
save_torchscript_model(model=quantized_model, 
                        model_dir='qat_weights', 
                        model_filename='qat_Test0.pth')

### Evaluate Model

In [None]:
criterion = nn.CrossEntropyLoss().cuda()
#model = create_model('resmlp_24', num_classes=NUM_CLASSES).cuda()
quantized_model = load_torchscript_model(model_filepath='qat_weights/qat_Test0.pth', device="cpu")

#model.load_state_dict(torch.load(DICT_PATH), strict=False)
# Evaluation
quantized_model.eval()
eval_loss, top1_acc, top5_acc = evaluate_model(model=quantized_model,
                                                test_loader=data_loader_val,
                                                device="cpu",
                                                criterion=criterion)
print("Epoch: {:02d} Eval Loss: {:.3f} Top1: {:.3f} Top5: {:.3f}".format(
    -1, eval_loss, top1_acc, top5_acc))

100%|██████████| 1042/1042 [29:09<00:00,  1.68s/it]

Epoch: -1 Eval Loss: 1.842 Top1: 74.376 Top5: 90.640





In [45]:
from importlib import reload

import run_model
reload(run_model)
from run_model import evaluate_model, train_one_epoch
from run_model import load_model, save_torchscript_model, load_torchscript_model

In [None]:
import torch.quantization
import torch.quantization._numeric_suite as ns

In [52]:
quantized_model = load_model(quantized_model, model_filepath='qat_weights/test.pth', device="cpu")

In [58]:
quantized_model

QuantizedResMLP(
  (quant): Quantize(scale=tensor([0.0358]), zero_point=tensor([125]), dtype=torch.quint8)
  (dequant): DeQuantize()
  (module): resmlp_models(
    (patch_embed): PatchEmbed(
      (proj): QuantizedConv2d(3, 384, kernel_size=(16, 16), stride=(16, 16), scale=0.3109653890132904, zero_point=124)
      (norm): Identity()
    )
    (blocks): ModuleList(
      (0): layers_scale_mlp_blocks(
        (norm1): QuantizedLinear(in_features=384, out_features=384, scale=0.073409304022789, zero_point=124, qscheme=torch.per_tensor_affine)
        (attn): QuantizedLinear(in_features=196, out_features=196, scale=0.08399761468172073, zero_point=116, qscheme=torch.per_tensor_affine)
        (drop_path): Identity()
        (norm2): QuantizedLinear(in_features=384, out_features=384, scale=0.18273833394050598, zero_point=143, qscheme=torch.per_tensor_affine)
        (mlp): Mlp(
          (fc1): QuantizedLinearReLU(in_features=384, out_features=1536, scale=0.10778413712978363, zero_point=0, qs

In [57]:
wt_compare_dict = ns.compare_weights(float_model.state_dict(), quantized_model.state_dict())
print('keys of wt_compare_dict:')
print(wt_compare_dict.keys())

#print("\nkeys of wt_compare_dict entry for conv1's weight:")
mystr = 'module.blocks.0.norm2._packed_params._packed_params'
#print(wt_compare_dict[mystr].keys())
#print(wt_compare_dict[mystr]['float'])
#print(wt_compare_dict[mystr]['quantized'])

keys of wt_compare_dict:
dict_keys(['module.patch_embed.proj.weight', 'module.blocks.0.norm1._packed_params._packed_params', 'module.blocks.0.attn._packed_params._packed_params', 'module.blocks.0.norm2._packed_params._packed_params', 'module.blocks.0.mlp.fc1._packed_params._packed_params', 'module.blocks.0.mlp.fc2._packed_params._packed_params', 'module.blocks.0.gamma_1._packed_params._packed_params', 'module.blocks.0.gamma_2._packed_params._packed_params', 'module.blocks.1.norm1._packed_params._packed_params', 'module.blocks.1.attn._packed_params._packed_params', 'module.blocks.1.norm2._packed_params._packed_params', 'module.blocks.1.mlp.fc1._packed_params._packed_params', 'module.blocks.1.mlp.fc2._packed_params._packed_params', 'module.blocks.1.gamma_1._packed_params._packed_params', 'module.blocks.1.gamma_2._packed_params._packed_params', 'module.blocks.2.norm1._packed_params._packed_params', 'module.blocks.2.attn._packed_params._packed_params', 'module.blocks.2.norm2._packed_params

In [38]:
quantized_model.module

RecursiveScriptModule(
  original_name=resmlp_models
  (patch_embed): RecursiveScriptModule(
    original_name=PatchEmbed
    (proj): RecursiveScriptModule(original_name=Conv2d)
    (norm): RecursiveScriptModule(original_name=Identity)
  )
  (blocks): RecursiveScriptModule(
    original_name=ModuleList
    (0): RecursiveScriptModule(
      original_name=layers_scale_mlp_blocks
      (norm1): RecursiveScriptModule(
        original_name=Linear
        (_packed_params): RecursiveScriptModule(original_name=LinearPackedParams)
      )
      (attn): RecursiveScriptModule(
        original_name=Linear
        (_packed_params): RecursiveScriptModule(original_name=LinearPackedParams)
      )
      (drop_path): RecursiveScriptModule(original_name=Identity)
      (norm2): RecursiveScriptModule(
        original_name=Linear
        (_packed_params): RecursiveScriptModule(original_name=LinearPackedParams)
      )
      (mlp): RecursiveScriptModule(
        original_name=Mlp
        (fc1): Recursiv

In [37]:
for name, val in quantized_model.module.state_dict().items():
  print(f"{name}:", end=" ")
  s = val.shape
  print(f"\t{s}")

# print(quantized_model)

In [23]:
model.state_dict()

OrderedDict([('patch_embed.proj.weight',
              tensor([[[[ 1.8547e-02,  3.1375e-02, -3.4824e-04,  ...,  1.0358e-02,
                         -2.4780e-02, -3.0519e-02],
                        [-6.4273e-03, -1.6747e-02,  3.3060e-03,  ...,  2.3723e-02,
                         -9.9016e-03,  1.8857e-02],
                        [-2.0897e-02,  3.2766e-02, -7.3521e-04,  ..., -3.0972e-02,
                          1.9634e-02,  3.2349e-02],
                        ...,
                        [ 2.4877e-02,  1.8711e-02,  6.0188e-03,  ..., -1.8882e-02,
                          5.2874e-03, -2.7802e-02],
                        [-1.2207e-02,  1.5945e-02,  1.8371e-02,  ...,  1.2325e-03,
                         -2.8527e-02, -2.4299e-02],
                        [ 8.1906e-03, -1.7239e-02,  3.3641e-03,  ...,  3.0933e-02,
                          2.2067e-02,  1.3474e-02]],
              
                       [[ 1.2587e-02,  2.6582e-02,  6.8919e-03,  ...,  8.6741e-03,
                     