In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
import torch.quantization

# Set device (use CUDA if available)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# MNIST dataset with normalization
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = datasets.MNIST(root="./data", train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root="./data", train=False, transform=transform, download=True)

# Data loaders
train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=64, shuffle=False)


In [2]:
class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.quant = torch.ao.quantization.QuantStub()
        self.conv1 = nn.Conv2d(1, 32, kernel_size=3, stride=1, padding=1)
        self.relu1 = nn.ReLU()
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(32, 64, kernel_size=3, stride=1, padding=1)
        self.relu2 = nn.ReLU()
        self.fc1 = nn.Linear(64 * 7 * 7, 128)
        self.relu3 = nn.ReLU()
        self.fc2 = nn.Linear(128, 10)
    
    def forward(self, x):
        x = self.quant(x)
        x = self.pool(self.relu1(self.conv1(x)))
        x = self.pool(self.relu2(self.conv2(x)))
        x = x.reshape(-1, 64 * 7 * 7)
        x = self.relu3(self.fc1(x))
        x = self.fc2(x)
        return x


In [3]:
model = Net().to(device)

# Fuse layers (for better optimization during quantization)
model.fuse_model = lambda: torch.quantization.fuse_modules(model, [["conv1", "relu1"], ["conv2", "relu2"]])
model.fuse_model()

# Specify quantization configuration
model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')

# Prepare for QAT
model_prepared = torch.quantization.prepare_qat(model)




In [4]:
# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model_prepared.parameters(), lr=0.001)

# Training loop
num_epochs = 1
model_prepared.train()
for epoch in range(num_epochs):
    for batch_idx, (images, labels) in enumerate(train_loader):
        images, labels = images.to(device), labels.to(device)

        # Forward pass
        outputs = model_prepared(images)
        loss = criterion(outputs, labels)

        # Backward pass and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
    print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}")


Epoch [1/5], Loss: 0.0101
Epoch [2/5], Loss: 0.0100
Epoch [3/5], Loss: 0.0046
Epoch [4/5], Loss: 0.1279
Epoch [5/5], Loss: 0.0010


In [5]:
model_quantized = torch.quantization.convert(model_prepared)


In [8]:
model_quantized

Net(
  (quant): Quantize(scale=tensor([0.0157]), zero_point=tensor([64]), dtype=torch.quint8)
  (conv1): QuantizedConv2d(1, 32, kernel_size=(3, 3), stride=(1, 1), scale=0.04050112143158913, zero_point=66, padding=(1, 1))
  (relu1): ReLU()
  (pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
  (conv2): QuantizedConv2d(32, 64, kernel_size=(3, 3), stride=(1, 1), scale=0.1573706567287445, zero_point=85, padding=(1, 1))
  (relu2): ReLU()
  (fc1): QuantizedLinear(in_features=3136, out_features=128, scale=0.6508694887161255, zero_point=82, qscheme=torch.per_channel_affine)
  (relu3): ReLU()
  (fc2): QuantizedLinear(in_features=128, out_features=10, scale=0.3936760127544403, zero_point=72, qscheme=torch.per_channel_affine)
)

In [7]:
# Evaluate quantized model
model_quantized.eval()
correct = 0
total = 0

with torch.no_grad():
    for images, labels in test_loader:
        images, labels = images.to(device), labels.to(device)
        outputs = model_quantized(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print(f"Accuracy of the quantized model on the test set: {100 * correct / total:.2f}%")


RuntimeError: view size is not compatible with input tensor's size and stride (at least one dimension spans across two contiguous subspaces). Use .reshape(...) instead.

In [None]:
model_prepared