# Tutorial: Quantizing a PyTorch Model with Metinor

In this tutorial, we will use Metinor to qauntize a PyTorch model. Quantization is the process of converting a model from using floating-point numbers to using integers. This can reduce the model size and make it faster to run on hardware that supports integer operations.

In [1]:
# Suppress warnings (to make the output cleaner in this notebook)
import warnings

warnings.filterwarnings("ignore")

## Create a model

Metinor can quantize supports subclasses of `nn.Module`. In this tutorial, we will use a pre-trained MobileNet-V2 from `torchvision`.

In [2]:
from torchvision.models import mobilenet_v2, MobileNet_V2_Weights

# Create the model to profile
model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
input_shape = (1, 3, 224, 224)


# Print data type of weights and biases
def print_model(model, level=0):
    for name, child in model.named_children():
        details = f"{name}: {child.__class__.__name__}("
        if hasattr(child, "weight") and child.weight is not None:
            details += f"weight={child.weight.dtype}"
        if hasattr(child, "bias") and child.bias is not None:
            details += f", bias={child.bias.dtype}"
        details += ")"
        print(" " * level * 2 + details)
        if len(list(child.children())) > 0:
            print_model(child, level + 1)


print_model(model)

features: Sequential()
  0: Conv2dNormActivation()
    0: Conv2d(weight=torch.float32)
    1: BatchNorm2d(weight=torch.float32, bias=torch.float32)
    2: ReLU6()
  1: InvertedResidual()
    conv: Sequential()
      0: Conv2dNormActivation()
        0: Conv2d(weight=torch.float32)
        1: BatchNorm2d(weight=torch.float32, bias=torch.float32)
        2: ReLU6()
      1: Conv2d(weight=torch.float32)
      2: BatchNorm2d(weight=torch.float32, bias=torch.float32)
  2: InvertedResidual()
    conv: Sequential()
      0: Conv2dNormActivation()
        0: Conv2d(weight=torch.float32)
        1: BatchNorm2d(weight=torch.float32, bias=torch.float32)
        2: ReLU6()
      1: Conv2dNormActivation()
        0: Conv2d(weight=torch.float32)
        1: BatchNorm2d(weight=torch.float32, bias=torch.float32)
        2: ReLU6()
      2: Conv2d(weight=torch.float32)
      3: BatchNorm2d(weight=torch.float32, bias=torch.float32)
  3: InvertedResidual()
    conv: Sequential()
      0: Conv2dNormActivat

## Import Qauntization Functions

In [3]:
# Import quantization functions
from metinor.functional.quantization import (
    quantize,
    quantize_node,
    list_quantization_strategies,
)

# List all available strategies
list_quantization_strategies()

['WeightOnlyQuantization', 'WeightBiasQuantization']

## Complete Model Quantization

### Weight-Only Quantization

In [4]:
model_quant_weightonly_float16 = quantize(model, "WeightOnlyQuantization", "float", 16)
print_model(model_quant_weightonly_float16)

features: Sequential()
  0: Conv2dNormActivation()
    0: QuantConv2d(weight=torch.float32)
      input_quant: ActQuantProxyFromInjector()
        _zero_hw_sentinel: StatelessBuffer()
      output_quant: ActQuantProxyFromInjector()
        _zero_hw_sentinel: StatelessBuffer()
      weight_quant: WeightQuantProxyFromInjector()
        _zero_hw_sentinel: StatelessBuffer()
        tensor_quant: RescalingIntQuant()
          int_quant: IntQuant()
            float_to_int_impl: RoundSte()
            tensor_clamp_impl: TensorClampSte()
            delay_wrapper: DelayWrapper()
              delay_impl: _NoDelay()
          scaling_impl: StatsFromParameterScaling()
            parameter_list_stats: _ParameterListStats()
              first_tracked_param: _ViewParameterWrapper()
                view_shape_impl: OverTensorView()
              stats: _Stats()
                stats_impl: AbsMax()
            stats_scaling_impl: _StatsScaling()
              affine_rescaling: Identity()
         

### Weight-Bias Quantization

In [5]:
model_quant_weightbias_int8 = quantize(model, "WeightBiasQuantization", "fixed", 8)
print_model(model_quant_weightbias_int8)

features: Sequential()
  0: Conv2dNormActivation()
    0: QuantConv2d(weight=torch.float32, bias=torch.float32)
      input_quant: ActQuantProxyFromInjector()
        _zero_hw_sentinel: StatelessBuffer()
      output_quant: ActQuantProxyFromInjector()
        _zero_hw_sentinel: StatelessBuffer()
      weight_quant: WeightQuantProxyFromInjector()
        _zero_hw_sentinel: StatelessBuffer()
        tensor_quant: RescalingIntQuant()
          int_quant: IntQuant()
            float_to_int_impl: RoundSte()
            tensor_clamp_impl: TensorClampSte()
            delay_wrapper: DelayWrapper()
              delay_impl: _NoDelay()
          scaling_impl: StatsFromParameterScaling()
            parameter_list_stats: _ParameterListStats()
              first_tracked_param: _ViewParameterWrapper()
                view_shape_impl: OverTensorView()
              stats: _Stats()
                stats_impl: AbsMax()
            stats_scaling_impl: _StatsScaling()
              affine_rescaling: 

## Layerwise Quantization

In [None]:
from metinor.visualizer import draw_graph, compute_max_depth
from metinor.functional.quantization import quantize_node

# Calculate the maximum depth of the model
model = mobilenet_v2(weights=MobileNet_V2_Weights.DEFAULT)
depth = compute_max_depth(model)

# Create model graph with maximum depth
graph = draw_graph(model, input_size=input_shape, depth=depth, device="cpu")

# Create a dictionary of node_id: module
node_ids = list(graph.id_dict.keys())
node_module_dict = {}
for node_id in node_ids:
    node = graph.find_node_by_id(node_id)
    node_module_dict[node_id] = node

# Quantize layer by id
node = node_module_dict[node_ids[1]]
print("Node Module: ", node.module_unit)

quantized_module = quantize_node(
    node_ids[1], graph, "WeightOnlyQuantization", "float", 4
)
# print('Quantized Module: ', quantized_module)

# Find node again to verify the change
node = graph.find_node_by_id(node_ids[1])
print("Updated Node: ", node.module_unit)

print_model(model)