# QAT PROJECT: SST-2 DATASET

This file aims to load a model trained with pytorch and convert it to onnx format, the final objective is to apply QAT and observe the trade-offs between the baseline model and the optimized model.

In [18]:
from config import (
    FINE_TUNED_MODEL_SAVE_PATH,
    TOKENIZED_DATASET_SAVE_PATH, 
    TOKENIZER_SAVE_PATH, 
    PER_DEVICE_EVAL_BATCH_SIZE, 
    PER_DEVICE_TRAIN_BATCH_SIZE,
    SUBSET_SIZE, 
    NUM_PROCESSES_FOR_MAP, 
    MAX_SEQUENCE_LENGTH, 
    MODEL_NAME,
    QUANTIZED_QAT_MODEL_SAVE_PATH,
    #ONNX_MODEL_SAVE_PATH
)
import torch

if torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

print(f"Using device: {device}")


SUBSET_SIZE = 100

Using device: cpu


In [6]:
from transformers import DistilBertForSequenceClassification

model = DistilBertForSequenceClassification.from_pretrained(FINE_TUNED_MODEL_SAVE_PATH)

print("Model loaded successfully from:", FINE_TUNED_MODEL_SAVE_PATH)

Model loaded successfully from: ./fine_tuned_baseline_model


In [7]:
from src.evaluate_models import evaluate_pytorch_model
from src.data_preparation import load_and_preprocess_data, get_subsetted_datasets

sst2_ds, tokenized_ds, parent_tokenizer = load_and_preprocess_data(
    model_name=MODEL_NAME,
    tokenizer_save_path=TOKENIZER_SAVE_PATH,
    tokenized_dataset_save_path=TOKENIZED_DATASET_SAVE_PATH,
    max_length=MAX_SEQUENCE_LENGTH,
    num_processes_for_map=NUM_PROCESSES_FOR_MAP
)

tok_train_ds, tok_val_ds = get_subsetted_datasets(
    tokenized_ds=tokenized_ds,
    train_subset_size=SUBSET_SIZE,
)

Loading SST-2 dataset...
Loading tokenizer from local path: ./distilbert_tokenizer_local
Loading tokenized dataset from: ./SST2_tokenized_dataset

Using a SUBSET for training (Train size: 100).
Final subset sizes: Train=100, Eval=10


In [8]:
print("\nStarting evaluation of the baseline model...")
evaluate_pytorch_model(
    model_path=FINE_TUNED_MODEL_SAVE_PATH,
    eval_dataset=tok_val_ds,
    batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
    tokenizer=parent_tokenizer
)
print("Baseline model evaluation complete!")


Starting evaluation of the baseline model...
Evaluation device: cpu


Evaluating:   0%|          | 0/1 [00:00<?, ?it/s]

PyTorch Model Accuracy: 0.9000
Average Inference Time per Batch: 0.5241 seconds
Model Size: 255.43 MB
Baseline model evaluation complete!


In [9]:
from optimum.exporters.onnx import main_export

def export_to_onnx(model_name_or_path, output, tokenizer, task="sequence-classification", opset=17):
    """
    Export a model to ONNX format.
    
    Args:
        model_name_or_path (str): Path to the model.
        output (str): Output directory for the ONNX model.
        task (str): Task type for the model.
        tokenizer: Tokenizer used for the model.
        opset (int): ONNX opset version.
    """
    main_export(
        model_name_or_path=model_name_or_path,
        output=output,
        task=task,
        tokenizer=tokenizer,
        opset=opset
    )

In [12]:
from torch.quantization import (
    QConfig,
    FakeQuantize,
    PerChannelMinMaxObserver,
    MovingAverageMinMaxObserver
)
from torch.ao.quantization.qconfig_mapping import QConfigMapping
from torch.ao.quantization import get_default_qconfig, prepare_qat, convert


model.train()

# Exclude embeddings from quantization
model.distilbert.embeddings.qconfig = None


custom_qconfig = QConfig(
    activation=FakeQuantize.with_args(
        observer=MovingAverageMinMaxObserver, 
        quant_min=0, 
        quant_max=255, 
        dtype=torch.quint8
    ),
    weight=FakeQuantize.with_args(
        observer=MovingAverageMinMaxObserver, 
        quant_min=-128, 
        quant_max=127, 
        dtype=torch.qint8,
        qscheme=torch.per_tensor_symmetric)
)

model.qconfig = custom_qconfig

# qat_model = prepare_qat_fx(model)
qat_model = prepare_qat(model, inplace=False)
qat_model.to(device)

print("PyTorch model prepared for Quantization-Aware Training.")

PyTorch model prepared for Quantization-Aware Training.


