In [1]:
import ipdb
import os, sys
import torch
import torch.nn.functional as F
from Examples.common import image_net_config
from Examples.torch.utils.image_net_evaluator import ImageNetEvaluator
from Examples.torch.utils.image_net_trainer import ImageNetTrainer
from Examples.torch.utils.image_net_data_loader import ImageNetDataLoader
from torchvision.models import resnet50
from aimet_torch.model_preparer import prepare_model
from aimet_torch.batch_norm_fold import fold_all_batch_norms
from aimet_common.defs import QuantScheme
from aimet_torch.quantsim import QuantizationSimModel
from aimet_torch.qc_quantize_op import StaticGridQuantWrapper
from functools import partial
from tqdm import tqdm
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"

  param_schemas = callee.param_schemas()
  param_schemas = callee.param_schemas()


2024-11-25 00:40:51,293 - root - INFO - AIMET




In [2]:
TEST_NUM = 100
DATASET_DIR   = '/data/dataset/ImageNet_small'
Calibrate_DIR = '/data/dataset/ImageNet_small'

In [3]:
def hook(name,module, input, output):
    if module not in cached_input_output:
        cached_input_output[module] = []
    # Meanwhile store data in the RAM.
    cached_input_output[module].append((input[0].detach().cpu(), output.detach().cpu()))

In [4]:
class ImageNetDataPipeline:

    @staticmethod
    def get_val_dataloader() -> torch.utils.data.DataLoader:
        """
        Instantiates a validation dataloader for ImageNet dataset and returns it
        """
        data_loader = ImageNetDataLoader(Calibrate_DIR,
                                         image_size=image_net_config.dataset['image_size'],
                                         batch_size=image_net_config.evaluation['batch_size'],
                                         is_training=False,
                                         num_workers=image_net_config.evaluation['num_workers']).data_loader
        return data_loader

    @staticmethod
    def evaluate(model: torch.nn.Module, use_cuda: bool) -> float:
        """
        Given a torch model, evaluates its Top-1 accuracy on the dataset
        :param model: the model to evaluate
        :param use_cuda: whether or not the GPU should be used.
        """
        evaluator = ImageNetEvaluator(DATASET_DIR, image_size=image_net_config.dataset['image_size'],
                                      batch_size=image_net_config.evaluation['batch_size'],
                                      num_workers=image_net_config.evaluation['num_workers'])

        return evaluator.evaluate(model, iterations=None, use_cuda=use_cuda)

    @staticmethod
    def finetune(model: torch.nn.Module, epochs, learning_rate, learning_rate_schedule, use_cuda):
        """
        Given a torch model, finetunes the model to improve its accuracy
        :param model: the model to finetune
        :param epochs: The number of epochs used during the finetuning step.
        :param learning_rate: The learning rate used during the finetuning step.
        :param learning_rate_schedule: The learning rate schedule used during the finetuning step.
        :param use_cuda: whether or not the GPU should be used.
        """
        trainer = ImageNetTrainer(DATASET_DIR, image_size=image_net_config.dataset['image_size'],
                                  batch_size=image_net_config.train['batch_size'],
                                  num_workers=image_net_config.train['num_workers'])

        trainer.train(model, max_epochs=epochs, learning_rate=learning_rate,
                      learning_rate_schedule=learning_rate_schedule, use_cuda=use_cuda)
def pass_calibration_data(sim_model, use_cuda):
    data_loader = ImageNetDataPipeline.get_val_dataloader()
    batch_size = data_loader.batch_size

    if use_cuda:
        device = torch.device('cuda')
    else:
        device = torch.device('cpu')

    sim_model.eval()
    samples = 1000

    batch_cntr = 0
    with torch.no_grad():
        for input_data, target_data in data_loader:

            inputs_batch = input_data.to(device)
            sim_model(inputs_batch)

            batch_cntr += 1
            if (batch_cntr * batch_size) > samples:
                break

In [None]:
model = resnet50(pretrained=True)
model = prepare_model(model)
use_cuda = False
if torch.cuda.is_available():
    use_cuda = True
    model.to(torch.device('cuda'))
_ = fold_all_batch_norms(model, input_shapes=(1, 3, 224, 224))
dummy_input = torch.rand(1, 3, 224, 224)    # Shape for each ImageNet sample is (3 channels) x (224 height) x (224 width)
if use_cuda:    dummy_input = dummy_input.cuda()

In [None]:
# quntization 실행
sim = QuantizationSimModel(model=model,
                           quant_scheme=QuantScheme.post_training_tf_enhanced,
                           dummy_input=dummy_input,
                           default_output_bw=8,
                           default_param_bw=8)
sim.compute_encodings(forward_pass_callback=pass_calibration_data, forward_pass_callback_args=use_cuda)
os.makedirs('./output/', exist_ok=True)
dummy_input = dummy_input.cpu()
sim.export(path='./output/', filename_prefix='resnet50_after_qat', dummy_input=dummy_input)

In [10]:
module_name=[]
for name,m in sim.model.named_modules():    
    if name =='': continue       
    if not isinstance(m,StaticGridQuantWrapper): continue
    if not (isinstance(m._module_to_wrap, torch.nn.Linear) or isinstance(m._module_to_wrap, torch.nn.Conv2d)): continue
    module_name.append(name)
module_name_fx = [x.replace('.','_') for x in module_name]

