In [1]:
import sys
if '..' not in sys.path:
    sys.path.append('..')

In [2]:
import torch
from torchvision.models import ResNet

from tqdm import tqdm
import numpy as np
import os

from aimet_torch.model_preparer import prepare_model
from aimet_torch.batch_norm_fold import fold_all_batch_norms
from aimet_common.defs import QuantScheme
from aimet_torch.quantsim import QuantizationSimModel, QuantParams
from aimet_torch.adaround.adaround_weight import Adaround, AdaroundParameters
from aimet_torch.cross_layer_equalization import equalize_model
from aimet_torch.bias_correction import correct_bias

from source.data import get_training_dataloader, get_test_dataloader
from source.models import BasicBlock, ResNet18Quant

2023-01-06 16:18:31,055 - root - INFO - AIMET


In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

'cpu'

In [4]:
def accuracy(model, dataloader, device='gpu'):
    model.eval()
    with torch.no_grad():
        correct = 0.0
        for (images, labels) in tqdm(dataloader):
            if device == 'gpu':
                images = images.cuda()
                labels = labels.cuda()
            outputs = model(images)
            _, preds = outputs.max(1)
            correct += preds.eq(labels).sum()

    print('Acc:', correct.float() / len(dataloader.dataset))

In [5]:
CIFAR100_TRAIN_MEAN = (0.5070751592371323, 0.48654887331495095, 0.4409178433670343)
CIFAR100_TRAIN_STD = (0.2673342858792401, 0.2564384629170883, 0.27615047132568404)

In [6]:
train_loader = get_training_dataloader(
    '../data',
    CIFAR100_TRAIN_MEAN,
    CIFAR100_TRAIN_STD,
    num_workers=4,
    batch_size=64,
    shuffle=True
)

test_loader = get_test_dataloader(
    '../data',
    CIFAR100_TRAIN_MEAN,
    CIFAR100_TRAIN_STD,
    num_workers=4,
    batch_size=64,
    shuffle=True
)

Files already downloaded and verified
Files already downloaded and verified


In [7]:
def pass_calibration_data(sim_model, device):
    batch_size = train_loader.batch_size

    sim_model.eval()
    samples = 1000

    batch_cntr = 0
    with torch.no_grad():
        for input_data, target_data in train_loader:

            inputs_batch = input_data.to(device)
            sim_model(inputs_batch)

            batch_cntr += 1
            print(batch_cntr * batch_size)
            if (batch_cntr * batch_size) > samples:
                break

In [8]:
model = ResNet(num_classes=100, block=BasicBlock, layers=[2, 2, 2, 2])
model.load_state_dict(torch.load('../models/resnet18_cifar100.sd', map_location=torch.device(device)))
model.eval()
model = model.to(device)

In [None]:
%%time
accuracy(model, test_loader, device=device)

# Quantization

In [8]:
model = prepare_model(model)

2022-12-21 23:53:18,450 - Quant - INFO - Functional         : Adding new module for node: {add} 
2022-12-21 23:53:18,450 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer1_0_relu_1} 
2022-12-21 23:53:18,451 - Quant - INFO - Functional         : Adding new module for node: {add_1} 
2022-12-21 23:53:18,452 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer1_1_relu_1} 
2022-12-21 23:53:18,452 - Quant - INFO - Functional         : Adding new module for node: {add_2} 
2022-12-21 23:53:18,453 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer2_0_relu_1} 
2022-12-21 23:53:18,453 - Quant - INFO - Functional         : Adding new module for node: {add_3} 
2022-12-21 23:53:18,454 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer2_1_relu_1} 
2022-12-21 23:53:18,454 - Quant - INFO - Functional         : Adding new module for node: {add_4} 
2022-12-21 23:53:18,455 - Quant - INFO - Reused/Duplicate   : Adding ne



In [9]:
_ = fold_all_batch_norms(model, input_shapes=(1, 3, 224, 224))

