# Run Quantized ResNet50 via MIGraphX
This notebook walks through the PTQ workflow for running a quantized model using torch_migraphx.

## 1. Use PyTorch's Quantization API to perform quantization
We will closely follow the steps provided in [PyTorch docs](https://pytorch.org/docs/stable/quantization.html#prototype-fx-graph-mode-quantization) for FX quantization.

In [None]:
import torch
import torch.ao.quantization.quantize_fx as quantize_fx
from torchvision import models

In [None]:
model_fp32 = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1).eval()
input_fp32 = torch.randn(2, 3, 28, 28)

torch_fp32_out = model_fp32(input_fp32)

### Use the Quantization API to prepare and calibrate the model
Torch-MIGraphX provides supported qconfig and backend configs that are the recommended settings for performing quantization that is compatible with MIGraphX. Additional configs will also work as long as the configs ensure symmetric quantization. Currently, only symmetric quantization is supported in MIGraphX.



In [None]:
from torch_migraphx.fx.quantization import (
    get_migraphx_backend_config,
    get_migraphx_qconfig_mapping,
)

In [None]:
# Prepare
qconfig_mapping = get_migraphx_qconfig_mapping()
backend_config = get_migraphx_backend_config()

model_prepared = quantize_fx.prepare_fx(
    model_fp32,
    qconfig_mapping,
    (input_fp32, ),
    backend_config=backend_config,
)

# Pseudo-calibrate with fake data
for _ in range(100):
    inp = torch.randn(2, 3, 28, 28)
    model_prepared(inp)
    

# Convert to quantized model
model_quantized = quantize_fx.convert_fx(
    model_prepared,
    qconfig_mapping=qconfig_mapping,
    backend_config=backend_config,
)

# Reference torch int8 cpu output
torch_qout = model_quantized(input_fp32)

## 2. Lower Quantized Model to MIGraphX
This step is the same as lowering any other model using the FX Tracing path! Note that in general we need to suppress accuracy check when lowering. This is because the lowering pass will try to compare the pytorch INT8 implementation result with the MIGraphX INT8 result and in practice the different implementations can lead to significant differences for some values. In this resnet example, and most other examples, the MIGraphX implementation tends to provide better results when compared to the FP32 reference output.

In [None]:
from torch_migraphx.fx import lower_to_mgx

In [None]:
mgx_model = lower_to_mgx(
    model_quantized,
    (input_fp32, ),
    suppress_accuracy_check=True,
)

# MIGraphX int8 output
mgx_out = mgx_model(input_fp32.cuda())

Compare outputs

In [None]:
print(f"PyTorch FP32 (Gold Value):\n{torch_fp32_out}")
print(f"PyTorch INT8 (CPU Impl):\n{torch_qout}")
print(f"MIGraphX INT8:\n{mgx_out}")

## 3. Performance
Let's do a quick test to measure the performance gain from using quantization. 

In [None]:
import copy
from torch_migraphx.fx.utils import LowerPrecision

In [None]:
# We will use this function to benchmark all modules:
def benchmark_module(model, inputs, iterations=100):
    model(*inputs)
    torch.cuda.synchronize()

    start_event = torch.cuda.Event(enable_timing=True)
    end_event = torch.cuda.Event(enable_timing=True)

    start_event.record()
    for _ in range(iterations):
        model(*inputs)
    end_event.record()
    torch.cuda.synchronize()

    return start_event.elapsed_time(end_event) / iterations

In [None]:
# Benchmark Torch FP32 for baseline
torch_fp32_time = benchmark_module(model_fp32.cuda(), [input_fp32.cuda()])

In [None]:
# Benchmark MIGraphX FP32
mgx_module_fp32 = lower_to_mgx(copy.deepcopy(model_fp32), [input_fp32])
mgx_fp32_time = benchmark_module(mgx_module_fp32, [input_fp32.cuda()])

In [None]:
# Benchmark MIGraphX FP16
mgx_module_fp16 = lower_to_mgx(copy.deepcopy(model_fp32), [input_fp32], lower_precision=LowerPrecision.FP16)
mgx_fp16_time = benchmark_module(mgx_module_fp16, [input_fp32.cuda()])

In [None]:
# Benchmark MIGraphX INT8
mgx_int8_time = benchmark_module(mgx_model, [input_fp32.cuda()])

In [None]:
print(f"{torch_fp32_time=:0.4f}ms")
print(f"{mgx_fp32_time=:0.4f}ms")
print(f"{mgx_fp16_time=:0.4f}ms")
print(f"{mgx_int8_time=:0.4f}ms")