In [1]:
import sys
if '..' not in sys.path:
    sys.path.append('..')

In [2]:
from flopco import FlopCo

In [4]:
import torch
from torchvision.models import resnet18

from tqdm import tqdm
import numpy as np
import os
import random
from functools import partial

from aimet_torch.model_preparer import prepare_model
from aimet_torch.batch_norm_fold import fold_all_batch_norms
from aimet_common.defs import QuantScheme
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.adaround.adaround_weight import Adaround, AdaroundParameters

from source.data import get_imagenet_test_loader, get_imagenet_train_val_loaders
from source.eval import accuracy
# from source.admm import build_cp_layer
from source.utils import bncalibrate_model
from source.rank_map import get_rank_map


seed = 42
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
device

device(type='cuda', index=0)

In [7]:
import torchvision
torch.__version__, torchvision.__version__

('1.9.1+cu111', '0.10.1+cu111')

In [4]:
def pass_calibration_data(sim_model, use_cuda, dataloader):
    batch_size = train_loader.batch_size

    if use_cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    sim_model.eval()
    samples = 1000

    batch_cntr = 0
    with torch.no_grad():
        for input_data, target_data in train_loader:

            inputs_batch = input_data.to(device)
            sim_model(inputs_batch)

            batch_cntr += 1
            print(batch_cntr * batch_size)
            if (batch_cntr * batch_size) >= samples:
                break

In [5]:
train_loader, val_loader = get_imagenet_train_val_loaders(data_root='/gpfs/gpfs0/k.sobolev/ILSVRC-12/',
                                       batch_size=32,
                                       num_workers=4,
                                       pin_memory=True,
                                       val_perc=0.04,
                                       shuffle=True,
                                       random_seed=seed)

In [6]:
test_loader = get_imagenet_test_loader(data_root='/gpfs/gpfs0/k.sobolev/ILSVRC-12/', 
                                       batch_size=32,
                                       num_workers=4,
                                       pin_memory=True,
                                       shuffle=False)

In [7]:
method = 'parafac_epc'
qscheme = 'tensor_affine'
bits = 8
eps = 0.0005
decomp = 'cp3-epc'
rank_map = get_rank_map(eps, decomp)
num_samples = 2048

In [8]:
model_name = f"m={method}_b={bits}_e={eps}_d={decomp}_{qscheme.split('_')[-1]}.calibrated_{num_samples}"
# model = resnet18(pretrained=True).to(device)
print(f"loading model {model_name}")
model = torch.load('../checkpoints/'+model_name)
model.eval()
model = model.to(device)

loading model m=parafac_epc_b=8_e=0.0005_d=cp3-epc_affine.calibrated_2048


In [10]:
model_stats = FlopCo(model.to(device), img_size=(1, 3, 224, 224), device=device)
orig_macs = 0
for x in model_stats.macs.values():
    orig_macs += x[0]
orig_macs

1814073344

In [11]:
for module in ['layer1', 'layer2', 'layer3', 'layer4']:
    for layer_path in [f'{module}.0.conv1', f'{module}.0.conv2', 
#                        f'{module}.0.downsample',
                       f'{module}.1.conv1', f'{module}.1.conv2']:
        # there is no layer1.0.downsample layer
        if layer_path == 'layer1.0.downsample': continue
        # layer1.0.conv1 is crusial
#         if layer_path == 'layer1.0.conv1': continue
            
        lname, lidx, ltype = layer_path.split('.')
        lidx = int(lidx)
        layer = model.__getattr__(lname)[lidx].__getattr__(ltype)
        kernel_size = layer.kernel_size
        stride = layer.stride
        padding = layer.padding
        cin = layer.in_channels
        cout = layer.out_channels
        rank = rank_map[layer_path]
        bias = layer.bias
        if bias is not None: bias = bias.detach()
        
        factor_name = os.path.join(f"../{bits}bit_{qscheme}",
                                   f"factors_{method}_seed{seed}", 
                                   f"{layer_path}_{method}_random_rank_{rank}_mode_")
        print('loading factors:', factor_name)
        A = torch.load(factor_name+'0.pt').float().to(device)
        assert A.dtype == torch.float 
        B = torch.load(factor_name+'1.pt').float().to(device)
        C = torch.load(factor_name+'2.pt').float().to(device)
    
        model.__getattr__(lname)[lidx].__setattr__(
            ltype, build_cp_layer(rank, [A,B,C], bias, cin, cout, kernel_size, padding, stride).to(device))

