# Tutorial 3: Running Quantization-Aware Training (QAT) on Bert

In this tutorial, we'll build on top of Tutorial 2 by taking the Bert model fine tuned for sequence classification and running Mase's quantization pass. First, we'll run simple Post-Training Quantization (PTQ) and see how much accuracy drops. Then, we'll run some further training iterations of the quantized model (i.e. QAT) and see whether the accuracy of the trained quantized model approaches the accuracy of the original (full-precision) model.

In [1]:
checkpoint = "prajjwal1/bert-tiny"
tokenizer_checkpoint = "bert-base-uncased"
dataset_name = "imdb"

## Importing the model

If you are starting from scratch, you can create a MaseGraph for Bert by running the following cell.

In [2]:
from transformers import AutoModelForSequenceClassification

from chop import MaseGraph
import chop.passes as passes

model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
model.config.problem_type = "single_label_classification"

mg = MaseGraph(
    model,
    hf_input_names=[
        "input_ids",
        "attention_mask",
        "labels",
    ],
)

mg, _ = passes.init_metadata_analysis_pass(mg)
mg, _ = passes.add_common_metadata_analysis_pass(mg)

  from .autonotebook import tqdm as notebook_tqdm
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting model.config.use_cache = False.
[32mINFO    [0m [34mGetting dummy input for prajjwal1/bert-tiny.[0m


If you have previously ran the tutorial on LoRA Finetuning, run the following cell to import the fine tuned checkpoint.

In [2]:
from pathlib import Path
from chop import MaseGraph

lab0_out_dir = Path("/workspace/labs/lab0/outputs")
lab1_out_dir = Path("/workspace/labs/lab1/outputs")
lab1_out_dir.mkdir(parents=True, exist_ok=True)

mg = MaseGraph.from_checkpoint(f"{lab0_out_dir}/tutorial_2_lora")

  import pynvml  # type: ignore[import]
  from .autonotebook import tqdm as notebook_tqdm


## Post-Training Quantization (PTQ)

Here, we simply quantize the model and evaluate the effect in its accuracy. First, let's evaluate the model accuracy before quantization (if you're coming from Tutorial 2, this should be the same as the post-LoRA evaluation accuracy). As seen in Tutorial 2, we can use the `get_tokenized_dataset` and `get_trainer` utilities to generate a HuggingFace `Trainer` instance for training and evaluation.

In [3]:
from chop.tools import get_tokenized_dataset, get_trainer

dataset, tokenizer = get_tokenized_dataset(
    dataset=dataset_name,
    checkpoint=tokenizer_checkpoint,
    return_tokenizer=True,
)

trainer = get_trainer(
    model=mg.model,
    tokenized_dataset=dataset,
    tokenizer=tokenizer,
    evaluate_metric="accuracy",
)

# Evaluate accuracy
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")

[32mINFO    [0m [34mTokenizing dataset imdb with AutoTokenizer for bert-base-uncased.[0m
Map: 100%|██████████| 25000/25000 [00:05<00:00, 4883.96 examples/s]
Map: 100%|██████████| 25000/25000 [00:05<00:00, 4749.33 examples/s]
Map: 100%|██████████| 50000/50000 [00:10<00:00, 4641.88 examples/s]
  trainer = Trainer(


Evaluation accuracy: 0.83488


To run the quantization pass, we pass a quantization configuration dictionary as argument. This defines the quantization mode, numerical format and precision for each operator in the graph. We'll run the quantization in "by type" mode, meaning nodes are quantized according to their `mase_op`. Other modes include by name and by regex name. We'll quantize all activations, weights and biases in the model to fixed-point with the same precision. This may be sub-optimal, but works as an example. In future tutorials, we'll see how to run the `search` flow in `Mase` to find optimal quantization configurations to minimize accuracy loss.

In [4]:
import chop.passes as passes

quantization_config = {
    "by": "type",
    "default": {
        "config": {
            "name": None,
        }
    },
    "linear": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
    },
}

mg, _ = passes.quantize_transform_pass(
    mg,
    pass_args=quantization_config,
)

Let's evaluate the immediate effect of quantization on the model accuracy.

In [5]:
trainer = get_trainer(
    model=mg.model,
    tokenized_dataset=dataset,
    tokenizer=tokenizer,
    evaluate_metric="accuracy",
)
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")

  trainer = Trainer(


Evaluation accuracy: 0.7494


We can save the current checkpoint for future reference (optional).

In [6]:
from pathlib import Path

mg.export(f"{lab1_out_dir}/tutorial_3_ptq")

[32mINFO    [0m [34mExporting MaseGraph to /workspace/labs/lab1/outputs/tutorial_3_ptq.pt, /workspace/labs/lab1/outputs/tutorial_3_ptq.mz[0m
[32mINFO    [0m [34mExporting GraphModule to /workspace/labs/lab1/outputs/tutorial_3_ptq.pt[0m
[32mINFO    [0m [34mSaving full model format[0m
[32mINFO    [0m [34mExporting MaseMetadata to /workspace/labs/lab1/outputs/tutorial_3_ptq.mz[0m


## Quantization-Aware Training (QAT)

You should have seen in the last section that quantization can lead to a significant drop in accuracy. Next, we'll run QAT to evaluate whether this performance gap can be reduced. To run QAT in Mase, all you need to do is include the model back in your training loop after running the quantization pass.

In [7]:
# Evaluate accuracy
trainer.train()
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")

Step,Training Loss
500,0.4057
1000,0.3945
1500,0.4057
2000,0.3878
2500,0.3884
3000,0.3882


Evaluation accuracy: 0.83856


We can see the accuracy of the quantized model can match (or sometimes exceed) the full precision model, with a much lower memory requirement to store the weights. Finally, save the final checkpoint for future tutorials.

In [8]:
from pathlib import Path

mg.export(f"{lab1_out_dir}/tutorial_3_qat")

[32mINFO    [0m [34mExporting MaseGraph to /workspace/labs/lab1/outputs/tutorial_3_qat.pt, /workspace/labs/lab1/outputs/tutorial_3_qat.mz[0m
[32mINFO    [0m [34mExporting GraphModule to /workspace/labs/lab1/outputs/tutorial_3_qat.pt[0m
[32mINFO    [0m [34mSaving full model format[0m
[32mINFO    [0m [34mExporting MaseMetadata to /workspace/labs/lab1/outputs/tutorial_3_qat.mz[0m