2022-12-21 23:53:19,001 - Utils - INFO - ...... subset to store [Conv_0, BatchNormalization_1]
2022-12-21 23:53:19,002 - Utils - INFO - ...... subset to store [Conv_4, BatchNormalization_5]
2022-12-21 23:53:19,002 - Utils - INFO - ...... subset to store [Conv_7, BatchNormalization_8]
2022-12-21 23:53:19,002 - Utils - INFO - ...... subset to store [Conv_11, BatchNormalization_12]
2022-12-21 23:53:19,003 - Utils - INFO - ...... subset to store [Conv_14, BatchNormalization_15]
2022-12-21 23:53:19,003 - Utils - INFO - ...... subset to store [Conv_18, BatchNormalization_19]
2022-12-21 23:53:19,004 - Utils - INFO - ...... subset to store [Conv_21, BatchNormalization_22]
2022-12-21 23:53:19,004 - Utils - INFO - ...... subset to store [Conv_27, BatchNormalization_28]
2022-12-21 23:53:19,004 - Utils - INFO - ...... subset to store [Conv_30, BatchNormalization_31]
2022-12-21 23:53:19,005 - Utils - INFO - ...... subset to store [Conv_34, BatchNormalization_35]
2022-12-21 23:53:19,005 - Utils - IN

In [10]:
dummy_input = torch.rand(1, 3, 224, 224)    # Shape for each ImageNet sample is (3 channels) x (224 height) x (224 width)
dummy_input = dummy_input.cuda()

sim = QuantizationSimModel(model=model,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           dummy_input=dummy_input,
                           default_output_bw=4,
                           default_param_bw=4)

2022-12-21 23:53:22,829 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.8/dist-packages/aimet_common/quantsim_config/default_config.json
2022-12-21 23:53:22,849 - Quant - INFO - Unsupported op type Squeeze
2022-12-21 23:53:22,849 - Quant - INFO - Unsupported op type Pad
2022-12-21 23:53:22,849 - Quant - INFO - Unsupported op type Mean
2022-12-21 23:53:22,852 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2022-12-21 23:53:22,853 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2022-12-21 23:53:22,853 - Utils - INFO - ...... subset to store [Conv_3, Relu_4]
2022-12-21 23:53:22,854 - Utils - INFO - ...... subset to store [Conv_3, Relu_4]
2022-12-21 23:53:22,854 - Utils - INFO - ...... subset to store [Add_6, Relu_7]
2022-12-21 23:53:22,855 - Utils - INFO - ...... subset to store [Add_6, Relu_7]
2022-12-21 23:53:22,855 - Utils - INFO - ...... subset to store [Conv_8, Relu_9]
2022-12-21 23:53:22,855 - Utils - INFO - ...... subs

In [11]:
sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=device)

64
128
192
256
320
384
448
512
576
640
704
768
832
896
960
1024


In [12]:
%%time
accuracy(sim.model, test_loader, device='gpu')

100%|██████████| 157/157 [00:02<00:00, 55.95it/s]

Acc: tensor(0.2235, device='cuda:0')
CPU times: user 1.89 s, sys: 259 ms, total: 2.15 s
Wall time: 2.81 s





# AdaRound

In [8]:
model = prepare_model(model)

2022-12-21 23:54:11,038 - Quant - INFO - Functional         : Adding new module for node: {add} 
2022-12-21 23:54:11,039 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer1_0_relu_1} 
2022-12-21 23:54:11,039 - Quant - INFO - Functional         : Adding new module for node: {add_1} 
2022-12-21 23:54:11,041 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer1_1_relu_1} 
2022-12-21 23:54:11,042 - Quant - INFO - Functional         : Adding new module for node: {add_2} 
2022-12-21 23:54:11,042 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer2_0_relu_1} 
2022-12-21 23:54:11,044 - Quant - INFO - Functional         : Adding new module for node: {add_3} 
2022-12-21 23:54:11,045 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer2_1_relu_1} 
2022-12-21 23:54:11,045 - Quant - INFO - Functional         : Adding new module for node: {add_4} 
2022-12-21 23:54:11,046 - Quant - INFO - Reused/Duplicate   : Adding ne



In [11]:
_ = fold_all_batch_norms(model, input_shapes=(1, 3, 224, 224))

In [12]:
params = AdaroundParameters(data_loader=train_loader, 
                            num_batches=2000//train_loader.batch_size, 
                            default_num_iterations=10000)