In [30]:
mm = sim.model
m_inout= dict()
for node in mm.graph.nodes:    
    if node.name not in module_name_fx: continue
    prev_node=[]
    for arg in node.args:
        if not isinstance(arg, torch.fx.Node): continue
        prev_node.append((arg.name).replace('_','.'))
    next_node=[]
    for user in node.users:
        next_node.append((user.name).replace('_','.'))
    m_inout[(node.name).replace('_','.')] = {'prev':prev_node,'next':next_node}

In [31]:
for k in m_inout.keys():
    print(k)
    print(m_inout[k]['prev'])
    print(m_inout[k]['next'])
    print('---------------')

conv1
['x']
['bn1']
---------------
layer1.0.conv1
['maxpool']
['layer1.0.bn1']
---------------
layer1.0.conv2
['layer1.0.relu']
['layer1.0.bn2']
---------------
layer1.0.conv3
['layer1.0.module.relu.1']
['layer1.0.bn3']
---------------
layer1.0.downsample.0
['maxpool']
['layer1.0.downsample.1']
---------------
layer1.1.conv1
['layer1.0.module.relu.2']
['layer1.1.bn1']
---------------
layer1.1.conv2
['layer1.1.relu']
['layer1.1.bn2']
---------------
layer1.1.conv3
['layer1.1.module.relu.1']
['layer1.1.bn3']
---------------
layer1.2.conv1
['layer1.1.module.relu.2']
['layer1.2.bn1']
---------------
layer1.2.conv2
['layer1.2.relu']
['layer1.2.bn2']
---------------
layer1.2.conv3
['layer1.2.module.relu.1']
['layer1.2.bn3']
---------------
layer2.0.conv1
['layer1.2.module.relu.2']
['layer2.0.bn1']
---------------
layer2.0.conv2
['layer2.0.relu']
['layer2.0.bn2']
---------------
layer2.0.conv3
['layer2.0.module.relu.1']
['layer2.0.bn3']
---------------
layer2.0.downsample.0
['layer1.2.module

In [27]:
sim.model.avgpool

StaticGridQuantWrapper(
  (_module_to_wrap): AdaptiveAvgPool2d(output_size=(1, 1))
)

In [46]:
ll

[x,
 conv1,
 bn1,
 relu,
 maxpool,
 layer1_0_conv1,
 layer1_0_bn1,
 layer1_0_relu,
 layer1_0_conv2,
 layer1_0_bn2,
 layer1_0_module_relu_1,
 layer1_0_conv3,
 layer1_0_bn3,
 layer1_0_downsample_0,
 layer1_0_downsample_1,
 layer1_0_module_add,
 layer1_0_module_relu_2,
 layer1_1_conv1,
 layer1_1_bn1,
 layer1_1_relu,
 layer1_1_conv2,
 layer1_1_bn2,
 layer1_1_module_relu_1,
 layer1_1_conv3,
 layer1_1_bn3,
 layer1_1_module_add_1,
 layer1_1_module_relu_2,
 layer1_2_conv1,
 layer1_2_bn1,
 layer1_2_relu,
 layer1_2_conv2,
 layer1_2_bn2,
 layer1_2_module_relu_1,
 layer1_2_conv3,
 layer1_2_bn3,
 layer1_2_module_add_2,
 layer1_2_module_relu_2,
 layer2_0_conv1,
 layer2_0_bn1,
 layer2_0_relu,
 layer2_0_conv2,
 layer2_0_bn2,
 layer2_0_module_relu_1,
 layer2_0_conv3,
 layer2_0_bn3,
 layer2_0_downsample_0,
 layer2_0_downsample_1,
 layer2_0_module_add_3,
 layer2_0_module_relu_2,
 layer2_1_conv1,
 layer2_1_bn1,
 layer2_1_relu,
 layer2_1_conv2,
 layer2_1_bn2,
 layer2_1_module_relu_1,
 layer2_1_conv3,
 laye

In [42]:
ll

[x,
 conv1,
 bn1,
 relu,
 maxpool,
 layer1_0_conv1,
 layer1_0_bn1,
 layer1_0_relu,
 layer1_0_conv2,
 layer1_0_bn2,
 layer1_0_module_relu_1,
 layer1_0_conv3,
 layer1_0_bn3,
 layer1_0_downsample_0,
 layer1_0_downsample_1,
 layer1_0_module_add,
 layer1_0_module_relu_2,
 layer1_1_conv1,
 layer1_1_bn1,
 layer1_1_relu,
 layer1_1_conv2,
 layer1_1_bn2,
 layer1_1_module_relu_1,
 layer1_1_conv3,
 layer1_1_bn3,
 layer1_1_module_add_1,
 layer1_1_module_relu_2,
 layer1_2_conv1,
 layer1_2_bn1,
 layer1_2_relu,
 layer1_2_conv2,
 layer1_2_bn2,
 layer1_2_module_relu_1,
 layer1_2_conv3,
 layer1_2_bn3,
 layer1_2_module_add_2,
 layer1_2_module_relu_2,
 layer2_0_conv1,
 layer2_0_bn1,
 layer2_0_relu,
 layer2_0_conv2,
 layer2_0_bn2,
 layer2_0_module_relu_1,
 layer2_0_conv3,
 layer2_0_bn3,
 layer2_0_downsample_0,
 layer2_0_downsample_1,
 layer2_0_module_add_3,
 layer2_0_module_relu_2,
 layer2_1_conv1,
 layer2_1_bn1,
 layer2_1_relu,
 layer2_1_conv2,
 layer2_1_bn2,
 layer2_1_module_relu_1,
 layer2_1_conv3,
 laye