In [1]:
import torch
from torchvision.models.segmentation import deeplabv3_mobilenet_v3_large, DeepLabV3_MobileNet_V3_Large_Weights
from SSP.process_voc import VOCSegmentationWithJointTransform, JointTransform
from torch.utils.data import DataLoader


In [2]:
model = deeplabv3_mobilenet_v3_large(weights=DeepLabV3_MobileNet_V3_Large_Weights, num_classes=21)
model.load_state_dict(torch.load("weights/F_model_weights_pruned_0.4.pth"), strict=True)
model.eval();



In [3]:
dummy_input = torch.randn(1, 3, 256, 256)

In [4]:
torch.onnx.export(
    model,                          # your PyTorch model
    dummy_input,                    # example input
    "model.onnx",                   # output file name
    export_params=True,            # store weights
    opset_version=11,              # ONNX version (11+ is widely supported)
    do_constant_folding=True,      # optimize constants
    input_names=['input'],         # name for input tensor
    output_names=['output'],       # name for output tensor
    dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}}  # support variable batch size
)

In [5]:
import onnx
onnx_model = onnx.load("model.onnx")
onnx.checker.check_model(onnx_model)
print("ONNX model is valid.")

ONNX model is valid.


In [6]:
import onnxruntime as ort
import numpy as np

ort_session = ort.InferenceSession("model.onnx")

# Convert PyTorch tensor to NumPy array
input_numpy = dummy_input.numpy()

# Run inference
outputs = ort_session.run(None, {"input": input_numpy})
print(outputs[0].shape)

(1, 21, 256, 256)


In [7]:
dataset = VOCSegmentationWithJointTransform(
    root='data',
    year='2012',
    image_set='train',
    download=True,
    joint_transform=JointTransform()
)

In [8]:
loader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=4)

In [12]:
from onnxruntime.quantization import CalibrationDataReader
import torchvision.transforms as T
import torch

class VOCDataReader(CalibrationDataReader):
    def __init__(self, dataloader, input_name):
        self.dataloader = iter(dataloader)
        self.input_name = input_name
        self.normalize = T.Normalize(mean=[0.485, 0.456, 0.406],
                                     std=[0.229, 0.224, 0.225])

    def get_next(self):
        try:
            images, _ = next(self.dataloader)
            # Apply normalization
            images = torch.stack([self.normalize(img) for img in images])
            return {self.input_name: images.numpy()}
        except StopIteration:
            return None


In [13]:
model_onnx = onnx.load("model.onnx")
input_name = model_onnx.graph.input[0].name

In [14]:
from onnxruntime.quantization import quantize_static, QuantType

reader = VOCDataReader(loader, input_name)

quantize_static(
    model_input='model.onnx',
    model_output='model_quantized.onnx',
    calibration_data_reader=reader,
    quant_format='QOperator',
    weight_type=QuantType.QInt8,
    activation_type=QuantType.QInt8
)