In [13]:
os.makedirs('./cifar100_w4/', exist_ok=True)

In [None]:
%time
dummy_input = torch.rand(1, 3, 224, 224).to(device)

ada_model = Adaround.apply_adaround(model, dummy_input, params,
                                    path="cifar100_w4", 
                                    filename_prefix='adaround', 
                                    default_param_bw=4,
                                    default_quant_scheme=QuantScheme.post_training_tf_enhanced)

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 5.48 µs
2023-01-06 16:19:36,456 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.8/dist-packages/aimet_common/quantsim_config/default_config.json
2023-01-06 16:19:36,482 - Quant - INFO - Unsupported op type Squeeze
2023-01-06 16:19:36,483 - Quant - INFO - Unsupported op type Pad
2023-01-06 16:19:36,483 - Quant - INFO - Unsupported op type Mean
2023-01-06 16:19:36,489 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2023-01-06 16:19:36,490 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2023-01-06 16:19:36,491 - Utils - INFO - ...... subset to store [Conv_3, Relu_4]
2023-01-06 16:19:36,493 - Utils - INFO - ...... subset to store [Conv_3, Relu_4]
2023-01-06 16:19:36,495 - Utils - INFO - ...... subset to store [Add_6, Relu_7]
2023-01-06 16:19:36,496 - Utils - INFO - ...... subset to store [Add_6, Relu_7]
2023-01-06 16:19:36,497 - Utils - INFO - ...... subset to store [Conv

                                                                                                             

2023-01-06 16:19:41,037 - Quant - INFO - Started Optimizing weight rounding of module: conv1


                                                                                                             

2023-01-06 16:21:19,862 - Quant - INFO - Started Optimizing weight rounding of module: layer1.0.conv1


                                                                                                             

2023-01-06 16:22:47,419 - Quant - INFO - Started Optimizing weight rounding of module: layer1.0.conv2


                                                                                                             

2023-01-06 16:24:16,454 - Quant - INFO - Started Optimizing weight rounding of module: layer1.1.conv1


                                                                                                             

2023-01-06 16:25:58,798 - Quant - INFO - Started Optimizing weight rounding of module: layer1.1.conv2


                                                                                                             

2023-01-06 16:27:41,623 - Quant - INFO - Started Optimizing weight rounding of module: layer2.0.conv1


                                                                                                             

2023-01-06 16:29:34,728 - Quant - INFO - Started Optimizing weight rounding of module: layer2.0.conv2


                                                                                                             

2023-01-06 16:32:04,699 - Quant - INFO - Started Optimizing weight rounding of module: layer2.0.downsample.0


                                                                                                             

2023-01-06 16:33:13,959 - Quant - INFO - Started Optimizing weight rounding of module: layer2.1.conv1


                                                                                                             

2023-01-06 16:35:54,293 - Quant - INFO - Started Optimizing weight rounding of module: layer2.1.conv2


                                                                                                             

2023-01-06 16:38:36,065 - Quant - INFO - Started Optimizing weight rounding of module: layer3.0.conv1


                                                                                                             

2023-01-06 16:42:20,217 - Quant - INFO - Started Optimizing weight rounding of module: layer3.0.conv2


                                                                                                             

2023-01-06 16:49:03,751 - Quant - INFO - Started Optimizing weight rounding of module: layer3.0.downsample.0


                                                                                                             

2023-01-06 16:50:30,655 - Quant - INFO - Started Optimizing weight rounding of module: layer3.1.conv1


                                                                                                             

2023-01-06 16:57:30,810 - Quant - INFO - Started Optimizing weight rounding of module: layer3.1.conv2


                                                                                                             

2023-01-06 17:04:47,316 - Quant - INFO - Started Optimizing weight rounding of module: layer4.0.conv1


                                                                                                             

2023-01-06 17:17:34,492 - Quant - INFO - Started Optimizing weight rounding of module: layer4.0.conv2


                                                                                                             

2023-01-06 17:42:15,588 - Quant - INFO - Started Optimizing weight rounding of module: layer4.0.downsample.0


                                                                                                             

2023-01-06 17:44:38,886 - Quant - INFO - Started Optimizing weight rounding of module: layer4.1.conv1


                                                                                                             

2023-01-06 18:09:53,089 - Quant - INFO - Started Optimizing weight rounding of module: layer4.1.conv2


 88%|████████████████████████████████████████████████████████████▉        | 60/68 [1:50:12<35:20, 265.07s/it]

In [14]:
dummy_input = torch.rand(1, 3, 224, 224)
dummy_input = dummy_input.to(device)

sim = QuantizationSimModel(model=ada_model,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           dummy_input=dummy_input,
                           default_output_bw=4,
                           default_param_bw=4)

2022-12-22 00:01:17,018 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.8/dist-packages/aimet_common/quantsim_config/default_config.json
2022-12-22 00:01:17,035 - Quant - INFO - Unsupported op type Squeeze
2022-12-22 00:01:17,036 - Quant - INFO - Unsupported op type Pad
2022-12-22 00:01:17,036 - Quant - INFO - Unsupported op type Mean
2022-12-22 00:01:17,039 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2022-12-22 00:01:17,040 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2022-12-22 00:01:17,040 - Utils - INFO - ...... subset to store [Conv_3, Relu_4]
2022-12-22 00:01:17,041 - Utils - INFO - ...... subset to store [Conv_3, Relu_4]
2022-12-22 00:01:17,041 - Utils - INFO - ...... subset to store [Add_6, Relu_7]
2022-12-22 00:01:17,042 - Utils - INFO - ...... subset to store [Add_6, Relu_7]
2022-12-22 00:01:17,042 - Utils - INFO - ...... subset to store [Conv_8, Relu_9]
2022-12-22 00:01:17,042 - Utils - INFO - ...... subs

In [15]:
sim.set_and_freeze_param_encodings(encoding_path=os.path.join("cifar100_w4", 'adaround.encodings'))

2022-12-22 00:01:17,181 - Quant - INFO - Setting quantization encodings for parameter: conv1.weight
2022-12-22 00:01:17,181 - Quant - INFO - Freezing quantization encodings for parameter: conv1.weight
2022-12-22 00:01:17,182 - Quant - INFO - Setting quantization encodings for parameter: layer1.0.conv1.weight
2022-12-22 00:01:17,182 - Quant - INFO - Freezing quantization encodings for parameter: layer1.0.conv1.weight
2022-12-22 00:01:17,183 - Quant - INFO - Setting quantization encodings for parameter: layer1.0.conv2.weight
2022-12-22 00:01:17,183 - Quant - INFO - Freezing quantization encodings for parameter: layer1.0.conv2.weight
2022-12-22 00:01:17,184 - Quant - INFO - Setting quantization encodings for parameter: layer1.1.conv1.weight
2022-12-22 00:01:17,184 - Quant - INFO - Freezing quantization encodings for parameter: layer1.1.conv1.weight
2022-12-22 00:01:17,184 - Quant - INFO - Setting quantization encodings for parameter: layer1.1.conv2.weight
2022-12-22 00:01:17,185 - Quant -

In [16]:
sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=device)