loading factors: ../8bit_tensor_affine/factors_parafac_epc_seed42/layer1.0.conv1_parafac_epc_random_rank_101_mode_
loading factors: ../8bit_tensor_affine/factors_parafac_epc_seed42/layer1.0.conv2_parafac_epc_random_rank_90_mode_
loading factors: ../8bit_tensor_affine/factors_parafac_epc_seed42/layer1.1.conv1_parafac_epc_random_rank_133_mode_
loading factors: ../8bit_tensor_affine/factors_parafac_epc_seed42/layer1.1.conv2_parafac_epc_random_rank_107_mode_
loading factors: ../8bit_tensor_affine/factors_parafac_epc_seed42/layer2.0.conv1_parafac_epc_random_rank_249_mode_
loading factors: ../8bit_tensor_affine/factors_parafac_epc_seed42/layer2.0.conv2_parafac_epc_random_rank_293_mode_
loading factors: ../8bit_tensor_affine/factors_parafac_epc_seed42/layer2.1.conv1_parafac_epc_random_rank_302_mode_
loading factors: ../8bit_tensor_affine/factors_parafac_epc_seed42/layer2.1.conv2_parafac_epc_random_rank_212_mode_
loading factors: ../8bit_tensor_affine/factors_parafac_epc_seed42/layer3.0.conv1_

In [12]:
model_stats = FlopCo(model.to(device), img_size=(1, 3, 224, 224), device=device)
redc_macs = 0
for x in model_stats.macs.values():
    redc_macs += x[0]
# redc_macs / orig_macs
redc_macs / 1814073344

0.5268461466351826

In [13]:
accuracy(model, test_loader, device=device)

100%|██████████| 1562/1562 [01:40<00:00, 15.59it/s]


0.19962387964148529

In [14]:
model = bncalibrate_model(model, train_loader, num_samples=num_samples, device=device)

  0%|          | 65/38435 [00:22<3:43:02,  2.87it/s]


In [15]:
accuracy(model, test_loader, device=device)

100%|██████████| 1562/1562 [01:38<00:00, 15.88it/s]


0.6469070102432779

In [17]:
# torch.save(model, f"m={method}_b={bits}_e={eps}_d={decomp}_{qscheme.split('_')[-1]}.calibrated_{num_samples}")
torch.save(model, '../checkpoints/'+model_name)

In [20]:
model_name

'm=admm_b=8_e=0.001_d=cp3_affine.calibrated_10000'

# Quantization

In [40]:
model = prepare_model(model)

2023-01-23 18:54:54,035 - Quant - INFO - Functional         : Adding new module for node: {add} 
2023-01-23 18:54:54,036 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer1_0_relu_1} 
2023-01-23 18:54:54,036 - Quant - INFO - Functional         : Adding new module for node: {add_1} 
2023-01-23 18:54:54,037 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer1_1_relu_1} 
2023-01-23 18:54:54,037 - Quant - INFO - Functional         : Adding new module for node: {add_2} 
2023-01-23 18:54:54,038 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer2_0_relu_1} 
2023-01-23 18:54:54,038 - Quant - INFO - Functional         : Adding new module for node: {add_3} 
2023-01-23 18:54:54,039 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer2_1_relu_1} 
2023-01-23 18:54:54,039 - Quant - INFO - Functional         : Adding new module for node: {add_4} 
2023-01-23 18:54:54,040 - Quant - INFO - Reused/Duplicate   : Adding ne



In [41]:
_ = fold_all_batch_norms(model, input_shapes=(1, 3, 224, 224))

2023-01-23 18:54:58,828 - Utils - INFO - ...... subset to store [Conv_0, BatchNormalization_1]
2023-01-23 18:54:58,829 - Utils - INFO - ...... subset to store [Conv_4, BatchNormalization_5]
2023-01-23 18:54:58,829 - Utils - INFO - ...... subset to store [Conv_9, BatchNormalization_10]
2023-01-23 18:54:58,830 - Utils - INFO - ...... subset to store [Conv_15, BatchNormalization_16]
2023-01-23 18:54:58,830 - Utils - INFO - ...... subset to store [Conv_20, BatchNormalization_21]
2023-01-23 18:54:58,830 - Utils - INFO - ...... subset to store [Conv_26, BatchNormalization_27]
2023-01-23 18:54:58,831 - Utils - INFO - ...... subset to store [Conv_31, BatchNormalization_32]
2023-01-23 18:54:58,831 - Utils - INFO - ...... subset to store [Conv_39, BatchNormalization_40]
2023-01-23 18:54:58,832 - Utils - INFO - ...... subset to store [Conv_44, BatchNormalization_45]
2023-01-23 18:54:58,832 - Utils - INFO - ...... subset to store [Conv_50, BatchNormalization_51]
2023-01-23 18:54:58,832 - Utils - I