In [14]:
from transformers import TrainingArguments, Trainer
from src.utils import compute_metrics

num_qat_epochs = 2
qat_output_dir = "./qat_finetuning_output"
qat_learning_rate = 2e-5

print(f"\nStarting QAT fine-tuning for {num_qat_epochs} epochs...")

qat_training_args = TrainingArguments(
    output_dir=qat_output_dir,
    num_train_epochs=num_qat_epochs,
    per_device_train_batch_size=PER_DEVICE_TRAIN_BATCH_SIZE,
    per_device_eval_batch_size=PER_DEVICE_EVAL_BATCH_SIZE,
    learning_rate=qat_learning_rate,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    logging_dir=f"{qat_output_dir}/logs",
    logging_steps=50,
    load_best_model_at_end=True,
    metric_for_best_model="accuracy",
    greater_is_better=True,
    fp16=torch.cuda.is_available(),
    report_to="tensorboard",
)

qat_trainer = Trainer(
    model=qat_model,
    args=qat_training_args,
    train_dataset=tok_train_ds,
    eval_dataset=tok_val_ds,
    processing_class=parent_tokenizer,
    compute_metrics=compute_metrics,
)

qat_trainer.train()
print("QAT fine-tuning complete.")


Starting QAT fine-tuning for 2 epochs...




Epoch,Training Loss,Validation Loss,Accuracy
1,No log,0.081199,0.9
2,No log,0.019116,1.0




QAT fine-tuning complete.


In [44]:
import os
from transformers import AutoTokenizer, AutoConfig

# --- Convert the model to its final quantized version ---
print("\nConverting QAT-trained model to final quantized version...")
qat_model.eval() 
quantized_model_obj = torch.quantization.convert(model, inplace=True)
print("Model converted to final quantized version.")

qat_model_save_path = "./quantized_pytorch_model"

# --- Save the QAT-trained and converted model ---
if not os.path.exists(qat_model_save_path):
    os.makedirs(qat_model_save_path)

torch.save(quantized_model_obj.state_dict(), os.path.join(qat_model_save_path, "pytorch_model.bin"))
print(f"QAT-trained model object saved to: {os.path.join(qat_model_save_path, 'pytorch_model.bin')}")

original_config = AutoConfig.from_pretrained(FINE_TUNED_MODEL_SAVE_PATH)
original_config.save_pretrained(qat_model_save_path)
tokenizer_obj = AutoTokenizer.from_pretrained(FINE_TUNED_MODEL_SAVE_PATH)
tokenizer_obj.save_pretrained(qat_model_save_path)
print("QAT model config and tokenizer saved.")


Converting QAT-trained model to final quantized version...
Model converted to final quantized version.
QAT-trained model object saved to: ./quantized_pytorch_model\pytorch_model.bin
QAT model config and tokenizer saved.


In [34]:
from optimum.onnxruntime.configuration import QuantizationConfig, QuantFormat, QuantType

qat_model.eval()

quantization_config_onnx = QuantizationConfig(
    is_static=True,
    per_channel=True,
    format=QuantFormat.QDQ,
    operators_to_quantize=["MatMul", "Gemm"],
    weights_symmetric=True,
    activations_symmetric=False,
    weights_dtype=QuantType.QInt8,
    activations_dtype=QuantType.QUInt8,
)

In [45]:
from optimum.exporters.onnx import main_export

# Set the output directory for the final ONNX model
output_onnx_dir = "./onnx_models_quantized"
os.makedirs(output_onnx_dir, exist_ok=True)
output_path = os.path.join(output_onnx_dir, "model.onnx")

main_export(
    # Pass the path to the directory containing the FP32 model
    model_name_or_path=qat_model_save_path,
    output=output_onnx_dir,
    task="sequence-classification",
    tokenizer=parent_tokenizer,
    opset=17,
    # This tells the exporter to perform quantization during export
    quantization_config=quantization_config_onnx,
    library_name='transformers',
    framework='pt',
)


In [48]:
from src.evaluate_models import evaluate_onnx_model

onnx_metrics, onnx_inference_time, onnx_model_size = evaluate_onnx_model(
    onnx_model_path=output_onnx_dir + "/model.onnx",
    tokenizer=parent_tokenizer,
    eval_dataset=tok_val_ds,
    use_gpu=torch.cuda.is_available(),
    batch_size=PER_DEVICE_EVAL_BATCH_SIZE
)


Evaluating ONNX model...


Evaluating ONNX Model:   0%|          | 0/1 [00:00<?, ?it/s]

ONNX Model Accuracy on CPU: 0.9000
Average Inference Time per Batch on CPU: 0.8408 seconds
ONNX Model Size: 255.52 MB
