### Quantizing a deep model’s weights from FP32 (32-bit floating point) to a bit-width of b. We have implemented asymmetric quantization and symmetric quantization methods.

In [4]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import urllib.request
import os
from collections import defaultdict, OrderedDict

def download_url(url):
    filename = url.split('/')[-1]
    if not os.path.exists(filename):
        print(f"Downloading {url}")
        urllib.request.urlretrieve(url, filename)
    return filename

class VGG(nn.Module):
  ARCH = [64, 128, 'M', 256, 256, 'M', 512, 512, 'M', 512, 512, 'M']

  def __init__(self) -> None:
    super().__init__()

    layers = []
    counts = defaultdict(int)

    def add(name: str, layer: nn.Module) -> None:
      layers.append((f"{name}{counts[name]}", layer))
      counts[name] += 1

    in_channels = 3
    for x in self.ARCH:
      if x != 'M':
        # conv-bn-relu
        add("conv", nn.Conv2d(in_channels, x, 3, padding=1, bias=False))
        add("bn", nn.BatchNorm2d(x))
        add("relu", nn.ReLU(True))
        in_channels = x
      else:
        # maxpool
        add("pool", nn.MaxPool2d(2))
    add("avgpool", nn.AvgPool2d(2))
    self.backbone = nn.Sequential(OrderedDict(layers))
    self.classifier = nn.Linear(512, 10)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    # backbone: [N, 3, 32, 32] => [N, 512, 2, 2]
    x = self.backbone(x)

    # avgpool: [N, 512, 2, 2] => [N, 512]
    # x = x.mean([2, 3])
    x = x.view(x.shape[0], -1)

    # classifier: [N, 512] => [N, 10]
    x = self.classifier(x)
    return x

def symmetric_quantize_layer(weight, bit_width):
    """
    Symmetrically quantize a layer's weights
    """
    n_levels = 2 ** (bit_width - 1) - 1
    scale = torch.max(torch.abs(weight)) / n_levels
    
    # Quantize weights
    quantized_weight = torch.clip(torch.round(weight / scale), -n_levels, n_levels)
    
    # Dequantize
    dequantized_weight = quantized_weight * scale
    return dequantized_weight, scale

def asymmetric_quantize_layer(weight, bit_width):
    """
    Asymmetrically quantize a layer's weights
    """
    n_levels = 2 ** bit_width - 1
    w_max = torch.max(weight)
    w_min = torch.min(weight)
    
    scale = (w_max - w_min) / n_levels
    zero_point = torch.round(-w_min / scale)
    
    # Quantize weights
    quantized_weight = torch.clip(torch.round(weight / scale + zero_point), 0, n_levels)
    
    # Dequantize
    dequantized_weight = (quantized_weight - zero_point) * scale
    return dequantized_weight, scale, zero_point

def quantize_model(model, bit_width, quantization_type='symmetric'):
    """
    Quantize all weights in the model
    """
    quantized_state_dict = {}
    scaling_factors = {}
    zero_points = {}
    
    # Quantize each parameter in the model
    for name, param in model.state_dict().items():
        if 'weight' in name:  # Only quantize weights, not biases
            if quantization_type == 'symmetric':
                quantized_weight, scale = symmetric_quantize_layer(param.data, bit_width)
                quantized_state_dict[name] = quantized_weight
                scaling_factors[name] = scale
            else:  # asymmetric
                quantized_weight, scale, zero_point = asymmetric_quantize_layer(param.data, bit_width)
                quantized_state_dict[name] = quantized_weight
                scaling_factors[name] = scale
                zero_points[name] = zero_point
        else:
            quantized_state_dict[name] = param.data
            
    return quantized_state_dict, scaling_factors, zero_points

def evaluate_model(model, test_loader, device):
    """
    Evaluate the model's accuracy
    """
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs.data, 1)
            total += labels.size(0)
            correct += (predicted == labels).sum().item()
    return 100 * correct / total

def main():
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Load model
    checkpoint_url = "https://hanlab18.mit.edu/files/course/labs/vgg.cifar.pretrained.pth"
    model = VGG().to(device)
    checkpoint = torch.load(download_url(checkpoint_url), map_location=device)
    model.load_state_dict(checkpoint['state_dict'])
    print(f"=> loaded checkpoint '{checkpoint_url}'")
    
    # Original model size
    original_size = sum(p.numel() * 32 for p in model.parameters()) / 8  # in bytes
    print(f"Original model size: {original_size / 1024 / 1024:.2f} MB")
    
    # Quantize model (try both symmetric and asymmetric)
    bit_widths = [8, 4, 2]  # Test different bit widths
    
    for bit_width in bit_widths:
        print(f"\nQuantizing to {bit_width} bits:")
        
        # Symmetric quantization
        quantized_state_dict, scaling_factors, _ = quantize_model(model, bit_width, 'symmetric')
        
        # Create a new model instance for the quantized weights
        quantized_model = VGG().to(device)
        quantized_model.load_state_dict(quantized_state_dict)
        
        # Calculate quantized model size
        quantized_size = sum(p.numel() * bit_width for p in quantized_model.parameters()) / 8  # in bytes
        
        print(f"Symmetric quantization:")
        print(f"Quantized model size: {quantized_size / 1024 / 1024:.2f} MB")
        print(f"Compression ratio: {original_size / quantized_size:.2f}x")
        
        # Calculate quantization error
        error = 0
        total_params = 0
        for (name, orig_param), (_, quant_param) in zip(model.named_parameters(), 
                                                       quantized_model.named_parameters()):
            if 'weight' in name:
                error += torch.mean((orig_param - quant_param) ** 2).item()
                total_params += 1
        
        print(f"Average quantization MSE: {error/total_params:.6f}")
        
        # Do the same for asymmetric quantization
        quantized_state_dict, scaling_factors, zero_points = quantize_model(model, bit_width, 'asymmetric')
        quantized_model = VGG().to(device)
        quantized_model.load_state_dict(quantized_state_dict)
        
        print(f"\nAsymmetric quantization:")
        print(f"Quantized model size: {quantized_size / 1024 / 1024:.2f} MB")
        print(f"Compression ratio: {original_size / quantized_size:.2f}x")
        
        error = 0
        for (name, orig_param), (_, quant_param) in zip(model.named_parameters(), 
                                                       quantized_model.named_parameters()):
            if 'weight' in name:
                error += torch.mean((orig_param - quant_param) ** 2).item()
        
        print(f"Average quantization MSE: {error/total_params:.6f}")

if __name__ == "__main__":
    main()

=> loaded checkpoint 'https://hanlab18.mit.edu/files/course/labs/vgg.cifar.pretrained.pth'
Original model size: 35.20 MB

Quantizing to 8 bits:
Symmetric quantization:
Quantized model size: 8.80 MB
Compression ratio: 4.00x
Average quantization MSE: 0.000001

Asymmetric quantization:
Quantized model size: 8.80 MB
Compression ratio: 4.00x
Average quantization MSE: 0.000000

Quantizing to 4 bits:
Symmetric quantization:
Quantized model size: 4.40 MB
Compression ratio: 8.00x
Average quantization MSE: 0.000220

Asymmetric quantization:
Quantized model size: 4.40 MB
Compression ratio: 8.00x
Average quantization MSE: 0.000077

Quantizing to 2 bits:
Symmetric quantization:
Quantized model size: 2.20 MB
Compression ratio: 16.00x
Average quantization MSE: 0.010571

Asymmetric quantization:
Quantized model size: 2.20 MB
Compression ratio: 16.00x
Average quantization MSE: 0.001350
