In [1]:
from glob import glob
import numpy as np
from PIL import Image
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import onnxruntime
import data

In [2]:
img_size = 160

## Model

In [3]:
# model = torchvision.models.mnasnet0_5()
# model.classifier[1] = nn.Linear(1280, 2)
# model.load_state_dict(torch.load('weights/mnasnet0_5_fp16.pth'))

model = torchvision.models.shufflenet_v2_x1_0()
model.fc = nn.Linear(1024, 2)
model.load_state_dict(torch.load('weights/shufflenet_v2_x1_0_fp16.pth'))

model.eval();

## Data

In [4]:
norm_transform = transforms.Normalize(*data.IMAGENET_STATS)

val_transform = transforms.Compose([
        transforms.CenterCrop(img_size),
        transforms.ToTensor(),
        norm_transform,
    ])

In [5]:
paths = glob('/data/img_align_celeba/*')[:1000]
img = Image.open(paths[0])
val_img = val_transform(img).unsqueeze(0)

### Pytorch model

In [6]:
pr = model(val_img)[0]

In [7]:
pr

tensor([ 2.2068, -2.1051], grad_fn=<SelectBackward>)

In [8]:
%%timeit
pr = model(val_img)[0]

12.5 ms ± 49.1 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


### Onnx model

In [9]:
# ort_session = onnxruntime.InferenceSession('weights/mnasnet0_5_fp16.onnx')
ort_session = onnxruntime.InferenceSession('weights/shufflenet_v2_x1_0_fp16.onnx')

In [10]:
val_img_fp16 = val_img.numpy().astype(np.float16)

In [11]:
ort_inputs = {ort_session.get_inputs()[0].name: val_img_fp16}
ort_outs = ort_session.run(None, ort_inputs)

In [12]:
ort_outs

[array([[ 2.209, -2.107]], dtype=float16)]

In [13]:
%%timeit
ort_outs = ort_session.run(None, ort_inputs)

9.84 ms ± 443 µs per loop (mean ± std. dev. of 7 runs, 100 loops each)


## Check for equal results

In [14]:
for path in paths:
    img = Image.open(path)
    val_img = val_transform(img).unsqueeze(0)
    
    outs = model(val_img)[0]
    pr = torch.argmax(outs).item()
    
    val_img_fp16 = val_img.numpy().astype(np.float16)
    ort_inputs = {ort_session.get_inputs()[0].name: val_img_fp16}
    ort_outs = ort_session.run(None, ort_inputs)[0]
    pr_onnx = ort_outs.argmax()
    
    if pr != pr_onnx:
        print('!')

0 diferences for 1000 images

In [15]:
outs

tensor([ 2.2127, -2.1110], grad_fn=<SelectBackward>)

In [16]:
ort_outs

array([[ 2.215, -2.113]], dtype=float16)