64
128
192
256
320
384
448
512
576
640
704
768
832
896
960
1024


In [17]:
accuracy(sim.model, test_loader, device=device)

CPU times: user 2 µs, sys: 0 ns, total: 2 µs
Wall time: 4.29 µs


100%|██████████| 157/157 [00:02<00:00, 56.59it/s]

Acc: tensor(0.4052, device='cuda:0')





# Cross-Layer Equalization

In [9]:
model = prepare_model(model)

2023-01-06 16:18:41,948 - Quant - INFO - Functional         : Adding new module for node: {add} 
2023-01-06 16:18:41,951 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer1_0_relu_1} 
2023-01-06 16:18:41,953 - Quant - INFO - Functional         : Adding new module for node: {add_1} 
2023-01-06 16:18:41,955 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer1_1_relu_1} 
2023-01-06 16:18:41,956 - Quant - INFO - Functional         : Adding new module for node: {add_2} 
2023-01-06 16:18:41,957 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer2_0_relu_1} 
2023-01-06 16:18:41,958 - Quant - INFO - Functional         : Adding new module for node: {add_3} 
2023-01-06 16:18:41,959 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer2_1_relu_1} 
2023-01-06 16:18:41,960 - Quant - INFO - Functional         : Adding new module for node: {add_4} 
2023-01-06 16:18:41,967 - Quant - INFO - Reused/Duplicate   : Adding ne



