# Tutorial 3: Running Quantization-Aware Training (QAT) on Bert

In [None]:
!git clone https://github.com/tonytarizzo/mase.git
%cd mase
!python -m pip install -e . -vvv
%cd src

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  Skipping link: none of the wheel's tags (pp36-none-any, pp37-none-any, pp38-none-any) are compatible (run pip debug --verbose to show compatible tags): https://files.pythonhosted.org/packages/5f/f6/c3e89e719845eec391b6c9c56c839bc5a23ba005c18a5a0168ce564b2166/coverage-6.4.1-pp36.pp37.pp38-none-any.whl (from https://pypi.org/simple/coverage/) (requires-python:>=3.7)
  Found link https://files.pythonhosted.org/packages/29/88/f42e8e662fc5f705071b6587855d6cac8b91a27f75855e8f2183703ef98a/coverage-6.4.1.tar.gz (from https://pypi.org/simple/coverage/) (requires-python:>=3.7), version: 6.4.1
  Skipping link: none of the wheel's tags (cp310-cp310-macosx_10_9_x86_64) are compatible (run pip debug --verbose to show compatible tags): https://files.pythonhosted.org/packages/68/8d/8218b3604ca937f2d1a4b05033de4c5dc92adfc0262e54636ad21c67a132/coverage-6.4.2-cp310-cp310-macosx_10_9_x86_64.whl (from https://pypi.org/simple/coverage/) (req

In this tutorial, we'll build on top of Tutorial 2 by taking the Bert model fine tuned for sequence classification and running Mase's quantization pass. First, we'll run simple Post-Training Quantization (PTQ) and see how much accuracy drops. Then, we'll run some further training iterations of the quantized model (i.e. QAT) and see whether the accuracy of the trained quantized model approaches the accuracy of the original (full-precision) model.

In [None]:
checkpoint = "prajjwal1/bert-tiny"
tokenizer_checkpoint = "bert-base-uncased"
dataset_name = "imdb"

## Importing the model

If you are starting from scratch, you can create a MaseGraph for Bert by running the following cell.

In [None]:
from transformers import AutoModelForSequenceClassification

from chop import MaseGraph
import chop.passes as passes

model = AutoModelForSequenceClassification.from_pretrained(checkpoint)
model.config.problem_type = "single_label_classification"

mg = MaseGraph(
    model,
    hf_input_names=[
        "input_ids",
        "attention_mask",
        "labels",
    ],
)

mg, _ = passes.init_metadata_analysis_pass(mg)
mg, _ = passes.add_common_metadata_analysis_pass(mg)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-tiny and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
`past_key_values` were not specified as input names, but model.config.use_cache = True. Setting model.config.use_cache = False.
[32mINFO    [0m [34mGetting dummy input for prajjwal1/bert-tiny.[0m


tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154,  102],
        [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877,  102]])
tensor([[0, 0, 0, 0, 0, 0, 0, 0, 0, 0]])
tensor([[ 101, 9932, 2089, 2202, 2058, 1996, 2088, 2028, 2154,  102],
        [ 101, 2023, 2003, 2339, 2017, 2323, 4553, 4748, 4877,  102]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
tensor([[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]])
tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]],


        [[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]]])
