# Tutorial 3: Running Quantization-Aware Training (QAT) on a Bert model

In [None]:
from pathlib import Path
import torch
from chop import MaseGraph
import chop.passes as passes

mg = MaseGraph.from_checkpoint(f"{Path.home()}/tutorial_2")

print(mg.model)

## Post-Training Quantization (PTQ)

Here, we simply quantize the model and evaluate the effect in its accuracy.

In [None]:
import chop.passes as passes

quantization_config = {
    "by": "type",
    # default config, this would be used for any node that does not have a specific config
    "default": {"config": {"name": None}},
    "linear": {
        "config": {
            "name": "integer",
            # data
            "data_in_width": 8,
            "data_in_frac_width": 4,
            # weight
            "weight_width": 8,
            "weight_frac_width": 4,
            # bias
            "bias_width": 8,
            "bias_frac_width": 4,
        }
    },
}

mg, _ = passes.quantize_transform_pass(
    mg,
    pass_args=quantization_config,
)

In [None]:
from chop.tools import get_tokenized_dataset, get_trainer

dataset, tokenizer = get_tokenized_dataset(
    dataset="imdb",
    checkpoint="bert-base-uncased",
    return_tokenizer=True,
)

trainer = get_trainer(
    model=mg.model,
    tokenized_dataset=dataset,
    tokenizer=tokenizer,
    evaluate_metric="accuracy",
)

In [None]:
# Evaluate accuracy
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")

## Quantization-Aware Training (QAT)

In [None]:
from time import time


def train(trainer):
    start_time = time()
    trainer.train()
    end_time = time()

    print(f"Training for 1 epoch took {end_time - start_time} seconds")


train(trainer)

In [None]:
# Evaluate accuracy
eval_results = trainer.evaluate()
print(f"Evaluation accuracy: {eval_results['eval_accuracy']}")