In [1]:
pip install onnxruntime-gpu

Looking in indexes: https://pypi.org/simple, https://pypi.ngc.nvidia.com
Collecting onnxruntime-gpu
  Downloading onnxruntime_gpu-1.11.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (108.9 MB)
[K     |████████████████████████████████| 108.9 MB 47.7 MB/s eta 0:00:01        | 19.4 MB 47.6 MB/s eta 0:00:02███▎                     | 35.0 MB 47.6 MB/s eta 0:00:02
Collecting flatbuffers
  Downloading flatbuffers-2.0-py2.py3-none-any.whl (26 kB)
Installing collected packages: flatbuffers, onnxruntime-gpu
Successfully installed flatbuffers-2.0 onnxruntime-gpu-1.11.1
Note: you may need to restart the kernel to use updated packages.


In [21]:
import numpy as np

import torch
import torchvision

In [22]:
dummy_input = torch.randn(10, 3, 224, 224, device='cuda')
model = torchvision.models.alexnet(pretrained=True).cuda()

In [23]:
input_names = [ "actual_input_1" ] + [ "learned_%d" % i for i in range(16)]
output_names = [ "output1" ]

In [24]:
input_names, output_names

(['actual_input_1',
  'learned_0',
  'learned_1',
  'learned_2',
  'learned_3',
  'learned_4',
  'learned_5',
  'learned_6',
  'learned_7',
  'learned_8',
  'learned_9',
  'learned_10',
  'learned_11',
  'learned_12',
  'learned_13',
  'learned_14',
  'learned_15'],
 ['output1'])

In [25]:
torch.onnx.export(model, dummy_input, 'alexnet.onnx', verbose=True, input_names=input_names, output_names=output_names)

graph(%actual_input_1 : Float(10, 3, 224, 224, strides=[150528, 50176, 224, 1], requires_grad=0, device=cuda:0),
      %learned_0 : Float(64, 3, 11, 11, strides=[363, 121, 11, 1], requires_grad=1, device=cuda:0),
      %learned_1 : Float(64, strides=[1], requires_grad=1, device=cuda:0),
      %learned_2 : Float(192, 64, 5, 5, strides=[1600, 25, 5, 1], requires_grad=1, device=cuda:0),
      %learned_3 : Float(192, strides=[1], requires_grad=1, device=cuda:0),
      %learned_4 : Float(384, 192, 3, 3, strides=[1728, 9, 3, 1], requires_grad=1, device=cuda:0),
      %learned_5 : Float(384, strides=[1], requires_grad=1, device=cuda:0),
      %learned_6 : Float(256, 384, 3, 3, strides=[3456, 9, 3, 1], requires_grad=1, device=cuda:0),
      %learned_7 : Float(256, strides=[1], requires_grad=1, device=cuda:0),
      %learned_8 : Float(256, 256, 3, 3, strides=[2304, 9, 3, 1], requires_grad=1, device=cuda:0),
      %learned_9 : Float(256, strides=[1], requires_grad=1, device=cuda:0),
      %learn

In [26]:
import onnx

In [27]:
model = onnx.load('alexnet.onnx')

In [28]:
onnx.checker.check_model(model)

In [29]:
print(onnx.helper.printable_graph(model.graph))

graph torch_jit (
  %actual_input_1[FLOAT, 10x3x224x224]
) initializers (
  %learned_0[FLOAT, 64x3x11x11]
  %learned_1[FLOAT, 64]
  %learned_2[FLOAT, 192x64x5x5]
  %learned_3[FLOAT, 192]
  %learned_4[FLOAT, 384x192x3x3]
  %learned_5[FLOAT, 384]
  %learned_6[FLOAT, 256x384x3x3]
  %learned_7[FLOAT, 256]
  %learned_8[FLOAT, 256x256x3x3]
  %learned_9[FLOAT, 256]
  %learned_10[FLOAT, 4096x9216]
  %learned_11[FLOAT, 4096]
  %learned_12[FLOAT, 4096x4096]
  %learned_13[FLOAT, 4096]
  %learned_14[FLOAT, 1000x4096]
  %learned_15[FLOAT, 1000]
) {
  %input = Conv[dilations = [1, 1], group = 1, kernel_shape = [11, 11], pads = [2, 2, 2, 2], strides = [4, 4]](%actual_input_1, %learned_0, %learned_1)
  %onnx::MaxPool_18 = Relu(%input)
  %input.4 = MaxPool[kernel_shape = [3, 3], pads = [0, 0, 0, 0], strides = [2, 2]](%onnx::MaxPool_18)
  %input.8 = Conv[dilations = [1, 1], group = 1, kernel_shape = [5, 5], pads = [2, 2, 2, 2], strides = [1, 1]](%input.4, %learned_2, %learned_3)
  %onnx::MaxPool_21 = Re

In [30]:
import onnxruntime as ort

In [31]:
ort_sess = ort.InferenceSession('alexnet.onnx', providers=['TensorrtExecutionProvider', 'CUDAExecutionProvider'])

In [32]:
outputs = ort_sess.run(
    None,
    {"actual_input_1": np.random.randn(10, 3, 224, 224).astype(np.float32)},
)

[[-0.27882484 -1.6610326  -1.4640824  ... -1.4062617  -0.90210056
   0.9749887 ]
 [ 0.16546041 -1.2406901  -1.4637522  ... -1.046088   -0.6202794
   1.2154511 ]
 [-0.05865124 -1.7179849  -1.456497   ... -1.1318249  -0.9501119
   1.0824864 ]
 ...
 [-0.09073795 -1.6183611  -1.6515392  ... -1.1856104  -1.006303
   1.0273564 ]
 [ 0.17899546 -1.7340089  -1.7147235  ... -1.1643307  -1.1804765
   1.1147445 ]
 [-0.08731508 -1.302654   -1.1459596  ... -1.1892314  -0.9049844
   1.036444  ]]


In [38]:
print(np.argmax(outputs[0], -1))

[735 735 533 735 474 735 741 735 533 735]
