In [38]:
import torch
import time

def measure_time_torch(n_iter, data, model, device):
    model.to(device)
    model.eval
    data = data.to(device)
    with torch.no_grad():
        start = time.time()
        for _ in range(n_iter):
            _ = model(data)
        
        end = time.time()
    
    return end - start


def measure_time_onnx(n_iter, data, ort_sess):
    start = time.time()
    for _ in range(n_iter):
        _ = ort_sess.run(output_names=['mask'], input_feed={'image': data.numpy()})
    
    end = time.time()
    
    return end - start

In [32]:
import cv2
import torch
import torchvision.transforms as transforms
import onnx
import onnxruntime as ort

import models
import imagenet.mobilenet

In [21]:
# model
model = models.Model(pretrained=False)
model.load_state_dict(torch.load('trained1_state.pth', map_location='cpu'))

# data
transform = transforms.ToTensor()
data = transform(cv2.resize(cv2.imread("test.png"), (224, 224)))[None, ...]
data.shape

# iters
n_iter = 1000

# Pytorch personal PC

## Cuda

In [22]:
device = 'cuda:0'
res = measure_time_torch(n_iter, data, model, device)
res

22.09052610397339

## CPU

In [23]:
device = 'cpu'
res = measure_time_torch(n_iter, data, model, device)
res

158.67240381240845

# ONNX personal PC

## CPU

In [25]:
torch.onnx.export(
        model, data, "fast_deep.onnx",
        input_names=['image'], output_names=['mask']
)

model_onnx_spec = onnx.load('fast_deep.onnx')
onnx.checker.check_model(model_onnx_spec)

ort_sess = ort.InferenceSession(
    'fast_deep.onnx', 
    providers=['CPUExecutionProvider']
)

  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


In [40]:
res = measure_time_onnx(n_iter, data, ort_sess)
res

94.53361821174622

## Cuda

In [42]:
ort_inference_session=ort.InferenceSession(
        'fast_deep.onnx', 
        providers=['CUDAExecutionProvider']
    )

In [43]:
res = measure_time_onnx(n_iter, data, ort_inference_session)
res

20.936627626419067