In [5]:
import torch
from torchdistill.models.classification import resnet
from torchdistill.datasets.util import build_data_loader
from torchdistill.common import file_util, yaml_util, module_util
import os
from torch.nn import DataParallel
import argparse
from torch.nn.parallel import DistributedDataParallel
from torchdistill.misc.log import set_basic_log_config, setup_log_file, SmoothedValue, MetricLogger
from examples.torchvision.change_targets import transform_targets
from torchdistill.common.constant import def_logger
import time
import datetime

logger = def_logger.getChild(__name__)

def get_argparser():
    parser = argparse.ArgumentParser(description='Knowledge distillation for image classification models')
    parser.add_argument('--config', required=True, help='yaml file path')
    parser.add_argument('-test_only', action='store_true', help='only test the models')
    parser.add_argument('-fuse', action='store_true', help='fuse the layers')
    parser.add_argument('-cuda', action='store_true', help='run on cuda')

    return parser

def compute_accuracy(outputs, targets, topk=(1,)):
    """Computes the accuracy over the k top predictions for the specified values of k"""
    with torch.no_grad():
        maxk = max(topk)
        batch_size = targets.size(0)
        _, preds = outputs.topk(maxk, 1, True, True)
        preds = preds.t()
        corrects = preds.eq(targets[None])
        result_list = []
        for k in topk:
            correct_k = corrects[:k].flatten().sum(dtype=torch.float32)
            result_list.append(correct_k * (100.0 / batch_size))
        return result_list


@torch.inference_mode()
def evaluate(model, data_loader, device, device_ids, distributed, log_freq=1000, title=None, header='Test:',target_classes=None):
    model.to(device)
    if distributed:
        model = DistributedDataParallel(model, device_ids=device_ids)
    elif device.type.startswith('cuda'):
        model = DataParallel(model, device_ids=device_ids)

    if title is not None:
        logger.info(title)

    model.eval()
    metric_logger = MetricLogger(delimiter='  ')
    for image, target in metric_logger.log_every(data_loader, log_freq, header):
        image = image.to(device, non_blocking=True)
        if target_classes != None:
            target = transform_targets(target)
        target = target.to(device, non_blocking=True)
        output = model(image)
        acc1, acc5 = compute_accuracy(output, target, topk=(1, 5))
        # FIXME need to take into account that the datasets
        # could have been padded in distributed setup
        batch_size = image.shape[0]
        metric_logger.meters['acc1'].update(acc1.item(), n=batch_size)
        metric_logger.meters['acc5'].update(acc5.item(), n=batch_size)

    # gather the stats from all processes
    metric_logger.synchronize_between_processes()
    top1_accuracy = metric_logger.acc1.global_avg
    top5_accuracy = metric_logger.acc5.global_avg
    logger.info(' * Acc@1 {:.4f}\tAcc@5 {:.4f}\n'.format(top1_accuracy, top5_accuracy))
    return metric_logger.acc1.global_avg

def load_model(model_config,device):
    model = resnet.resnet(20,10,False,False)
    state_dict = model_config["src_ckpt"]
    model.load_state_dict(torch.load(state_dict,device,weights_only=False)["model"])
    return model

In [None]:
config = "configs/sample/cifar10/quantization/cifar10_resnet20_quantize.yaml"

config = yaml_util.load_yaml_file(os.path.expanduser(config))
# Load your pre-trained model

model_config = config["model"]
model = load_model(model_config).eval()
# Use 'qnnpack' for ARM CPUs: torch.quantization.get_default_qconfig('qnnpack')

calibration_data_loader_config = config["test"]["test_data_loader"]

dataset_dict = config['datasets']
calibration_data_loader = build_data_loader(dataset_dict[calibration_data_loader_config['dataset_id']],calibration_data_loader_config,False)

In [15]:
#QUANTIZING    
quantize = True
fuse = False
if quantize:
    start_time = time.time()
    logger.info("Quantizing model...")
    if fuse:
        print("Fusing model")
        model_fused = torch.quantization.fuse_modules(model, [['conv1', 'bn1', 'relu']])
    else:
        print('Model will not be fused')
        model_fused = model
    # Set up quantization configuration
    torch.backends.quantized.engine = "fbgemm"
    model_fused.qconfig = torch.quantization.get_default_qconfig('fbgemm')  # for x86 CPUs

    # Prepare model for quantization
    model_prepared = torch.quantization.prepare(model_fused)

    # Calibrate with sample data
    with torch.no_grad():
        for data, _ in calibration_data_loader:
            model_prepared(data)

    quantized_model = torch.quantization.convert(model_prepared)
    logger.info(f"Model is quantized")
    total_time = time.time() - start_time
    total_time_str = str(datetime.timedelta(seconds=int(total_time)))
    logger.info('Quantization time {}'.format(total_time_str))
 

Model will not be fused




In [None]:
quantized_model.eval()
val_data_loader_config = config['quantize']["val_data_loader"]
val_data_loader = build_data_loader(dataset_dict[val_data_loader_config['dataset_id']],val_data_loader_config,False)

cuda = False
if cuda:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else: 
    device = torch.device("cpu")
val_top1_accuracy = evaluate(quantized_model, val_data_loader, device, None, False,log_freq=config["quantize"]["log_freq"],header="Validation")

