In [49]:
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms as transforms
from torch.quantization import quantize_dynamic
torch.backends.quantized.engine = 'qnnpack' # https://github.com/pytorch/pytorch/issues/29327
import time

In [50]:
device = torch.device("mps")

In [51]:
# Define a simple Autoencoder
class Autoencoder(nn.Module):
    def __init__(self, input_size=28*28, encoding_dim=3):
        super(Autoencoder, self).__init__()
        # Layer sizes
        self.input_size = input_size
        self.hidden_layer1_size = 128
        self.hidden_layer2_size = 64
        self.encoding_dim = encoding_dim

        # Encoder layers
        self.encoder = nn.Sequential(
            nn.Linear(self.input_size, self.hidden_layer1_size),
            nn.ReLU(True),
            nn.Linear(self.hidden_layer1_size, self.hidden_layer2_size),
            nn.ReLU(True),
            nn.Linear(self.hidden_layer2_size, self.encoding_dim),
            nn.ReLU(True)
        )
        # Decoder layers
        self.decoder = nn.Sequential(
            nn.Linear(self.encoding_dim, self.hidden_layer2_size),
            nn.ReLU(True),
            nn.Linear(self.hidden_layer2_size, self.hidden_layer1_size),
            nn.ReLU(True),
            nn.Linear(self.hidden_layer1_size, self.input_size),
            nn.Tanh()
        )

    def forward(self, x):
        x = self.encoder(x)
        x = self.decoder(x)
        return x

In [52]:
# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,))])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=512, shuffle=True, num_workers=2)
val_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=512, shuffle=False, num_workers=2)

In [53]:

# Initialize the model, loss function, and optimizer
model = Autoencoder().to(device)
criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

In [54]:
# Training loop
def train(model, num_epochs=5):
    model.train()
    for epoch in range(num_epochs):
        start_time = time.time()
        for data in train_loader:
            img, _ = data
            img = img.view(img.size(0), -1).to(device)
            output = model(img)
            loss = criterion(output, img)

            optimizer.zero_grad(set_to_none=True)
            loss.backward()
            optimizer.step()
        
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}, Time: {time.time() - start_time:.2f}s')

# Train the model
train(model)

Epoch [1/5], Loss: 0.2648, Time: 2.41s
Epoch [2/5], Loss: 0.2606, Time: 2.32s
Epoch [3/5], Loss: 0.2788, Time: 2.26s
Epoch [4/5], Loss: 0.2739, Time: 2.28s
Epoch [5/5], Loss: 0.2721, Time: 2.35s


Baseline perf:
Epoch [1/5], Loss: 0.2267, Time: 2.40s
Epoch [2/5], Loss: 0.2169, Time: 2.30s
Epoch [3/5], Loss: 0.1915, Time: 2.26s
Epoch [4/5], Loss: 0.1854, Time: 2.27s
Epoch [5/5], Loss: 0.1928, Time: 2.34s


Post-Training Quantization

In [55]:
def quantize_model(model):
    # Specify the layers to be quantized
    quantized_model = quantize_dynamic(
        model, 
        {nn.Linear},  # Specify the layers you want to quantize
        dtype=torch.qint8,
    )
    return quantized_model
def evaluate_inference_time(model, data_loader, device):
    model= model.to(device)
    model.eval()
    start_time = time.time()
    with torch.no_grad():
        for img, _ in data_loader:
            img = img.view(img.size(0), -1).to(device)
            model(img)
    end_time = time.time()
    return end_time - start_time
def validate_model(model, data_loader, criterion, device):
    model = model.to(device)
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for img, _ in data_loader:
            img = img.view(img.size(0), -1).to(device)
            output = model(img)
            loss = criterion(output, img)
            total_loss += loss.item()
    return total_loss / len(data_loader)


In [56]:
# Assuming you have a validation loader named 'val_loader' and a criterion
# Quantize the model
quantized_autoencoder = quantize_model(model.to('cpu'))

quant_device = 'cpu'  # cpu or mps
# Measure inference time
original_time = evaluate_inference_time(model, val_loader, quant_device)
quantized_time = evaluate_inference_time(quantized_autoencoder, val_loader, quant_device)

# Validate the models
original_val_loss = validate_model(model, val_loader, criterion, quant_device)
quantized_val_loss = validate_model(quantized_autoencoder, val_loader, criterion, quant_device)

print(f"Original Model - Inference Time: {original_time}, Validation Loss: {original_val_loss}")
print(f"Quantized Model - Inference Time: {quantized_time}, Validation Loss: {quantized_val_loss}")


Original Model - Inference Time: 2.173731803894043, Validation Loss: 0.26968853175640106
Quantized Model - Inference Time: 1.3928959369659424, Validation Loss: 0.26968856155872345


On cpu:
Original Model - Inference Time: 1.9714100360870361, Validation Loss: 0.2335045002400875
Quantized Model(qint8)- Inference Time: 1.4053568840026855, Validation Loss: 0.23385296761989594

Quantization Aware Training

In [67]:
def prepare_model_for_qat(model):
    # Fuse Conv, bn and relu
    # (Add if you have layers that can be fused)

    # Prepare the model for quantization aware training
    model.train()
    model.qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    torch.quantization.prepare_qat(model, inplace=True)
    return model
