In [1]:
import numpy as np
import torch
import onnx
import onnxruntime
from onnxruntime import quantization
import dataset
import os
import os.path

import torch.utils.data as data
from PIL import Image
import numpy as np
import os
import time

import torch
from PIL import Image
from torch.autograd import Variable
from torchvision import transforms

from mirrornet import MirrorNet
import torch
import torch.nn as nn
import torch.nn.functional as F

In [None]:
torch_model = MirrorNet()
torch_input = torch.randn(1, 1, 32, 32)
onnx_program = torch.onnx.dynamo_export(torch_model, torch_input)
onnx_program.save("my_image_classifier.onnx")

In [2]:
model_fp32_path = '/home/ayush/fyp/model_mirrornet.onnx'
model_prep_path = 'model_prep.onnx'

quantization.shape_inference.quant_pre_process(model_fp32_path, model_prep_path, skip_symbolic_shape=False)

In [3]:
img_transform = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
#     transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])

# def tar_transform(f1):
#     to_pil = transforms.ToPILImage()
#     h, w = 384,384
#     return np.array(transforms.Resize((h, w))(to_pil(f1)))

def make_dataset(root):
    img_list = [os.path.splitext(f)[0] for f in os.listdir(os.path.join(root, 'image')) if f.endswith('.jpg')]
    return [
        (os.path.join(root, 'image', img_name + '.jpg'), os.path.join(root, 'mask', img_name + '.png'))
        for img_name in img_list]


class ImageFolder(data.Dataset):
    def __init__(self, root, joint_transform=None, img_transform=None, target_transform=None):
        self.root = root
        self.imgs = make_dataset(root)
        self.joint_transform = joint_transform
        self.img_transform = img_transform
        self.target_transform = target_transform

    def __getitem__(self, index):
        img_path, gt_path = self.imgs[index]
        img = Image.open(img_path).convert('RGB')
        target = Image.open(gt_path)
        if self.joint_transform is not None:
            img, target = self.joint_transform(img, target)
        if self.img_transform is not None:
            img = self.img_transform(img)
        if self.target_transform is not None:
            target = self.target_transform(target)

        return img, target

    def __len__(self):
        return len(self.imgs)


In [4]:
ds = ImageFolder("/home/ayush/fyp/ICCV2019_MirrorNet/MSD/test/", img_transform = img_transform, target_transform = img_transform)
calib_ds = torch.utils.data.Subset(ds, list(range(500)))

In [5]:
from mirrornet import MirrorNet
import onnxruntime as ort

device_ids = [0]
torch.cuda.set_device(device_ids[0])
net = MirrorNet().cuda(device_ids[0])
net.eval()

ort_provider = ['CPUExecutionProvider']
if torch.cuda.is_available():
    net.to('cuda')
    ort_provider = ['CUDAExecutionProvider']

ort_sess = ort.InferenceSession(model_fp32_path, providers=ort_provider)



In [6]:
class QuntizationDataReader(quantization.CalibrationDataReader):
    def __init__(self, torch_ds, batch_size, input_name):

        self.torch_dl = torch.utils.data.DataLoader(torch_ds, batch_size=batch_size, shuffle=False)

        self.input_name = input_name
        self.datasize = len(self.torch_dl)

        self.enum_data = iter(self.torch_dl)

    def to_numpy(self, pt_tensor):
        return pt_tensor.detach().cpu().numpy() if pt_tensor.requires_grad else pt_tensor.cpu().numpy()

    def get_next(self):
        batch = next(self.enum_data, None)
        if batch is not None:
          return {self.input_name: self.to_numpy(batch[0])}
        else:
          return None

    def rewind(self):
        self.enum_data = iter(self.torch_dl)

qdr = QuntizationDataReader(calib_ds, batch_size=1, input_name=ort_sess.get_inputs()[0].name)

In [7]:
q_static_opts = {"ActivationSymmetric":False,
                 "WeightSymmetric":True}
if torch.cuda.is_available():
    q_static_opts = {"ActivationSymmetric":True,
                  "WeightSymmetric":True}

model_int8_path = 'model_quant.onnx'
quantized_model = quantization.quantize_static(model_input=model_prep_path,
                                               model_output=model_int8_path,
                                               calibration_data_reader=qdr,
                                               extra_options=q_static_opts)