<a href="https://colab.research.google.com/github/Tuan-Lee-23/Vietnamese-News-Generative-Model/blob/main/Dynamic_Quantization_on_GPT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# (experimental) Dynamic Quantization on BERT

**Author**: [Jianyu Huang](https://github.com/jianyuh)

**Reviewed by**: [Raghuraman Krishnamoorthi](https://github.com/raghuramank100)

**Edited by**: [Jessica Lin](https://)

# 3. Apply the dynamic quantization

We call `torch.quantization.quantize_dynamic` on the model to apply the dynamic quantization on the HuggingFace BERT model. Specifically,

- We specify that we want the torch.nn.Linear modules in our model to be quantized;
- We specify that we want weights to be converted to quantized int8 values.

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM

tokenizer = AutoTokenizer.from_pretrained("tuanle/VN-News-GPT2")

model = AutoModelForCausalLM.from_pretrained("tuanle/VN-News-GPT2")
model.eval()

Model Sizes
FP32 Model Size: 486.77 MB
INT8 Model Size: 523.59 MB


In [None]:
!pip install -q onnx

[K     |████████████████████████████████| 12.8 MB 5.2 MB/s 
[?25h

In [None]:
# from quantize_helper import QuantizeHelper
from transformers.modeling_utils import Conv1D
import torch

def _conv1d_to_linear(module):
    in_size, out_size = module.weight.shape
    linear = torch.nn.Linear(in_size, out_size)
    linear.weight.data = module.weight.data.T.contiguous()
    linear.bias.data = module.bias.data
    return linear


def conv1d_to_linear(model):
    '''in-place
    This is for Dynamic Quantization, as Conv1D is not recognized by PyTorch, convert it to nn.Linear
    '''
    for name in list(model._modules):
        module = model._modules[name]
        if isinstance(module, Conv1D):
            linear = _conv1d_to_linear(module)
            model._modules[name] = linear
        else:
            conv1d_to_linear(module)

def get_model_size(model, temp_dir="/tmp"):

    model_dir = os.path.join(temp_dir, "temp")
    torch.save(model.state_dict(), model_dir)
    # model.save_pretrained(model_dir)
    size = os.path.getsize(model_dir)
    os.remove(model_dir)
    
    return size


def quantize_torch_model(model, dtype=torch.qint8):
    '''
    Usage: model = quantize_model(model)
    TODO: mix of in-place and return, but results are different
    '''
    conv1d_to_linear(model)
    quantized_model = torch.quantization.quantize_dynamic(model, {torch.nn.Linear}, dtype=dtype)
    return quantized_model

quantized_model = quantize_torch_model(model)
print("=" * 75)
print("Model Sizes")
print("=" * 75)

model_size = get_model_size(model=model)
quantized_model_size = get_model_size(model=quantized_model)

print("FP32 Model Size: {:.2f} MB".format(model_size / (2 ** 20)))
print("INT8 Model Size: {:.2f} MB".format(quantized_model_size / (2 ** 20)))

Model Sizes
FP32 Model Size: 486.77 MB
INT8 Model Size: 280.62 MB


The BERT model used in this tutorial (bert-base-uncased) has a vocabulary size V of 30522. With the embedding size of 768, the total size of the word embedding table is ~ 4 (Bytes/FP32) * 30522 * 768 = 90 MB. So with the help of quantization, the model size of the non-embedding table part is reduced from 350 MB (FP32 model) to 90 MB (INT8 model).

Running this locally on a MacBook Pro, without quantization, inference (for all 408 examples in MRPC dataset) takes about 160 seconds, and with quantization it takes just about 90 seconds. We summarize the results for running the quantized BERT model inference on a Macbook Pro as the follows:

```
| Prec | F1 score | Model Size | 1 thread | 4 threads | 
| FP32 |  0.9019  |   438 MB   | 160 sec  | 85 sec    |
| INT8 |  0.8953  |   181 MB   |  90 sec  | 46 sec    |
```

We have 0.6% F1 score accuracy after applying the post-training dynamic quantization on the fine-tuned BERT model on the MRPC task. As a comparison, in a [recent paper](https://arxiv.org/pdf/1910.06188.pdf) (Table 1), it achieved 0.8788 by applying the post-training dynamic quantization and 0.8956 by applying the quantization-aware training. The main difference is that we support the asymmetric quantization in PyTorch while that paper supports the symmetric quantization only.

Note that we set the number of threads to 1 for the single-thread comparison in this tutorial. We also support the intra-op parallelization for these quantized INT8 operators. The users can now set multi-thread by `torch.set_num_threads(N)` (`N` is the number of intra-op parallelization threads). One preliminary requirement to enable the intra-op parallelization support is to build PyTorch with the right [backend](https://pytorch.org/docs/stable/notes/cpu_threading_torchscript_inference.html#build-options) such as OpenMP, Native, or TBB. You can use `torch.__config__.parallel_info()` to check the parallelization settings. On the same MacBook Pro using PyTorch with Native backend for parallelization, we can get about 46 seconds for processing the evaluation of MRPC dataset.

## 3.3 Serialize the quantized model
We can serialize and save the quantized model for the future use.

In [None]:
quantized_output_dir = configs.output_dir + "quantized/"
if not os.path.exists(quantized_output_dir):
    os.makedirs(quantized_output_dir)
    quantized_model.save_pretrained(quantized_output_dir)

# Conclusion
In this tutorial, we demonstrated how to demonstrate how to convert a well-known state-of-the-art NLP model like BERT into dynamic quantized model. Dynamic quantization can reduce the size of the model while only having a limited implication on accuracy.

Thanks for reading! As always, we welcome any feedback, so please create an issue [here](https://github.com/pytorch/pytorch/issues) if you have any.

# References
[1] J.Devlin, M. Chang, K. Lee and K. Toutanova, [BERT: Pre-training of Deep Bidirectional Transformers for Language Understanding](https://arxiv.org/pdf/1810.04805.pdf) (2018)

[2] [HuggingFace Transformers](https://github.com/huggingface/transformers).

[3] O. Zafrir, G. Boudoukh, P. Izsak, & M. Wasserblat (2019). [Q8BERT: Quantized 8bit BERT](https://arxiv.org/pdf/1910.06188.pdf).