def train_with_qat(model, train_loader, num_epochs=5):
    model.train()
    model.to("mps")
    for epoch in range(num_epochs):
        for data in train_loader:
            img, _ = data
            img = img.view(img.size(0), -1)
            output = model(img)
            loss = criterion(output, img)

            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        # Freeze quantizer parameters
        if epoch > 3:
            model.apply(torch.quantization.disable_observer)
        if epoch > 2:
            model.apply(torch.nn.intrinsic.qat.freeze_bn_stats)

        # Print loss
        print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item():.4f}')

    # Convert to quantized model
    torch.quantization.convert(model, inplace=True)


In [68]:
# Prepare the model for QAT
# The operator 'aten::_fused_moving_avg_obs_fq_helper' is not currently implemented for the MPS device
qat_model = prepare_model_for_qat(model.to("mps"))

# Train the model with QAT
criterion = nn.MSELoss()
qat_model = qat_model.to("mps")
train_with_qat(qat_model, train_loader)
# Validate the QAT model
qat_val_loss = validate_model(qat_model, val_loader, criterion, quant_device)
qat_inference_time = evaluate_inference_time(qat_model, val_loader, quant_device)

print(f"QAT Model - Inference Time: {qat_inference_time}, Validation Loss: {qat_val_loss}")


NotImplementedError: Could not run 'quantized::linear' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'quantized::linear' is only available for these backends: [MPS, QuantizedCPU, BackendSelect, Python, FuncTorchDynamicLayerBackMode, Functionalize, Named, Conjugate, Negative, ZeroTensor, ADInplaceOrView, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradMPS, AutogradXPU, AutogradHPU, AutogradLazy, AutogradMeta, Tracer, AutocastCPU, AutocastCUDA, FuncTorchBatched, FuncTorchVmapMode, Batched, VmapMode, FuncTorchGradWrapper, PythonTLSSnapshot, FuncTorchDynamicLayerFrontMode, PreDispatch, PythonDispatcher].

MPS: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/mps/MPSFallback.mm:75 [backend fallback]
QuantizedCPU: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/native/quantized/cpu/qlinear.cpp:1137 [kernel]
BackendSelect: fallthrough registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/BackendSelectFallbackKernel.cpp:3 [backend fallback]
Python: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/PythonFallbackKernel.cpp:153 [backend fallback]
FuncTorchDynamicLayerBackMode: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/functorch/DynamicLayer.cpp:498 [backend fallback]
Functionalize: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/FunctionalizeFallbackKernel.cpp:290 [backend fallback]
Named: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/NamedRegistrations.cpp:7 [backend fallback]
Conjugate: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/ConjugateFallback.cpp:17 [backend fallback]
Negative: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/native/NegateFallback.cpp:19 [backend fallback]
ZeroTensor: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/ZeroTensorFallback.cpp:86 [backend fallback]
ADInplaceOrView: fallthrough registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/VariableFallbackKernel.cpp:86 [backend fallback]
AutogradOther: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/VariableFallbackKernel.cpp:53 [backend fallback]
AutogradCPU: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/VariableFallbackKernel.cpp:57 [backend fallback]
AutogradCUDA: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/VariableFallbackKernel.cpp:65 [backend fallback]
AutogradXLA: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/VariableFallbackKernel.cpp:69 [backend fallback]
AutogradMPS: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/VariableFallbackKernel.cpp:77 [backend fallback]
AutogradXPU: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/VariableFallbackKernel.cpp:61 [backend fallback]
AutogradHPU: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/VariableFallbackKernel.cpp:90 [backend fallback]
AutogradLazy: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/VariableFallbackKernel.cpp:73 [backend fallback]
AutogradMeta: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/VariableFallbackKernel.cpp:81 [backend fallback]
Tracer: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/torch/csrc/autograd/TraceTypeManual.cpp:296 [backend fallback]
AutocastCPU: fallthrough registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/autocast_mode.cpp:382 [backend fallback]
AutocastCUDA: fallthrough registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/autocast_mode.cpp:249 [backend fallback]
FuncTorchBatched: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/functorch/LegacyBatchingRegistrations.cpp:710 [backend fallback]
FuncTorchVmapMode: fallthrough registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/functorch/VmapModeRegistrations.cpp:28 [backend fallback]
Batched: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/LegacyBatchingRegistrations.cpp:1075 [backend fallback]
VmapMode: fallthrough registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/VmapModeRegistrations.cpp:33 [backend fallback]
FuncTorchGradWrapper: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/functorch/TensorWrapper.cpp:203 [backend fallback]
PythonTLSSnapshot: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/PythonFallbackKernel.cpp:161 [backend fallback]
FuncTorchDynamicLayerFrontMode: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/functorch/DynamicLayer.cpp:494 [backend fallback]
PreDispatch: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/PythonFallbackKernel.cpp:165 [backend fallback]
PythonDispatcher: registered at /private/var/folders/k1/30mswbxs7r1g6zwn8y4fyt500000gp/T/abs_5ae0635zuj/croot/pytorch-select_1700511177724/work/aten/src/ATen/core/PythonFallbackKernel.cpp:157 [backend fallback]
