In [1]:
import torch
import torch.nn as nn 
import torch.quantization as tq
from transformers import AutoModelForCausalLM, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from dotenv import load_dotenv
load_dotenv()

True

In [3]:
model_name = "distilgpt2"
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)

In [4]:
qat_config = tq.QConfig(
    activation = tq.FakeQuantize.with_args(
        observer = tq.MovingAverageMinMaxObserver,
        quant_min = 0, quant_max = 255,
        dtype = torch.quint8, 
        qscheme = torch.per_tensor_affine
    ),
    weight = tq.FakeQuantize.with_args(
        observer = tq.MinMaxObserver,
        quant_min = -128, quant_max = 127,
        dtype = torch.qint8,
        qscheme = torch.per_tensor_symmetric
    )
)

In [5]:
# We are going to quantize only the linear layers and not the embedding layers

for m in model.modules():
    if isinstance(m, nn.Embedding):
        m.qconfig = None
    if isinstance(m, nn.LayerNorm):
        m.qconfig = None

In [6]:
model.qconfig = qat_config

In [7]:
# we will now fuse the model with fake quantization operations

model.train()
tq.prepare_qat(model, inplace=True)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  tq.prepare_qat(model, inplace=True)


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (drop): Dropout(p=0.1, inplace=False)
    (h): ModuleList(
      (0-5): 6 x GPT2Block(
        (ln_1): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (attn): GPT2Attention(
          (c_attn): Conv1D(nf=2304, nx=768)
          (c_proj): Conv1D(nf=768, nx=768)
          (attn_dropout): Dropout(p=0.1, inplace=False)
          (resid_dropout): Dropout(p=0.1, inplace=False)
        )
        (ln_2): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
        (mlp): GPT2MLP(
          (c_fc): Conv1D(nf=3072, nx=768)
          (c_proj): Conv1D(nf=768, nx=3072)
          (act): NewGELUActivation()
          (dropout): Dropout(p=0.1, inplace=False)
        )
      )
    )
    (ln_f): LayerNorm((768,), eps=1e-05, elementwise_affine=True)
  )
  (lm_head): Linear(
    in_features=768, out_features=50257, bias=False
    (weight_fake_quant): FakeQuantize(
      fake_qua

In [8]:
# Now fine tuning the quantization aware model 

inputs = tokenizer("Quantization Aware Training in LLMs!", return_tensors="pt")

In [9]:
labels = inputs["input_ids"]

In [10]:
# Setting up the optimizer

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

In [11]:
for step in range(50):
    ouptuts = model(**inputs, labels=labels)
    loss = ouptuts.loss
    loss.backward()
    optimizer.step()
    optimizer.zero_grad()
    if step % 10 == 0:
        print(f"Step {step}, Loss: {loss.item()}")

`loss_type=None` was set in the config but it is unrecognized. Using the default loss: `ForCausalLMLoss`.


Step 0, Loss: 9.39015007019043
Step 10, Loss: 2.0682854652404785
Step 20, Loss: 0.5169497132301331
Step 30, Loss: 0.006866047624498606
Step 40, Loss: 0.013810182921588421


In [12]:
# Converting the model to a fully quantized model

qat_model = tq.convert(model.eval(), inplace=False)

For migrations of users: 
1. Eager mode quantization (torch.ao.quantization.quantize, torch.ao.quantization.quantize_dynamic), please migrate to use torchao eager mode quantize_ API instead 
2. FX graph mode quantization (torch.ao.quantization.quantize_fx.prepare_fx,torch.ao.quantization.quantize_fx.convert_fx, please migrate to use torchao pt2e quantization API instead (prepare_pt2e, convert_pt2e) 
3. pt2e quantization has been migrated to torchao (https://github.com/pytorch/ao/tree/main/torchao/quantization/pt2e) 
see https://github.com/pytorch/ao/issues/2259 for more details
  qat_model = tq.convert(model.eval(), inplace=False)


RuntimeError: Didn't find engine for operation quantized::linear_prepack NoQEngine