In [42]:
dummy_input = torch.rand(1, 3, 224, 224)    # Shape for each ImageNet sample is (3 channels) x (224 height) x (224 width)
dummy_input = dummy_input.cuda()

sim = QuantizationSimModel(model=model,
#                            quant_scheme=QuantScheme.post_training_tf,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           dummy_input=dummy_input,
                           default_output_bw=8,
                           default_param_bw=8)

2023-01-23 18:55:05,749 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.8/dist-packages/aimet_common/quantsim_config/default_config.json
2023-01-23 18:55:05,769 - Quant - INFO - Unsupported op type Squeeze
2023-01-23 18:55:05,769 - Quant - INFO - Unsupported op type Pad
2023-01-23 18:55:05,770 - Quant - INFO - Unsupported op type Mean
2023-01-23 18:55:05,776 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2023-01-23 18:55:05,776 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2023-01-23 18:55:05,777 - Utils - INFO - ...... subset to store [Conv_3, Relu_4]
2023-01-23 18:55:05,777 - Utils - INFO - ...... subset to store [Conv_3, Relu_4]
2023-01-23 18:55:05,777 - Utils - INFO - ...... subset to store [Add_8, Relu_9]
2023-01-23 18:55:05,778 - Utils - INFO - ...... subset to store [Add_8, Relu_9]
2023-01-23 18:55:05,778 - Utils - INFO - ...... subset to store [Conv_12, Relu_13]
2023-01-23 18:55:05,779 - Utils - INFO - ...... su

In [43]:
sim.compute_encodings(forward_pass_callback=pass_calibration_data,
                      forward_pass_callback_args=True)

500
1000


In [44]:
# post_training_tf_enchanced
accuracy(sim.model, test_loader, device='cuda')

100%|██████████| 100/100 [01:43<00:00,  1.03s/it]


0.00096

# AdaRound

In [9]:
model = prepare_model(model)

2023-01-26 18:08:06,007 - Quant - INFO - Functional         : Adding new module for node: {add} 
2023-01-26 18:08:06,008 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer1_0_relu_1} 
2023-01-26 18:08:06,009 - Quant - INFO - Functional         : Adding new module for node: {add_1} 
2023-01-26 18:08:06,009 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer1_1_relu_1} 
2023-01-26 18:08:06,010 - Quant - INFO - Functional         : Adding new module for node: {add_2} 
2023-01-26 18:08:06,010 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer2_0_relu_1} 
2023-01-26 18:08:06,011 - Quant - INFO - Functional         : Adding new module for node: {add_3} 
2023-01-26 18:08:06,011 - Quant - INFO - Reused/Duplicate   : Adding new module for node: {layer2_1_relu_1} 
2023-01-26 18:08:06,012 - Quant - INFO - Functional         : Adding new module for node: {add_4} 
2023-01-26 18:08:06,012 - Quant - INFO - Reused/Duplicate   : Adding ne



In [10]:
_ = fold_all_batch_norms(model, input_shapes=(1, 3, 224, 224))

2023-01-26 18:08:07,396 - Utils - INFO - ...... subset to store [Conv_0, BatchNormalization_1]
2023-01-26 18:08:07,398 - Utils - INFO - ...... subset to store [Conv_6, BatchNormalization_7]
2023-01-26 18:08:07,398 - Utils - INFO - ...... subset to store [Conv_11, BatchNormalization_12]
2023-01-26 18:08:07,399 - Utils - INFO - ...... subset to store [Conv_17, BatchNormalization_18]
2023-01-26 18:08:07,399 - Utils - INFO - ...... subset to store [Conv_22, BatchNormalization_23]
2023-01-26 18:08:07,399 - Utils - INFO - ...... subset to store [Conv_28, BatchNormalization_29]
2023-01-26 18:08:07,400 - Utils - INFO - ...... subset to store [Conv_33, BatchNormalization_34]
2023-01-26 18:08:07,400 - Utils - INFO - ...... subset to store [Conv_41, BatchNormalization_42]
2023-01-26 18:08:07,401 - Utils - INFO - ...... subset to store [Conv_46, BatchNormalization_47]
2023-01-26 18:08:07,401 - Utils - INFO - ...... subset to store [Conv_52, BatchNormalization_53]
2023-01-26 18:08:07,401 - Utils - 

