# Introduction

This tutorial demonstrates how to perform quantization aware training (QAT) on a [DistilBERT](https://huggingface.co/distilbert-base-uncased-finetuned-sst-2-english) model and export the quantized PyTorch model to an onnx model.

## Prerequisite

### 1. Install packages

In [None]:
!pip install datasets neural-compressor transformers torch onnxruntime onnx

## Run

In [1]:
model_name_or_path = "distilbert-base-uncased-finetuned-sst-2-english"
task = "sst2"

### 1. Prepare dataloader

In [3]:
import torch
import logging
from datasets import load_dataset
from transformers import AutoTokenizer
from transformers.data.data_collator import DataCollatorWithPadding

logger = logging.getLogger(__name__)
logging.basicConfig(format = '%(asctime)s - %(levelname)s - %(name)s -   %(message)s',
                    datefmt = '%m/%d/%Y %H:%M:%S',
                    level = logging.WARN)

task_to_keys = {
    "cola": ("sentence", None),
    "mnli": ("premise", "hypothesis"),
    "mrpc": ("sentence1", "sentence2"),
    "qnli": ("question", "sentence"),
    "qqp": ("question1", "question2"),
    "rte": ("sentence1", "sentence2"),
    "sst2": ("sentence", None),
    "stsb": ("sentence1", "sentence2"),
    "wnli": ("sentence1", "sentence2"),
}

class GLUEDataset:
    def __init__(self, task, model_name_or_path, max_seq_length=128, data_dir=None):
        raw_dataset = load_dataset('glue', task, cache_dir=data_dir, split='train')
        tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
        sentence1_key, sentence2_key = task_to_keys[task]
        origin_keys = raw_dataset[0].keys()
        def preprocess_function(examples):
            # Tokenize the texts
            args = (
                (examples[sentence1_key],) if sentence2_key is None else (examples[sentence1_key], examples[sentence2_key])
            )
            result = tokenizer(*args, padding=True, max_length=max_seq_length, truncation=True)
            if  "label" in examples:
                result["label"] = examples["label"]
            return result
        self.dataset = raw_dataset.map(
            preprocess_function, batched=True, load_from_cache_file=True, remove_columns=origin_keys
        )

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, index):
        batch = self.dataset[index]
        return batch

# Generate SST-2 dataloader for DistilBERT model
dataset = GLUEDataset(task=task, 
                      model_name_or_path=model_name_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_name_or_path)
dataloader = torch.utils.data.DataLoader(dataset, 
                                         collate_fn=DataCollatorWithPadding(tokenizer))



### 2. Perform quantization aware training

In [4]:
from transformers import AutoModelForSequenceClassification
from neural_compressor import QuantizationAwareTrainingConfig
from neural_compressor.training import prepare_compression

# training function
def train_func(compression_manager, model, dataloader):
    compression_manager.callbacks.on_train_begin()
    optimizer = torch.optim.SGD(model.parameters(), lr=0.0001)
    epochs = 1
    iters = 10
    for nepoch in range(epochs):
        model.train()
        for idx, batch in enumerate(dataloader):
            batch.pop('labels')
            output = model(**batch)
            loss = output.logits[0][0]
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            if idx >= iters:
                break
    compression_manager.callbacks.on_train_end()
    return model

# Perform quantization aware training
model = AutoModelForSequenceClassification.from_pretrained(model_name_or_path)
quant_conf = QuantizationAwareTrainingConfig()
compression_manager = prepare_compression(model, quant_conf)
q_model = train_func(compression_manager, compression_manager.model, dataloader)

2023-02-13 16:39:23 [INFO] Fx trace of the entire model failed. We will conduct auto quantization
You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.
2023-02-13 16:39:29 [INFO] |******Mixed Precision Statistics******|
2023-02-13 16:39:29 [INFO] +----------------------+-------+-------+
2023-02-13 16:39:29 [INFO] |       Op Type        | Total |  INT8 |
2023-02-13 16:39:29 [INFO] +----------------------+-------+-------+
2023-02-13 16:39:29 [INFO] |      Embedding       |   2   |   2   |
2023-02-13 16:39:29 [INFO] | quantize_per_tensor  |   51  |   51  |
2023-02-13 16:39:29 [INFO] |      LayerNorm       |   13  |   13  |
2023-02-13 16:39:29 [INFO] |      dequantize      |   51  |   51  |
2023-02-13 16:39:29 [INFO] |        Linear        |   38  |   38  |
2023-02-13 16:39:29 [INFO] |       Dropout        |   6   

### 3. Export to ONNX model

In [5]:
# Get params for export function
it = iter(dataloader)
input = next(it)
input.pop('labels')
symbolic_names = {0: 'batch_size', 1: 'max_seq_len'}
dynamic_axes = {k: symbolic_names for k in input.keys()}

# Export INT8 PyTorch model to INT8 ONNX model
from neural_compressor.config import Torch2ONNXConfig
int8_onnx_config = Torch2ONNXConfig(dtype="int8",
                                    opset_version=14,
                                    quant_format="QLinear", # or "QDQ" to export to QDQ ONNX model
                                    example_inputs=tuple(input.values()),
                                    input_names=list(input.keys()),
                                    output_names=['labels'],
                                    dynamic_axes=dynamic_axes,
                                    )
q_model.export('distilbert-base-uncased-finetuned-sst-2-english-qat.onnx', int8_onnx_config)

  mask, torch.tensor(torch.finfo(scores.dtype).min)
2023-02-13 16:39:35 [INFO] Weight type: QInt8.
2023-02-13 16:39:35 [INFO] Activation type: QUInt8.
2023-02-13 16:40:05 [INFO] *************************************************************************************************
2023-02-13 16:40:05 [INFO] The INT8 ONNX Model is exported to path: distilbert-base-uncased-finetuned-sst-2-english-qat.onnx
2023-02-13 16:40:05 [INFO] *************************************************************************************************