tensor([[[[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
          [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
       

If you have previously ran the tutorial on LoRA Finetuning, run the following cell to import the fine tuned checkpoint.

In [None]:
from pathlib import Path
from chop import MaseGraph

mg = MaseGraph.from_checkpoint(f"{Path.home()}/tutorial_2_lora")

  loaded_model = torch.load(f)


## Post-Training Quantization (PTQ)

Here, we simply quantize the model and evaluate the effect in its accuracy. First, let's evaluate the model accuracy before quantization (if you're coming from Tutorial 2, this should be the same as the post-LoRA evaluation accuracy). As seen in Tutorial 2, we can use the `get_tokenized_dataset` and `get_trainer` utilities to generate a HuggingFace `Trainer` instance for training and evaluation.

In [None]:
from chop.tools import get_tokenized_dataset, get_trainer

dataset, tokenizer = get_tokenized_dataset(
    dataset=dataset_name,
    checkpoint=tokenizer_checkpoint,
    return_tokenizer=True,
)

trainer = get_trainer(
    model=mg.model,
    tokenized_dataset=dataset,
    tokenizer=tokenizer,
    evaluate_metric="accuracy",
)

# Evaluate accuracy
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")

To run the quantization pass, we pass a quantization configuration dictionary as argument. This defines the quantization mode, numerical format and precision for each operator in the graph. We'll run the quantization in "by type" mode, meaning nodes are quantized according to their `mase_op`. Other modes include by name and by regex name. We'll quantize all activations, weights and biases in the model to fixed-point with the same precision. This may be sub-optimal, but works as an example. In future tutorials, we'll see how to run the `search` flow in `Mase` to find optimal quantization configurations to minimize accuracy loss.

In [None]:
import chop.passes as passes

# Linear configs were manually changed per combination for the tutorial task
quantization_config = {
    "by": "type",
    "default": {
        "config": {
            "name": None,
        }
    },
    "linear": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 16,
            "data_in_frac_width": 8,
            # weight
            "weight_width": 16,
            "weight_frac_width": 8,
            # bias
            "bias_width": 16,
            "bias_frac_width": 8,
        }
    },
}

mg, _ = passes.quantize_transform_pass(
    mg,
    pass_args=quantization_config,
)

Let's evaluate the immediate effect of quantization on the model accuracy.

In [None]:
trainer = get_trainer(
    model=mg.model,
    tokenized_dataset=dataset,
    tokenizer=tokenizer,
    evaluate_metric="accuracy",
)
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")

We can save the current checkpoint for future reference (optional).

In [None]:
from pathlib import Path

mg.export(f"{Path.home()}/tutorial_3_ptq")

[32mINFO    [0m [34mExporting MaseGraph to /root/tutorial_3_ptq.pt, /root/tutorial_3_ptq.mz[0m
[32mINFO    [0m [34mExporting GraphModule to /root/tutorial_3_ptq.pt[0m
[32mINFO    [0m [34mExporting MaseMetadata to /root/tutorial_3_ptq.mz[0m


## Quantization-Aware Training (QAT)

You should have seen in the last section that quantization can lead to a significant drop in accuracy. Next, we'll run QAT to evaluate whether this performance gap can be reduced. To run QAT in Mase, all you need to do is include the model back in your training loop after running the quantization pass.

In [None]:
# Evaluate accuracy
trainer.train()
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")

Step,Training Loss
500,0.3857
1000,0.3858
1500,0.3894
2000,0.371
2500,0.3782
3000,0.3828


Evaluation accuracy: 0.84468


# Evaluate accuracy
trainer.train()
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")

We can see the accuracy of the quantized model can match (or sometimes exceed) the full precision model, with a much lower memory requirement to store the weights. Finally, save the final checkpoint for future tutorials.

In [None]:
from pathlib import Path

mg.export(f"{Path.home()}/tutorial_3_qat")

[32mINFO    [0m [34mExporting MaseGraph to /root/tutorial_3_qat.pt, /root/tutorial_3_qat.mz[0m
[32mINFO    [0m [34mExporting GraphModule to /root/tutorial_3_qat.pt[0m
[32mINFO    [0m [34mExporting MaseMetadata to /root/tutorial_3_qat.mz[0m


In [None]:
### Plotting function for Lab 1, Part 1

import matplotlib.pyplot as plt
import matplotlib.ticker as mticker

# (Full-Precision) First test is full precision model,
# (PTQ) Second test is quantised model accuracy evaluation (no retraining),
# (QAT) Third test is quantised model accuracy evaluation (with retraining)

# The results were gathered manually and inputted below. Later test were done more programmatically
results = [
    {
        "test_id": 1,
        "data_in_width": 4,
        "data_in_frac_width": 2,
        "weight_width": 4,
        "weight_frac_width": 2,
        "bias_width": 4,
        "bias_frac_width": 2,
        "full_precision_accuracy": 0.83732,
        "ptq_accuracy": 0.5,
        "qat_accuracy": 0.5,
    },
    {
        "test_id": 2,
        "data_in_width": 8,
        "data_in_frac_width": 4,
        "weight_width": 8,
        "weight_frac_width": 4,
        "bias_width": 8,
        "bias_frac_width": 4,
        "full_precision_accuracy": 0.83732,
        "ptq_accuracy": 0.78388,
        "qat_accuracy": 0.84076,
    },
    {
        "test_id": 5,
        "data_in_width": 12,
        "data_in_frac_width": 8,
        "weight_width": 12,
        "weight_frac_width": 8,
        "bias_width": 12,
        "bias_frac_width": 8,
        "full_precision_accuracy": 0.83732,
        "ptq_accuracy": 0.83704,
        "qat_accuracy": 0.84412,
    },
    {
        "test_id": 6,
        "data_in_width": 16,
        "data_in_frac_width": 8,
        "weight_width": 16,
        "weight_frac_width": 8,
        "bias_width": 16,
        "bias_frac_width": 8,
        "full_precision_accuracy": 0.83732,
        "ptq_accuracy": 0.83796,
        "qat_accuracy": 0.84468,
    },
    {
        "test_id": 6,
        "data_in_width": 20,
        "data_in_frac_width": 10,
        "weight_width": 20,
        "weight_frac_width": 10,
        "bias_width": 20,
        "bias_frac_width": 10,
        "full_precision_accuracy": 0.83732,
        "ptq_accuracy": 0.83776,
        "qat_accuracy": 0.845,
    },
    {
        "test_id": 6,
        "data_in_width": 24,
        "data_in_frac_width": 12,
        "weight_width": 24,
        "weight_frac_width": 12,
        "bias_width": 24,
        "bias_frac_width": 12,
        "full_precision_accuracy": 0.83732,
        "ptq_accuracy": 0.83732,
        "qat_accuracy": 0.84488,
    },
    {
        "test_id": 7,
        "data_in_width": 28,
        "data_in_frac_width": 14,
        "weight_width": 28,
        "weight_frac_width": 14,
        "bias_width": 28,
        "bias_frac_width": 14,
        "full_precision_accuracy": 0.83732,
        "ptq_accuracy": 0.83736,
        "qat_accuracy": 0.84504,
    },
    {
        "test_id": 7,
        "data_in_width": 32,
        "data_in_frac_width": 16,
        "weight_width": 32,
        "weight_frac_width": 16,
        "bias_width": 32,
        "bias_frac_width": 16,
        "full_precision_accuracy": 0.83732,
        "ptq_accuracy": 0.83736,
        "qat_accuracy": 0.84492,
    }]

def plot_results_separated(results):
    fixed_point_widths = [4, 8, 12, 16, 20, 24, 28, 32]
    ptq_accuracies = [res["ptq_accuracy"] for res in results if res["ptq_accuracy"] is not None]
    qat_accuracies = [res["qat_accuracy"] for res in results if res["qat_accuracy"] is not None]
    highest_accuracies = [max(ptq_accuracies[i], qat_accuracies[i]) for i in range(len(ptq_accuracies))]

    # Plot 1: PTQ vs QAT
    plt.figure(figsize=(10, 6))
    plt.plot(
        fixed_point_widths[:len(ptq_accuracies)],
        ptq_accuracies,
        marker="s",
        label="PTQ Accuracy",
        linestyle="-.",
        linewidth=2,
        alpha=0.8,
    )
    plt.plot(
        fixed_point_widths[:len(qat_accuracies)],
        qat_accuracies,
        marker="d",
        label="QAT Accuracy",
        linestyle=":",
        linewidth=2,
        alpha=0.8,
    )
    # Add labels, grid, and legend
    plt.grid(visible=True, linestyle="--", alpha=0.5)
    plt.xlabel("Fixed Point Width", fontsize=12)
    plt.ylabel("Accuracy", fontsize=12)
    plt.title("PTQ vs QAT Accuracy by Fixed Point Width", fontsize=14, pad=15)
    plt.legend(fontsize=10, loc="lower right")

    # Format y-axis to show percentages
    plt.gca().yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{x*100:.0f}%"))

    # Add value labels on the points
    for x, y in zip(fixed_point_widths[:len(ptq_accuracies)], ptq_accuracies):
        plt.text(x, y + 0.002, f"{y*100:.2f}%", fontsize=8, ha="center")
    for x, y in zip(fixed_point_widths[:len(qat_accuracies)], qat_accuracies):
        plt.text(x, y + 0.002, f"{y*100:.2f}%", fontsize=8, ha="center")

    plt.tight_layout()
    plt.savefig("ptq_vs_qat_accuracy.png", dpi=300)
    plt.show()

    # Plot 2: Highest Accuracy vs Fixed Width
    plt.figure(figsize=(10, 6))
    plt.plot(
        fixed_point_widths[:len(highest_accuracies)],
        highest_accuracies,
        marker="o",
        label="Highest Accuracy",
        linestyle="--",
        linewidth=2,
    )

    # Add labels, grid, and legend
    plt.grid(visible=True, linestyle="--", alpha=0.5)
    plt.xlabel("Fixed Point Width", fontsize=12)
    plt.ylabel("Accuracy", fontsize=12)
    plt.title("Highest Accuracy by Fixed Point Width", fontsize=14, pad=15)

    # Format y-axis to show percentages
    plt.gca().yaxis.set_major_formatter(mticker.FuncFormatter(lambda x, _: f"{x*100:.0f}%"))

    # Add value labels on the points
    for x, y in zip(fixed_point_widths[:len(highest_accuracies)], highest_accuracies):
        plt.text(x, y + 0.002, f"{y*100:.2f}%", fontsize=8, ha="center")

    plt.tight_layout()
    plt.savefig("fixed_point_width_vs_accuracy.png", dpi=300)
    plt.show()

plot_results_separated(results)
