In [1]:
from FaceLandmarkDetection.src.quantization.helper import *

In [2]:
def calibrate(model, data_loader):
    model.eval()
    with torch.no_grad():
        for image, target in data_loader:
            model(image)
            
import os
def print_size_of_model(model):
    if isinstance(model, torch.jit.RecursiveScriptModule):
        torch.jit.save(model, "temp.p")
    else:
        torch.jit.save(torch.jit.script(model), "temp.p")
    print('Size (MB):', os.path.getsize("temp.p")/1e6)
    os.remove('temp.p')

In [3]:
num_classes = 136
cuda_device = torch.device("cuda:0")
cpu_device = torch.device("cpu:0")

model_dir = "checkpoints"
model_filename = "resnet18_FLM.pt"
model_filepath = os.path.join(model_dir, model_filename)

In [4]:
# Create an untrained model.
model = create_model(num_classes=136)
# Load a pretrained model.
float_model = load_model(model=model, model_filepath=model_filepath, device=cpu_device)

train_dataset, val_dataset = get_data()
train_loader = make_loader(train_dataset, 64)
val_loader = make_loader(val_dataset, 32)

1111


In [5]:
ts_model = torch.jit.script(float_model.eval()) # or torch.jit.trace(float_model, input)

STATIC QUANTIZATION

In [6]:
import torch
from torch.quantization import get_default_qconfig, quantize_jit
qconfig = get_default_qconfig('fbgemm')
qconfig

QConfig(activation=functools.partial(<class 'torch.quantization.observer.HistogramObserver'>, reduce_range=True), weight=functools.partial(<class 'torch.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric))

In [7]:
quantized_model = quantize_jit(
    ts_model, # TorchScript model
    {'': qconfig}, # qconfig dict
    calibrate, # calibration function
    [val_loader]) # positional arguments to calibration function, typically some sample dataset

In [8]:
print("Size of model before quantization")
print_size_of_model(ts_model)
print("Size of model after quantization")
print_size_of_model(quantized_model)

Size of model before quantization
Size (MB): 45.075501
Size of model after quantization
Size (MB): 11.355301


In [9]:
evaluate_model(model=model, test_loader=val_loader, device=cpu_device, criterion=nn.MSELoss())

0.04602653190896318

In [8]:
fp32_cpu_inference_latency = measure_inference_latency(model=ts_model, device=cpu_device, input_size=(32,1,224,224), num_samples=100)
int8_cpu_inference_latency = measure_inference_latency(model=quantized_model, device=cpu_device, input_size=(32,1,224,224), num_samples=100)
print("FP32 CPU Inference Latency: {:.2f} ms / sample".format(fp32_cpu_inference_latency * 1000))
print("INT8 CPU Inference Latency: {:.2f} ms /sample".format(int8_cpu_inference_latency * 1000))

FP32 CPU Inference Latency: 1613.90 ms / sample
INT8 CPU Inference Latency: 274.68 ms /sample


DYNAMIC QUANTIZATION

In [9]:
import torch
from torch.quantization import per_channel_dynamic_qconfig
from torch.quantization import quantize_dynamic_jit

In [10]:
qconfig_dict = {'': per_channel_dynamic_qconfig}
qconfig_dict

{'': QConfigDynamic(activation=<class 'torch.quantization.observer.MinMaxDynamicQuantObserver'>, weight=functools.partial(<class 'torch.quantization.observer.PerChannelMinMaxObserver'>, dtype=torch.qint8, qscheme=torch.per_channel_symmetric))}

In [11]:
quantized_model = quantize_dynamic_jit(ts_model, qconfig_dict)

In [12]:
print("Size of model before quantization")
print_size_of_model(ts_model)
print("Size of model after quantization")
print_size_of_model(quantized_model)

Size of model before quantization
Size (MB): 45.075181
Size of model after quantization
Size (MB): 44.750093


In [13]:
int8_eval_loss = evaluate_model(model=quantized_model, test_loader=val_loader, device=cpu_device, criterion=nn.MSELoss())
# Skip this assertion since the values might deviate a lot.
# assert model_equivalence(model_1=model, model_2=quantized_jit_model, device=cpu_device, rtol=1e-01, atol=1e-02, num_tests=100, input_size=(1,3,32,32)), "Quantized model deviates from the original model too much!"

print("INT8 evaluation loss: {:.3f}".format(int8_eval_loss))

INT8 evaluation loss: 0.048


In [None]:
int8_cpu_inference_latency = measure_inference_latency(model=quantized_model, device=cpu_device, input_size=(32,1,224,224), num_samples=100)
print("INT8 CPU Inference Latency: {:.2f} ms /sample".format(int8_cpu_inference_latency * 1000))

NO SUPPORT FOR QAT IN GRAPH MODE

INT8 CPU Inference Latency: 1666.39 ms /sample


NO SUPPORT FOR QAT IN GRAPH MODE