[Reference](https://levelup.gitconnected.com/quantization-aware-training-with-pytorch-38d0bdb0f873)

In [1]:
!pip install torchtune

Collecting torchtune
  Downloading torchtune-0.6.1-py3-none-any.whl.metadata (24 kB)
Collecting torchdata==0.11.0 (from torchtune)
  Downloading torchdata-0.11.0-py3-none-any.whl.metadata (6.3 kB)
Collecting datasets (from torchtune)
  Downloading datasets-3.6.0-py3-none-any.whl.metadata (19 kB)
Collecting tiktoken (from torchtune)
  Downloading tiktoken-0.9.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (6.7 kB)
Collecting blobfile>=2 (from torchtune)
  Downloading blobfile-3.0.0-py3-none-any.whl.metadata (15 kB)
Collecting omegaconf (from torchtune)
  Downloading omegaconf-2.3.0-py3-none-any.whl.metadata (3.9 kB)
Collecting pycryptodomex>=3.8 (from blobfile>=2->torchtune)
  Downloading pycryptodomex-3.22.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (3.4 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets->torchtune)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets->torchtune)
  Downloading xxhash-3.

# 1. Eager Mode Quantization

In [1]:
import os, torch, torch.nn as nn, torch.optim as optim

# 1. Model definition with QuantStub/DeQuantStub
class QATCNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.quant   = torch.quantization.QuantStub()
        self.conv1   = nn.Conv2d(1, 16, 3, padding=1)
        self.relu1   = nn.ReLU()
        self.pool    = nn.MaxPool2d(2)
        self.conv2   = nn.Conv2d(16, 32, 3, padding=1)
        self.relu2   = nn.ReLU()
        self.fc      = nn.Linear(32*14*14, 10)
        self.dequant = torch.quantization.DeQuantStub()

    def forward(self, x):
        x = self.quant(x)
        x = self.pool(self.relu1(self.conv1(x)))
        x = self.relu2(self.conv2(x))
        x = x.flatten(1)
        x = self.fc(x)
        return self.dequant(x)

# 2. QAT preparation
model = QATCNN()
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
torch.quantization.prepare_qat(model, inplace=True)

# 3. Tiny training loop
opt = optim.SGD(model.parameters(), lr=1e-2)
crit = nn.CrossEntropyLoss()
for _ in range(3):
    inp = torch.randn(16,1,28,28)
    tgt = torch.randint(0,10,(16,))
    opt.zero_grad(); crit(model(inp), tgt).backward(); opt.step()

# 4. Convert to real int8
model.eval()
int8_model = torch.quantization.convert(model)

# 5. Storage benefit
torch.save(model.state_dict(), "fp32.pth")
torch.save(int8_model.state_dict(), "int8.pth")
mb = lambda p: os.path.getsize(p)/1e6
print(f"FP32: {mb('fp32.pth'):.2f} MB  vs  INT8: {mb('int8.pth'):.2f} MB")



FP32: 0.29 MB  vs  INT8: 0.07 MB


# 2. FX Graph Mode Quantization

In [2]:
# import torch, torchvision.models as models
# from torch.ao.quantization import get_default_qat_qconfig_mapping
# from torch.ao.quantization import prepare_qat_fx, convert_fx

# model = models.resnet18(weights=None)     # or pretrained=True
# model.train()

# # 1-liner qconfig mapping
# qmap = get_default_qat_qconfig_mapping("fbgemm")
# # Graph rewrite
# model_prepared = prepare_qat_fx(model, qmap)

# # Fine-tune for a few epochs
# model_prepared.eval()
# int8_resnet = convert_fx(model_prepared)

# 3. PyTorch 2 Export Quantization

In [3]:
import torch
from torch import nn
from torch._export import capture_pre_autograd_graph
from torch.ao.quantization.quantize_pt2e import (
    prepare_qat_pt2e, convert_pt2e)
from torch.ao.quantization.quantizer.xnnpack_quantizer import (
    XNNPACKQuantizer, get_symmetric_quantization_config)

class Tiny(nn.Module):
    def __init__(self): super().__init__(); self.fc=nn.Linear(8,4)
    def forward(self,x): return self.fc(x)

ex_in = (torch.randn(2,8),)
exported = torch.export.export_for_training(Tiny(), ex_in).module()
quantizer = XNNPACKQuantizer().set_global(get_symmetric_quantization_config())
qat_mod = prepare_qat_pt2e(exported, quantizer)

# Fine-tune the model ...
int8_mod = convert_pt2e(qat_mod)
torch.ao.quantization.move_exported_model_to_eval(int8_mod)



GraphModule(
  (fc): Module()
)

# Large-Language-Model Int4/Int8 Hybrid Demo

In [4]:
pip install torchao

Collecting torchao
  Downloading torchao-0.10.0-cp39-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl.metadata (15 kB)
Downloading torchao-0.10.0-cp39-abi3-manylinux_2_24_x86_64.manylinux_2_28_x86_64.whl (5.5 MB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/5.5 MB[0m [31m?[0m eta [36m-:--:--[0m[2K   [91m━━━━━━━━━━[0m[90m╺[0m[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.4/5.5 MB[0m [31m43.7 MB/s[0m eta [36m0:00:01[0m[2K   [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m [32m5.4/5.5 MB[0m [31m92.2 MB/s[0m eta [36m0:00:01[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m5.5/5.5 MB[0m [31m64.4 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: torchao
Successfully installed torchao-0.10.0


In [7]:
import torch
from torchtune.models.llama3 import llama3
from torchao.quantization.prototype.qat import Int8DynActInt4WeightQATQuantizer

model = llama3(vocab_size=4096, num_layers=16,
               num_heads=16, num_kv_heads=4,
               embed_dim=2048, max_seq_len=2048).cuda()

qat_quant = Int8DynActInt4WeightQATQuantizer()
model = qat_quant.prepare(model).train()

#  ––– Kathy-like micro-fine-tune –––
optim = torch.optim.AdamW(model.parameters(), 1e-4)
lossf = torch.nn.CrossEntropyLoss()
for _ in range(100):
    ids   = torch.randint(0,4096,(2,128)).cuda()
    label = torch.randint(0,4096,(2,128)).cuda()
    output = model(ids)  # shape: [2, 128, 4096]
    loss = lossf(output.view(-1, 4096), label.view(-1))  # reshape both
    optim.zero_grad()
    loss.backward()
    optim.step()

model_quant = qat_quant.convert(model)
torch.save(model_quant.state_dict(),"llama3_int4int8.pth")