NotImplementedError: Could not run 'quantized::conv2d.new' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::conv2d.new' is only available for these backends: [Meta, QuantizedCPU, QuantizedCUDA, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMTIA, AutogradMeta, Tracer, AutocastCPU, AutocastXPU, AutocastMPS, AutocastCUDA, FuncTorchBatched, BatchedNestedTensor, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

Meta: registered at /pytorch/aten/src/ATen/core/MetaFallbackKernel.cpp:23 [backend fallback]
QuantizedCPU: registered at /pytorch/aten/src/ATen/native/quantized/cpu/qconv.cpp:2045 [kernel]
QuantizedCUDA: registered at /pytorch/aten/src/ATen/native/quantized/cudnn/Conv.cpp:391 [kernel]
BackendSelect: fallthrough registered at /pytorch/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:194 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:503 [backend fallback]
Functionalize: registered at /pytorch/aten/src/ATen/FunctionalizeFallbackKernel.cpp:349 [backend fallback]
Named: registered at /pytorch/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /pytorch/aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at /pytorch/aten/src/ATen/native/NegateFallback.cpp:18 [backend fallback]
ZeroTensor: registered at /pytorch/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:100 [backend fallback]
AutogradOther: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:63 [backend fallback]
AutogradCPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:67 [backend fallback]
AutogradCUDA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:75 [backend fallback]
AutogradXLA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:83 [backend fallback]
AutogradMPS: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:91 [backend fallback]
AutogradXPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:71 [backend fallback]
AutogradHPU: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:104 [backend fallback]
AutogradLazy: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:87 [backend fallback]
AutogradMTIA: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:79 [backend fallback]
AutogradMeta: registered at /pytorch/aten/src/ATen/core/VariableFallbackKernel.cpp:95 [backend fallback]
Tracer: registered at /pytorch/torch/csrc/autograd/TraceTypeManual.cpp:294 [backend fallback]
AutocastCPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:322 [backend fallback]
AutocastXPU: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:465 [backend fallback]
AutocastMPS: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:209 [backend fallback]
AutocastCUDA: fallthrough registered at /pytorch/aten/src/ATen/autocast_mode.cpp:165 [backend fallback]
FuncTorchBatched: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:731 [backend fallback]
BatchedNestedTensor: registered at /pytorch/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:758 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /pytorch/aten/src/ATen/functorch/VmapModeRegistrations.cpp:27 [backend fallback]
Batched: registered at /pytorch/aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at /pytorch/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /pytorch/aten/src/ATen/functorch/TensorWrapper.cpp:207 [backend fallback]
PythonTLSSnapshot: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:202 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /pytorch/aten/src/ATen/functorch/DynamicLayer.cpp:499 [backend fallback]
PreDispatch: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:206 [backend fallback]
PythonDispatcher: registered at /pytorch/aten/src/ATen/core/PythonFallbackKernel.cpp:198 [backend fallback]


In [6]:
import torch

config = "configs/sample/cifar10/quantization/cifar10_resnet20_quantize.yaml"

config = yaml_util.load_yaml_file(os.path.expanduser(config))
# Load your pre-trained model

model_config = config["model"]
model = load_model(model_config,device=torch.device('cpu')).eval()
# Use 'qnnpack' for ARM CPUs: torch.quantization.get_default_qconfig('qnnpack')

calibration_data_loader_config = config["test"]["test_data_loader"]

dataset_dict = config['datasets']
calibration_data_loader = build_data_loader(dataset_dict[calibration_data_loader_config['dataset_id']],calibration_data_loader_config,False)

In [7]:
quantized_model = torch.quantization.quantize_dynamic(
    model, 
    {torch.nn.Linear, torch.nn.LSTM}, 
    dtype=torch.qint8
)

In [10]:
quantized_model.eval()
val_data_loader_config = config['quantize']["val_data_loader"]
val_data_loader = build_data_loader(dataset_dict[val_data_loader_config['dataset_id']],val_data_loader_config,False)

cuda = False
if cuda:
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
else: 
    device = torch.device("cpu")
val_top1_accuracy = evaluate(quantized_model, val_data_loader, device, None, False,log_freq=config["quantize"]["log_freq"],header="Validation")

In [None]:

print(val_top1_accuracy)
torch.save(quantized_model.state_dict(), 'dynamic_quantized_model_state_dict.pth')
torch.save(quantized_model, 'dynamic_quantized_model.pth')

92.7


In [None]:

loaded_quantized_model = torch.load('dynamic_quantized_model.pth',weights_only=False)

In [None]:

loaded_val_top1_accuracy = evaluate(loaded_quantized_model, val_data_loader, device, None, False,log_freq=config["quantize"]["log_freq"],header="Validation")
print(loaded_val_top1_accuracy)

92.7


In [18]:
loadedw_q_model = load_model(model_config,device=torch.device('cpu')).eval()
loadedw_q_model = torch.load("dynamic_quantized_model_state_dict.pth")

In [None]:

loadedw_val_top1_accuracy = evaluate(loadedw_q_model, val_data_loader, device, None, False,log_freq=config["quantize"]["log_freq"],header="Validation")
print(loadedw_val_top1_accuracy)