In [13]:
params = AdaroundParameters(data_loader=val_loader, 
                            num_batches=2048//val_loader.batch_size, 
                            default_num_iterations=20000)

In [14]:
dummy_input = torch.rand(1, 3, 224, 224).cuda()
# adaround_path = f"adaround_e={eps}_d={decomp}_{qscheme.split('_')[-1]}.calibrated_{num_samples}"
adaround_path = 'adaround'

In [22]:
!rm -r ./adaround/

In [23]:
os.makedirs(adaround_path, exist_ok=True)

In [24]:
bits = 4

In [25]:
%time
dummy_input = torch.rand(1, 3, 224, 224)  
dummy_input = dummy_input.cuda()

ada_model = Adaround.apply_adaround(model, dummy_input, params,
                                    path=adaround_path, 
                                    filename_prefix='adaround', 
                                    default_param_bw=bits,
                                    default_quant_scheme=QuantScheme.post_training_tf_enhanced)

CPU times: user 3 µs, sys: 0 ns, total: 3 µs
Wall time: 6.68 µs
2023-01-26 18:50:09,055 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.8/dist-packages/aimet_common/quantsim_config/default_config.json
2023-01-26 18:50:09,073 - Quant - INFO - Unsupported op type Squeeze
2023-01-26 18:50:09,074 - Quant - INFO - Unsupported op type Pad
2023-01-26 18:50:09,074 - Quant - INFO - Unsupported op type Mean
2023-01-26 18:50:09,080 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2023-01-26 18:50:09,080 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2023-01-26 18:50:09,081 - Utils - INFO - ...... subset to store [Conv_5, Relu_6]
2023-01-26 18:50:09,081 - Utils - INFO - ...... subset to store [Conv_5, Relu_6]
2023-01-26 18:50:09,081 - Utils - INFO - ...... subset to store [Add_10, Relu_11]
2023-01-26 18:50:09,082 - Utils - INFO - ...... subset to store [Add_10, Relu_11]
2023-01-26 18:50:09,082 - Utils - INFO - ...... subset to store [

                                       

2023-01-26 18:50:19,719 - Quant - INFO - Started Optimizing weight rounding of module: conv1


                                                 

2023-01-26 18:51:56,959 - Quant - INFO - Started Optimizing weight rounding of module: layer1.0.conv1.conv1


                                                 

2023-01-26 18:52:35,109 - Quant - INFO - Started Optimizing weight rounding of module: layer1.0.conv1.conv2


                                               

2023-01-26 18:53:16,824 - Quant - INFO - Started Optimizing weight rounding of module: layer1.0.conv1.conv3


                                               

2023-01-26 18:54:02,695 - Quant - INFO - Started Optimizing weight rounding of module: layer1.0.conv2.conv1


                                                

2023-01-26 18:54:42,592 - Quant - INFO - Started Optimizing weight rounding of module: layer1.0.conv2.conv2


                                                

2023-01-26 18:55:22,776 - Quant - INFO - Started Optimizing weight rounding of module: layer1.0.conv2.conv3


                                                

2023-01-26 18:56:05,058 - Quant - INFO - Started Optimizing weight rounding of module: layer1.1.conv1.conv1


                                                

2023-01-26 18:56:48,063 - Quant - INFO - Started Optimizing weight rounding of module: layer1.1.conv1.conv2


                                                

2023-01-26 18:57:36,523 - Quant - INFO - Started Optimizing weight rounding of module: layer1.1.conv1.conv3


                                                

2023-01-26 18:58:27,050 - Quant - INFO - Started Optimizing weight rounding of module: layer1.1.conv2.conv1


                                                

2023-01-26 18:59:06,338 - Quant - INFO - Started Optimizing weight rounding of module: layer1.1.conv2.conv2


                                                

2023-01-26 18:59:50,038 - Quant - INFO - Started Optimizing weight rounding of module: layer1.1.conv2.conv3


                                                

2023-01-26 19:00:34,139 - Quant - INFO - Started Optimizing weight rounding of module: layer2.0.conv1.conv1


                                                

2023-01-26 19:01:31,933 - Quant - INFO - Started Optimizing weight rounding of module: layer2.0.conv1.conv2


                                                

2023-01-26 19:02:19,214 - Quant - INFO - Started Optimizing weight rounding of module: layer2.0.conv1.conv3


                                                

2023-01-26 19:03:02,076 - Quant - INFO - Started Optimizing weight rounding of module: layer2.0.conv2.conv1


                                                

2023-01-26 19:03:43,150 - Quant - INFO - Started Optimizing weight rounding of module: layer2.0.conv2.conv2


                                                

2023-01-26 19:04:20,149 - Quant - INFO - Started Optimizing weight rounding of module: layer2.0.conv2.conv3


                                                

2023-01-26 19:05:04,004 - Quant - INFO - Started Optimizing weight rounding of module: layer2.0.downsample.0


                                                

2023-01-26 19:05:42,157 - Quant - INFO - Started Optimizing weight rounding of module: layer2.1.conv1.conv1


                                                

2023-01-26 19:06:23,538 - Quant - INFO - Started Optimizing weight rounding of module: layer2.1.conv1.conv2


                                                

2023-01-26 19:07:00,848 - Quant - INFO - Started Optimizing weight rounding of module: layer2.1.conv1.conv3


                                                

2023-01-26 19:07:45,827 - Quant - INFO - Started Optimizing weight rounding of module: layer2.1.conv2.conv1


                                                

2023-01-26 19:08:26,240 - Quant - INFO - Started Optimizing weight rounding of module: layer2.1.conv2.conv2


                                                

2023-01-26 19:09:01,432 - Quant - INFO - Started Optimizing weight rounding of module: layer2.1.conv2.conv3


                                                

2023-01-26 19:09:43,797 - Quant - INFO - Started Optimizing weight rounding of module: layer3.0.conv1.conv1


                                                

2023-01-26 19:10:24,991 - Quant - INFO - Started Optimizing weight rounding of module: layer3.0.conv1.conv2


                                                

2023-01-26 19:10:59,740 - Quant - INFO - Started Optimizing weight rounding of module: layer3.0.conv1.conv3


                                                

2023-01-26 19:11:37,234 - Quant - INFO - Started Optimizing weight rounding of module: layer3.0.conv2.conv1


                                                

2023-01-26 19:12:12,170 - Quant - INFO - Started Optimizing weight rounding of module: layer3.0.conv2.conv2


                                                

2023-01-26 19:12:45,600 - Quant - INFO - Started Optimizing weight rounding of module: layer3.0.conv2.conv3


                                                

2023-01-26 19:13:28,417 - Quant - INFO - Started Optimizing weight rounding of module: layer3.0.downsample.0


                                                

2023-01-26 19:14:09,223 - Quant - INFO - Started Optimizing weight rounding of module: layer3.1.conv1.conv1


                                                

2023-01-26 19:14:44,750 - Quant - INFO - Started Optimizing weight rounding of module: layer3.1.conv1.conv2


                                                

2023-01-26 19:15:18,826 - Quant - INFO - Started Optimizing weight rounding of module: layer3.1.conv1.conv3


                                                

2023-01-26 19:16:02,791 - Quant - INFO - Started Optimizing weight rounding of module: layer3.1.conv2.conv1


                                                

2023-01-26 19:16:38,454 - Quant - INFO - Started Optimizing weight rounding of module: layer3.1.conv2.conv2


                                                

2023-01-26 19:17:12,126 - Quant - INFO - Started Optimizing weight rounding of module: layer3.1.conv2.conv3


                                                

2023-01-26 19:17:55,460 - Quant - INFO - Started Optimizing weight rounding of module: layer4.0.conv1.conv1


                                                

2023-01-26 19:18:32,029 - Quant - INFO - Started Optimizing weight rounding of module: layer4.0.conv1.conv2


                                                

2023-01-26 19:19:06,494 - Quant - INFO - Started Optimizing weight rounding of module: layer4.0.conv1.conv3


                                                

2023-01-26 19:19:49,831 - Quant - INFO - Started Optimizing weight rounding of module: layer4.0.conv2.conv1


                                                

2023-01-26 19:20:31,178 - Quant - INFO - Started Optimizing weight rounding of module: layer4.0.conv2.conv2


                                                

2023-01-26 19:21:04,885 - Quant - INFO - Started Optimizing weight rounding of module: layer4.0.conv2.conv3


                                                

2023-01-26 19:21:48,607 - Quant - INFO - Started Optimizing weight rounding of module: layer4.0.downsample.0


                                                

2023-01-26 19:22:25,507 - Quant - INFO - Started Optimizing weight rounding of module: layer4.1.conv1.conv1


                                                

2023-01-26 19:23:06,990 - Quant - INFO - Started Optimizing weight rounding of module: layer4.1.conv1.conv2


                                                

2023-01-26 19:23:40,976 - Quant - INFO - Started Optimizing weight rounding of module: layer4.1.conv1.conv3


                                                

2023-01-26 19:24:25,492 - Quant - INFO - Started Optimizing weight rounding of module: layer4.1.conv2.conv1


                                                

2023-01-26 19:25:07,219 - Quant - INFO - Started Optimizing weight rounding of module: layer4.1.conv2.conv2


                                                

2023-01-26 19:25:40,869 - Quant - INFO - Started Optimizing weight rounding of module: layer4.1.conv2.conv3


                                                

2023-01-26 19:26:23,688 - Quant - INFO - Started Optimizing weight rounding of module: fc


100%|██████████| 100/100 [36:38<00:00, 21.99s/it]

2023-01-26 19:26:58,635 - Quant - INFO - Deleting model inputs from location: /tmp/adaround/
2023-01-26 19:26:58,776 - Quant - INFO - Completed Adarounding Model





In [26]:
dummy_input = torch.rand(1, 3, 224, 224)
dummy_input = dummy_input.cuda()

sim = QuantizationSimModel(model=ada_model,
#                            quant_scheme=QuantScheme.post_training_tf,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           dummy_input=dummy_input,
                           default_output_bw=8,
                           default_param_bw=bits)

2023-01-26 19:26:59,150 - Quant - INFO - No config file provided, defaulting to config file at /usr/local/lib/python3.8/dist-packages/aimet_common/quantsim_config/default_config.json
2023-01-26 19:26:59,167 - Quant - INFO - Unsupported op type Squeeze
2023-01-26 19:26:59,167 - Quant - INFO - Unsupported op type Pad
2023-01-26 19:26:59,168 - Quant - INFO - Unsupported op type Mean
2023-01-26 19:26:59,174 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2023-01-26 19:26:59,175 - Utils - INFO - ...... subset to store [Conv_0, Relu_1]
2023-01-26 19:26:59,175 - Utils - INFO - ...... subset to store [Conv_5, Relu_6]
2023-01-26 19:26:59,175 - Utils - INFO - ...... subset to store [Conv_5, Relu_6]
2023-01-26 19:26:59,176 - Utils - INFO - ...... subset to store [Add_10, Relu_11]
2023-01-26 19:26:59,176 - Utils - INFO - ...... subset to store [Add_10, Relu_11]
2023-01-26 19:26:59,177 - Utils - INFO - ...... subset to store [Conv_14, Relu_15]
2023-01-26 19:26:59,177 - Utils - INFO - .....

In [27]:
sim.set_and_freeze_param_encodings(encoding_path=os.path.join(adaround_path, 'adaround.encodings'))

2023-01-26 19:26:59,203 - Quant - INFO - Setting quantization encodings for parameter: conv1.weight
2023-01-26 19:26:59,203 - Quant - INFO - Freezing quantization encodings for parameter: conv1.weight
2023-01-26 19:26:59,204 - Quant - INFO - Setting quantization encodings for parameter: layer1.0.conv1.conv1.weight
2023-01-26 19:26:59,204 - Quant - INFO - Freezing quantization encodings for parameter: layer1.0.conv1.conv1.weight
2023-01-26 19:26:59,205 - Quant - INFO - Setting quantization encodings for parameter: layer1.0.conv1.conv2.weight
2023-01-26 19:26:59,205 - Quant - INFO - Freezing quantization encodings for parameter: layer1.0.conv1.conv2.weight
2023-01-26 19:26:59,205 - Quant - INFO - Setting quantization encodings for parameter: layer1.0.conv1.conv3.weight
2023-01-26 19:26:59,206 - Quant - INFO - Freezing quantization encodings for parameter: layer1.0.conv1.conv3.weight
2023-01-26 19:26:59,206 - Quant - INFO - Setting quantization encodings for parameter: layer1.0.conv2.conv

In [28]:
sim.compute_encodings(forward_pass_callback=partial(pass_calibration_data, dataloader=train_loader),
                      forward_pass_callback_args=True)

32
64
96
128
160
192
224
256
288
320
352
384
416
448
480
512
544
576
608
640
672
704
736
768
800
832
864
896
928
960
992
1024


In [29]:
# tf_enhanced
accuracy(sim.model, test_loader, device=device)

100%|██████████| 1562/1562 [01:44<00:00, 15.02it/s]


0.6070342509603073