In [10]:
# not using batchnorm fold because of equalization
# Note: Interestingly, CLE needs BN statistics for its procedure. 
# If a BN folded model is provided, CLE will run the CLS (cross-layer scaling) optimization step 
# but will skip the HBA (high-bias absorption) step. 
equalize_model(model, input_shapes=(1, 3, 224, 224))

2023-01-06 16:18:45,310 - Utils - INFO - ...... subset to store [Conv_0, BatchNormalization_1]
2023-01-06 16:18:45,313 - Utils - INFO - ...... subset to store [Conv_4, BatchNormalization_5]
2023-01-06 16:18:45,316 - Utils - INFO - ...... subset to store [Conv_7, BatchNormalization_8]
2023-01-06 16:18:45,318 - Utils - INFO - ...... subset to store [Conv_11, BatchNormalization_12]
2023-01-06 16:18:45,320 - Utils - INFO - ...... subset to store [Conv_14, BatchNormalization_15]
2023-01-06 16:18:45,321 - Utils - INFO - ...... subset to store [Conv_18, BatchNormalization_19]
2023-01-06 16:18:45,322 - Utils - INFO - ...... subset to store [Conv_21, BatchNormalization_22]
2023-01-06 16:18:45,323 - Utils - INFO - ...... subset to store [Conv_27, BatchNormalization_28]
2023-01-06 16:18:45,323 - Utils - INFO - ...... subset to store [Conv_30, BatchNormalization_31]
2023-01-06 16:18:45,324 - Utils - INFO - ...... subset to store [Conv_34, BatchNormalization_35]
2023-01-06 16:18:45,325 - Utils - IN

In [11]:
dummy_input = torch.rand(1, 3, 224, 224)  
dummy_input = dummy_input.to(device)

sim = QuantizationSimModel(model=model,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           dummy_input=dummy_input,
                           default_output_bw=4,
                           default_param_bw=8)

2023-01-06 14:55:34,980 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.8/dist-packages/aimet_common/quantsim_config/default_config.json
2023-01-06 14:55:35,208 - Quant - INFO - Unsupported op type Squeeze
2023-01-06 14:55:35,209 - Quant - INFO - Unsupported op type Pad
2023-01-06 14:55:35,210 - Quant - INFO - Unsupported op type Mean
2023-01-06 14:55:35,220 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2023-01-06 14:55:35,222 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2023-01-06 14:55:35,229 - Utils - INFO - ...... subset to store [Conv_3, Relu_4]
2023-01-06 14:55:35,230 - Utils - INFO - ...... subset to store [Conv_3, Relu_4]
2023-01-06 14:55:35,231 - Utils - INFO - ...... subset to store [Add_6, Relu_7]
2023-01-06 14:55:35,232 - Utils - INFO - ...... subset to store [Add_6, Relu_7]
2023-01-06 14:55:35,233 - Utils - INFO - ...... subset to store [Conv_8, Relu_9]
2023-01-06 14:55:35,234 - Utils - INFO - ...... subs

In [12]:
sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=device)

64
128
192
256
320
384
448
512
576
640
704
768
832
896
960
1024


In [13]:
accuracy(sim.model, test_loader, device=device)

100%|██████████████████████████████████████████████████████████████████████| 157/157 [02:03<00:00,  1.27it/s]

Acc: tensor(0.4113)





# Bias Correction

In [11]:
# bias correction is applied to an equalized model
bc_params = QuantParams(weight_bw=4, act_bw=4, round_mode="nearest",
                        quant_scheme=QuantScheme.post_training_tf_enhanced)

In [12]:
%%time 
correct_bias(model, bc_params, num_quant_samples=500,
             data_loader=train_loader, num_bias_correct_samples=1000)

