In [1]:
import torch
import torch.nn as nn 
import torch.quantization as tq
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from dotenv import load_dotenv
load_dotenv()

True

In [3]:
model_name = "distilgpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [5]:
qat_config = tq.QConfig(
    activation = tq.FakeQuantize.with_args(
        observer = tq.MovingAverageMinMaxObserver,
        quant_min = 0, quant_max = 255,
        dtype = torch.quint8, 
        qscheme = torch.per_tensor_affine
    ),
    weight = tq.FakeQuantize.with_args(
        observer = tq.MinMaxObserver,
        quant_min = -128, quant_max = 127,
        dtype = torch.qint8,
        qscheme = torch.per_tensor_symmetric
    )
)

In [6]:
# We are going to quantize only the linear layers and not the embedding layers

for name, module in model.named_modules():
    if isinstance(module, nn.Embedding):
        model.qconfig = None 

In [7]:
model.qconfig = qat_config

In [8]:
# we will now fuse the model with fake quantization operations

model.train()
tq.prepare_qat(model, inplace=True)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  tq.prepare_qat(model, inplace=True)


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(
      50257, 768
      (activation_post_process): FakeQuantize(
        fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32)
        (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
      )
    )
    (wpe): Embedding(
      1024, 768
      (activation_post_process): FakeQuantize(
        fake_quant_enabled=tensor([1], dtype=torch.uint8), observer_enabled=tensor([1], dtype=torch.uint8), quant_min=0, quant_max=255, dtype=torch.quint8, qscheme=torch.per_tensor_affine, ch_axis=-1, scale=tensor([1.]), zero_point=tensor([0], dtype=torch.int32)
        (activation_post_process): MovingAverageMinMaxObserver(min_val=inf, max_val=-inf)
      )
    )
    (drop): Dropout(p=0.1, inplace=False)
    (h):

In [9]:
# Now fine tuning the quantization aware model 

inputs = tokenizer("Quantization Aware Training in LLMs!", return_tensors="pt")

In [10]:
labels = inputs["input_ids"]

In [None]:
# Setting up the optimizer

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)