In [1]:
#Dynamic Quantization

import torch
import torch.nn as nn
import torch.quantization

class SimpleModel(nn.Module):
    def __init__(self):
        super(SimpleModel, self).__init__()
        self.fc1 = nn.Linear(128, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = SimpleModel()
model.eval()
quantized_model = torch.quantization.quantize_dynamic(
    model, {nn.Linear}, dtype=torch.qint8
)

print("Original model:", model)
print("Quantized model:", quantized_model)


Original model: SimpleModel(
  (fc1): Linear(in_features=128, out_features=64, bias=True)
  (relu): ReLU()
  (fc2): Linear(in_features=64, out_features=10, bias=True)
)
Quantized model: SimpleModel(
  (fc1): DynamicQuantizedLinear(in_features=128, out_features=64, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
  (relu): ReLU()
  (fc2): DynamicQuantizedLinear(in_features=64, out_features=10, dtype=torch.qint8, qscheme=torch.per_tensor_affine)
)


In [2]:
#Static Quantization

import torch.quantization as quant

model = SimpleModel()
model.eval()

model.qconfig = quant.get_default_qconfig('fbgemm')
prepared_model = quant.prepare(model)

def calibration_data():
    for _ in range(100):
        yield torch.randn(1, 128)

for data in calibration_data():
    prepared_model(data)

quantized_model = quant.convert(prepared_model)

print("Quantized model:", quantized_model)




Quantized model: SimpleModel(
  (fc1): QuantizedLinear(in_features=128, out_features=64, scale=0.03221610561013222, zero_point=66, qscheme=torch.per_channel_affine)
  (relu): ReLU()
  (fc2): QuantizedLinear(in_features=64, out_features=10, scale=0.015236616134643555, zero_point=79, qscheme=torch.per_channel_affine)
)


In [3]:
#QAT (Quantization Aware Training)

class QATModel(nn.Module):
    def __init__(self):
        super(QATModel, self).__init__()
        self.fc1 = nn.Linear(128, 64)
        self.relu = nn.ReLU()
        self.fc2 = nn.Linear(64, 10)

    def forward(self, x):
        x = self.fc1(x)
        x = self.relu(x)
        x = self.fc2(x)
        return x

model = QATModel()
model.train()

model.qconfig = quant.get_default_qat_qconfig('fbgemm')
qat_model = quant.prepare_qat(model)

optimizer = torch.optim.SGD(qat_model.parameters(), lr=0.01, momentum=0.9)

for epoch in range(5):
    optimizer.zero_grad()
    input_data = torch.randn(16, 128)
    output = qat_model(input_data)
    loss = nn.CrossEntropyLoss()(output, torch.randint(0, 10, (16,)))
    loss.backward()
    optimizer.step()

qat_model.eval()
quantized_model = quant.convert(qat_model)

print("Quantized model after QAT:", quantized_model)


Quantized model after QAT: QATModel(
  (fc1): QuantizedLinear(in_features=128, out_features=64, scale=0.02763197384774685, zero_point=60, qscheme=torch.per_channel_affine)
  (relu): ReLU()
  (fc2): QuantizedLinear(in_features=64, out_features=10, scale=0.009817427955567837, zero_point=48, qscheme=torch.per_channel_affine)
)