2023-01-06 16:10:24,724 - Utils - INFO - ...... subset to store [Conv_0]
2023-01-06 16:10:24,725 - Utils - INFO - ...... subset to store [Conv_0]
2023-01-06 16:10:24,726 - Utils - INFO - ...... subset to store [Conv_0]
2023-01-06 16:10:24,727 - Utils - INFO - ...... subset to store [Conv_3]
2023-01-06 16:10:24,727 - Utils - INFO - ...... subset to store [Conv_3]
2023-01-06 16:10:24,728 - Utils - INFO - ...... subset to store [Conv_3]
2023-01-06 16:10:24,728 - Utils - INFO - ...... subset to store [Conv_5]
2023-01-06 16:10:24,729 - Utils - INFO - ...... subset to store [Conv_5]
2023-01-06 16:10:24,730 - Utils - INFO - ...... subset to store [Conv_5]
2023-01-06 16:10:24,733 - Utils - INFO - ...... subset to store [Conv_8]
2023-01-06 16:10:24,735 - Utils - INFO - ...... subset to store [Conv_8]
2023-01-06 16:10:24,737 - Utils - INFO - ...... subset to store [Conv_8]
2023-01-06 16:10:24,739 - Utils - INFO - ...... subset to store [Conv_10]
2023-01-06 16:10:24,740 - Utils - INFO - ...... su

2023-01-06 16:11:14,173 - Quant - INFO - Corrected bias for the layer
2023-01-06 16:11:14,175 - Quant - INFO - Correcting layer layer2.0.conv2 using Empirical Bias Correction
2023-01-06 16:11:25,069 - Quant - INFO - Corrected bias for the layer
2023-01-06 16:11:25,070 - Quant - INFO - Correcting layer layer2.0.downsample.0 using Empirical Bias Correction
2023-01-06 16:11:36,092 - Quant - INFO - Corrected bias for the layer
2023-01-06 16:11:36,094 - Quant - INFO - Correcting layer layer2.1.conv1 using Empirical Bias Correction
2023-01-06 16:11:48,811 - Quant - INFO - Corrected bias for the layer
2023-01-06 16:11:48,812 - Quant - INFO - Correcting layer layer2.1.conv2 using Empirical Bias Correction
2023-01-06 16:12:01,488 - Quant - INFO - Corrected bias for the layer
2023-01-06 16:12:01,489 - Quant - INFO - Correcting layer layer3.0.conv1 using Empirical Bias Correction
2023-01-06 16:12:15,443 - Quant - INFO - Corrected bias for the layer
2023-01-06 16:12:15,448 - Quant - INFO - Correct

In [13]:
dummy_input = torch.rand(1, 3, 224, 224)  
dummy_input = dummy_input.to(device)

sim = QuantizationSimModel(model=model,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           dummy_input=dummy_input,
                           default_output_bw=4,
                           default_param_bw=4)

2023-01-06 16:15:26,372 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.8/dist-packages/aimet_common/quantsim_config/default_config.json
2023-01-06 16:15:26,391 - Quant - INFO - Unsupported op type Squeeze
2023-01-06 16:15:26,392 - Quant - INFO - Unsupported op type Pad
2023-01-06 16:15:26,393 - Quant - INFO - Unsupported op type Mean
2023-01-06 16:15:26,397 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2023-01-06 16:15:26,399 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2023-01-06 16:15:26,400 - Utils - INFO - ...... subset to store [Conv_3, Relu_4]
2023-01-06 16:15:26,400 - Utils - INFO - ...... subset to store [Conv_3, Relu_4]
2023-01-06 16:15:26,401 - Utils - INFO - ...... subset to store [Add_6, Relu_7]
2023-01-06 16:15:26,402 - Utils - INFO - ...... subset to store [Add_6, Relu_7]
2023-01-06 16:15:26,403 - Utils - INFO - ...... subset to store [Conv_8, Relu_9]
2023-01-06 16:15:26,404 - Utils - INFO - ...... subs

In [14]:
sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=device)

64
128
192
256
320
384
448
512
576
640
704
768
832
896
960
1024


In [15]:
accuracy(sim.model, test_loader, device=device)

100%|██████████████████████████████████████████████████████████████████████| 157/157 [01:51<00:00,  1.41it/s]

Acc: tensor(0.2384)



